Skip to content

Commit

Permalink
feat: add HugrView::node_connections to get all links between nodes (
Browse files Browse the repository at this point in the history
…#460)

Uses portgraph `get_connections`
  • Loading branch information
ss2165 authored Aug 29, 2023
1 parent ca4434e commit ae81e42
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
pub mod hierarchy;
pub mod sibling;

#[cfg(test)]
mod tests;

pub use hierarchy::{DescendantsGraph, HierarchyView, SiblingGraph};
pub use sibling::SiblingSubgraph;

Expand Down Expand Up @@ -52,6 +55,11 @@ pub trait HugrView: sealed::HugrInternals {
where
Self: 'a;

/// Iterator over the links between two nodes.
type NodeConnections<'a>: Iterator<Item = [Port; 2]>
where
Self: 'a;

/// Return the root node of this view.
#[inline]
fn root(&self) -> Node {
Expand Down Expand Up @@ -113,6 +121,9 @@ pub trait HugrView: sealed::HugrInternals {
/// Iterator over the nodes and ports connected to a port.
fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_>;

/// Iterator the links between two nodes.
fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_>;

/// Returns whether a port is connected.
fn is_linked(&self, node: Node, port: Port) -> bool {
self.linked_ports(node, port).next().is_some()
Expand Down Expand Up @@ -240,6 +251,8 @@ where
where
Self: 'a;

type NodeConnections<'a> = MapWithCtx<multiportgraph::NodeConnections<'a>,&'a Hugr, [Port; 2]> where Self: 'a;

#[inline]
fn contains_node(&self, node: Node) -> bool {
self.as_ref().graph.contains_node(node.index)
Expand Down Expand Up @@ -300,6 +313,21 @@ where
})
}

#[inline]
fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> {
let hugr = self.as_ref();

hugr.graph
.get_connections(node.index, other.index)
.with_context(hugr)
.map_with_context(|(p1, p2), hugr| {
[p1, p2].map(|link| {
let offset = hugr.graph.port_offset(link.port()).unwrap();
offset.into()
})
})
}

#[inline]
fn num_ports(&self, node: Node, dir: Direction) -> usize {
self.as_ref().graph.num_ports(node.index, dir)
Expand Down
38 changes: 38 additions & 0 deletions src/hugr/views/hierarchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ where
> where
Self: 'a;

type NodeConnections<'a> = MapWithCtx<
<FlatRegionGraph<'g> as LinkView>::NodeConnections<'a>,
&'a Self,
[Port; 2],
> where
Self: 'a;

#[inline]
fn contains_node(&self, node: Node) -> bool {
self.graph.contains_node(node.index)
Expand Down Expand Up @@ -165,6 +172,18 @@ where
})
}

fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> {
self.graph
.get_connections(node.index, other.index)
.with_context(self)
.map_with_context(|(p1, p2), hugr| {
[p1, p2].map(|link| {
let offset = hugr.graph.port_offset(link).unwrap();
offset.into()
})
})
}

#[inline]
fn num_ports(&self, node: Node, dir: Direction) -> usize {
self.graph.num_ports(node.index, dir)
Expand Down Expand Up @@ -268,6 +287,13 @@ where
> where
Self: 'a;

type NodeConnections<'a> = MapWithCtx<
<RegionGraph<'g> as LinkView>::NodeConnections<'a>,
&'a Self,
[Port; 2],
> where
Self: 'a;

#[inline]
fn contains_node(&self, node: Node) -> bool {
self.graph.contains_node(node.index)
Expand Down Expand Up @@ -334,6 +360,18 @@ where
})
}

fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> {
self.graph
.get_connections(node.index, other.index)
.with_context(self)
.map_with_context(|(p1, p2), hugr| {
[p1, p2].map(|link| {
let offset = hugr.graph.port_offset(link).unwrap();
offset.into()
})
})
}

#[inline]
fn num_ports(&self, node: Node, dir: Direction) -> usize {
self.graph.num_ports(node.index, dir)
Expand Down
44 changes: 44 additions & 0 deletions src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use portgraph::PortOffset;

use crate::{
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
ops::handle::NodeHandle,
std_extensions::quantum::test::cx_gate,
type_row,
types::FunctionType,
HugrView,
};

#[test]
fn node_connections() -> Result<(), BuildError> {
let mut dfg = DFGBuilder::new(FunctionType::new(
type_row![QB_T, QB_T],
type_row![QB_T, QB_T],
))?;

let [q1, q2] = dfg.input_wires_arr();

let n1 = dfg.add_dataflow_op(cx_gate(), [q1, q2])?;
let [q1, q2] = n1.outputs_arr();
let n2 = dfg.add_dataflow_op(cx_gate(), [q2, q1])?;

let h = dfg.finish_hugr_with_outputs(n2.outputs())?;

let connections: Vec<_> = h.node_connections(n1.node(), n2.node()).collect();

assert_eq!(
&connections[..],
&[
[
PortOffset::new_outgoing(0).into(),
PortOffset::new_incoming(1).into()
],
[
PortOffset::new_outgoing(1).into(),
PortOffset::new_incoming(0).into()
],
]
);
Ok(())
}

0 comments on commit ae81e42

Please sign in to comment.