Skip to content

Commit

Permalink
feat: insert_hugr/insert_view return node map (#535)
Browse files Browse the repository at this point in the history
Also corresponding builder methods. `insert_hugr_internal` already
returns this, so just wrap in a new `struct InsertionResult` translating
the NodeIndex's into Nodes.
  • Loading branch information
acl-cqc authored Sep 14, 2023
1 parent 79de213 commit e2213ba
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 32 deletions.
9 changes: 5 additions & 4 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{Node, NodeMetadata, Port, ValidationError};
Expand Down Expand Up @@ -108,13 +109,13 @@ pub trait Container {
}

/// Insert a HUGR as a child of the container.
fn add_hugr(&mut self, child: Hugr) -> Result<Node, BuildError> {
fn add_hugr(&mut self, child: Hugr) -> Result<InsertionResult, BuildError> {
let parent = self.container_node();
Ok(self.hugr_mut().insert_hugr(parent, child)?)
}

/// Insert a copy of a HUGR as a child of the container.
fn add_hugr_view(&mut self, child: &impl HugrView) -> Result<Node, BuildError> {
fn add_hugr_view(&mut self, child: &impl HugrView) -> Result<InsertionResult, BuildError> {
let parent = self.container_node();
Ok(self.hugr_mut().insert_from_view(parent, child)?)
}
Expand Down Expand Up @@ -230,7 +231,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr(hugr)?;
let node = self.add_hugr(hugr)?.new_root;

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand All @@ -251,7 +252,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr_view(hugr)?;
let node = self.add_hugr_view(hugr)?.new_root;

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand Down
46 changes: 36 additions & 10 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,19 @@ pub trait HugrMut: HugrView + HugrMutInternals {
}

/// Insert another hugr into this one, under a given root node.
///
/// Returns the root node of the inserted hugr.
#[inline]
fn insert_hugr(&mut self, root: Node, other: Hugr) -> Result<Node, HugrError> {
fn insert_hugr(&mut self, root: Node, other: Hugr) -> Result<InsertionResult, HugrError> {
self.valid_node(root)?;
self.hugr_mut().insert_hugr(root, other)
}

/// Copy another hugr into this one, under a given root node.
///
/// Returns the root node of the inserted hugr.
#[inline]
fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> Result<Node, HugrError> {
fn insert_from_view(
&mut self,
root: Node,
other: &impl HugrView,
) -> Result<InsertionResult, HugrError> {
self.valid_node(root)?;
self.hugr_mut().insert_from_view(root, other)
}
Expand All @@ -167,6 +167,26 @@ pub trait HugrMut: HugrView + HugrMutInternals {
}
}

/// Records the result of inserting a Hugr or view
/// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]
pub struct InsertionResult {
/// The node, after insertion, that was the root of the inserted Hugr.
/// (That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root]))
pub new_root: Node,
/// Map from nodes in the Hugr/view that was inserted, to their new
/// positions in the Hugr into which said was inserted.
pub node_map: HashMap<Node, Node>,
}

impl InsertionResult {
fn translating_indices(new_root: Node, node_map: HashMap<NodeIndex, NodeIndex>) -> Self {
Self {
new_root,
node_map: HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into()))),
}
}
}

/// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr.
impl<T> HugrMut for T
where
Expand Down Expand Up @@ -258,7 +278,7 @@ where
Ok((src_port, dst_port))
}

fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> Result<Node, HugrError> {
fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> Result<InsertionResult, HugrError> {
let (other_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other)?;
// Update the optypes and metadata, taking them from the other graph.
for (&node, &new_node) in node_map.iter() {
Expand All @@ -267,10 +287,15 @@ where
let meta = other.metadata.take(node);
self.as_mut().set_metadata(node.into(), meta).unwrap();
}
Ok(other_root)
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(other_root, node_map))
}

fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> Result<Node, HugrError> {
fn insert_from_view(
&mut self,
root: Node,
other: &impl HugrView,
) -> Result<InsertionResult, HugrError> {
let (other_root, node_map) = insert_hugr_internal(self.as_mut(), root, other)?;
// Update the optypes and metadata, copying them from the other graph.
for (&node, &new_node) in node_map.iter() {
Expand All @@ -281,7 +306,8 @@ where
.set_metadata(node.into(), meta.clone())
.unwrap();
}
Ok(other_root)
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(other_root, node_map))
}
}

Expand Down
36 changes: 18 additions & 18 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use crate::extension::ExtensionSet;
use crate::hugr::rewrite::Rewrite;
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
use crate::ops::{BasicBlock, OpTag, OpTrait, OpType};
use crate::ops::handle::NodeHandle;
use crate::ops::{BasicBlock, OpTrait, OpType};
use crate::{type_row, Node};

/// Moves part of a Control-flow Sibling Graph into a new CFG-node
Expand Down Expand Up @@ -114,7 +115,7 @@ impl Rewrite for OutlineCfg {
let outer_entry = h.children(outer_cfg).next().unwrap();

// 2. new_block contains input node, sub-cfg, exit node all connected
let new_block = {
let (new_block, cfg_node) = {
let input_extensions = h.get_nodetype(entry).input_extensions().cloned();
let mut new_block_bldr = BlockBuilder::new(
inputs.clone(),
Expand All @@ -130,26 +131,24 @@ impl Rewrite for OutlineCfg {
let cfg = new_block_bldr
.cfg_builder(wires_in, input_extensions, outputs, extension_delta)
.unwrap();
let cfg_outputs = cfg.finish_sub_container().unwrap().outputs();
let cfg = cfg.finish_sub_container().unwrap();
let predicate = new_block_bldr
.add_constant(ops::Const::simple_unary_predicate(), ExtensionSet::new())
.unwrap();
let pred_wire = new_block_bldr.load_const(&predicate).unwrap();
new_block_bldr.set_outputs(pred_wire, cfg_outputs).unwrap();
h.insert_hugr(outer_cfg, new_block_bldr.hugr().clone())
.unwrap()
new_block_bldr
.set_outputs(pred_wire, cfg.outputs())
.unwrap();
let ins_res = h
.insert_hugr(outer_cfg, new_block_bldr.hugr().clone())
.unwrap();
(
ins_res.new_root,
*ins_res.node_map.get(&cfg.node()).unwrap(),
)
};

// 3. Extract Cfg node created above (it moved when we called insert_hugr)
let cfg_node = h
.children(new_block)
.filter(|n| h.get_optype(*n).tag() == OpTag::Cfg)
.exactly_one()
.ok() // HugrMut::Children is not Debug
.unwrap();
let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();

// 4. Entry edges. Change any edges into entry_block from outside, to target new_block
// 3. Entry edges. Change any edges into entry_block from outside, to target new_block
let preds: Vec<_> = h
.linked_ports(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
.collect();
Expand All @@ -165,7 +164,8 @@ impl Rewrite for OutlineCfg {
h.move_before_sibling(new_block, outer_entry).unwrap();
}

// 5. Children of new CFG.
// 4. Children of new CFG.
let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
// Entry node must be first
h.move_before_sibling(entry, inner_exit).unwrap();
// And remaining nodes
Expand All @@ -176,7 +176,7 @@ impl Rewrite for OutlineCfg {
}
}

// 6. Exit edges.
// 5. Exit edges.
// Retarget edge from exit_node (that used to target outside) to inner_exit
let exit_port = h
.node_outputs(exit)
Expand Down

0 comments on commit e2213ba

Please sign in to comment.