From 7ff032b4ff73c496cb7b7d4fee87bebc697b6ade Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 19 Sep 2023 11:49:30 +0100 Subject: [PATCH] fix: Make unification logic less strict (#538) The `Plus` constraint has up till now been too strict in extension unification. So far we have been deriving things like the following: ``` a := Plus("C", b) a := {"A", "B", "C"} implies b := {"A", "B"} ``` (as in `minus_test`, which this PR removes) However, plus should be meant as the union of the the singleton set that it specifies with the extension set of another metavariable. Hence in the above example the solution for `b` could be either `{"A", "B"}` or `{"A", "B", "C"}`. This means that if we have chains of circuit operations which all add the "quantum" resource, they can all be represented as `Plus` constraints --- src/extension/infer.rs | 132 +++++++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 57 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 593726a0c..191adf278 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -495,6 +495,12 @@ impl UnificationContext { Constraint::Equal(other_meta) => { self.eq_graph.register_eq(meta, *other_meta); } + // N.B. If `meta` is already solved, we can't use that + // information to solve `other_meta`. This is because the Plus + // constraint only signifies a preorder. + // I.e. if meta = other_meta + 'R', it's still possible that the + // solution is meta = other_meta because we could be adding 'R' + // to a set which already contained it. Constraint::Plus(r, other_meta) => { if let Some(rs) = self.get_solution(other_meta) { let mut rrs = rs.clone(); @@ -516,10 +522,6 @@ impl UnificationContext { solved = true; } }; - } else if let Some(superset) = self.get_solution(&meta) { - let subset = ExtensionSet::singleton(r).missing_from(superset); - self.add_solution(self.resolve(*other_meta), subset); - solved = true; }; } } @@ -670,7 +672,7 @@ mod test { use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}; - use crate::extension::{ExtensionSet, EMPTY_REG}; + use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; @@ -678,6 +680,7 @@ mod test { use crate::types::{FunctionType, Type}; use cool_asserts::assert_matches; + use itertools::Itertools; use portgraph::NodeIndex; const NAT: Type = crate::extension::prelude::USIZE_T; @@ -939,58 +942,6 @@ mod test { Ok(()) } - #[test] - fn minus_test() -> Result<(), Box> { - let const_true = ops::Const::true_val(); - const BOOLEAN: Type = Type::new_simple_predicate(2); - let just_bool = type_row![BOOLEAN]; - - let abc = ExtensionSet::from_iter([A, B, C]); - - // Parent graph is closed - let mut hugr = closed_dfg_root_hugr( - FunctionType::new(type_row![], just_bool.clone()).with_extension_delta(&abc), - ); - - let [_, output] = hugr.get_io(hugr.root()).unwrap(); - - let root = hugr.root(); - let [child, _, ochild] = create_with_io( - &mut hugr, - root, - ops::DFG { - signature: FunctionType::new(type_row![], just_bool.clone()) - .with_extension_delta(&abc), - }, - )?; - - let const_node = hugr.add_node_with_parent(child, NodeType::open_extensions(const_true))?; - let lift_node = hugr.add_node_with_parent( - child, - NodeType::open_extensions(ops::LeafOp::Lift { - type_row: just_bool, - new_extension: C, - }), - )?; - - hugr.connect(const_node, 0, lift_node, 0)?; - hugr.connect(lift_node, 0, ochild, 0)?; - hugr.connect(child, 0, output, 0)?; - - hugr.infer_extensions()?; - - // The solution for the const node should be {A, B}! - assert_eq!( - hugr.get_nodetype(const_node) - .signature() - .unwrap() - .output_extensions(), - ExtensionSet::from_iter([A, B]) - ); - - Ok(()) - } - fn create_with_io( hugr: &mut Hugr, parent: Node, @@ -1091,4 +1042,71 @@ mod test { } Ok(()) } + + #[test] + fn extension_adding_sequence() -> Result<(), Box> { + let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]); + + let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG { + signature: df_sig + .clone() + .with_extension_delta(&ExtensionSet::from_iter([A, B])), + })); + + let root = hugr.root(); + let input = hugr.add_node_with_parent( + root, + NodeType::open_extensions(ops::Input { + types: type_row![NAT], + }), + )?; + let output = hugr.add_node_with_parent( + root, + NodeType::open_extensions(ops::Output { + types: type_row![NAT], + }), + )?; + + // Make identical dataflow nodes which add extension requirement "A" or "B" + let df_nodes: Vec = vec![A, A, B, B, A, B] + .into_iter() + .map(|ext| { + let [node, input, output] = create_with_io( + &mut hugr, + root, + ops::DFG { + signature: df_sig + .clone() + .with_extension_delta(&ExtensionSet::singleton(&ext)), + }, + ) + .unwrap(); + + let lift = hugr + .add_node_with_parent( + node, + NodeType::open_extensions(ops::LeafOp::Lift { + type_row: type_row![NAT], + new_extension: ext, + }), + ) + .unwrap(); + + hugr.connect(input, 0, lift, 0).unwrap(); + hugr.connect(lift, 0, output, 0).unwrap(); + + node + }) + .collect(); + + // Connect nodes in order (0 -> 1 -> 2 ...) + let nodes = [vec![input], df_nodes, vec![output]].concat(); + for (src, tgt) in nodes.into_iter().tuple_windows() { + hugr.connect(src, 0, tgt, 0)?; + } + + hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + + Ok(()) + } }