Skip to content

Commit

Permalink
fix: Make unification logic less strict (#538)
Browse files Browse the repository at this point in the history
The `Plus` constraint has up till now been too strict in extension
unification. So far we have been deriving things like the following:
```
a := Plus("C", b)
a := {"A", "B", "C"}
implies
b := {"A", "B"}
```
(as in `minus_test`, which this PR removes)
However, plus should be meant as the union of the the singleton set that
it specifies with the extension set of another metavariable. Hence in
the above example the solution for `b` could be either `{"A", "B"}` or
`{"A", "B", "C"}`.

This means that if we have chains of circuit operations which all add
the "quantum" resource, they can all be represented as `Plus`
constraints
  • Loading branch information
croyzor authored Sep 19, 2023
1 parent 5b4786f commit 7ff032b
Showing 1 changed file with 75 additions and 57 deletions.
132 changes: 75 additions & 57 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ impl UnificationContext {
Constraint::Equal(other_meta) => {
self.eq_graph.register_eq(meta, *other_meta);
}
// N.B. If `meta` is already solved, we can't use that
// information to solve `other_meta`. This is because the Plus
// constraint only signifies a preorder.
// I.e. if meta = other_meta + 'R', it's still possible that the
// solution is meta = other_meta because we could be adding 'R'
// to a set which already contained it.
Constraint::Plus(r, other_meta) => {
if let Some(rs) = self.get_solution(other_meta) {
let mut rrs = rs.clone();
Expand All @@ -516,10 +522,6 @@ impl UnificationContext {
solved = true;
}
};
} else if let Some(superset) = self.get_solution(&meta) {
let subset = ExtensionSet::singleton(r).missing_from(superset);
self.add_solution(self.resolve(*other_meta), subset);
solved = true;
};
}
}
Expand Down Expand Up @@ -670,14 +672,15 @@ mod test {
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::{ExtensionSet, EMPTY_REG};
use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait};
use crate::type_row;
use crate::types::{FunctionType, Type};

use cool_asserts::assert_matches;
use itertools::Itertools;
use portgraph::NodeIndex;

const NAT: Type = crate::extension::prelude::USIZE_T;
Expand Down Expand Up @@ -939,58 +942,6 @@ mod test {
Ok(())
}

#[test]
fn minus_test() -> Result<(), Box<dyn Error>> {
let const_true = ops::Const::true_val();
const BOOLEAN: Type = Type::new_simple_predicate(2);
let just_bool = type_row![BOOLEAN];

let abc = ExtensionSet::from_iter([A, B, C]);

// Parent graph is closed
let mut hugr = closed_dfg_root_hugr(
FunctionType::new(type_row![], just_bool.clone()).with_extension_delta(&abc),
);

let [_, output] = hugr.get_io(hugr.root()).unwrap();

let root = hugr.root();
let [child, _, ochild] = create_with_io(
&mut hugr,
root,
ops::DFG {
signature: FunctionType::new(type_row![], just_bool.clone())
.with_extension_delta(&abc),
},
)?;

let const_node = hugr.add_node_with_parent(child, NodeType::open_extensions(const_true))?;
let lift_node = hugr.add_node_with_parent(
child,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: just_bool,
new_extension: C,
}),
)?;

hugr.connect(const_node, 0, lift_node, 0)?;
hugr.connect(lift_node, 0, ochild, 0)?;
hugr.connect(child, 0, output, 0)?;

hugr.infer_extensions()?;

// The solution for the const node should be {A, B}!
assert_eq!(
hugr.get_nodetype(const_node)
.signature()
.unwrap()
.output_extensions(),
ExtensionSet::from_iter([A, B])
);

Ok(())
}

fn create_with_io(
hugr: &mut Hugr,
parent: Node,
Expand Down Expand Up @@ -1091,4 +1042,71 @@ mod test {
}
Ok(())
}

#[test]
fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::from_iter([A, B])),
}));

let root = hugr.root();
let input = hugr.add_node_with_parent(
root,
NodeType::open_extensions(ops::Input {
types: type_row![NAT],
}),
)?;
let output = hugr.add_node_with_parent(
root,
NodeType::open_extensions(ops::Output {
types: type_row![NAT],
}),
)?;

// Make identical dataflow nodes which add extension requirement "A" or "B"
let df_nodes: Vec<Node> = vec![A, A, B, B, A, B]
.into_iter()
.map(|ext| {
let [node, input, output] = create_with_io(
&mut hugr,
root,
ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::singleton(&ext)),
},
)
.unwrap();

let lift = hugr
.add_node_with_parent(
node,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: ext,
}),
)
.unwrap();

hugr.connect(input, 0, lift, 0).unwrap();
hugr.connect(lift, 0, output, 0).unwrap();

node
})
.collect();

// Connect nodes in order (0 -> 1 -> 2 ...)
let nodes = [vec![input], df_nodes, vec![output]].concat();
for (src, tgt) in nodes.into_iter().tuple_windows() {
hugr.connect(src, 0, tgt, 0)?;
}

hugr.infer_and_validate(&PRELUDE_REGISTRY)?;

Ok(())
}
}

0 comments on commit 7ff032b

Please sign in to comment.