Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Make unification logic less strict #538

Merged
merged 5 commits into from
Sep 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 75 additions & 57 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ impl UnificationContext {
Constraint::Equal(other_meta) => {
self.eq_graph.register_eq(meta, *other_meta);
}
// N.B. If `meta` is already solved, we can't use that
// information to solve `other_meta`. This is because the Plus
// constraint only signifies a preorder.
// I.e. if meta = other_meta + 'R', it's still possible that the
// solution is meta = other_meta because we could be adding 'R'
// to a set which already contained it.
Constraint::Plus(r, other_meta) => {
if let Some(rs) = self.get_solution(other_meta) {
let mut rrs = rs.clone();
Expand All @@ -516,10 +522,6 @@ impl UnificationContext {
solved = true;
}
};
} else if let Some(superset) = self.get_solution(&meta) {
let subset = ExtensionSet::singleton(r).missing_from(superset);
self.add_solution(self.resolve(*other_meta), subset);
solved = true;
};
}
}
Expand Down Expand Up @@ -670,14 +672,15 @@ mod test {
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::{ExtensionSet, EMPTY_REG};
use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait};
use crate::type_row;
use crate::types::{FunctionType, Type};

use cool_asserts::assert_matches;
use itertools::Itertools;
use portgraph::NodeIndex;

const NAT: Type = crate::extension::prelude::USIZE_T;
Expand Down Expand Up @@ -939,58 +942,6 @@ mod test {
Ok(())
}

#[test]
fn minus_test() -> Result<(), Box<dyn Error>> {
let const_true = ops::Const::true_val();
const BOOLEAN: Type = Type::new_simple_predicate(2);
let just_bool = type_row![BOOLEAN];

let abc = ExtensionSet::from_iter([A, B, C]);

// Parent graph is closed
let mut hugr = closed_dfg_root_hugr(
FunctionType::new(type_row![], just_bool.clone()).with_extension_delta(&abc),
);

let [_, output] = hugr.get_io(hugr.root()).unwrap();

let root = hugr.root();
let [child, _, ochild] = create_with_io(
&mut hugr,
root,
ops::DFG {
signature: FunctionType::new(type_row![], just_bool.clone())
.with_extension_delta(&abc),
},
)?;

let const_node = hugr.add_node_with_parent(child, NodeType::open_extensions(const_true))?;
let lift_node = hugr.add_node_with_parent(
child,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: just_bool,
new_extension: C,
}),
)?;

hugr.connect(const_node, 0, lift_node, 0)?;
hugr.connect(lift_node, 0, ochild, 0)?;
hugr.connect(child, 0, output, 0)?;

hugr.infer_extensions()?;

// The solution for the const node should be {A, B}!
assert_eq!(
hugr.get_nodetype(const_node)
.signature()
.unwrap()
.output_extensions(),
ExtensionSet::from_iter([A, B])
);

Ok(())
}

fn create_with_io(
hugr: &mut Hugr,
parent: Node,
Expand Down Expand Up @@ -1091,4 +1042,71 @@ mod test {
}
Ok(())
}

#[test]
fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::from_iter([A, B])),
}));

let root = hugr.root();
let input = hugr.add_node_with_parent(
root,
NodeType::open_extensions(ops::Input {
types: type_row![NAT],
}),
)?;
let output = hugr.add_node_with_parent(
root,
NodeType::open_extensions(ops::Output {
types: type_row![NAT],
}),
)?;

// Make identical dataflow nodes which add extension requirement "A" or "B"
let df_nodes: Vec<Node> = vec![A, A, B, B, A, B]
.into_iter()
.map(|ext| {
let [node, input, output] = create_with_io(
&mut hugr,
root,
ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::singleton(&ext)),
},
)
.unwrap();

let lift = hugr
.add_node_with_parent(
node,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: ext,
}),
)
.unwrap();

hugr.connect(input, 0, lift, 0).unwrap();
hugr.connect(lift, 0, output, 0).unwrap();

node
})
.collect();

// Connect nodes in order (0 -> 1 -> 2 ...)
let nodes = [vec![input], df_nodes, vec![output]].concat();
for (src, tgt) in nodes.into_iter().tuple_windows() {
hugr.connect(src, 0, tgt, 0)?;
}

hugr.infer_and_validate(&PRELUDE_REGISTRY)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - so without the fix, this line would break, yes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case LGTM ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!


Ok(())
}
}