Skip to content

Commit

Permalink
refactor: Avoid using builder methods in inference test
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Sep 1, 2023
1 parent 8839c09 commit 4f116c9
Showing 1 changed file with 65 additions and 23 deletions.
88 changes: 65 additions & 23 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,14 +665,11 @@ mod test {
use std::error::Error;

use super::*;
use crate::builder::{
BuildError, CaseBuilder, ConditionalBuilder, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer,
};
use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::{ExtensionSet, EMPTY_REG};
use crate::hugr::HugrMut;
use crate::hugr::{validate::ValidationError, Hugr, HugrView, NodeType};
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait};
use crate::type_row;
use crate::types::{FunctionType, Type};

Expand Down Expand Up @@ -1012,46 +1009,91 @@ mod test {
Ok(())
}

fn create_with_io(
hugr: &mut Hugr,
parent: Node,
op: impl Into<OpType>,
) -> Result<[Node; 3], Box<dyn Error>> {
let op: OpType = op.into();
let input_types = op.signature().input;
let output_types = op.signature().output;

let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?;
let input = hugr.add_node_with_parent(
node,
NodeType::open_extensions(ops::Input { types: input_types }),
)?;
let output = hugr.add_node_with_parent(
node,
NodeType::open_extensions(ops::Output {
types: output_types,
}),
)?;
Ok([node, input, output])
}

#[test]
fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
fn build_case(
mut case_builder: CaseBuilder<&mut Hugr>,
hugr: &mut Hugr,
conditional_node: Node,
op: ops::Case,
first_ext: ExtensionId,
second_ext: ExtensionId,
) -> Result<Node, Box<dyn Error>> {
let [w] = case_builder.input_wires_arr();
let lift1 = case_builder.add_dataflow_node(
let [case, case_in, case_out] = create_with_io(hugr, conditional_node, op)?;

let lift1 = hugr.add_node_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: first_ext,
}),
[w],
)?;
let [w] = lift1.outputs_arr();
let lift2 = case_builder.add_dataflow_node(

let lift2 = hugr.add_node_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: second_ext,
}),
[w],
)?;
let [w] = lift2.outputs_arr();
let handle = case_builder.finish_with_outputs([w])?;
Ok(handle.node())

hugr.connect(case_in, 0, lift1, 0)?;
hugr.connect(lift1, 0, lift2, 0)?;
hugr.connect(lift2, 0, case_out, 0)?;

Ok(case)
}

let predicate_inputs = vec![type_row![]; 2];
let rs = ExtensionSet::from_iter(["A".into(), "B".into()]);
let mut conditional_builder =
ConditionalBuilder::new(predicate_inputs, type_row![NAT], type_row![NAT], rs)?;
let case_builder = conditional_builder.case_builder(0)?;
let case0_node = build_case(case_builder, "A".into(), "B".into())?;

let case_builder = conditional_builder.case_builder(1)?;
let case1_node = build_case(case_builder, "B".into(), "A".into())?;
let inputs = type_row![NAT];
let outputs = type_row![NAT];

let op = ops::Conditional {
predicate_inputs,
other_inputs: inputs.clone(),
outputs: outputs.clone(),
extension_delta: rs.clone(),
};

let mut hugr = Hugr::new(NodeType::pure(op));
let conditional_node = hugr.root();

let case_op = ops::Case {
signature: FunctionType::new(inputs, outputs).with_extension_delta(&rs),
};
let case0_node = build_case(
&mut hugr,
conditional_node,
case_op.clone(),
"A".into(),
"B".into(),
)?;

let conditional_node = conditional_builder.container_node();
let hugr = conditional_builder.hugr_mut();
let case1_node = build_case(&mut hugr, conditional_node, case_op, "B".into(), "A".into())?;

hugr.infer_extensions()?;

Expand Down

0 comments on commit 4f116c9

Please sign in to comment.