Skip to content

Commit

Permalink
Expose HugrMut; add pub trait Buildable to parametrize all Builder cl…
Browse files Browse the repository at this point in the history
…asses
  • Loading branch information
acl-cqc committed Jul 24, 2023
1 parent 51701cc commit 65e02f1
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 138 deletions.
12 changes: 7 additions & 5 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
#[cfg(test)]
pub(crate) mod test {
use super::*;
use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder};
use crate::builder::{
BuildError, Buildable, CFGBuilder, Container, DataflowSubContainer, HugrBuilder,
};
use crate::hugr::region::{FlatRegionView, Region};
use crate::ops::{
handle::{BasicBlockID, ConstID, NodeHandle},
Expand Down Expand Up @@ -577,7 +579,7 @@ pub(crate) mod test {
dataflow_builder.finish_with_outputs([u].into_iter().chain(w))
}

fn build_if_then_else_merge<T: AsMut<Hugr> + AsRef<Hugr>>(
fn build_if_then_else_merge<T: Buildable>(
cfg: &mut CFGBuilder<T>,
const_pred: &ConstID,
unit_const: &ConstID,
Expand All @@ -590,7 +592,7 @@ pub(crate) mod test {
Ok((split, merge))
}

fn build_then_else_merge_from_if<T: AsMut<Hugr> + AsRef<Hugr>>(
fn build_then_else_merge_from_if<T: Buildable>(
cfg: &mut CFGBuilder<T>,
unit_const: &ConstID,
split: BasicBlockID,
Expand All @@ -615,7 +617,7 @@ pub(crate) mod test {
}

// Returns loop tail - caller must link header to tail, and provide 0th successor of tail
fn build_loop_from_header<T: AsMut<Hugr> + AsRef<Hugr>>(
fn build_loop_from_header<T: Buildable>(
cfg: &mut CFGBuilder<T>,
const_pred: &ConstID,
header: BasicBlockID,
Expand All @@ -629,7 +631,7 @@ pub(crate) mod test {
}

// Result is header and tail. Caller must provide 0th successor of header (linking to tail), and 0th successor of tail.
fn build_loop<T: AsMut<Hugr> + AsRef<Hugr>>(
fn build_loop<T: Buildable>(
cfg: &mut CFGBuilder<T>,
const_pred: &ConstID,
unit_const: &ConstID,
Expand Down
2 changes: 1 addition & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use handle::BuildHandle;

mod build_traits;
pub use build_traits::{
Container, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, SubContainer,
Buildable, Container, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, SubContainer,
};

mod dataflow;
Expand Down
77 changes: 56 additions & 21 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,59 @@ use crate::Hugr;

use crate::hugr::HugrMut;

pub trait Buildable: HugrMut {
type BaseMut<'a>: Buildable
where
Self: 'a;
type BaseView<'a>: HugrView
where
Self: 'a;
/// The underlying [`Hugr`] being built
fn hugr_mut(&mut self) -> Self::BaseMut<'_>;
/// Immutable reference to HUGR being built
fn hugr(&self) -> Self::BaseView<'_>;
}

impl Buildable for Hugr {
type BaseMut<'a> = &'a mut Hugr where Self: 'a;

type BaseView<'a> = &'a Hugr where Self: 'a;

fn hugr_mut(&mut self) -> Self::BaseMut<'_> {
self
}

fn hugr(&self) -> Self::BaseView<'_> {
&self
}
}

impl<H: HugrMut + HugrView> Buildable for &mut H {
type BaseMut<'a> = &'a mut H where Self: 'a;

type BaseView<'a> = &'a H where Self: 'a;

fn hugr_mut(&mut self) -> Self::BaseMut<'_> {
self
}

fn hugr(&self) -> Self::BaseView<'_> {
&self
}
}

/// Trait for HUGR container builders.
/// Containers are nodes that are parents of sibling graphs.
/// Implementations of this trait allow the child sibling graph to be added to
/// the HUGR.
pub trait Container {
type BaseMut<'a>: HugrMut where Self: 'a;
type BaseView<'a>: HugrView where Self: 'a;
type Base: Buildable;
/// The container node.
fn container_node(&self) -> Node;
/// The underlying [`Hugr`] being built
fn hugr_mut(&mut self) -> Self::BaseMut<'_>;
fn hugr_mut(&mut self) -> <Self::Base as Buildable>::BaseMut<'_>;
/// Immutable reference to HUGR being built
fn hugr(&self) -> Self::BaseView<'_>;
fn hugr(&self) -> <Self::Base as Buildable>::BaseView<'_>;
/// Add an [`OpType`] as the final child of the container.
fn add_child_op(&mut self, op: impl Into<OpType>) -> Result<Node, BuildError> {
let parent = self.container_node();
Expand Down Expand Up @@ -81,7 +121,7 @@ pub trait Container {
&mut self,
name: impl Into<String>,
signature: Signature,
) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
) -> Result<FunctionBuilder<<Self::Base as Buildable>::BaseMut<'_>>, BuildError> {
let f_node = self.add_child_op(ops::FuncDefn {
name: name.into(),
signature: signature.clone(),
Expand Down Expand Up @@ -254,7 +294,7 @@ pub trait Dataflow: Container {
&mut self,
signature: Signature,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
) -> Result<DFGBuilder<<Self::Base as Buildable>::BaseMut<'_>>, BuildError> {
let (dfg_n, _) = add_op_with_wires(
self,
ops::DFG {
Expand All @@ -280,7 +320,7 @@ pub trait Dataflow: Container {
&mut self,
inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
output_types: TypeRow,
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
) -> Result<CFGBuilder<<Self::Base as Buildable>::BaseMut<'_>>, BuildError> {
let (input_types, input_wires): (Vec<SimpleType>, Vec<Wire>) = inputs.into_iter().unzip();

let inputs: TypeRow = input_types.into();
Expand Down Expand Up @@ -339,7 +379,7 @@ pub trait Dataflow: Container {
just_inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
inputs_outputs: impl IntoIterator<Item = (SimpleType, Wire)>,
just_out_types: TypeRow,
) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
) -> Result<TailLoopBuilder<<Self::Base as Buildable>::BaseMut<'_>>, BuildError> {
let (input_types, mut input_wires): (Vec<SimpleType>, Vec<Wire>) =
just_inputs.into_iter().unzip();
let (rest_types, rest_input_wires): (Vec<SimpleType>, Vec<Wire>) =
Expand Down Expand Up @@ -373,7 +413,7 @@ pub trait Dataflow: Container {
(predicate_inputs, predicate_wire): (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
output_types: TypeRow,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
) -> Result<ConditionalBuilder<<Self::Base as Buildable>::BaseMut<'_>>, BuildError> {
let mut input_wires = vec![predicate_wire];
let (input_types, rest_input_wires): (Vec<SimpleType>, Vec<Wire>) =
other_inputs.into_iter().unzip();
Expand Down Expand Up @@ -528,9 +568,7 @@ pub trait Dataflow: Container {
function: &FuncID<DEFINED>,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let hugr = self.hugr();
let def_op = hugr.get_optype(function.node());
let signature = match def_op {
let signature = match self.hugr().get_optype(function.node()) {
OpType::FuncDefn(ops::FuncDefn { signature, .. })
| OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(),
_ => {
Expand All @@ -542,7 +580,7 @@ pub trait Dataflow: Container {
};
let const_in_port = signature.output.len();
let op_id = self.add_dataflow_op(ops::Call { signature }, input_wires)?;
let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
let src_port = self.hugr().num_outputs(function.node()) - 1;

self.hugr_mut()
.connect(function.node(), src_port, op_id.node(), const_in_port)?;
Expand Down Expand Up @@ -588,9 +626,10 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
dst_port,
)?;
}
let base = data_builder.hugr_mut();
let base = data_builder.hugr();
let op = base.get_optype(op_node);
let some_df_outputs = !op.signature().output.is_empty();
drop(base);
if !any_local_df_inputs && some_df_outputs {
// If op has no inputs add a StateOrder edge from input to place in
// causal cone of Input node
Expand All @@ -607,7 +646,7 @@ fn wire_up<T: Dataflow + ?Sized>(
dst: Node,
dst_port: usize,
) -> Result<bool, BuildError> {
let base = data_builder.hugr_mut();
let mut base = data_builder.hugr_mut();
let src_offset = Port::new_outgoing(src_port);

let src_parent = base.get_parent(src);
Expand Down Expand Up @@ -654,14 +693,10 @@ fn wire_up<T: Dataflow + ?Sized>(
}
}

data_builder
.hugr_mut()
.connect(src, src_port, dst, dst_port)?;
base.connect(src, src_port, dst, dst_port)?;
Ok(local_source
&& matches!(
data_builder
.hugr_mut()
.get_optype(dst)
base.get_optype(dst)
.port_kind(Port::new_incoming(dst_port))
.unwrap(),
EdgeKind::Value(_)
Expand Down
Loading

0 comments on commit 65e02f1

Please sign in to comment.