diff --git a/src/extension/infer.rs b/src/extension/infer.rs index fb037d543..05ae623f9 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -665,14 +665,11 @@ mod test { use std::error::Error; use super::*; - use crate::builder::{ - BuildError, CaseBuilder, ConditionalBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, - DataflowSubContainer, - }; + use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::{ExtensionSet, EMPTY_REG}; use crate::hugr::HugrMut; use crate::hugr::{validate::ValidationError, Hugr, HugrView, NodeType}; - use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle}; + use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; use crate::type_row; use crate::types::{FunctionType, Type}; @@ -1012,46 +1009,91 @@ mod test { Ok(()) } + fn create_with_io( + hugr: &mut Hugr, + parent: Node, + op: impl Into, + ) -> Result<[Node; 3], Box> { + let op: OpType = op.into(); + let input_types = op.signature().input; + let output_types = op.signature().output; + + let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?; + let input = hugr.add_node_with_parent( + node, + NodeType::open_extensions(ops::Input { types: input_types }), + )?; + let output = hugr.add_node_with_parent( + node, + NodeType::open_extensions(ops::Output { + types: output_types, + }), + )?; + Ok([node, input, output]) + } + #[test] fn test_conditional_inference() -> Result<(), Box> { fn build_case( - mut case_builder: CaseBuilder<&mut Hugr>, + hugr: &mut Hugr, + conditional_node: Node, + op: ops::Case, first_ext: ExtensionId, second_ext: ExtensionId, ) -> Result> { - let [w] = case_builder.input_wires_arr(); - let lift1 = case_builder.add_dataflow_node( + let [case, case_in, case_out] = create_with_io(hugr, conditional_node, op)?; + + let lift1 = hugr.add_node_with_parent( + case, NodeType::open_extensions(ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: first_ext, }), - [w], )?; - let [w] = lift1.outputs_arr(); - let lift2 = case_builder.add_dataflow_node( + + let lift2 = hugr.add_node_with_parent( + case, NodeType::open_extensions(ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: second_ext, }), - [w], )?; - let [w] = lift2.outputs_arr(); - let handle = case_builder.finish_with_outputs([w])?; - Ok(handle.node()) + + hugr.connect(case_in, 0, lift1, 0)?; + hugr.connect(lift1, 0, lift2, 0)?; + hugr.connect(lift2, 0, case_out, 0)?; + + Ok(case) } let predicate_inputs = vec![type_row![]; 2]; let rs = ExtensionSet::from_iter(["A".into(), "B".into()]); - let mut conditional_builder = - ConditionalBuilder::new(predicate_inputs, type_row![NAT], type_row![NAT], rs)?; - let case_builder = conditional_builder.case_builder(0)?; - let case0_node = build_case(case_builder, "A".into(), "B".into())?; - let case_builder = conditional_builder.case_builder(1)?; - let case1_node = build_case(case_builder, "B".into(), "A".into())?; + let inputs = type_row![NAT]; + let outputs = type_row![NAT]; + + let op = ops::Conditional { + predicate_inputs, + other_inputs: inputs.clone(), + outputs: outputs.clone(), + extension_delta: rs.clone(), + }; + + let mut hugr = Hugr::new(NodeType::pure(op)); + let conditional_node = hugr.root(); + + let case_op = ops::Case { + signature: FunctionType::new(inputs, outputs).with_extension_delta(&rs), + }; + let case0_node = build_case( + &mut hugr, + conditional_node, + case_op.clone(), + "A".into(), + "B".into(), + )?; - let conditional_node = conditional_builder.container_node(); - let hugr = conditional_builder.hugr_mut(); + let case1_node = build_case(&mut hugr, conditional_node, case_op, "B".into(), "A".into())?; hugr.infer_extensions()?;