diff --git a/src/hugr/views.rs b/src/hugr/views.rs index bead9168b..d882eca42 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -3,6 +3,9 @@ pub mod hierarchy; pub mod sibling; +#[cfg(test)] +mod tests; + pub use hierarchy::{DescendantsGraph, HierarchyView, SiblingGraph}; pub use sibling::SiblingSubgraph; @@ -52,6 +55,11 @@ pub trait HugrView: sealed::HugrInternals { where Self: 'a; + /// Iterator over the links between two nodes. + type NodeConnections<'a>: Iterator + where + Self: 'a; + /// Return the root node of this view. #[inline] fn root(&self) -> Node { @@ -113,6 +121,9 @@ pub trait HugrView: sealed::HugrInternals { /// Iterator over the nodes and ports connected to a port. fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_>; + /// Iterator the links between two nodes. + fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_>; + /// Returns whether a port is connected. fn is_linked(&self, node: Node, port: Port) -> bool { self.linked_ports(node, port).next().is_some() @@ -240,6 +251,8 @@ where where Self: 'a; + type NodeConnections<'a> = MapWithCtx,&'a Hugr, [Port; 2]> where Self: 'a; + #[inline] fn contains_node(&self, node: Node) -> bool { self.as_ref().graph.contains_node(node.index) @@ -300,6 +313,21 @@ where }) } + #[inline] + fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + let hugr = self.as_ref(); + + hugr.graph + .get_connections(node.index, other.index) + .with_context(hugr) + .map_with_context(|(p1, p2), hugr| { + [p1, p2].map(|link| { + let offset = hugr.graph.port_offset(link.port()).unwrap(); + offset.into() + }) + }) + } + #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { self.as_ref().graph.num_ports(node.index, dir) diff --git a/src/hugr/views/hierarchy.rs b/src/hugr/views/hierarchy.rs index 98eedd5b4..d9f412bf7 100644 --- a/src/hugr/views/hierarchy.rs +++ b/src/hugr/views/hierarchy.rs @@ -93,6 +93,13 @@ where > where Self: 'a; + type NodeConnections<'a> = MapWithCtx< + as LinkView>::NodeConnections<'a>, + &'a Self, + [Port; 2], + > where + Self: 'a; + #[inline] fn contains_node(&self, node: Node) -> bool { self.graph.contains_node(node.index) @@ -165,6 +172,18 @@ where }) } + fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + self.graph + .get_connections(node.index, other.index) + .with_context(self) + .map_with_context(|(p1, p2), hugr| { + [p1, p2].map(|link| { + let offset = hugr.graph.port_offset(link).unwrap(); + offset.into() + }) + }) + } + #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { self.graph.num_ports(node.index, dir) @@ -268,6 +287,13 @@ where > where Self: 'a; + type NodeConnections<'a> = MapWithCtx< + as LinkView>::NodeConnections<'a>, + &'a Self, + [Port; 2], + > where + Self: 'a; + #[inline] fn contains_node(&self, node: Node) -> bool { self.graph.contains_node(node.index) @@ -334,6 +360,18 @@ where }) } + fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + self.graph + .get_connections(node.index, other.index) + .with_context(self) + .map_with_context(|(p1, p2), hugr| { + [p1, p2].map(|link| { + let offset = hugr.graph.port_offset(link).unwrap(); + offset.into() + }) + }) + } + #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { self.graph.num_ports(node.index, dir) diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs new file mode 100644 index 000000000..2fac7724e --- /dev/null +++ b/src/hugr/views/tests.rs @@ -0,0 +1,44 @@ +use portgraph::PortOffset; + +use crate::{ + builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::QB_T, + ops::handle::NodeHandle, + std_extensions::quantum::test::cx_gate, + type_row, + types::FunctionType, + HugrView, +}; + +#[test] +fn node_connections() -> Result<(), BuildError> { + let mut dfg = DFGBuilder::new(FunctionType::new( + type_row![QB_T, QB_T], + type_row![QB_T, QB_T], + ))?; + + let [q1, q2] = dfg.input_wires_arr(); + + let n1 = dfg.add_dataflow_op(cx_gate(), [q1, q2])?; + let [q1, q2] = n1.outputs_arr(); + let n2 = dfg.add_dataflow_op(cx_gate(), [q2, q1])?; + + let h = dfg.finish_hugr_with_outputs(n2.outputs())?; + + let connections: Vec<_> = h.node_connections(n1.node(), n2.node()).collect(); + + assert_eq!( + &connections[..], + &[ + [ + PortOffset::new_outgoing(0).into(), + PortOffset::new_incoming(1).into() + ], + [ + PortOffset::new_outgoing(1).into(), + PortOffset::new_incoming(0).into() + ], + ] + ); + Ok(()) +}