Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Insert/extract subgraphs from a HugrView #552

Merged
merged 5 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr(hugr)?.new_root;
let node = self.add_hugr(hugr)?.new_root.unwrap();

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand All @@ -252,7 +252,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr_view(hugr)?.new_root;
let node = self.add_hugr_view(hugr)?.new_root.unwrap();

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand Down
107 changes: 100 additions & 7 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::collections::HashMap;
use std::ops::Range;

use portgraph::view::{NodeFilter, NodeFiltered};
use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap};

use crate::hugr::{Direction, HugrError, HugrView, Node, NodeType};
Expand All @@ -12,6 +13,7 @@ use crate::{Hugr, Port};

use self::sealed::HugrMutInternals;

use super::views::SiblingSubgraph;
use super::{NodeMetadata, PortIndex, Rewrite};

/// Functions for low-level building of a HUGR.
Expand Down Expand Up @@ -158,6 +160,26 @@ pub trait HugrMut: HugrView + HugrMutInternals {
self.hugr_mut().insert_from_view(root, other)
}

/// Copy a subgraph from another hugr into this one, under a given root node.
///
/// Sibling order is not preserved.
///
/// The returned `InsertionResult` does not contain a `new_root` value, since
/// a subgraph may not have a defined root.
//
// TODO: Try to preserve the order when possible? We cannot always ensure
// it, since the subgraph may have arbitrary nodes without including their
// parent.
fn insert_subgraph(
&mut self,
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> Result<InsertionResult, HugrError> {
self.valid_node(root)?;
self.hugr_mut().insert_subgraph(root, other, subgraph)
}

/// Applies a rewrite to the graph.
fn apply_rewrite<R, E>(&mut self, rw: impl Rewrite<ApplyResult = R, Error = E>) -> Result<R, E>
where
Expand All @@ -171,15 +193,21 @@ pub trait HugrMut: HugrView + HugrMutInternals {
/// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]
pub struct InsertionResult {
/// The node, after insertion, that was the root of the inserted Hugr.
/// (That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root]))
pub new_root: Node,
///
/// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root]
///
/// When inserting a subgraph, this value is `None`.
pub new_root: Option<Node>,
/// Map from nodes in the Hugr/view that was inserted, to their new
/// positions in the Hugr into which said was inserted.
pub node_map: HashMap<Node, Node>,
}

impl InsertionResult {
fn translating_indices(new_root: Node, node_map: HashMap<NodeIndex, NodeIndex>) -> Self {
fn translating_indices(
new_root: Option<Node>,
node_map: HashMap<NodeIndex, NodeIndex>,
) -> Self {
Self {
new_root,
node_map: HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into()))),
Expand Down Expand Up @@ -276,10 +304,13 @@ where
let optype = other.op_types.take(node);
self.as_mut().op_types.set(new_node, optype);
let meta = other.metadata.take(node);
self.as_mut().set_metadata(node.into(), meta).unwrap();
self.as_mut().set_metadata(new_node.into(), meta).unwrap();
}
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(other_root, node_map))
Ok(InsertionResult::translating_indices(
Some(other_root),
node_map,
))
}

fn insert_from_view(
Expand All @@ -294,11 +325,40 @@ where
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.get_metadata(node.into());
self.as_mut()
.set_metadata(node.into(), meta.clone())
.set_metadata(new_node.into(), meta.clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done in an earlier PR, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no, that was outline_cfg. but same error 😲

Copy link
Collaborator Author

@aborgna-q aborgna-q Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, seems like an easy to miss bug :/
And we never do anything with the metadata, so the error is never triggered.

.unwrap();
}
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(other_root, node_map))
Ok(InsertionResult::translating_indices(
Some(other_root),
node_map,
))
}

fn insert_subgraph(
&mut self,
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> Result<InsertionResult, HugrError> {
// Create a portgraph view with the explicit list of nodes defined by the subgraph.
let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> =
NodeFiltered::new_node_filtered(
other.portgraph(),
|node, ctx| ctx.contains(&node.into()),
subgraph.nodes(),
);
let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph)?;
// Update the optypes and metadata, copying them from the other graph.
for (&node, &new_node) in node_map.iter() {
let nodetype = other.get_nodetype(node.into());
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.get_metadata(node.into());
self.as_mut()
.set_metadata(new_node.into(), meta.clone())
.unwrap();
}
Ok(InsertionResult::translating_indices(None, node_map))
}
}

Expand Down Expand Up @@ -341,6 +401,39 @@ fn insert_hugr_internal(
Ok((other_root.into(), node_map))
}

