From e2213ba0f5814b508904cd0f4c64f96d2ede31db Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Sep 2023 09:51:42 +0100 Subject: [PATCH] feat: insert_hugr/insert_view return node map (#535) Also corresponding builder methods. `insert_hugr_internal` already returns this, so just wrap in a new `struct InsertionResult` translating the NodeIndex's into Nodes. --- src/builder/build_traits.rs | 9 ++++--- src/hugr/hugrmut.rs | 46 ++++++++++++++++++++++++++------- src/hugr/rewrite/outline_cfg.rs | 36 +++++++++++++------------- 3 files changed, 59 insertions(+), 32 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index f444262de..1f90d142d 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -1,3 +1,4 @@ +use crate::hugr::hugrmut::InsertionResult; use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::views::HugrView; use crate::hugr::{Node, NodeMetadata, Port, ValidationError}; @@ -108,13 +109,13 @@ pub trait Container { } /// Insert a HUGR as a child of the container. - fn add_hugr(&mut self, child: Hugr) -> Result { + fn add_hugr(&mut self, child: Hugr) -> Result { let parent = self.container_node(); Ok(self.hugr_mut().insert_hugr(parent, child)?) } /// Insert a copy of a HUGR as a child of the container. - fn add_hugr_view(&mut self, child: &impl HugrView) -> Result { + fn add_hugr_view(&mut self, child: &impl HugrView) -> Result { let parent = self.container_node(); Ok(self.hugr_mut().insert_from_view(parent, child)?) } @@ -230,7 +231,7 @@ pub trait Dataflow: Container { input_wires: impl IntoIterator, ) -> Result, BuildError> { let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); - let node = self.add_hugr(hugr)?; + let node = self.add_hugr(hugr)?.new_root; let inputs = input_wires.into_iter().collect(); wire_up_inputs(inputs, node, self)?; @@ -251,7 +252,7 @@ pub trait Dataflow: Container { input_wires: impl IntoIterator, ) -> Result, BuildError> { let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); - let node = self.add_hugr_view(hugr)?; + let node = self.add_hugr_view(hugr)?.new_root; let inputs = input_wires.into_iter().collect(); wire_up_inputs(inputs, node, self)?; diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 352ef1164..0070ef92c 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -141,19 +141,19 @@ pub trait HugrMut: HugrView + HugrMutInternals { } /// Insert another hugr into this one, under a given root node. - /// - /// Returns the root node of the inserted hugr. #[inline] - fn insert_hugr(&mut self, root: Node, other: Hugr) -> Result { + fn insert_hugr(&mut self, root: Node, other: Hugr) -> Result { self.valid_node(root)?; self.hugr_mut().insert_hugr(root, other) } /// Copy another hugr into this one, under a given root node. - /// - /// Returns the root node of the inserted hugr. #[inline] - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> Result { + fn insert_from_view( + &mut self, + root: Node, + other: &impl HugrView, + ) -> Result { self.valid_node(root)?; self.hugr_mut().insert_from_view(root, other) } @@ -167,6 +167,26 @@ pub trait HugrMut: HugrView + HugrMutInternals { } } +/// Records the result of inserting a Hugr or view +/// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view] +pub struct InsertionResult { + /// The node, after insertion, that was the root of the inserted Hugr. + /// (That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root])) + pub new_root: Node, + /// Map from nodes in the Hugr/view that was inserted, to their new + /// positions in the Hugr into which said was inserted. + pub node_map: HashMap, +} + +impl InsertionResult { + fn translating_indices(new_root: Node, node_map: HashMap) -> Self { + Self { + new_root, + node_map: HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into()))), + } + } +} + /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. impl HugrMut for T where @@ -258,7 +278,7 @@ where Ok((src_port, dst_port)) } - fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> Result { + fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> Result { let (other_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other)?; // Update the optypes and metadata, taking them from the other graph. for (&node, &new_node) in node_map.iter() { @@ -267,10 +287,15 @@ where let meta = other.metadata.take(node); self.as_mut().set_metadata(node.into(), meta).unwrap(); } - Ok(other_root) + debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index)); + Ok(InsertionResult::translating_indices(other_root, node_map)) } - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> Result { + fn insert_from_view( + &mut self, + root: Node, + other: &impl HugrView, + ) -> Result { let (other_root, node_map) = insert_hugr_internal(self.as_mut(), root, other)?; // Update the optypes and metadata, copying them from the other graph. for (&node, &new_node) in node_map.iter() { @@ -281,7 +306,8 @@ where .set_metadata(node.into(), meta.clone()) .unwrap(); } - Ok(other_root) + debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index)); + Ok(InsertionResult::translating_indices(other_root, node_map)) } } diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index a93ee56e4..b55b215cc 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -9,7 +9,8 @@ use crate::extension::ExtensionSet; use crate::hugr::rewrite::Rewrite; use crate::hugr::{HugrMut, HugrView}; use crate::ops; -use crate::ops::{BasicBlock, OpTag, OpTrait, OpType}; +use crate::ops::handle::NodeHandle; +use crate::ops::{BasicBlock, OpTrait, OpType}; use crate::{type_row, Node}; /// Moves part of a Control-flow Sibling Graph into a new CFG-node @@ -114,7 +115,7 @@ impl Rewrite for OutlineCfg { let outer_entry = h.children(outer_cfg).next().unwrap(); // 2. new_block contains input node, sub-cfg, exit node all connected - let new_block = { + let (new_block, cfg_node) = { let input_extensions = h.get_nodetype(entry).input_extensions().cloned(); let mut new_block_bldr = BlockBuilder::new( inputs.clone(), @@ -130,26 +131,24 @@ impl Rewrite for OutlineCfg { let cfg = new_block_bldr .cfg_builder(wires_in, input_extensions, outputs, extension_delta) .unwrap(); - let cfg_outputs = cfg.finish_sub_container().unwrap().outputs(); + let cfg = cfg.finish_sub_container().unwrap(); let predicate = new_block_bldr .add_constant(ops::Const::simple_unary_predicate(), ExtensionSet::new()) .unwrap(); let pred_wire = new_block_bldr.load_const(&predicate).unwrap(); - new_block_bldr.set_outputs(pred_wire, cfg_outputs).unwrap(); - h.insert_hugr(outer_cfg, new_block_bldr.hugr().clone()) - .unwrap() + new_block_bldr + .set_outputs(pred_wire, cfg.outputs()) + .unwrap(); + let ins_res = h + .insert_hugr(outer_cfg, new_block_bldr.hugr().clone()) + .unwrap(); + ( + ins_res.new_root, + *ins_res.node_map.get(&cfg.node()).unwrap(), + ) }; - // 3. Extract Cfg node created above (it moved when we called insert_hugr) - let cfg_node = h - .children(new_block) - .filter(|n| h.get_optype(*n).tag() == OpTag::Cfg) - .exactly_one() - .ok() // HugrMut::Children is not Debug - .unwrap(); - let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); - - // 4. Entry edges. Change any edges into entry_block from outside, to target new_block + // 3. Entry edges. Change any edges into entry_block from outside, to target new_block let preds: Vec<_> = h .linked_ports(entry, h.node_inputs(entry).exactly_one().ok().unwrap()) .collect(); @@ -165,7 +164,8 @@ impl Rewrite for OutlineCfg { h.move_before_sibling(new_block, outer_entry).unwrap(); } - // 5. Children of new CFG. + // 4. Children of new CFG. + let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); // Entry node must be first h.move_before_sibling(entry, inner_exit).unwrap(); // And remaining nodes @@ -176,7 +176,7 @@ impl Rewrite for OutlineCfg { } } - // 6. Exit edges. + // 5. Exit edges. // Retarget edge from exit_node (that used to target outside) to inner_exit let exit_port = h .node_outputs(exit)