diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 0db5d39d0..5b31ab4fe 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -1,13 +1,16 @@ //! Read-only access into HUGR graphs and subgraphs. -pub mod hierarchy; +pub mod descendants; +mod petgraph; pub mod sibling; +pub mod sibling_subgraph; #[cfg(test)] mod tests; -pub use hierarchy::{DescendantsGraph, HierarchyView, SiblingGraph}; -pub use sibling::SiblingSubgraph; +pub use descendants::DescendantsGraph; +pub use sibling::SiblingGraph; +pub use sibling_subgraph::SiblingSubgraph; use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; use itertools::{Itertools, MapInto}; @@ -18,8 +21,8 @@ use super::{Hugr, NodeMetadata, NodeType}; use crate::ops::handle::NodeHandle; use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpType, DFG}; use crate::types::{EdgeKind, FunctionType}; -use crate::Direction; -use crate::{Node, Port}; +use crate::{Direction, Node, Port}; +use ::petgraph::visit as pv; /// A trait for inspecting HUGRs. /// For end users we intend this to be superseded by region-specific APIs. @@ -228,6 +231,27 @@ pub trait HugrView: sealed::HugrInternals { } } +/// A common trait for views of a HUGR hierarchical subgraph. +pub trait HierarchyView<'a>: + HugrView + + pv::GraphBase + + pv::GraphProp + + pv::NodeCount + + pv::NodeIndexable + + pv::EdgeCount + + pv::Visitable + + pv::GetAdjacencyMatrix + + pv::Visitable +where + for<'g> &'g Self: pv::IntoNeighborsDirected + pv::IntoNodeIdentifiers, +{ + /// The base from which the subgraph is derived. + type Base; + + /// Create a hierarchical view of a HUGR given a root node. + fn new(hugr: &'a Self::Base, root: Node) -> Self; +} + impl HugrView for T where T: AsRef, diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs new file mode 100644 index 000000000..cf0262be3 --- /dev/null +++ b/src/hugr/views/descendants.rs @@ -0,0 +1,316 @@ +//! DescendantsGraph: view onto the subgraph of the HUGR starting from a root +//! (all descendants at all depths). + +use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; +use itertools::{Itertools, MapInto}; +use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; + +use crate::ops::handle::NodeHandle; +use crate::ops::OpTrait; +use crate::{hugr::NodeType, hugr::OpType, Direction, Hugr, Node, Port}; + +use super::{sealed::HugrInternals, HierarchyView, HugrView, NodeMetadata}; + +type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; + +/// View of a HUGR descendants graph. +/// +/// Includes the root node (which uniquely has no parent) and all its descendants. +/// +/// See [`SiblingGraph`] for a view that includes only the root and +/// its immediate children. Prefer using [`SiblingGraph`] when possible, +/// as it is more efficient. +/// +/// Implements the [`HierarchyView`] trait, as well as [`HugrView`] and petgraph's +/// _visit_ traits, so can be used interchangeably with [`SiblingGraph`]. +/// +/// [`SiblingGraph`]: super::SiblingGraph +pub struct DescendantsGraph<'g, Root = Node, Base = Hugr> +where + Base: HugrInternals, +{ + /// The chosen root node. + root: Node, + + /// The graph encoding the adjacency structure of the HUGR. + graph: RegionGraph<'g>, + + /// The node hierarchy. + hugr: &'g Base, + + /// The operation handle of the root node. + _phantom: std::marker::PhantomData, +} + +impl<'g, Root, Base: Clone> Clone for DescendantsGraph<'g, Root, Base> +where + Root: NodeHandle, + Base: HugrInternals + HugrView, +{ + fn clone(&self) -> Self { + DescendantsGraph::new(self.hugr, self.root) + } +} + +impl<'g, Root, Base> HugrView for DescendantsGraph<'g, Root, Base> +where + Root: NodeHandle, + Base: HugrInternals + HugrView, +{ + type RootHandle = Root; + + type Nodes<'a> = MapInto< as PortView>::Nodes<'a>, Node> + where + Self: 'a; + + type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> + where + Self: 'a; + + type Children<'a> = MapInto, Node> + where + Self: 'a; + + type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> + where + Self: 'a; + + type PortLinks<'a> = MapWithCtx< + as LinkView>::PortLinks<'a>, + &'a Self, + (Node, Port), + > where + Self: 'a; + + type NodeConnections<'a> = MapWithCtx< + as LinkView>::NodeConnections<'a>, + &'a Self, + [Port; 2], + > where + Self: 'a; + + #[inline] + fn contains_node(&self, node: Node) -> bool { + self.graph.contains_node(node.index) + } + + #[inline] + fn get_parent(&self, node: Node) -> Option { + self.hugr + .get_parent(node) + .filter(|&parent| self.graph.contains_node(parent.index)) + .map(Into::into) + } + + #[inline] + fn get_optype(&self, node: Node) -> &OpType { + self.hugr.get_optype(node) + } + + #[inline] + fn get_nodetype(&self, node: Node) -> &NodeType { + self.hugr.get_nodetype(node) + } + + #[inline] + fn get_metadata(&self, node: Node) -> &NodeMetadata { + self.hugr.get_metadata(node) + } + + #[inline] + fn node_count(&self) -> usize { + self.graph.node_count() + } + + #[inline] + fn edge_count(&self) -> usize { + self.graph.link_count() + } + + #[inline] + fn nodes(&self) -> Self::Nodes<'_> { + self.graph.nodes_iter().map_into() + } + + #[inline] + fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + self.graph.port_offsets(node.index, dir).map_into() + } + + #[inline] + fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + self.graph.all_port_offsets(node.index).map_into() + } + + fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { + let port = self.graph.port_index(node.index, port.offset).unwrap(); + self.graph + .port_links(port) + .with_context(self) + .map_with_context(|(_, link), region| { + let port: PortIndex = link.into(); + let node = region.graph.port_node(port).unwrap(); + let offset = region.graph.port_offset(port).unwrap(); + (node.into(), offset.into()) + }) + } + + fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + self.graph + .get_connections(node.index, other.index) + .with_context(self) + .map_with_context(|(p1, p2), hugr| { + [p1, p2].map(|link| { + let offset = hugr.graph.port_offset(link).unwrap(); + offset.into() + }) + }) + } + + #[inline] + fn num_ports(&self, node: Node, dir: Direction) -> usize { + self.graph.num_ports(node.index, dir) + } + + #[inline] + fn children(&self, node: Node) -> Self::Children<'_> { + match self.graph.contains_node(node.index) { + true => self.base_hugr().hierarchy.children(node.index).map_into(), + false => portgraph::hierarchy::Children::default().map_into(), + } + } + + #[inline] + fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + self.graph.neighbours(node.index, dir).map_into() + } + + #[inline] + fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + self.graph.all_neighbours(node.index).map_into() + } + + #[inline] + fn get_io(&self, node: Node) -> Option<[Node; 2]> { + self.base_hugr().get_io(node) + } + + fn get_function_type(&self) -> Option<&crate::types::FunctionType> { + self.base_hugr().get_function_type() + } +} + +impl<'a, Root, Base> HierarchyView<'a> for DescendantsGraph<'a, Root, Base> +where + Root: NodeHandle, + Base: HugrView, +{ + type Base = Base; + + fn new(hugr: &'a Base, root: Node) -> Self { + let root_tag = hugr.get_optype(root).tag(); + if !Root::TAG.is_superset(root_tag) { + // TODO: Return an error + panic!("Root node must have the correct operation type tag.") + } + Self { + root, + graph: RegionGraph::new_region( + &hugr.base_hugr().graph, + &hugr.base_hugr().hierarchy, + root.index, + ), + hugr, + _phantom: std::marker::PhantomData, + } + } +} + +impl<'g, Root, Base> super::sealed::HugrInternals for DescendantsGraph<'g, Root, Base> +where + Root: NodeHandle, + Base: HugrInternals, +{ + type Portgraph<'p> = &'p RegionGraph<'g> where Self: 'p; + + #[inline] + fn portgraph(&self) -> Self::Portgraph<'_> { + &self.graph + } + + #[inline] + fn base_hugr(&self) -> &Hugr { + self.hugr.base_hugr() + } + + #[inline] + fn root_node(&self) -> Node { + self.root + } +} + +#[cfg(test)] +pub(super) mod test { + use crate::{ + builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, + ops::handle::NodeHandle, + std_extensions::quantum::test::h_gate, + type_row, + types::{FunctionType, Type}, + }; + + use super::*; + + const NAT: Type = crate::extension::prelude::USIZE_T; + const QB: Type = crate::extension::prelude::QB_T; + + /// Make a module hugr with a fn definition containing an inner dfg node. + /// + /// Returns the hugr, the fn node id, and the nested dgf node id. + pub(in crate::hugr::views) fn make_module_hgr( + ) -> Result<(Hugr, Node, Node), Box> { + let mut module_builder = ModuleBuilder::new(); + + let (f_id, inner_id) = { + let mut func_builder = module_builder.define_function( + "main", + FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).pure(), + )?; + + let [int, qb] = func_builder.input_wires_arr(); + + let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; + + let inner_id = { + let inner_builder = func_builder.dfg_builder( + FunctionType::new(type_row![NAT], type_row![NAT]), + None, + [int], + )?; + let w = inner_builder.input_wires(); + inner_builder.finish_with_outputs(w) + }?; + + let f_id = + func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?; + (f_id, inner_id) + }; + let hugr = module_builder.finish_prelude_hugr()?; + Ok((hugr, f_id.handle().node(), inner_id.handle().node())) + } + + #[test] + fn full_region() -> Result<(), Box> { + let (hugr, def, inner) = make_module_hgr()?; + + let region: DescendantsGraph = DescendantsGraph::new(&hugr, def); + + assert_eq!(region.node_count(), 7); + assert!(region.nodes().all(|n| n == def + || hugr.get_parent(n) == Some(def) + || hugr.get_parent(n) == Some(inner))); + assert_eq!(region.children(inner).count(), 2); + + Ok(()) + } +} diff --git a/src/hugr/views/hierarchy.rs b/src/hugr/views/hierarchy.rs deleted file mode 100644 index 76d67ecff..000000000 --- a/src/hugr/views/hierarchy.rs +++ /dev/null @@ -1,605 +0,0 @@ -//! Hierarchical views for HUGR. -//! -//! Views into subgraphs of HUGRs that are based on the hierarchical relationship -//! of the HUGR nodes. Such a subgraph includes a root node and some of its -//! descendants. The root node is the only node in the view that has no parent. -//! -//! There are currently 2 hierarchical views: -//! - [`SiblingGraph`]: A view of the subgraph induced by the children -//! of the root node. -//! - [`DescendantsGraph`]: A view of the subgraph induced by all the -//! descendants of the root node. -//! -//! Both views implement the [`HierarchyView`] trait, so they can be used -//! interchangeably. They implement [`HugrView`] as well as petgraph's _visit_ -//! traits. - -pub mod petgraph; - -use std::iter; - -use ::petgraph::visit as pv; -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; -use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; - -use crate::ops::handle::NodeHandle; -use crate::ops::OpTrait; -use crate::{hugr::NodeType, hugr::OpType, Direction, Hugr, Node, Port}; - -use super::{sealed::HugrInternals, HugrView, NodeMetadata}; - -type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>; - -/// View of a HUGR sibling graph. -/// -/// Includes only the root node and its direct children. -/// -/// For a view that includes all the descendants of the root, see [`DescendantsGraph`]. -pub struct SiblingGraph<'g, Root = Node, Base = Hugr> -where - Base: HugrInternals, -{ - /// The chosen root node. - root: Node, - - /// The filtered portgraph encoding the adjacency structure of the HUGR. - graph: FlatRegionGraph<'g>, - - /// The rest of the HUGR. - hugr: &'g Base, - - /// The operation type of the root node. - _phantom: std::marker::PhantomData, -} - -impl<'g, Root, Base> Clone for SiblingGraph<'g, Root, Base> -where - Root: NodeHandle, - Base: HugrInternals + HugrView, -{ - fn clone(&self) -> Self { - SiblingGraph::new(self.hugr, self.root) - } -} - -impl<'g, Root, Base> HugrView for SiblingGraph<'g, Root, Base> -where - Root: NodeHandle, - Base: HugrInternals + HugrView, -{ - type RootHandle = Root; - - type Nodes<'a> = iter::Chain, MapInto, Node>> - where - Self: 'a; - - type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> - where - Self: 'a; - - type Children<'a> = MapInto, Node> - where - Self: 'a; - - type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> - where - Self: 'a; - - type PortLinks<'a> = MapWithCtx< - as LinkView>::PortLinks<'a>, - &'a Self, - (Node, Port), - > where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx< - as LinkView>::NodeConnections<'a>, - &'a Self, - [Port; 2], - > where - Self: 'a; - - #[inline] - fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(node.index) - } - - #[inline] - fn get_parent(&self, node: Node) -> Option { - self.hugr.get_parent(node).filter(|&n| n == self.root) - } - - #[inline] - fn get_optype(&self, node: Node) -> &OpType { - self.hugr.get_optype(node) - } - - #[inline] - fn get_nodetype(&self, node: Node) -> &NodeType { - self.hugr.get_nodetype(node) - } - - #[inline] - fn get_metadata(&self, node: Node) -> &NodeMetadata { - self.hugr.get_metadata(node) - } - - #[inline] - fn node_count(&self) -> usize { - self.base_hugr().hierarchy.child_count(self.root.index) + 1 - } - - #[inline] - fn edge_count(&self) -> usize { - // Faster implementation than filtering all the nodes in the internal graph. - self.nodes() - .map(|n| self.output_neighbours(n).count()) - .sum() - } - - #[inline] - fn nodes(&self) -> Self::Nodes<'_> { - // Faster implementation than filtering all the nodes in the internal graph. - let children = self - .base_hugr() - .hierarchy - .children(self.root.index) - .map_into(); - iter::once(self.root).chain(children) - } - - #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { - self.graph.port_offsets(node.index, dir).map_into() - } - - #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { - self.graph.all_port_offsets(node.index).map_into() - } - - fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { - let port = self.graph.port_index(node.index, port.offset).unwrap(); - self.graph - .port_links(port) - .with_context(self) - .map_with_context(|(_, link), region| { - let port: PortIndex = link.into(); - let node = region.graph.port_node(port).unwrap(); - let offset = region.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) - } - - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { - self.graph - .get_connections(node.index, other.index) - .with_context(self) - .map_with_context(|(p1, p2), hugr| { - [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link).unwrap(); - offset.into() - }) - }) - } - - #[inline] - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(node.index, dir) - } - - #[inline] - fn children(&self, node: Node) -> Self::Children<'_> { - match node == self.root { - true => self.base_hugr().hierarchy.children(node.index).map_into(), - false => portgraph::hierarchy::Children::default().map_into(), - } - } - - #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { - self.graph.neighbours(node.index, dir).map_into() - } - - #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { - self.graph.all_neighbours(node.index).map_into() - } - - #[inline] - fn get_io(&self, node: Node) -> Option<[Node; 2]> { - if node == self.root() { - self.base_hugr().get_io(node) - } else { - None - } - } - - fn get_function_type(&self) -> Option<&crate::types::FunctionType> { - self.base_hugr().get_function_type() - } -} - -type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; - -/// View of a HUGR descendants graph. -/// -/// Includes the root node and all its descendants nodes at any depth. -/// -/// For a view that includes only the direct children of the root, see -/// [`SiblingGraph`]. Prefer using [`SiblingGraph`] over this type when -/// possible, as it is more efficient. -pub struct DescendantsGraph<'g, Root = Node, Base = Hugr> -where - Base: HugrInternals, -{ - /// The chosen root node. - root: Node, - - /// The graph encoding the adjacency structure of the HUGR. - graph: RegionGraph<'g>, - - /// The node hierarchy. - hugr: &'g Base, - - /// The operation handle of the root node. - _phantom: std::marker::PhantomData, -} - -impl<'g, Root, Base: Clone> Clone for DescendantsGraph<'g, Root, Base> -where - Root: NodeHandle, - Base: HugrInternals + HugrView, -{ - fn clone(&self) -> Self { - DescendantsGraph::new(self.hugr, self.root) - } -} - -impl<'g, Root, Base> HugrView for DescendantsGraph<'g, Root, Base> -where - Root: NodeHandle, - Base: HugrInternals + HugrView, -{ - type RootHandle = Root; - - type Nodes<'a> = MapInto< as PortView>::Nodes<'a>, Node> - where - Self: 'a; - - type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> - where - Self: 'a; - - type Children<'a> = MapInto, Node> - where - Self: 'a; - - type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> - where - Self: 'a; - - type PortLinks<'a> = MapWithCtx< - as LinkView>::PortLinks<'a>, - &'a Self, - (Node, Port), - > where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx< - as LinkView>::NodeConnections<'a>, - &'a Self, - [Port; 2], - > where - Self: 'a; - - #[inline] - fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(node.index) - } - - #[inline] - fn get_parent(&self, node: Node) -> Option { - self.hugr - .get_parent(node) - .filter(|&parent| self.graph.contains_node(parent.index)) - .map(Into::into) - } - - #[inline] - fn get_optype(&self, node: Node) -> &OpType { - self.hugr.get_optype(node) - } - - #[inline] - fn get_nodetype(&self, node: Node) -> &NodeType { - self.hugr.get_nodetype(node) - } - - #[inline] - fn get_metadata(&self, node: Node) -> &NodeMetadata { - self.hugr.get_metadata(node) - } - - #[inline] - fn node_count(&self) -> usize { - self.graph.node_count() - } - - #[inline] - fn edge_count(&self) -> usize { - self.graph.link_count() - } - - #[inline] - fn nodes(&self) -> Self::Nodes<'_> { - self.graph.nodes_iter().map_into() - } - - #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { - self.graph.port_offsets(node.index, dir).map_into() - } - - #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { - self.graph.all_port_offsets(node.index).map_into() - } - - fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { - let port = self.graph.port_index(node.index, port.offset).unwrap(); - self.graph - .port_links(port) - .with_context(self) - .map_with_context(|(_, link), region| { - let port: PortIndex = link.into(); - let node = region.graph.port_node(port).unwrap(); - let offset = region.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) - } - - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { - self.graph - .get_connections(node.index, other.index) - .with_context(self) - .map_with_context(|(p1, p2), hugr| { - [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link).unwrap(); - offset.into() - }) - }) - } - - #[inline] - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(node.index, dir) - } - - #[inline] - fn children(&self, node: Node) -> Self::Children<'_> { - match self.graph.contains_node(node.index) { - true => self.base_hugr().hierarchy.children(node.index).map_into(), - false => portgraph::hierarchy::Children::default().map_into(), - } - } - - #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { - self.graph.neighbours(node.index, dir).map_into() - } - - #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { - self.graph.all_neighbours(node.index).map_into() - } - - #[inline] - fn get_io(&self, node: Node) -> Option<[Node; 2]> { - self.base_hugr().get_io(node) - } - - fn get_function_type(&self) -> Option<&crate::types::FunctionType> { - self.base_hugr().get_function_type() - } -} - -/// A common trait for views of a HUGR hierarchical subgraph. -pub trait HierarchyView<'a>: - HugrView - + pv::GraphBase - + pv::GraphProp - + pv::NodeCount - + pv::NodeIndexable - + pv::EdgeCount - + pv::Visitable - + pv::GetAdjacencyMatrix - + pv::Visitable -where - for<'g> &'g Self: pv::IntoNeighborsDirected + pv::IntoNodeIdentifiers, -{ - /// The base from which the subgraph is derived. - type Base; - - /// Create a hierarchical view of a HUGR given a root node. - fn new(hugr: &'a Self::Base, root: Node) -> Self; -} - -impl<'a, Root, Base> HierarchyView<'a> for SiblingGraph<'a, Root, Base> -where - Root: NodeHandle, - Base: HugrView, -{ - type Base = Base; - - fn new(hugr: &'a Base, root: Node) -> Self { - let root_tag = hugr.get_optype(root).tag(); - if !Root::TAG.is_superset(root_tag) { - // TODO: Return an error - panic!("Root node must have the correct operation type tag.") - } - Self { - root, - graph: FlatRegionGraph::new_flat_region( - &hugr.base_hugr().graph, - &hugr.base_hugr().hierarchy, - root.index, - ), - hugr, - _phantom: std::marker::PhantomData, - } - } -} - -impl<'a, Root, Base> HierarchyView<'a> for DescendantsGraph<'a, Root, Base> -where - Root: NodeHandle, - Base: HugrView, -{ - type Base = Base; - - fn new(hugr: &'a Base, root: Node) -> Self { - let root_tag = hugr.get_optype(root).tag(); - if !Root::TAG.is_superset(root_tag) { - // TODO: Return an error - panic!("Root node must have the correct operation type tag.") - } - Self { - root, - graph: RegionGraph::new_region( - &hugr.base_hugr().graph, - &hugr.base_hugr().hierarchy, - root.index, - ), - hugr, - _phantom: std::marker::PhantomData, - } - } -} - -impl<'g, Root, Base> HugrInternals for SiblingGraph<'g, Root, Base> -where - Root: NodeHandle, - Base: HugrInternals, -{ - type Portgraph<'p> = &'p FlatRegionGraph<'g> where Self: 'p; - - #[inline] - fn portgraph(&self) -> Self::Portgraph<'_> { - &self.graph - } - - #[inline] - fn base_hugr(&self) -> &Hugr { - self.hugr.base_hugr() - } - - #[inline] - fn root_node(&self) -> Node { - self.root - } -} - -impl<'g, Root, Base> super::sealed::HugrInternals for DescendantsGraph<'g, Root, Base> -where - Root: NodeHandle, - Base: HugrInternals, -{ - type Portgraph<'p> = &'p RegionGraph<'g> where Self: 'p; - - #[inline] - fn portgraph(&self) -> Self::Portgraph<'_> { - &self.graph - } - - #[inline] - fn base_hugr(&self) -> &Hugr { - self.hugr.base_hugr() - } - - #[inline] - fn root_node(&self) -> Node { - self.root - } -} - -#[cfg(test)] -mod test { - use crate::{ - builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, - ops::handle::NodeHandle, - std_extensions::quantum::test::h_gate, - type_row, - types::{FunctionType, Type}, - }; - - use super::*; - - const NAT: Type = crate::extension::prelude::USIZE_T; - const QB: Type = crate::extension::prelude::QB_T; - - /// Make a module hugr with a fn definition containing an inner dfg node. - /// - /// Returns the hugr, the fn node id, and the nested dgf node id. - fn make_module_hgr() -> Result<(Hugr, Node, Node), Box> { - let mut module_builder = ModuleBuilder::new(); - - let (f_id, inner_id) = { - let mut func_builder = module_builder.define_function( - "main", - FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).pure(), - )?; - - let [int, qb] = func_builder.input_wires_arr(); - - let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; - - let inner_id = { - let inner_builder = func_builder.dfg_builder( - FunctionType::new(type_row![NAT], type_row![NAT]), - None, - [int], - )?; - let w = inner_builder.input_wires(); - inner_builder.finish_with_outputs(w) - }?; - - let f_id = - func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?; - (f_id, inner_id) - }; - let hugr = module_builder.finish_prelude_hugr()?; - Ok((hugr, f_id.handle().node(), inner_id.handle().node())) - } - - #[test] - fn flat_region() -> Result<(), Box> { - let (hugr, def, inner) = make_module_hgr()?; - - let region: SiblingGraph = SiblingGraph::new(&hugr, def); - - assert_eq!(region.node_count(), 5); - assert!(region - .nodes() - .all(|n| n == def || hugr.get_parent(n) == Some(def))); - assert_eq!(region.children(inner).count(), 0); - - Ok(()) - } - - #[test] - fn full_region() -> Result<(), Box> { - let (hugr, def, inner) = make_module_hgr()?; - - let region: DescendantsGraph = DescendantsGraph::new(&hugr, def); - - assert_eq!(region.node_count(), 7); - assert!(region.nodes().all(|n| n == def - || hugr.get_parent(n) == Some(def) - || hugr.get_parent(n) == Some(inner))); - assert_eq!(region.children(inner).count(), 2); - - Ok(()) - } -} diff --git a/src/hugr/views/hierarchy/petgraph.rs b/src/hugr/views/petgraph.rs similarity index 100% rename from src/hugr/views/hierarchy/petgraph.rs rename to src/hugr/views/petgraph.rs diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index c344be6b2..84532831a 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -1,790 +1,281 @@ -//! Views for HUGR sibling subgraphs. -//! -//! Views into convex subgraphs of HUGRs within a single level of the -//! hierarchy, i.e. within a sibling graph. Convex subgraph are always -//! induced subgraphs, i.e. they are defined by a subset of the sibling nodes. -//! -//! Sibling subgraphs complement [`super::HierarchyView`]s in the sense that the -//! latter provide views for subgraphs defined by hierarchical relationships, -//! while the former provide views for subgraphs within a single level of the -//! hierarchy. - -use std::collections::HashSet; - -use itertools::Itertools; -use portgraph::{view::Subgraph, Direction, PortView}; -use thiserror::Error; - -use crate::{ - ops::{ - handle::{ContainerHandle, DataflowOpID}, - OpTag, OpTrait, - }, - types::{FunctionType, Type}, - Hugr, Node, Port, SimpleReplacement, -}; - -use super::HugrView; - -/// A non-empty convex subgraph of a HUGR sibling graph. +//! SiblingGraph: view onto a sibling subgraph of the HUGR. + +use std::iter; + +use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; +use itertools::{Itertools, MapInto}; +use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; + +use crate::ops::handle::NodeHandle; +use crate::ops::OpTrait; +use crate::{hugr::NodeType, hugr::OpType, Direction, Hugr, Node, Port}; + +use super::{sealed::HugrInternals, HierarchyView, HugrView, NodeMetadata}; + +type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>; + +/// View of a HUGR sibling graph. +/// +/// Includes only the root node and its direct children, but no deeper descendants. +/// Uniquely, the root node has no parent. /// -/// A HUGR region in which all nodes share the same parent. Unlike -/// [`super::SiblingGraph`], not all nodes of the sibling graph must be -/// included. A convex subgraph is always an induced subgraph, i.e. it is defined -/// by a set of nodes and all edges between them. - -/// The incoming boundary (resp. outgoing boundary) is given by the input (resp. -/// output) ports of the subgraph that are linked to nodes outside of the subgraph. -/// The signature of the subgraph is then given by the types of the incoming -/// and outgoing boundary ports. Given a replacement with the same signature, -/// a [`SimpleReplacement`] can be constructed to rewrite the subgraph with the -/// replacement. +/// See [`DescendantsGraph`] for a view that includes all descendants of the root. /// -/// The ordering of the nodes in the subgraph is irrelevant to define the convex -/// subgraph, but it determines the ordering of the boundary signature. +/// Implements the [`HierarchyView`] trait, as well as [`HugrView`] and petgraph's +/// _visit_ traits, so can be used interchangeably with [`DescendantsGraph`]. /// -/// At the moment we do not support state order edges at the subgraph boundary. -/// The `boundary_port` and `signature` methods will panic if any are found. -/// State order edges are also unsupported in replacements in -/// `create_simple_replacement`. -#[derive(Clone, Debug)] -pub struct SiblingSubgraph<'g, Base> { - /// The underlying Hugr. - base: &'g Base, - /// The nodes of the induced subgraph. - nodes: Vec, - /// The input ports of the subgraph. - /// - /// Grouped by input parameter. Each port must be unique and belong to a - /// node in `nodes`. - inputs: Vec>, - /// The output ports of the subgraph. - /// - /// Repeated ports are allowed and correspond to copying the output. Every - /// port must belong to a node in `nodes`. - outputs: Vec<(Node, Port)>, +/// [`DescendantsGraph`]: super::DescendantsGraph +pub struct SiblingGraph<'g, Root = Node, Base = Hugr> +where + Base: HugrInternals, +{ + /// The chosen root node. + root: Node, + + /// The filtered portgraph encoding the adjacency structure of the HUGR. + graph: FlatRegionGraph<'g>, + + /// The rest of the HUGR. + hugr: &'g Base, + + /// The operation type of the root node. + _phantom: std::marker::PhantomData, } -/// The type of the incoming boundary of [`SiblingSubgraph`]. -pub type IncomingPorts = Vec>; -/// The type of the outgoing boundary of [`SiblingSubgraph`]. -pub type OutgoingPorts = Vec<(Node, Port)>; - -impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> { - /// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted HUGR. - /// - /// The subgraph is given by the nodes between the input and output - /// children nodes of the parent node. If you wish to create a subgraph - /// from another root, wrap the `region` argument in a [`super::SiblingGraph`]. - /// - /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the - /// subgraph is empty. - pub fn try_from_dataflow_graph(dfg_graph: &'g Base) -> Result - where - Base: Clone + HugrView, - Root: ContainerHandle, - { - let parent = dfg_graph.root(); - let nodes = dfg_graph.children(parent).skip(2).collect_vec(); - let (inputs, outputs) = get_input_output_ports(dfg_graph); - - validate_subgraph(dfg_graph, &nodes, &inputs, &outputs)?; - - if nodes.is_empty() { - Err(InvalidSubgraph::EmptySubgraph) - } else { - Ok(Self { - base: dfg_graph, - nodes, - inputs, - outputs, - }) - } +impl<'g, Root, Base> Clone for SiblingGraph<'g, Root, Base> +where + Root: NodeHandle, + Base: HugrInternals + HugrView, +{ + fn clone(&self) -> Self { + SiblingGraph::new(self.hugr, self.root) } +} - /// Create a new sibling subgraph from some boundary edges. - /// - /// Any sibling subgraph can be defined using two sets of boundary edges - /// $B_I$ and $B_O$, the incoming and outgoing boundary edges respectively. - /// Intuitively, the sibling subgraph is all the edges and nodes "between" - /// an edge of $B_I$ and an edge of $B_O$. - /// - /// ## Definition - /// - /// More formally, the sibling subgraph of a graph $G = (V, E)$ given - /// by sets of incoming and outoing boundary edges $B_I, B_O \subseteq E$ - /// is the graph given by the connected components of the graph - /// $G' = (V, E \ B_I \ B_O)$ that contain at least one node that is either - /// - the target of an incoming boundary edge, or - /// - the source of an outgoing boundary edge. - /// - /// A subgraph is well-formed if for every edge in the HUGR - /// - it is in $B_I$ if and only if it has a source outside of the subgraph - /// and a target inside of it, and - /// - it is in $B_O$ if and only if it has a source inside of the subgraph - /// and a target outside of it. - /// - /// ## Arguments - /// - /// The `incoming` and `outgoing` arguments give $B_I$ and $B_O$ respectively. - /// Incoming edges must be given by incoming ports and outgoing edges by - /// outgoing ports. The ordering of the incoming and outgoing ports defines - /// the signature of the subgraph. - /// - /// Incoming boundary ports must be unique and partitioned by input - /// parameter: two ports within the same set of the partition must be - /// copyable and will result in the input being copied. Outgoing - /// boundary ports are given in a list and can appear multiple times if - /// they are copyable, in which case the output will be copied. - /// - /// ## Errors - /// - /// This function fails if the subgraph is not convex, if the nodes - /// do not share a common parent or if the subgraph is empty. - pub fn try_from_boundary_ports( - base: &'g Base, - incoming: IncomingPorts, - outgoing: OutgoingPorts, - ) -> Result - where - Base: Clone + HugrView, - { - let mut checker = ConvexChecker::new(base); - Self::try_from_boundary_ports_with_checker(base, incoming, outgoing, &mut checker) - } +impl<'g, Root, Base> HugrView for SiblingGraph<'g, Root, Base> +where + Root: NodeHandle, + Base: HugrInternals + HugrView, +{ + type RootHandle = Root; - /// Create a new sibling subgraph from some boundary edges. - /// - /// Provide a [`ConvexChecker`] instance to avoid constructing one for - /// faster convexity check. If you do not have one, use - /// [`SiblingSubgraph::try_from_boundary_ports`]. - /// - /// Refer to [`SiblingSubgraph::try_from_boundary_ports`] for the full - /// documentation. - pub fn try_from_boundary_ports_with_checker( - base: &'g Base, - inputs: IncomingPorts, - outputs: OutgoingPorts, - checker: &mut ConvexChecker<'g, Base>, - ) -> Result + type Nodes<'a> = iter::Chain, MapInto, Node>> where - Base: Clone + HugrView, - { - let pg = base.portgraph(); - let to_pg = |(n, p): (Node, Port)| pg.port_index(n.index, p.offset).expect("invalid port"); - - // Ordering of the edges here is preserved and becomes ordering of the signature. - let subpg = Subgraph::new_subgraph( - pg.clone(), - inputs - .iter() - .flatten() - .copied() - .chain(outputs.iter().copied()) - .map(to_pg), - ); - let nodes = subpg.nodes_iter().map_into().collect_vec(); - - validate_subgraph(base, &nodes, &inputs, &outputs)?; - - if !subpg.is_convex_with_checker(&mut checker.0) { - return Err(InvalidSubgraph::NotConvex); - } + Self: 'a; - Ok(Self { - base, - nodes, - inputs, - outputs, - }) - } + type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> + where + Self: 'a; - /// Create a new convex sibling subgraph from a set of nodes. - /// - /// This fails if the set of nodes is not convex, nodes do not share a - /// common parent or the subgraph is empty. - pub fn try_new( - base: &'g Base, - nodes: Vec, - inputs: IncomingPorts, - outputs: OutgoingPorts, - ) -> Result + type Children<'a> = MapInto, Node> where - Base: HugrView, - { - let mut checker = ConvexChecker::new(base); - Self::try_new_with_checker(base, nodes, inputs, outputs, &mut checker) - } + Self: 'a; - /// Create a new convex sibling subgraph from a set of nodes. - /// - /// Provide a [`ConvexChecker`] instance to avoid constructing one for - /// faster convexity check. If you do not have one, use [`SiblingSubgraph::try_new`]. - /// - /// This fails if the set of nodes is not convex, nodes do not share a - /// common parent or the subgraph is empty. - pub fn try_new_with_checker( - base: &'g Base, - nodes: Vec, - inputs: IncomingPorts, - outputs: OutgoingPorts, - checker: &mut ConvexChecker<'g, Base>, - ) -> Result + type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> where - Base: HugrView, - { - validate_subgraph(base, &nodes, &inputs, &outputs)?; + Self: 'a; - if !checker.0.is_node_convex(nodes.iter().map(|n| n.index)) { - return Err(InvalidSubgraph::NotConvex); - } + type PortLinks<'a> = MapWithCtx< + as LinkView>::PortLinks<'a>, + &'a Self, + (Node, Port), + > where + Self: 'a; - Ok(Self { - base, - nodes, - inputs, - outputs, - }) - } + type NodeConnections<'a> = MapWithCtx< + as LinkView>::NodeConnections<'a>, + &'a Self, + [Port; 2], + > where + Self: 'a; - /// An iterator over the nodes in the subgraph. - pub fn nodes(&self) -> &[Node] { - &self.nodes + #[inline] + fn contains_node(&self, node: Node) -> bool { + self.graph.contains_node(node.index) } - /// The signature of the subgraph. - pub fn signature(&self) -> FunctionType - where - Base: HugrView, - { - let input = self - .inputs - .iter() - .map(|part| { - let &(n, p) = part.iter().next().expect("is non-empty"); - let sig = self.base.get_optype(n).signature(); - sig.get(p).cloned().expect("must be dataflow edge") - }) - .collect_vec(); - let output = self - .outputs - .iter() - .map(|&(n, p)| { - let sig = self.base.get_optype(n).signature(); - sig.get(p).cloned().expect("must be dataflow edge") - }) - .collect_vec(); - FunctionType::new(input, output) + #[inline] + fn get_parent(&self, node: Node) -> Option { + self.hugr.get_parent(node).filter(|&n| n == self.root) } - /// The parent of the sibling subgraph. - pub fn get_parent(&self) -> Node - where - Base: HugrView, - { - self.base - .get_parent(self.nodes[0]) - .expect("invalid subgraph") + #[inline] + fn get_optype(&self, node: Node) -> &OpType { + self.hugr.get_optype(node) } - /// Construct a [`SimpleReplacement`] to replace `self` with `replacement`. - /// - /// `replacement` must be a hugr with DFG root and its signature must - /// match the signature of the subgraph. - /// - /// May return one of the following five errors - /// - [`InvalidReplacement::InvalidDataflowGraph`]: the replacement - /// graph is not a [`crate::ops::OpTag::DataflowParent`]-rooted graph, - /// - [`InvalidReplacement::InvalidDataflowParent`]: the replacement does - /// not have an input and output node, - /// - [`InvalidReplacement::InvalidSignature`]: the signature of the - /// replacement DFG does not match the subgraph signature, or - /// - [`InvalidReplacement::NonConvexSubgraph`]: the sibling subgraph is not - /// convex. - /// - /// At the moment we do not support state order edges. If any are found in - /// the replacement graph, this will panic. - pub fn create_simple_replacement( - &self, - replacement: Hugr, - ) -> Result - where - Base: HugrView, - { - let removal = self.nodes().iter().copied().collect(); - - let rep_root = replacement.root(); - let dfg_optype = replacement.get_optype(rep_root); - if !OpTag::Dfg.is_superset(dfg_optype.tag()) { - return Err(InvalidReplacement::InvalidDataflowGraph); - } - let Some((rep_input, rep_output)) = replacement.children(rep_root).take(2).collect_tuple() - else { - return Err(InvalidReplacement::InvalidDataflowParent); - }; - if dfg_optype.signature() != self.signature() { - return Err(InvalidReplacement::InvalidSignature); - } - - // TODO: handle state order edges. For now panic if any are present. - // See https://github.com/CQCL-DEV/hugr/discussions/432 - let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p)); - let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p)); - let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = - rep_inputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); - let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = - rep_outputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); - let mut order_ports = in_order_ports.into_iter().chain(out_order_ports); - if order_ports.any(|(n, p)| is_order_edge(&replacement, n, p)) { - unimplemented!("Found state order edges in replacement graph"); - } - - let nu_inp = rep_inputs - .into_iter() - .zip_eq(&self.inputs) - .flat_map(|((rep_source_n, rep_source_p), self_targets)| { - replacement - .linked_ports(rep_source_n, rep_source_p) - .flat_map(move |rep_target| { - self_targets - .iter() - .map(move |&self_target| (rep_target, self_target)) - }) - }) - .collect(); - let nu_out = self - .outputs - .iter() - .zip_eq(rep_outputs) - .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| { - self.base - .linked_ports(self_source_n, self_source_p) - .map(move |self_target| (self_target, rep_target_p)) - }) - .collect(); - - Ok(SimpleReplacement::new( - self.get_parent(), - removal, - replacement, - nu_inp, - nu_out, - )) + #[inline] + fn get_nodetype(&self, node: Node) -> &NodeType { + self.hugr.get_nodetype(node) } -} -/// Precompute convexity information for a HUGR. -/// -/// This can be used when constructing multiple sibling subgraphs to speed up -/// convexity checking. -pub struct ConvexChecker<'g, Base: 'g + HugrView>( - portgraph::algorithms::ConvexChecker>, -); - -impl<'g, Base: HugrView> ConvexChecker<'g, Base> { - /// Create a new convexity checker. - pub fn new(base: &'g Base) -> Self { - let pg = base.portgraph(); - Self(portgraph::algorithms::ConvexChecker::new(pg)) + #[inline] + fn get_metadata(&self, node: Node) -> &NodeMetadata { + self.hugr.get_metadata(node) } -} -/// The type of all ports in the iterator. -/// -/// If the array is empty or a port does not exist, returns `None`. -fn get_edge_type(hugr: &H, ports: &[(Node, Port)]) -> Option { - let &(n, p) = ports.first()?; - let edge_t = hugr.get_optype(n).signature().get(p)?.clone(); - ports - .iter() - .all(|&(n, p)| hugr.get_optype(n).signature().get(p) == Some(&edge_t)) - .then_some(edge_t) -} - -/// Whether a subgraph is valid. -/// -/// Does NOT check for convexity. -fn validate_subgraph( - hugr: &H, - nodes: &[Node], - inputs: &IncomingPorts, - outputs: &OutgoingPorts, -) -> Result<(), InvalidSubgraph> { - // Check nodes is not empty - if nodes.is_empty() { - return Err(InvalidSubgraph::EmptySubgraph); - } - // Check all nodes share parent - if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() { - return Err(InvalidSubgraph::NoSharedParent); + #[inline] + fn node_count(&self) -> usize { + self.base_hugr().hierarchy.child_count(self.root.index) + 1 } - // Check there are no linked "other" ports - if inputs - .iter() - .flatten() - .chain(outputs) - .any(|&(n, p)| is_order_edge(hugr, n, p)) - { - unimplemented!("Linked other ports not supported at boundary") + #[inline] + fn edge_count(&self) -> usize { + // Faster implementation than filtering all the nodes in the internal graph. + self.nodes() + .map(|n| self.output_neighbours(n).count()) + .sum() } - // Check inputs are incoming ports and outputs are outgoing ports - if inputs - .iter() - .flatten() - .any(|(_, p)| p.direction() == Direction::Outgoing) - { - return Err(InvalidSubgraph::InvalidBoundary); - } - if outputs - .iter() - .any(|(_, p)| p.direction() == Direction::Incoming) - { - return Err(InvalidSubgraph::InvalidBoundary); + #[inline] + fn nodes(&self) -> Self::Nodes<'_> { + // Faster implementation than filtering all the nodes in the internal graph. + let children = self + .base_hugr() + .hierarchy + .children(self.root.index) + .map_into(); + iter::once(self.root).chain(children) } - let mut ports_inside = inputs.iter().flatten().chain(outputs).copied(); - let mut ports_outside = ports_inside - .clone() - .flat_map(|(n, p)| hugr.linked_ports(n, p)); - // Check incoming & outgoing ports have target resp. source inside - let nodes = nodes.iter().copied().collect::>(); - if ports_inside.any(|(n, _)| !nodes.contains(&n)) { - return Err(InvalidSubgraph::InvalidBoundary); - } - // Check incoming & outgoing ports have source resp. target outside - if ports_outside.any(|(n, _)| nodes.contains(&n)) { - return Err(InvalidSubgraph::NotConvex); + #[inline] + fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + self.graph.port_offsets(node.index, dir).map_into() } - // Check inputs are unique - if !inputs.iter().flatten().all_unique() { - return Err(InvalidSubgraph::InvalidBoundary); + #[inline] + fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + self.graph.all_port_offsets(node.index).map_into() } - // Check no incoming partition is empty - if inputs.iter().any(|p| p.is_empty()) { - return Err(InvalidSubgraph::InvalidBoundary); + fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { + let port = self.graph.port_index(node.index, port.offset).unwrap(); + self.graph + .port_links(port) + .with_context(self) + .map_with_context(|(_, link), region| { + let port: PortIndex = link.into(); + let node = region.graph.port_node(port).unwrap(); + let offset = region.graph.port_offset(port).unwrap(); + (node.into(), offset.into()) + }) } - // Check edge types are equal within partition and copyable if partition size > 1 - if !inputs.iter().all(|ports| { - let Some(edge_t) = get_edge_type(hugr, ports) else { - return false; - }; - let require_copy = ports.len() > 1; - !require_copy || edge_t.copyable() - }) { - return Err(InvalidSubgraph::InvalidBoundary); + fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + self.graph + .get_connections(node.index, other.index) + .with_context(self) + .map_with_context(|(p1, p2), hugr| { + [p1, p2].map(|link| { + let offset = hugr.graph.port_offset(link).unwrap(); + offset.into() + }) + }) } - Ok(()) -} - -fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPorts) { - let (inp, out) = hugr - .children(hugr.root()) - .take(2) - .collect_tuple() - .expect("invalid DFG"); - if has_other_edge(hugr, inp, Direction::Outgoing) { - unimplemented!("Non-dataflow output not supported at input node") + #[inline] + fn num_ports(&self, node: Node, dir: Direction) -> usize { + self.graph.num_ports(node.index, dir) } - let dfg_inputs = hugr.get_optype(inp).signature().output_ports(); - if has_other_edge(hugr, out, Direction::Incoming) { - unimplemented!("Non-dataflow input not supported at output node") - } - let dfg_outputs = hugr.get_optype(out).signature().input_ports(); - let inputs = dfg_inputs - .into_iter() - .map(|p| hugr.linked_ports(inp, p).collect()) - .collect(); - let outputs = dfg_outputs - .into_iter() - .map(|p| { - hugr.linked_ports(out, p) - .exactly_one() - .ok() - .expect("invalid DFG") - }) - .collect(); - (inputs, outputs) -} - -/// Whether a port is linked to a state order edge. -fn is_order_edge(hugr: &H, node: Node, port: Port) -> bool { - let op = hugr.get_optype(node); - op.other_port_index(port.direction()) == Some(port) && hugr.is_linked(node, port) -} - -/// Whether node has a non-df linked port in the given direction. -fn has_other_edge(hugr: &H, node: Node, dir: Direction) -> bool { - let op = hugr.get_optype(node); - op.other_port(dir).is_some() && hugr.is_linked(node, op.other_port_index(dir).unwrap()) -} - -/// Errors that can occur while constructing a [`SimpleReplacement`]. -#[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum InvalidReplacement { - /// No DataflowParent root in replacement graph. - #[error("No DataflowParent root in replacement graph.")] - InvalidDataflowGraph, - /// Malformed DataflowParent in replacement graph. - #[error("Malformed DataflowParent in replacement graph.")] - InvalidDataflowParent, - /// Replacement graph boundary size mismatch. - #[error("Replacement graph boundary size mismatch.")] - InvalidSignature, - /// SiblingSubgraph is not convex. - #[error("SiblingSubgraph is not convex.")] - NonConvexSubgraph, -} - -/// Errors that can occur while constructing a [`SiblingSubgraph`]. -#[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum InvalidSubgraph { - /// The subgraph is not convex. - #[error("The subgraph is not convex.")] - NotConvex, - /// Not all nodes have the same parent. - #[error("Not a sibling subgraph.")] - NoSharedParent, - /// Empty subgraphs are not supported. - #[error("Empty subgraphs are not supported.")] - EmptySubgraph, - /// An invalid boundary port was found. - #[error("Invalid boundary port.")] - InvalidBoundary, -} - -#[cfg(test)] -mod tests { - use crate::{ - builder::{ - BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - ModuleBuilder, - }, - extension::{ - prelude::{BOOL_T, QB_T}, - EMPTY_REG, - }, - hugr::views::{HierarchyView, SiblingGraph}, - ops::{ - handle::{FuncID, NodeHandle}, - OpType, - }, - std_extensions::{logic::test::and_op, quantum::test::cx_gate}, - type_row, - }; - use super::*; - - impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> { - /// A sibling subgraph from a HUGR. - /// - /// The subgraph is given by the sibling graph of the root. If you wish to - /// create a subgraph from another root, wrap the argument `region` in a - /// [`super::SiblingGraph`]. - /// - /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the - /// subgraph is empty. - fn from_sibling_graph(sibling_graph: &'g Base) -> Result - where - Base: HugrView, - { - let root = sibling_graph.root(); - let nodes = sibling_graph.children(root).collect_vec(); - if nodes.is_empty() { - Err(InvalidSubgraph::EmptySubgraph) - } else { - Ok(Self { - base: sibling_graph, - nodes, - inputs: Vec::new(), - outputs: Vec::new(), - }) - } + #[inline] + fn children(&self, node: Node) -> Self::Children<'_> { + match node == self.root { + true => self.base_hugr().hierarchy.children(node.index).map_into(), + false => portgraph::hierarchy::Children::default().map_into(), } } - fn build_hugr() -> Result<(Hugr, Node), BuildError> { - let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - FunctionType::new_linear(type_row![QB_T, QB_T]).pure(), - )?; - let func_id = { - let mut dfg = mod_builder.define_declaration(&func)?; - let outs = dfg.add_dataflow_op(cx_gate(), dfg.input_wires())?; - dfg.finish_with_outputs(outs.outputs())? - }; - let hugr = mod_builder - .finish_prelude_hugr() - .map_err(|e| -> BuildError { e.into() })?; - Ok((hugr, func_id.node())) + #[inline] + fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + self.graph.neighbours(node.index, dir).map_into() } - /// A HUGR with a copy - fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> { - let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).pure(), - )?; - let func_id = { - let mut dfg = mod_builder.define_declaration(&func)?; - let in_wire = dfg.input_wires().exactly_one().unwrap(); - let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?; - dfg.finish_with_outputs(outs.outputs())? - }; - let hugr = mod_builder - .finish_hugr(&EMPTY_REG) - .map_err(|e| -> BuildError { e.into() })?; - Ok((hugr, func_id.node())) + #[inline] + fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + self.graph.all_neighbours(node.index).map_into() } - #[test] - fn construct_subgraph() -> Result<(), InvalidSubgraph> { - let (hugr, func_root) = build_hugr().unwrap(); - let sibling_graph: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); - let from_root = SiblingSubgraph::from_sibling_graph(&sibling_graph)?; - let region: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); - let from_region = SiblingSubgraph::from_sibling_graph(®ion)?; - assert_eq!(from_root.get_parent(), from_region.get_parent()); - assert_eq!(from_root.signature(), from_region.signature()); - Ok(()) + #[inline] + fn get_io(&self, node: Node) -> Option<[Node; 2]> { + if node == self.root() { + self.base_hugr().get_io(node) + } else { + None + } } - #[test] - fn construct_simple_replacement() -> Result<(), InvalidSubgraph> { - let (mut hugr, func_root) = build_hugr().unwrap(); - let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); - let sub = SiblingSubgraph::try_from_dataflow_graph(&func)?; - - let empty_dfg = { - let builder = DFGBuilder::new(FunctionType::new_linear(type_row![QB_T, QB_T])).unwrap(); - let inputs = builder.input_wires(); - builder.finish_prelude_hugr_with_outputs(inputs).unwrap() - }; - - let rep = sub.create_simple_replacement(empty_dfg).unwrap(); + fn get_function_type(&self) -> Option<&crate::types::FunctionType> { + self.base_hugr().get_function_type() + } +} - assert_eq!(rep.removal.len(), 1); +impl<'a, Root, Base> HierarchyView<'a> for SiblingGraph<'a, Root, Base> +where + Root: NodeHandle, + Base: HugrView, +{ + type Base = Base; + + fn new(hugr: &'a Base, root: Node) -> Self { + let root_tag = hugr.get_optype(root).tag(); + if !Root::TAG.is_superset(root_tag) { + // TODO: Return an error + panic!("Root node must have the correct operation type tag.") + } + Self { + root, + graph: FlatRegionGraph::new_flat_region( + &hugr.base_hugr().graph, + &hugr.base_hugr().hierarchy, + root.index, + ), + hugr, + _phantom: std::marker::PhantomData, + } + } +} - assert_eq!(hugr.node_count(), 5); // Module + Def + In + CX + Out - hugr.apply_rewrite(rep).unwrap(); - assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out +impl<'g, Root, Base> HugrInternals for SiblingGraph<'g, Root, Base> +where + Root: NodeHandle, + Base: HugrInternals, +{ + type Portgraph<'p> = &'p FlatRegionGraph<'g> where Self: 'p; - Ok(()) + #[inline] + fn portgraph(&self) -> Self::Portgraph<'_> { + &self.graph } - #[test] - fn test_signature() -> Result<(), InvalidSubgraph> { - let (hugr, dfg) = build_hugr().unwrap(); - let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, dfg); - let sub = SiblingSubgraph::try_from_dataflow_graph(&func)?; - assert_eq!( - sub.signature(), - FunctionType::new_linear(type_row![QB_T, QB_T]) - ); - Ok(()) + #[inline] + fn base_hugr(&self) -> &Hugr { + self.hugr.base_hugr() } - #[test] - fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> { - let (hugr, dfg) = build_hugr().unwrap(); - let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, dfg); - let sub = SiblingSubgraph::from_sibling_graph(&func)?; - - let empty_dfg = { - let builder = DFGBuilder::new(FunctionType::new_linear(type_row![QB_T])).unwrap(); - let inputs = builder.input_wires(); - builder.finish_prelude_hugr_with_outputs(inputs).unwrap() - }; - - assert_eq!( - sub.create_simple_replacement(empty_dfg).unwrap_err(), - InvalidReplacement::InvalidSignature - ); - Ok(()) + #[inline] + fn root_node(&self) -> Node { + self.root } +} - #[test] - fn convex_subgraph() { - let (hugr, func_root) = build_hugr().unwrap(); - let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); - assert_eq!( - SiblingSubgraph::try_from_dataflow_graph(&func) - .unwrap() - .nodes() - .len(), - 1 - ) - } +#[cfg(test)] +mod test { + use super::super::descendants::test::make_module_hgr; + use super::*; #[test] - fn convex_subgraph_2() { - let (hugr, func_root) = build_hugr().unwrap(); - let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); - let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); - // All graph except input/output nodes - SiblingSubgraph::try_from_boundary_ports( - &func, - hugr.node_outputs(inp) - .map(|p| hugr.linked_ports(inp, p).collect_vec()) - .filter(|ps| !ps.is_empty()) - .collect(), - hugr.node_inputs(out) - .filter_map(|p| hugr.linked_ports(out, p).exactly_one().ok()) - .collect(), - ) - .unwrap(); - } + fn flat_region() -> Result<(), Box> { + let (hugr, def, inner) = make_module_hgr()?; - #[test] - fn degen_boundary() { - let (hugr, func_root) = build_hugr().unwrap(); - let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); - let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap(); - let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); - // All graph but one edge - assert!(matches!( - SiblingSubgraph::try_from_boundary_ports( - &func, - vec![hugr.linked_ports(inp, first_cx_edge).collect()], - vec![(inp, first_cx_edge)], - ), - Err(InvalidSubgraph::NotConvex) - )); - } + let region: SiblingGraph = SiblingGraph::new(&hugr, def); - #[test] - fn non_convex_subgraph() { - let (hugr, func_root) = build_hugr().unwrap(); - let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); - let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); - let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); - let snd_cx_edge = hugr.node_inputs(out).next().unwrap(); - // All graph but one edge - assert!(matches!( - SiblingSubgraph::try_from_boundary_ports( - &func, - vec![vec![(out, snd_cx_edge)]], - vec![(inp, first_cx_edge)], - ), - Err(InvalidSubgraph::NotConvex) - )); - } + assert_eq!(region.node_count(), 5); + assert!(region + .nodes() + .all(|n| n == def || hugr.get_parent(n) == Some(def))); + assert_eq!(region.children(inner).count(), 0); - #[test] - fn preserve_signature() { - let (hugr, func_root) = build_hugr_classical().unwrap(); - let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); - let func = SiblingSubgraph::try_from_dataflow_graph(&func).unwrap(); - let OpType::FuncDefn(func_defn) = hugr.get_optype(func_root) else { - panic!() - }; - assert_eq!(func_defn.signature, func.signature()) + Ok(()) } } diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs new file mode 100644 index 000000000..c344be6b2 --- /dev/null +++ b/src/hugr/views/sibling_subgraph.rs @@ -0,0 +1,790 @@ +//! Views for HUGR sibling subgraphs. +//! +//! Views into convex subgraphs of HUGRs within a single level of the +//! hierarchy, i.e. within a sibling graph. Convex subgraph are always +//! induced subgraphs, i.e. they are defined by a subset of the sibling nodes. +//! +//! Sibling subgraphs complement [`super::HierarchyView`]s in the sense that the +//! latter provide views for subgraphs defined by hierarchical relationships, +//! while the former provide views for subgraphs within a single level of the +//! hierarchy. + +use std::collections::HashSet; + +use itertools::Itertools; +use portgraph::{view::Subgraph, Direction, PortView}; +use thiserror::Error; + +use crate::{ + ops::{ + handle::{ContainerHandle, DataflowOpID}, + OpTag, OpTrait, + }, + types::{FunctionType, Type}, + Hugr, Node, Port, SimpleReplacement, +}; + +use super::HugrView; + +/// A non-empty convex subgraph of a HUGR sibling graph. +/// +/// A HUGR region in which all nodes share the same parent. Unlike +/// [`super::SiblingGraph`], not all nodes of the sibling graph must be +/// included. A convex subgraph is always an induced subgraph, i.e. it is defined +/// by a set of nodes and all edges between them. + +/// The incoming boundary (resp. outgoing boundary) is given by the input (resp. +/// output) ports of the subgraph that are linked to nodes outside of the subgraph. +/// The signature of the subgraph is then given by the types of the incoming +/// and outgoing boundary ports. Given a replacement with the same signature, +/// a [`SimpleReplacement`] can be constructed to rewrite the subgraph with the +/// replacement. +/// +/// The ordering of the nodes in the subgraph is irrelevant to define the convex +/// subgraph, but it determines the ordering of the boundary signature. +/// +/// At the moment we do not support state order edges at the subgraph boundary. +/// The `boundary_port` and `signature` methods will panic if any are found. +/// State order edges are also unsupported in replacements in +/// `create_simple_replacement`. +#[derive(Clone, Debug)] +pub struct SiblingSubgraph<'g, Base> { + /// The underlying Hugr. + base: &'g Base, + /// The nodes of the induced subgraph. + nodes: Vec, + /// The input ports of the subgraph. + /// + /// Grouped by input parameter. Each port must be unique and belong to a + /// node in `nodes`. + inputs: Vec>, + /// The output ports of the subgraph. + /// + /// Repeated ports are allowed and correspond to copying the output. Every + /// port must belong to a node in `nodes`. + outputs: Vec<(Node, Port)>, +} + +/// The type of the incoming boundary of [`SiblingSubgraph`]. +pub type IncomingPorts = Vec>; +/// The type of the outgoing boundary of [`SiblingSubgraph`]. +pub type OutgoingPorts = Vec<(Node, Port)>; + +impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> { + /// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted HUGR. + /// + /// The subgraph is given by the nodes between the input and output + /// children nodes of the parent node. If you wish to create a subgraph + /// from another root, wrap the `region` argument in a [`super::SiblingGraph`]. + /// + /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the + /// subgraph is empty. + pub fn try_from_dataflow_graph(dfg_graph: &'g Base) -> Result + where + Base: Clone + HugrView, + Root: ContainerHandle, + { + let parent = dfg_graph.root(); + let nodes = dfg_graph.children(parent).skip(2).collect_vec(); + let (inputs, outputs) = get_input_output_ports(dfg_graph); + + validate_subgraph(dfg_graph, &nodes, &inputs, &outputs)?; + + if nodes.is_empty() { + Err(InvalidSubgraph::EmptySubgraph) + } else { + Ok(Self { + base: dfg_graph, + nodes, + inputs, + outputs, + }) + } + } + + /// Create a new sibling subgraph from some boundary edges. + /// + /// Any sibling subgraph can be defined using two sets of boundary edges + /// $B_I$ and $B_O$, the incoming and outgoing boundary edges respectively. + /// Intuitively, the sibling subgraph is all the edges and nodes "between" + /// an edge of $B_I$ and an edge of $B_O$. + /// + /// ## Definition + /// + /// More formally, the sibling subgraph of a graph $G = (V, E)$ given + /// by sets of incoming and outoing boundary edges $B_I, B_O \subseteq E$ + /// is the graph given by the connected components of the graph + /// $G' = (V, E \ B_I \ B_O)$ that contain at least one node that is either + /// - the target of an incoming boundary edge, or + /// - the source of an outgoing boundary edge. + /// + /// A subgraph is well-formed if for every edge in the HUGR + /// - it is in $B_I$ if and only if it has a source outside of the subgraph + /// and a target inside of it, and + /// - it is in $B_O$ if and only if it has a source inside of the subgraph + /// and a target outside of it. + /// + /// ## Arguments + /// + /// The `incoming` and `outgoing` arguments give $B_I$ and $B_O$ respectively. + /// Incoming edges must be given by incoming ports and outgoing edges by + /// outgoing ports. The ordering of the incoming and outgoing ports defines + /// the signature of the subgraph. + /// + /// Incoming boundary ports must be unique and partitioned by input + /// parameter: two ports within the same set of the partition must be + /// copyable and will result in the input being copied. Outgoing + /// boundary ports are given in a list and can appear multiple times if + /// they are copyable, in which case the output will be copied. + /// + /// ## Errors + /// + /// This function fails if the subgraph is not convex, if the nodes + /// do not share a common parent or if the subgraph is empty. + pub fn try_from_boundary_ports( + base: &'g Base, + incoming: IncomingPorts, + outgoing: OutgoingPorts, + ) -> Result + where + Base: Clone + HugrView, + { + let mut checker = ConvexChecker::new(base); + Self::try_from_boundary_ports_with_checker(base, incoming, outgoing, &mut checker) + } + + /// Create a new sibling subgraph from some boundary edges. + /// + /// Provide a [`ConvexChecker`] instance to avoid constructing one for + /// faster convexity check. If you do not have one, use + /// [`SiblingSubgraph::try_from_boundary_ports`]. + /// + /// Refer to [`SiblingSubgraph::try_from_boundary_ports`] for the full + /// documentation. + pub fn try_from_boundary_ports_with_checker( + base: &'g Base, + inputs: IncomingPorts, + outputs: OutgoingPorts, + checker: &mut ConvexChecker<'g, Base>, + ) -> Result + where + Base: Clone + HugrView, + { + let pg = base.portgraph(); + let to_pg = |(n, p): (Node, Port)| pg.port_index(n.index, p.offset).expect("invalid port"); + + // Ordering of the edges here is preserved and becomes ordering of the signature. + let subpg = Subgraph::new_subgraph( + pg.clone(), + inputs + .iter() + .flatten() + .copied() + .chain(outputs.iter().copied()) + .map(to_pg), + ); + let nodes = subpg.nodes_iter().map_into().collect_vec(); + + validate_subgraph(base, &nodes, &inputs, &outputs)?; + + if !subpg.is_convex_with_checker(&mut checker.0) { + return Err(InvalidSubgraph::NotConvex); + } + + Ok(Self { + base, + nodes, + inputs, + outputs, + }) + } + + /// Create a new convex sibling subgraph from a set of nodes. + /// + /// This fails if the set of nodes is not convex, nodes do not share a + /// common parent or the subgraph is empty. + pub fn try_new( + base: &'g Base, + nodes: Vec, + inputs: IncomingPorts, + outputs: OutgoingPorts, + ) -> Result + where + Base: HugrView, + { + let mut checker = ConvexChecker::new(base); + Self::try_new_with_checker(base, nodes, inputs, outputs, &mut checker) + } + + /// Create a new convex sibling subgraph from a set of nodes. + /// + /// Provide a [`ConvexChecker`] instance to avoid constructing one for + /// faster convexity check. If you do not have one, use [`SiblingSubgraph::try_new`]. + /// + /// This fails if the set of nodes is not convex, nodes do not share a + /// common parent or the subgraph is empty. + pub fn try_new_with_checker( + base: &'g Base, + nodes: Vec, + inputs: IncomingPorts, + outputs: OutgoingPorts, + checker: &mut ConvexChecker<'g, Base>, + ) -> Result + where + Base: HugrView, + { + validate_subgraph(base, &nodes, &inputs, &outputs)?; + + if !checker.0.is_node_convex(nodes.iter().map(|n| n.index)) { + return Err(InvalidSubgraph::NotConvex); + } + + Ok(Self { + base, + nodes, + inputs, + outputs, + }) + } + + /// An iterator over the nodes in the subgraph. + pub fn nodes(&self) -> &[Node] { + &self.nodes + } + + /// The signature of the subgraph. + pub fn signature(&self) -> FunctionType + where + Base: HugrView, + { + let input = self + .inputs + .iter() + .map(|part| { + let &(n, p) = part.iter().next().expect("is non-empty"); + let sig = self.base.get_optype(n).signature(); + sig.get(p).cloned().expect("must be dataflow edge") + }) + .collect_vec(); + let output = self + .outputs + .iter() + .map(|&(n, p)| { + let sig = self.base.get_optype(n).signature(); + sig.get(p).cloned().expect("must be dataflow edge") + }) + .collect_vec(); + FunctionType::new(input, output) + } + + /// The parent of the sibling subgraph. + pub fn get_parent(&self) -> Node + where + Base: HugrView, + { + self.base + .get_parent(self.nodes[0]) + .expect("invalid subgraph") + } + + /// Construct a [`SimpleReplacement`] to replace `self` with `replacement`. + /// + /// `replacement` must be a hugr with DFG root and its signature must + /// match the signature of the subgraph. + /// + /// May return one of the following five errors + /// - [`InvalidReplacement::InvalidDataflowGraph`]: the replacement + /// graph is not a [`crate::ops::OpTag::DataflowParent`]-rooted graph, + /// - [`InvalidReplacement::InvalidDataflowParent`]: the replacement does + /// not have an input and output node, + /// - [`InvalidReplacement::InvalidSignature`]: the signature of the + /// replacement DFG does not match the subgraph signature, or + /// - [`InvalidReplacement::NonConvexSubgraph`]: the sibling subgraph is not + /// convex. + /// + /// At the moment we do not support state order edges. If any are found in + /// the replacement graph, this will panic. + pub fn create_simple_replacement( + &self, + replacement: Hugr, + ) -> Result + where + Base: HugrView, + { + let removal = self.nodes().iter().copied().collect(); + + let rep_root = replacement.root(); + let dfg_optype = replacement.get_optype(rep_root); + if !OpTag::Dfg.is_superset(dfg_optype.tag()) { + return Err(InvalidReplacement::InvalidDataflowGraph); + } + let Some((rep_input, rep_output)) = replacement.children(rep_root).take(2).collect_tuple() + else { + return Err(InvalidReplacement::InvalidDataflowParent); + }; + if dfg_optype.signature() != self.signature() { + return Err(InvalidReplacement::InvalidSignature); + } + + // TODO: handle state order edges. For now panic if any are present. + // See https://github.com/CQCL-DEV/hugr/discussions/432 + let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p)); + let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p)); + let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = + rep_inputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); + let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = + rep_outputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); + let mut order_ports = in_order_ports.into_iter().chain(out_order_ports); + if order_ports.any(|(n, p)| is_order_edge(&replacement, n, p)) { + unimplemented!("Found state order edges in replacement graph"); + } + + let nu_inp = rep_inputs + .into_iter() + .zip_eq(&self.inputs) + .flat_map(|((rep_source_n, rep_source_p), self_targets)| { + replacement + .linked_ports(rep_source_n, rep_source_p) + .flat_map(move |rep_target| { + self_targets + .iter() + .map(move |&self_target| (rep_target, self_target)) + }) + }) + .collect(); + let nu_out = self + .outputs + .iter() + .zip_eq(rep_outputs) + .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| { + self.base + .linked_ports(self_source_n, self_source_p) + .map(move |self_target| (self_target, rep_target_p)) + }) + .collect(); + + Ok(SimpleReplacement::new( + self.get_parent(), + removal, + replacement, + nu_inp, + nu_out, + )) + } +} + +/// Precompute convexity information for a HUGR. +/// +/// This can be used when constructing multiple sibling subgraphs to speed up +/// convexity checking. +pub struct ConvexChecker<'g, Base: 'g + HugrView>( + portgraph::algorithms::ConvexChecker>, +); + +impl<'g, Base: HugrView> ConvexChecker<'g, Base> { + /// Create a new convexity checker. + pub fn new(base: &'g Base) -> Self { + let pg = base.portgraph(); + Self(portgraph::algorithms::ConvexChecker::new(pg)) + } +} + +/// The type of all ports in the iterator. +/// +/// If the array is empty or a port does not exist, returns `None`. +fn get_edge_type(hugr: &H, ports: &[(Node, Port)]) -> Option { + let &(n, p) = ports.first()?; + let edge_t = hugr.get_optype(n).signature().get(p)?.clone(); + ports + .iter() + .all(|&(n, p)| hugr.get_optype(n).signature().get(p) == Some(&edge_t)) + .then_some(edge_t) +} + +/// Whether a subgraph is valid. +/// +/// Does NOT check for convexity. +fn validate_subgraph( + hugr: &H, + nodes: &[Node], + inputs: &IncomingPorts, + outputs: &OutgoingPorts, +) -> Result<(), InvalidSubgraph> { + // Check nodes is not empty + if nodes.is_empty() { + return Err(InvalidSubgraph::EmptySubgraph); + } + // Check all nodes share parent + if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() { + return Err(InvalidSubgraph::NoSharedParent); + } + + // Check there are no linked "other" ports + if inputs + .iter() + .flatten() + .chain(outputs) + .any(|&(n, p)| is_order_edge(hugr, n, p)) + { + unimplemented!("Linked other ports not supported at boundary") + } + + // Check inputs are incoming ports and outputs are outgoing ports + if inputs + .iter() + .flatten() + .any(|(_, p)| p.direction() == Direction::Outgoing) + { + return Err(InvalidSubgraph::InvalidBoundary); + } + if outputs + .iter() + .any(|(_, p)| p.direction() == Direction::Incoming) + { + return Err(InvalidSubgraph::InvalidBoundary); + } + + let mut ports_inside = inputs.iter().flatten().chain(outputs).copied(); + let mut ports_outside = ports_inside + .clone() + .flat_map(|(n, p)| hugr.linked_ports(n, p)); + // Check incoming & outgoing ports have target resp. source inside + let nodes = nodes.iter().copied().collect::>(); + if ports_inside.any(|(n, _)| !nodes.contains(&n)) { + return Err(InvalidSubgraph::InvalidBoundary); + } + // Check incoming & outgoing ports have source resp. target outside + if ports_outside.any(|(n, _)| nodes.contains(&n)) { + return Err(InvalidSubgraph::NotConvex); + } + + // Check inputs are unique + if !inputs.iter().flatten().all_unique() { + return Err(InvalidSubgraph::InvalidBoundary); + } + + // Check no incoming partition is empty + if inputs.iter().any(|p| p.is_empty()) { + return Err(InvalidSubgraph::InvalidBoundary); + } + + // Check edge types are equal within partition and copyable if partition size > 1 + if !inputs.iter().all(|ports| { + let Some(edge_t) = get_edge_type(hugr, ports) else { + return false; + }; + let require_copy = ports.len() > 1; + !require_copy || edge_t.copyable() + }) { + return Err(InvalidSubgraph::InvalidBoundary); + } + + Ok(()) +} + +fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPorts) { + let (inp, out) = hugr + .children(hugr.root()) + .take(2) + .collect_tuple() + .expect("invalid DFG"); + if has_other_edge(hugr, inp, Direction::Outgoing) { + unimplemented!("Non-dataflow output not supported at input node") + } + let dfg_inputs = hugr.get_optype(inp).signature().output_ports(); + if has_other_edge(hugr, out, Direction::Incoming) { + unimplemented!("Non-dataflow input not supported at output node") + } + let dfg_outputs = hugr.get_optype(out).signature().input_ports(); + let inputs = dfg_inputs + .into_iter() + .map(|p| hugr.linked_ports(inp, p).collect()) + .collect(); + let outputs = dfg_outputs + .into_iter() + .map(|p| { + hugr.linked_ports(out, p) + .exactly_one() + .ok() + .expect("invalid DFG") + }) + .collect(); + (inputs, outputs) +} + +/// Whether a port is linked to a state order edge. +fn is_order_edge(hugr: &H, node: Node, port: Port) -> bool { + let op = hugr.get_optype(node); + op.other_port_index(port.direction()) == Some(port) && hugr.is_linked(node, port) +} + +/// Whether node has a non-df linked port in the given direction. +fn has_other_edge(hugr: &H, node: Node, dir: Direction) -> bool { + let op = hugr.get_optype(node); + op.other_port(dir).is_some() && hugr.is_linked(node, op.other_port_index(dir).unwrap()) +} + +/// Errors that can occur while constructing a [`SimpleReplacement`]. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum InvalidReplacement { + /// No DataflowParent root in replacement graph. + #[error("No DataflowParent root in replacement graph.")] + InvalidDataflowGraph, + /// Malformed DataflowParent in replacement graph. + #[error("Malformed DataflowParent in replacement graph.")] + InvalidDataflowParent, + /// Replacement graph boundary size mismatch. + #[error("Replacement graph boundary size mismatch.")] + InvalidSignature, + /// SiblingSubgraph is not convex. + #[error("SiblingSubgraph is not convex.")] + NonConvexSubgraph, +} + +/// Errors that can occur while constructing a [`SiblingSubgraph`]. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum InvalidSubgraph { + /// The subgraph is not convex. + #[error("The subgraph is not convex.")] + NotConvex, + /// Not all nodes have the same parent. + #[error("Not a sibling subgraph.")] + NoSharedParent, + /// Empty subgraphs are not supported. + #[error("Empty subgraphs are not supported.")] + EmptySubgraph, + /// An invalid boundary port was found. + #[error("Invalid boundary port.")] + InvalidBoundary, +} + +#[cfg(test)] +mod tests { + use crate::{ + builder::{ + BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + ModuleBuilder, + }, + extension::{ + prelude::{BOOL_T, QB_T}, + EMPTY_REG, + }, + hugr::views::{HierarchyView, SiblingGraph}, + ops::{ + handle::{FuncID, NodeHandle}, + OpType, + }, + std_extensions::{logic::test::and_op, quantum::test::cx_gate}, + type_row, + }; + + use super::*; + + impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> { + /// A sibling subgraph from a HUGR. + /// + /// The subgraph is given by the sibling graph of the root. If you wish to + /// create a subgraph from another root, wrap the argument `region` in a + /// [`super::SiblingGraph`]. + /// + /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the + /// subgraph is empty. + fn from_sibling_graph(sibling_graph: &'g Base) -> Result + where + Base: HugrView, + { + let root = sibling_graph.root(); + let nodes = sibling_graph.children(root).collect_vec(); + if nodes.is_empty() { + Err(InvalidSubgraph::EmptySubgraph) + } else { + Ok(Self { + base: sibling_graph, + nodes, + inputs: Vec::new(), + outputs: Vec::new(), + }) + } + } + } + + fn build_hugr() -> Result<(Hugr, Node), BuildError> { + let mut mod_builder = ModuleBuilder::new(); + let func = mod_builder.declare( + "test", + FunctionType::new_linear(type_row![QB_T, QB_T]).pure(), + )?; + let func_id = { + let mut dfg = mod_builder.define_declaration(&func)?; + let outs = dfg.add_dataflow_op(cx_gate(), dfg.input_wires())?; + dfg.finish_with_outputs(outs.outputs())? + }; + let hugr = mod_builder + .finish_prelude_hugr() + .map_err(|e| -> BuildError { e.into() })?; + Ok((hugr, func_id.node())) + } + + /// A HUGR with a copy + fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> { + let mut mod_builder = ModuleBuilder::new(); + let func = mod_builder.declare( + "test", + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).pure(), + )?; + let func_id = { + let mut dfg = mod_builder.define_declaration(&func)?; + let in_wire = dfg.input_wires().exactly_one().unwrap(); + let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?; + dfg.finish_with_outputs(outs.outputs())? + }; + let hugr = mod_builder + .finish_hugr(&EMPTY_REG) + .map_err(|e| -> BuildError { e.into() })?; + Ok((hugr, func_id.node())) + } + + #[test] + fn construct_subgraph() -> Result<(), InvalidSubgraph> { + let (hugr, func_root) = build_hugr().unwrap(); + let sibling_graph: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); + let from_root = SiblingSubgraph::from_sibling_graph(&sibling_graph)?; + let region: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); + let from_region = SiblingSubgraph::from_sibling_graph(®ion)?; + assert_eq!(from_root.get_parent(), from_region.get_parent()); + assert_eq!(from_root.signature(), from_region.signature()); + Ok(()) + } + + #[test] + fn construct_simple_replacement() -> Result<(), InvalidSubgraph> { + let (mut hugr, func_root) = build_hugr().unwrap(); + let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); + let sub = SiblingSubgraph::try_from_dataflow_graph(&func)?; + + let empty_dfg = { + let builder = DFGBuilder::new(FunctionType::new_linear(type_row![QB_T, QB_T])).unwrap(); + let inputs = builder.input_wires(); + builder.finish_prelude_hugr_with_outputs(inputs).unwrap() + }; + + let rep = sub.create_simple_replacement(empty_dfg).unwrap(); + + assert_eq!(rep.removal.len(), 1); + + assert_eq!(hugr.node_count(), 5); // Module + Def + In + CX + Out + hugr.apply_rewrite(rep).unwrap(); + assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out + + Ok(()) + } + + #[test] + fn test_signature() -> Result<(), InvalidSubgraph> { + let (hugr, dfg) = build_hugr().unwrap(); + let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, dfg); + let sub = SiblingSubgraph::try_from_dataflow_graph(&func)?; + assert_eq!( + sub.signature(), + FunctionType::new_linear(type_row![QB_T, QB_T]) + ); + Ok(()) + } + + #[test] + fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> { + let (hugr, dfg) = build_hugr().unwrap(); + let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, dfg); + let sub = SiblingSubgraph::from_sibling_graph(&func)?; + + let empty_dfg = { + let builder = DFGBuilder::new(FunctionType::new_linear(type_row![QB_T])).unwrap(); + let inputs = builder.input_wires(); + builder.finish_prelude_hugr_with_outputs(inputs).unwrap() + }; + + assert_eq!( + sub.create_simple_replacement(empty_dfg).unwrap_err(), + InvalidReplacement::InvalidSignature + ); + Ok(()) + } + + #[test] + fn convex_subgraph() { + let (hugr, func_root) = build_hugr().unwrap(); + let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); + assert_eq!( + SiblingSubgraph::try_from_dataflow_graph(&func) + .unwrap() + .nodes() + .len(), + 1 + ) + } + + #[test] + fn convex_subgraph_2() { + let (hugr, func_root) = build_hugr().unwrap(); + let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); + // All graph except input/output nodes + SiblingSubgraph::try_from_boundary_ports( + &func, + hugr.node_outputs(inp) + .map(|p| hugr.linked_ports(inp, p).collect_vec()) + .filter(|ps| !ps.is_empty()) + .collect(), + hugr.node_inputs(out) + .filter_map(|p| hugr.linked_ports(out, p).exactly_one().ok()) + .collect(), + ) + .unwrap(); + } + + #[test] + fn degen_boundary() { + let (hugr, func_root) = build_hugr().unwrap(); + let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); + let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); + // All graph but one edge + assert!(matches!( + SiblingSubgraph::try_from_boundary_ports( + &func, + vec![hugr.linked_ports(inp, first_cx_edge).collect()], + vec![(inp, first_cx_edge)], + ), + Err(InvalidSubgraph::NotConvex) + )); + } + + #[test] + fn non_convex_subgraph() { + let (hugr, func_root) = build_hugr().unwrap(); + let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); + let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); + let snd_cx_edge = hugr.node_inputs(out).next().unwrap(); + // All graph but one edge + assert!(matches!( + SiblingSubgraph::try_from_boundary_ports( + &func, + vec![vec![(out, snd_cx_edge)]], + vec![(inp, first_cx_edge)], + ), + Err(InvalidSubgraph::NotConvex) + )); + } + + #[test] + fn preserve_signature() { + let (hugr, func_root) = build_hugr_classical().unwrap(); + let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); + let func = SiblingSubgraph::try_from_dataflow_graph(&func).unwrap(); + let OpType::FuncDefn(func_defn) = hugr.get_optype(func_root) else { + panic!() + }; + assert_eq!(func_defn.signature, func.signature()) + } +}