/// Internal implementation of the `insert_subgraph` method for AsMut<Hugr>.
///
/// Returns a mapping from the nodes in the inserted graph to their new indices
/// in `hugr`.
///
/// This function does not update the optypes of the inserted nodes, so the
/// caller must do that.
///
/// In contrast to `insert_hugr_internal`, this function does not preserve
/// sibling order in the hierarchy. This is due to the subgraph not necessarily
/// having a single root, so the logic for reconstructing the hierarchy is not
/// able to just do a BFS.
fn insert_subgraph_internal(
hugr: &mut Hugr,
root: Node,
other: &impl HugrView,
portgraph: &impl portgraph::LinkView,
) -> Result<HashMap<NodeIndex, NodeIndex>, HugrError> {
let node_map = hugr.graph.insert_graph(&portgraph)?;

// A map for nodes that we inserted before their parent, so we couldn't
// update the hierarchy with their new id.
for (&node, &new_node) in node_map.iter() {
let new_parent = other
.get_parent(node.into())
.and_then(|parent| node_map.get(&parent.index).copied())
.unwrap_or(root.index);
hugr.hierarchy.push_child(new_node, new_parent)?;
}

Ok(node_map)
}

pub(crate) mod sealed {
use super::*;

Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl Rewrite for OutlineCfg {
.insert_hugr(outer_cfg, new_block_bldr.hugr().clone())
.unwrap();
(
ins_res.new_root,
ins_res.new_root.unwrap(),
*ins_res.node_map.get(&cfg.node()).unwrap(),
)
};
Expand Down
76 changes: 75 additions & 1 deletion src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ use itertools::Itertools;
use portgraph::{view::Subgraph, Direction, PortView};
use thiserror::Error;

use crate::builder::{Dataflow, DataflowHugr, FunctionBuilder};
use crate::extension::{ExtensionSet, PRELUDE_REGISTRY};
use crate::hugr::{HugrError, HugrMut};
use crate::types::Signature;
use crate::{
ops::{
handle::{ContainerHandle, DataflowOpID},
Expand Down Expand Up @@ -122,7 +126,7 @@ impl SiblingSubgraph {
/// ## Definition
///
/// More formally, the sibling subgraph of a graph $G = (V, E)$ given
/// by sets of incoming and outoing boundary edges $B_I, B_O \subseteq E$
/// by sets of incoming and outgoing boundary edges $B_I, B_O \subseteq E$
/// is the graph given by the connected components of the graph
/// $G' = (V, E \ B_I \ B_O)$ that contain at least one node that is either
/// - the target of an incoming boundary edge, or
Expand Down Expand Up @@ -281,6 +285,16 @@ impl SiblingSubgraph {
self.nodes.len()
}

/// Returns the computed [`IncomingPorts`] of the subgraph.
pub fn incoming_ports(&self) -> &IncomingPorts {
&self.inputs
}

/// Returns the computed [`OutgoingPorts`] of the subgraph.
pub fn outgoing_ports(&self) -> &OutgoingPorts {
&self.outputs
}

/// The signature of the subgraph.
pub fn signature(&self, hugr: &impl HugrView) -> FunctionType {
let input = self
Expand Down Expand Up @@ -386,6 +400,51 @@ impl SiblingSubgraph {
nu_out,
))
}

/// Create a new Hugr containing only the subgraph.
///
/// The new Hugr will contain a function root wth the same signature as the
/// subgraph and the specified `input_extensions`.
pub fn extract_subgraph(
&self,
hugr: &impl HugrView,
name: impl Into<String>,
input_extensions: ExtensionSet,
) -> Result<Hugr, HugrError> {
let signature = Signature {
signature: self.signature(hugr),
input_extensions,
};
let builder = FunctionBuilder::new(name, signature).unwrap();
let inputs = builder.input_wires();
let mut extracted = builder
.finish_hugr_with_outputs(inputs, &PRELUDE_REGISTRY)
.unwrap();
let node_map = extracted
.insert_subgraph(extracted.root(), hugr, self)?
.node_map;

// Disconnect the input and output nodes, and connect the inserted nodes
// in-between.
let [inp, out] = extracted.get_io(extracted.root()).unwrap();
for (inp_port, repl_ports) in extracted
.node_ports(inp, Direction::Outgoing)
.zip(self.inputs.iter())
{
extracted.disconnect(inp, inp_port)?;
for (repl_node, repl_port) in repl_ports {
extracted.connect(inp, inp_port, node_map[repl_node], *repl_port)?;
}
}
for (out_port, (repl_node, repl_port)) in extracted
.node_ports(out, Direction::Incoming)
.zip(self.outputs.iter())
{
extracted.connect(node_map[repl_node], *repl_port, out, out_port)?;
}

Ok(extracted)
}
}

/// Precompute convexity information for a HUGR.
Expand Down Expand Up @@ -590,6 +649,8 @@ pub enum InvalidSubgraph {

#[cfg(test)]
mod tests {
use std::error::Error;

use crate::{
builder::{
BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
Expand Down Expand Up @@ -821,4 +882,17 @@ mod tests {
};
assert_eq!(func_defn.signature, func.signature(&func_graph))
}

#[test]
fn extract_subgraph() -> Result<(), Box<dyn Error>> {
let (hugr, func_root) = build_hugr().unwrap();
let func_graph: SiblingGraph<'_, FuncID<true>> =
SiblingGraph::try_new(&hugr, func_root).unwrap();
let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
let extracted = subgraph.extract_subgraph(&hugr, "region", ExtensionSet::new())?;

extracted.validate(&PRELUDE_REGISTRY).unwrap();

Ok(())
}
}