diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 50cc91c3c..144a9d5d7 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -210,9 +210,8 @@ pub trait Dataflow: Container { let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); let node = self.add_hugr(hugr)?; - let [inp, _] = self.io(); let inputs = input_wires.into_iter().collect(); - wire_up_inputs(inputs, node, self, inp)?; + wire_up_inputs(inputs, node, self)?; Ok((node, num_outputs).into()) } @@ -232,9 +231,8 @@ pub trait Dataflow: Container { let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); let node = self.add_hugr_view(hugr)?; - let [inp, _] = self.io(); let inputs = input_wires.into_iter().collect(); - wire_up_inputs(inputs, node, self, inp)?; + wire_up_inputs(inputs, node, self)?; Ok((node, num_outputs).into()) } @@ -248,8 +246,8 @@ pub trait Dataflow: Container { &mut self, output_wires: impl IntoIterator, ) -> Result<(), BuildError> { - let [inp, out] = self.io(); - wire_up_inputs(output_wires.into_iter().collect_vec(), out, self, inp) + let [_, out] = self.io(); + wire_up_inputs(output_wires.into_iter().collect_vec(), out, self) } /// Return an array of the input wires. @@ -605,12 +603,10 @@ fn add_node_with_wires( nodetype: NodeType, inputs: Vec, ) -> Result<(Node, usize), BuildError> { - let [inp, _] = data_builder.io(); - let op_node = data_builder.add_child_node(nodetype.clone())?; let sig = nodetype.op_signature(); - wire_up_inputs(inputs, op_node, data_builder, inp)?; + wire_up_inputs(inputs, op_node, data_builder)?; Ok((op_node, sig.output().len())) } @@ -619,11 +615,9 @@ fn wire_up_inputs( inputs: Vec, op_node: Node, data_builder: &mut T, - inp: Node, ) -> Result<(), BuildError> { - let mut any_local_df_inputs = false; for (dst_port, wire) in inputs.into_iter().enumerate() { - any_local_df_inputs |= wire_up( + wire_up( data_builder, wire.node(), wire.source().index(), @@ -631,14 +625,6 @@ fn wire_up_inputs( dst_port, )?; } - let base = data_builder.hugr_mut(); - let op = base.get_optype(op_node); - let some_df_outputs = !op.signature().output.is_empty(); - if !any_local_df_inputs && some_df_outputs { - // If op has no inputs add a StateOrder edge from input to place in - // causal cone of Input node - data_builder.add_other_wire(inp, op_node)?; - }; Ok(()) } diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index c0e2226c1..c95d145b3 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -5,7 +5,7 @@ use std::iter; use itertools::Itertools; use petgraph::algo::dominators::{self, Dominators}; -use petgraph::visit::{DfsPostOrder, Walker}; +use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; @@ -341,9 +341,7 @@ impl<'a> ValidationContext<'a> { Ok(()) } - /// Ensure that the children of a node form a direct acyclic graph with a - /// single source and source. That is, their edges do not form cycles in the - /// graph and there are no dangling nodes. + /// Ensure that the children of a node form a directed acyclic graph. /// /// Inter-graph edges are ignored. Only internal dataflow, constant, or /// state order edges are considered. @@ -354,18 +352,11 @@ impl<'a> ValidationContext<'a> { }; let region: SiblingGraph = SiblingGraph::new(self.hugr, parent); - let entry_node = self.hugr.children(parent).next().unwrap(); - - let postorder = DfsPostOrder::new(®ion, entry_node); + let postorder = Topo::new(®ion); let nodes_visited = postorder.iter(®ion).filter(|n| *n != parent).count(); - // Local ScopedDefn's should not be reachable from the Input node, so discount them - let non_defn_count = self - .hugr - .children(parent) - .filter(|n| !OpTag::ScopedDefn.is_superset(self.hugr.get_optype(*n).tag())) - .count(); - if nodes_visited != non_defn_count { - return Err(ValidationError::NotABoundedDag { + let node_count = self.hugr.children(parent).count(); + if nodes_visited != node_count { + return Err(ValidationError::NotADag { node: parent, optype: op_type.clone(), }); @@ -603,9 +594,9 @@ pub enum ValidationError { /// The node must have children, but has none. #[error("The node {node:?} with optype {optype:?} must have children, but has none.")] ContainerWithoutChildren { node: Node, optype: OpType }, - /// The children of a node do not form a dag with single source and sink. - #[error("The children of an operation {optype:?} must form a dag with single source and sink. Loops are not allowed, nor are dangling nodes not in the path between the input and output. In node {node:?}.")] - NotABoundedDag { node: Node, optype: OpType }, + /// The children of a node do not form a DAG. + #[error("The children of an operation {optype:?} must form a DAG. Loops are not allowed. In node {node:?}.")] + NotADag { node: Node, optype: OpType }, /// There are invalid inter-graph edges. #[error(transparent)] InterGraphEdgeError(#[from] InterGraphEdgeError), @@ -699,7 +690,7 @@ mod test { use crate::ops::dataflow::IOTrait; use crate::ops::{self, LeafOp, OpType}; use crate::std_extensions::logic; - use crate::std_extensions::logic::test::and_op; + use crate::std_extensions::logic::test::{and_op, not_op}; use crate::types::{FunctionType, Type}; use crate::Direction; use crate::{type_row, Node}; @@ -1096,10 +1087,7 @@ mod test { let lcst = h.add_op_with_parent(h.root(), ops::LoadConstant { datatype: BOOL_T })?; h.connect(cst, 0, lcst, 0)?; h.connect(lcst, 0, and, 1)?; - // We are missing the edge from Input to LoadConstant, hence: - assert_matches!(h.validate(), Err(ValidationError::NotABoundedDag { .. })); - // Now include the LoadConstant node in the causal cone - h.add_other_edge(input, lcst)?; + // There is no edge from Input to LoadConstant, but that's OK: h.validate().unwrap(); Ok(()) } @@ -1256,4 +1244,24 @@ mod test { ); Ok(()) } + + #[test] + fn dfg_with_cycles() -> Result<(), HugrError> { + let mut h = Hugr::new(NodeType::pure(ops::DFG { + signature: FunctionType::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]), + })); + let input = h.add_op_with_parent(h.root(), ops::Input::new(type_row![BOOL_T, BOOL_T]))?; + let output = h.add_op_with_parent(h.root(), ops::Output::new(type_row![BOOL_T]))?; + let and = h.add_op_with_parent(h.root(), and_op())?; + let not1 = h.add_op_with_parent(h.root(), not_op())?; + let not2 = h.add_op_with_parent(h.root(), not_op())?; + h.connect(input, 0, and, 0)?; + h.connect(and, 0, not1, 0)?; + h.connect(not1, 0, and, 1)?; + h.connect(input, 1, not2, 0)?; + h.connect(not2, 0, output, 0)?; + // The graph contains a cycle: + assert_matches!(h.validate(), Err(ValidationError::NotADag { .. })); + Ok(()) + } } diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 565bca4f4..178c9504d 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -110,7 +110,7 @@ lazy_static! { pub(crate) mod test { use crate::{extension::prelude::BOOL_T, ops::LeafOp, types::type_param::TypeArg, Extension}; - use super::{extension, AND_NAME, EXTENSION, FALSE_NAME, TRUE_NAME}; + use super::{extension, AND_NAME, EXTENSION, FALSE_NAME, NOT_NAME, TRUE_NAME}; #[test] fn test_logic_extension() { @@ -131,11 +131,19 @@ pub(crate) mod test { } } - /// Generate a logic extension and operation over [`crate::prelude::BOOL_T`] + /// Generate a logic extension and "and" operation over [`crate::prelude::BOOL_T`] pub(crate) fn and_op() -> LeafOp { EXTENSION .instantiate_extension_op(AND_NAME, [TypeArg::BoundedNat(2)]) .unwrap() .into() } + + /// Generate a logic extension and "not" operation over [`crate::prelude::BOOL_T`] + pub(crate) fn not_op() -> LeafOp { + EXTENSION + .instantiate_extension_op(NOT_NAME, []) + .unwrap() + .into() + } }