Skip to content

Commit

Permalink
Merge branch 'main' into fix/identity-parent
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 authored Aug 31, 2023
2 parents c25390f + fa55478 commit 4177593
Show file tree
Hide file tree
Showing 29 changed files with 531 additions and 209 deletions.
11 changes: 6 additions & 5 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,14 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
pub(crate) mod test {
use super::*;
use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder};
use crate::extension::prelude::USIZE_T;

use crate::hugr::views::{HierarchyView, SiblingGraph};
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::Const;
use crate::types::test::EQ_T;
use crate::types::Type;
use crate::{type_row, Hugr};
const NAT: Type = EQ_T;
const NAT: Type = USIZE_T;

pub fn group_by<E: Eq + Hash + Ord, V: Eq + Hash>(h: HashMap<E, V>) -> HashSet<Vec<E>> {
let mut res = HashMap::new();
Expand Down Expand Up @@ -442,7 +443,7 @@ pub(crate) mod test {
let exit = cfg_builder.exit_block();
cfg_builder.branch(&tail, 0, &exit)?;

let h = cfg_builder.finish_hugr()?;
let h = cfg_builder.finish_prelude_hugr()?;

let (entry, exit) = (entry.node(), exit.node());
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
Expand Down Expand Up @@ -697,7 +698,7 @@ pub(crate) mod test {
let exit = cfg_builder.exit_block();
cfg_builder.branch(&tail, 0, &exit)?;

let h = cfg_builder.finish_hugr()?;
let h = cfg_builder.finish_prelude_hugr()?;
Ok((h, merge, tail))
}

Expand Down Expand Up @@ -734,7 +735,7 @@ pub(crate) mod test {
cfg_builder.branch(&entry, 0, &head)?;
cfg_builder.branch(&tail, 0, &exit)?;

let h = cfg_builder.finish_hugr()?;
let h = cfg_builder.finish_prelude_hugr()?;
Ok((h, head, tail))
}
}
6 changes: 3 additions & 3 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub(crate) mod test {
use super::{DataflowSubContainer, HugrBuilder};

pub(super) const NAT: Type = crate::extension::prelude::USIZE_T;
pub(super) const BIT: Type = crate::extension::prelude::USIZE_T;
pub(super) const BIT: Type = crate::extension::prelude::BOOL_T;
pub(super) const QB: Type = crate::extension::prelude::QB_T;

/// Wire up inputs of a Dataflow container to the outputs.
Expand All @@ -120,14 +120,14 @@ pub(crate) mod test {
let f_builder = module_builder.define_function("main", signature)?;

f(f_builder)?;
Ok(module_builder.finish_hugr()?)
Ok(module_builder.finish_prelude_hugr()?)
}

#[fixture]
pub(crate) fn simple_dfg_hugr() -> Hugr {
let dfg_builder =
DFGBuilder::new(FunctionType::new(type_row![BIT], type_row![BIT])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
dfg_builder.finish_hugr_with_outputs([i1]).unwrap()
dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap()
}
}
40 changes: 33 additions & 7 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
types::EdgeKind,
};

use crate::extension::ExtensionSet;
use crate::extension::{prelude_registry, ExtensionRegistry, ExtensionSet};
use crate::types::{FunctionType, Signature, Type, TypeRow};

use itertools::Itertools;
Expand Down Expand Up @@ -128,7 +128,18 @@ pub trait Container {
/// (with varying root node types)
pub trait HugrBuilder: Container {
/// Finish building the HUGR, perform any validation checks and return it.
fn finish_hugr(self) -> Result<Hugr, ValidationError>;
fn finish_hugr(self, extension_registry: &ExtensionRegistry) -> Result<Hugr, ValidationError>;

/// Finish building the HUGR (as [HugrBuilder::finish_hugr]),
/// validating against the [prelude] extension only
///
/// [prelude]: crate::extension::prelude
fn finish_prelude_hugr(self) -> Result<Hugr, ValidationError>
where
Self: Sized,
{
self.finish_hugr(&prelude_registry())
}
}

/// Types implementing this trait build a container graph region by borrowing a HUGR
Expand Down Expand Up @@ -282,8 +293,7 @@ pub trait Dataflow: Container {
signature: signature.clone(),
};
let nodetype = match &input_extensions {
// TODO: Make this NodeType::open_extensions
None => NodeType::pure(op),
None => NodeType::open_extensions(op),
Some(rs) => NodeType::new(op, rs.clone()),
};
let (dfg_n, _) = add_node_with_wires(self, nodetype, input_wires.into_iter().collect())?;
Expand Down Expand Up @@ -701,19 +711,35 @@ fn wire_up<T: Dataflow + ?Sized>(

/// Trait implemented by builders of Dataflow Hugrs
pub trait DataflowHugr: HugrBuilder + Dataflow {
/// Set outputs of dataflow HUGR and return HUGR
/// Set outputs of dataflow HUGR and return validated HUGR
/// # Errors
///
/// This function will return an error if there is an error when setting outputs.
/// * if there is an error when setting outputs
/// * if the Hugr does not validate
fn finish_hugr_with_outputs(
mut self,
outputs: impl IntoIterator<Item = Wire>,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, BuildError>
where
Self: Sized,
{
self.set_outputs(outputs)?;
Ok(self.finish_hugr()?)
Ok(self.finish_hugr(extension_registry)?)
}

/// Sets the outputs of a dataflow Hugr, validates against
/// the [prelude] extension only, and return the Hugr
///
/// [prelude]: crate::extension::prelude
fn finish_prelude_hugr_with_outputs(
self,
outputs: impl IntoIterator<Item = Wire>,
) -> Result<Hugr, BuildError>
where
Self: Sized,
{
self.finish_hugr_with_outputs(outputs, &prelude_registry())
}
}

Expand Down
18 changes: 12 additions & 6 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{
};

use crate::ops::{self, BasicBlock, OpType};
use crate::types::FunctionType;
use crate::{extension::ExtensionRegistry, types::FunctionType};
use crate::{hugr::views::HugrView, types::TypeRow};
use crate::{ops::handle::NodeHandle, types::Type};

Expand Down Expand Up @@ -70,8 +70,11 @@ impl CFGBuilder<Hugr> {
}

impl HugrBuilder for CFGBuilder<Hugr> {
fn finish_hugr(self) -> Result<Hugr, crate::hugr::ValidationError> {
self.base.validate()?;
fn finish_hugr(
mut self,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, crate::hugr::ValidationError> {
self.base.infer_and_validate(extension_registry)?;
Ok(self.base)
}
}
Expand Down Expand Up @@ -287,16 +290,19 @@ impl BlockBuilder<Hugr> {
mut self,
branch_wire: Wire,
outputs: impl IntoIterator<Item = Wire>,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, BuildError> {
self.set_outputs(branch_wire, outputs)?;
self.finish_hugr().map_err(BuildError::InvalidHUGR)
self.finish_hugr(extension_registry)
.map_err(BuildError::InvalidHUGR)
}
}

#[cfg(test)]
mod test {
use crate::builder::build_traits::HugrBuilder;
use crate::builder::{DataflowSubContainer, ModuleBuilder};

use crate::{builder::test::NAT, type_row};
use cool_asserts::assert_matches;

Expand All @@ -320,7 +326,7 @@ mod test {

func_builder.finish_with_outputs(cfg_id.outputs())?
};
module_builder.finish_hugr()
module_builder.finish_prelude_hugr()
};

assert_eq!(build_result.err(), None);
Expand All @@ -331,7 +337,7 @@ mod test {
fn basic_cfg_hugr() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
build_basic_cfg(&mut cfg_builder)?;
assert_matches!(cfg_builder.finish_hugr(), Ok(_));
assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_));

Ok(())
}
Expand Down
11 changes: 7 additions & 4 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::extension::ExtensionRegistry;
use crate::hugr::views::HugrView;
use crate::types::{FunctionType, TypeRow};

Expand Down Expand Up @@ -146,9 +147,11 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
}

impl HugrBuilder for ConditionalBuilder<Hugr> {
fn finish_hugr(mut self) -> Result<Hugr, crate::hugr::ValidationError> {
self.base.infer_extensions()?;
self.base.validate()?;
fn finish_hugr(
mut self,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, crate::hugr::ValidationError> {
self.base.infer_and_validate(extension_registry)?;
Ok(self.base)
}
}
Expand Down Expand Up @@ -265,7 +268,7 @@ mod test {
let [int] = conditional_id.outputs_arr();
fbuild.finish_with_outputs([int])?
};
Ok(module_builder.finish_hugr()?)
Ok(module_builder.finish_prelude_hugr()?)
};

assert_matches!(build_result, Ok(_));
Expand Down
34 changes: 18 additions & 16 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::ops;

use crate::types::{FunctionType, Signature};

use crate::extension::ExtensionSet;
use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::Node;
use crate::{hugr::HugrMut, Hugr};

Expand Down Expand Up @@ -52,16 +52,15 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
base.as_mut().add_node_with_parent(
parent,
match &input_extensions {
// TODO: Make this NodeType::open_extensions
None => NodeType::pure(input),
None => NodeType::open_extensions(input),
Some(rs) => NodeType::new(input, rs.clone()),
},
)?;
base.as_mut().add_node_with_parent(
parent,
match input_extensions.map(|inp| inp.union(&signature.extension_reqs)) {
// TODO: Make this NodeType::open_extensions
None => NodeType::new(output, signature.extension_reqs),
None => NodeType::open_extensions(output),
Some(rs) => NodeType::new(output, rs),
},
)?;
Expand Down Expand Up @@ -96,9 +95,11 @@ impl DFGBuilder<Hugr> {
}

impl HugrBuilder for DFGBuilder<Hugr> {
fn finish_hugr(mut self) -> Result<Hugr, ValidationError> {
let closure = self.base.infer_extensions()?;
self.base.validate_with_extension_closure(closure)?;
fn finish_hugr(
mut self,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, ValidationError> {
self.base.infer_and_validate(extension_registry)?;
Ok(self.base)
}
}
Expand Down Expand Up @@ -207,8 +208,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>, T: From<BuildHandle<DfgID>>> SubContainer for
}

impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
fn finish_hugr(self) -> Result<Hugr, ValidationError> {
self.0.finish_hugr()
fn finish_hugr(self, extension_registry: &ExtensionRegistry) -> Result<Hugr, ValidationError> {
self.0.finish_hugr(extension_registry)
}
}

Expand All @@ -221,6 +222,7 @@ pub(crate) mod test {
use crate::builder::build_traits::DataflowHugr;
use crate::builder::{DataflowSubContainer, ModuleBuilder};
use crate::extension::prelude::BOOL_T;
use crate::extension::EMPTY_REG;
use crate::hugr::validate::InterGraphEdgeError;
use crate::ops::{handle::NodeHandle, LeafOp, OpTag};

Expand Down Expand Up @@ -262,7 +264,7 @@ pub(crate) mod test {

func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?
};
module_builder.finish_hugr()
module_builder.finish_prelude_hugr()
};

assert_eq!(build_result.err(), None);
Expand All @@ -285,7 +287,7 @@ pub(crate) mod test {

f(f_build)?;

module_builder.finish_hugr()
module_builder.finish_hugr(&EMPTY_REG)
};
assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);

Expand Down Expand Up @@ -336,7 +338,7 @@ pub(crate) mod test {
let [q1] = f_build.input_wires_arr();
f_build.finish_with_outputs([q1, q1])?;

Ok(module_builder.finish_hugr()?)
Ok(module_builder.finish_prelude_hugr()?)
};

assert_eq!(builder(), Err(BuildError::NoCopyLinear(QB)));
Expand All @@ -361,7 +363,7 @@ pub(crate) mod test {

let nested = nested.finish_with_outputs([id.out_wire(0)])?;

f_build.finish_hugr_with_outputs([nested.out_wire(0)])
f_build.finish_hugr_with_outputs([nested.out_wire(0)], &EMPTY_REG)
};

assert_matches!(builder(), Ok(_));
Expand Down Expand Up @@ -407,7 +409,7 @@ pub(crate) mod test {
let mut dfg_builder = DFGBuilder::new(FunctionType::new(type_row![BIT], type_row![BIT]))?;
let [i1] = dfg_builder.input_wires_arr();
dfg_builder.set_metadata(json!(42));
let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1])?;
let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1], &EMPTY_REG)?;

// Create a module, and insert the DFG into it
let mut module_builder = ModuleBuilder::new();
Expand All @@ -423,7 +425,7 @@ pub(crate) mod test {
f_build.finish_with_outputs([id.out_wire(0)])?;
}

assert_eq!(module_builder.finish_hugr()?.node_count(), 7);
assert_eq!(module_builder.finish_hugr(&EMPTY_REG)?.node_count(), 7);

Ok(())
}
Expand Down Expand Up @@ -494,7 +496,7 @@ pub(crate) mod test {

let add_c = add_c.finish_with_outputs(wires)?;
let [w] = add_c.outputs_arr();
parent.finish_hugr_with_outputs([w])?;
parent.finish_hugr_with_outputs([w], &EMPTY_REG)?;

Ok(())
}
Expand Down
Loading

0 comments on commit 4177593

Please sign in to comment.