From 652bba17ea35f07356e4ee1b96a0217dff46bf49 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 21 Aug 2023 16:30:34 +0100 Subject: [PATCH] Remove `static_input` from `AbstractSignature` (#429) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #419 Need to add some logic to report the port for the static input if it exists. Matches existing templated logic for "other_input" etc. Boilerplate methods on Signature that ignore static_input can therefore be removed. --------- Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- specification/hugr.md | 5 +- src/builder/cfg.rs | 4 +- src/builder/circuit.rs | 6 +- src/builder/conditional.rs | 8 +- src/builder/dataflow.rs | 34 +++---- src/builder/module.rs | 8 +- src/builder/tail_loop.rs | 4 +- src/extension/op_def.rs | 2 +- src/extension/type_def.rs | 2 +- src/hugr/hugrmut.rs | 2 +- src/hugr/rewrite/simple_replace.rs | 34 ++----- src/hugr/serialize.rs | 16 +-- src/hugr/validate.rs | 26 ++--- src/hugr/views/hierarchy.rs | 4 +- src/macros.rs | 2 +- src/ops.rs | 36 +++++-- src/ops/constant.rs | 4 +- src/ops/controlflow.rs | 6 +- src/ops/dataflow.rs | 34 +++++-- src/ops/leaf.rs | 10 +- src/ops/validate.rs | 2 +- src/types/signature.rs | 157 ++++++----------------------- 22 files changed, 161 insertions(+), 245 deletions(-) diff --git a/specification/hugr.md b/specification/hugr.md index 66ec2e038..977946016 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -157,9 +157,8 @@ the node). Incoming ports are associated with exactly one edge. All edges associ with a port have the same type; thus a port has a well-defined type, matching that of its adjoining edges. The incoming and outgoing ports of a node are (separately) ordered. -The sequences of incoming and outgoing port types of a node constitute its -_signature_. This signature may include the types of both `Value` and `Static` -edges, with `Static` edges following `Value` edges in the ordering. +The sequences of incoming and outgoing port types (carried on `Value` edges) of a node constitute its +_signature_. Note that the locality is not fixed or even specified by the signature. diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index b00d5db8d..2ab8fe8c6 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -241,7 +241,7 @@ impl + AsRef> BlockBuilder { let predicate_type = Type::new_predicate(predicate_variants); let mut node_outputs = vec![predicate_type]; node_outputs.extend_from_slice(&other_outputs); - let signature = AbstractSignature::new_df(inputs, TypeRow::from(node_outputs)); + let signature = AbstractSignature::new(inputs, TypeRow::from(node_outputs)); let db = DFGBuilder::create_with_io(base, block_n, signature, None)?; Ok(BlockBuilder::from_dfg_builder(db)) } @@ -307,7 +307,7 @@ mod test { let mut module_builder = ModuleBuilder::new(); let mut func_builder = module_builder.define_function( "main", - AbstractSignature::new_df(vec![NAT], type_row![NAT]).pure(), + AbstractSignature::new(vec![NAT], type_row![NAT]).pure(), )?; let _f_id = { let [int] = func_builder.input_wires_arr(); diff --git a/src/builder/circuit.rs b/src/builder/circuit.rs index 47a09e890..e1bb510cb 100644 --- a/src/builder/circuit.rs +++ b/src/builder/circuit.rs @@ -149,7 +149,7 @@ mod test { #[test] fn simple_linear() { let build_res = build_main( - AbstractSignature::new_df(type_row![QB, QB], type_row![QB, QB]).pure(), + AbstractSignature::new(type_row![QB, QB], type_row![QB, QB]).pure(), |mut f_build| { let wires = f_build.input_wires().collect(); @@ -181,12 +181,12 @@ mod test { "MyOp", "unknown op".to_string(), vec![], - Some(AbstractSignature::new(vec![QB, NAT], vec![QB], vec![])), + Some(AbstractSignature::new(vec![QB, NAT], vec![QB])), )) .into(), ); let build_res = build_main( - AbstractSignature::new_df(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).pure(), + AbstractSignature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).pure(), |mut f_build| { let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 0d6550cfc..56ec1a70d 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -117,7 +117,7 @@ impl + AsRef> ConditionalBuilder { let outputs = cond.outputs; let case_op = ops::Case { - signature: AbstractSignature::new_df(inputs.clone(), outputs.clone()), + signature: AbstractSignature::new(inputs.clone(), outputs.clone()), }; let case_node = // add case before any existing subsequent cases @@ -134,7 +134,7 @@ impl + AsRef> ConditionalBuilder { let dfg_builder = DFGBuilder::create_with_io( self.hugr_mut(), case_node, - AbstractSignature::new_df(inputs, outputs), + AbstractSignature::new(inputs, outputs), None, )?; @@ -186,7 +186,7 @@ impl CaseBuilder { pub fn new(input: impl Into, output: impl Into) -> Result { let input = input.into(); let output = output.into(); - let signature = AbstractSignature::new_df(input, output); + let signature = AbstractSignature::new(input, output); let op = ops::Case { signature: signature.clone(), }; @@ -232,7 +232,7 @@ mod test { let mut module_builder = ModuleBuilder::new(); let mut fbuild = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(), + AbstractSignature::new(type_row![NAT], type_row![NAT]).pure(), )?; let tru_const = fbuild.add_constant(Const::true_val())?; let _fdef = { diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index f104c4ed2..5050a5d6a 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -242,7 +242,7 @@ mod test { let _f_id = { let mut func_builder = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![NAT, QB], type_row![NAT, QB]).pure(), + AbstractSignature::new(type_row![NAT, QB], type_row![NAT, QB]).pure(), )?; let [int, qb] = func_builder.input_wires_arr(); @@ -250,7 +250,7 @@ mod test { let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; let inner_builder = func_builder.dfg_builder( - AbstractSignature::new_df(type_row![NAT], type_row![NAT]), + AbstractSignature::new(type_row![NAT], type_row![NAT]), // TODO: This should be None Some(ExtensionSet::new()), [int], @@ -277,7 +277,7 @@ mod test { let f_build = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).pure(), + AbstractSignature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).pure(), )?; f(f_build)?; @@ -327,7 +327,7 @@ mod test { let f_build = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![QB], type_row![QB, QB]).pure(), + AbstractSignature::new(type_row![QB], type_row![QB, QB]).pure(), )?; let [q1] = f_build.input_wires_arr(); @@ -344,7 +344,7 @@ mod test { let builder = || -> Result { let mut f_build = FunctionBuilder::new( "main", - AbstractSignature::new_df(type_row![BIT], type_row![BIT]).pure(), + AbstractSignature::new(type_row![BIT], type_row![BIT]).pure(), )?; let [i1] = f_build.input_wires_arr(); @@ -352,7 +352,7 @@ mod test { let i1 = noop.out_wire(0); let mut nested = f_build.dfg_builder( - AbstractSignature::new_df(type_row![], type_row![BIT]), + AbstractSignature::new(type_row![], type_row![BIT]), None, [], )?; @@ -371,18 +371,15 @@ mod test { fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> { let mut f_build = FunctionBuilder::new( "main", - AbstractSignature::new_df(type_row![QB], type_row![QB]).pure(), + AbstractSignature::new(type_row![QB], type_row![QB]).pure(), )?; let [i1] = f_build.input_wires_arr(); let noop = f_build.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1])?; let i1 = noop.out_wire(0); - let mut nested = f_build.dfg_builder( - AbstractSignature::new_df(type_row![], type_row![QB]), - None, - [], - )?; + let mut nested = + f_build.dfg_builder(AbstractSignature::new(type_row![], type_row![QB]), None, [])?; let id_res = nested.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1]); @@ -400,8 +397,7 @@ mod test { #[test] fn dfg_hugr() -> Result<(), BuildError> { - let dfg_builder = - DFGBuilder::new(AbstractSignature::new_df(type_row![BIT], type_row![BIT]))?; + let dfg_builder = DFGBuilder::new(AbstractSignature::new(type_row![BIT], type_row![BIT]))?; let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1])?; @@ -416,7 +412,7 @@ mod test { fn insert_hugr() -> Result<(), BuildError> { // Create a simple DFG let mut dfg_builder = - DFGBuilder::new(AbstractSignature::new_df(type_row![BIT], type_row![BIT]))?; + DFGBuilder::new(AbstractSignature::new(type_row![BIT], type_row![BIT]))?; let [i1] = dfg_builder.input_wires_arr(); dfg_builder.set_metadata(json!(42)); let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1])?; @@ -427,7 +423,7 @@ mod test { { let mut f_build = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![BIT], type_row![BIT]).pure(), + AbstractSignature::new(type_row![BIT], type_row![BIT]).pure(), )?; let [i1] = f_build.input_wires_arr(); @@ -448,20 +444,20 @@ mod test { let c_extensions = ExtensionSet::singleton(&"C".into()); let abc_extensions = ab_extensions.clone().union(&c_extensions); - let parent_sig = AbstractSignature::new_df(type_row![BIT], type_row![BIT]) + let parent_sig = AbstractSignature::new(type_row![BIT], type_row![BIT]) .with_extension_delta(&abc_extensions); let mut parent = module_builder.define_function( "parent", parent_sig.with_input_extensions(ExtensionSet::new()), )?; - let add_c_sig = AbstractSignature::new_df(type_row![BIT], type_row![BIT]) + let add_c_sig = AbstractSignature::new(type_row![BIT], type_row![BIT]) .with_extension_delta(&c_extensions) .with_input_extensions(ab_extensions.clone()); let [w] = parent.input_wires_arr(); - let add_ab_sig = AbstractSignature::new_df(type_row![BIT], type_row![BIT]) + let add_ab_sig = AbstractSignature::new(type_row![BIT], type_row![BIT]) .with_extension_delta(&ab_extensions); // A box which adds extensions A and B, via child Lift nodes diff --git a/src/builder/module.rs b/src/builder/module.rs index 2a1c62e7b..4693b6278 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -188,7 +188,7 @@ mod test { let f_id = module_builder.declare( "main", - AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(), + AbstractSignature::new(type_row![NAT], type_row![NAT]).pure(), )?; let mut f_build = module_builder.define_declaration(&f_id)?; @@ -211,7 +211,7 @@ mod test { let f_build = module_builder.define_function( "main", - AbstractSignature::new_df( + AbstractSignature::new( vec![qubit_state_type.get_alias_type()], vec![qubit_state_type.get_alias_type()], ) @@ -231,11 +231,11 @@ mod test { let mut f_build = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(), + AbstractSignature::new(type_row![NAT], type_row![NAT]).pure(), )?; let local_build = f_build.define_function( "local", - AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(), + AbstractSignature::new(type_row![NAT], type_row![NAT]).pure(), )?; let [wire] = local_build.input_wires_arr(); let f_id = local_build.finish_with_outputs([wire])?; diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 726649d80..a006a45ff 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -21,7 +21,7 @@ impl + AsRef> TailLoopBuilder { tail_loop: &ops::TailLoop, ) -> Result { let signature = - AbstractSignature::new_df(tail_loop.body_input_row(), tail_loop.body_output_row()); + AbstractSignature::new(tail_loop.body_input_row(), tail_loop.body_output_row()); let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature, None)?; Ok(TailLoopBuilder::from_dfg_builder(dfg_build)) @@ -127,7 +127,7 @@ mod test { let mut module_builder = ModuleBuilder::new(); let mut fbuild = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![BIT], type_row![NAT]).pure(), + AbstractSignature::new(type_row![BIT], type_row![NAT]).pure(), )?; let _fdef = { let [b1] = fbuild.input_wires_arr(); diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index a3c4f0a5c..15095e7fc 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -227,7 +227,7 @@ impl OpDef { // TODO bring this assert back once resource inference is done? // https://github.com/CQCL-DEV/hugr/issues/425 // assert!(res.contains(self.extension())); - Ok(AbstractSignature::new_df(ins, outs).with_extension_delta(&res)) + Ok(AbstractSignature::new(ins, outs).with_extension_delta(&res)) } /// Optional description of the ports in the signature. diff --git a/src/extension/type_def.rs b/src/extension/type_def.rs index 1b8438e39..faab0904b 100644 --- a/src/extension/type_def.rs +++ b/src/extension/type_def.rs @@ -165,7 +165,7 @@ mod test { }; let typ = Type::new_extension( def.instantiate_concrete(vec![TypeArg::Type(Type::new_graph( - AbstractSignature::new_df(vec![], vec![]), + AbstractSignature::new(vec![], vec![]), ))]) .unwrap(), ); diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 112260284..ac856ab71 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -501,7 +501,7 @@ mod test { module, ops::FuncDefn { name: "main".into(), - signature: AbstractSignature::new_df(type_row![NAT], type_row![NAT, NAT]), + signature: AbstractSignature::new(type_row![NAT], type_row![NAT, NAT]), }, ) .expect("Failed to add function definition node"); diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 86dedcec5..ee2ca443b 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -77,18 +77,6 @@ impl Rewrite for SimpleReplacement { .collect::>(); // slice of nodes omitting Input and Output: let replacement_inner_nodes = &replacement_nodes[2..]; - for &node in replacement_inner_nodes { - // Check there are no const inputs. - if !self - .replacement - .get_optype(node) - .signature() - .static_input() - .is_empty() - { - return Err(SimpleReplacementError::InvalidReplacementNode()); - } - } let self_output_node = h.children(self.parent).nth(1).unwrap(); let replacement_output_node = *replacement_nodes.get(1).unwrap(); for &node in replacement_inner_nodes { @@ -237,7 +225,7 @@ mod test { let _f_id = { let mut func_builder = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![QB, QB, QB], type_row![QB, QB, QB]).pure(), + AbstractSignature::new(type_row![QB, QB, QB], type_row![QB, QB, QB]).pure(), )?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); @@ -245,7 +233,7 @@ mod test { let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?; let mut inner_builder = func_builder.dfg_builder( - AbstractSignature::new_df(type_row![QB, QB], type_row![QB, QB]), + AbstractSignature::new(type_row![QB, QB], type_row![QB, QB]), None, [qb0, qb1], )?; @@ -273,10 +261,8 @@ mod test { /// ┤ H ├┤ X ├ /// └───┘└───┘ fn make_dfg_hugr() -> Result { - let mut dfg_builder = DFGBuilder::new(AbstractSignature::new_df( - type_row![QB, QB], - type_row![QB, QB], - ))?; + let mut dfg_builder = + DFGBuilder::new(AbstractSignature::new(type_row![QB, QB], type_row![QB, QB]))?; let [wire0, wire1] = dfg_builder.input_wires_arr(); let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?; let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; @@ -291,10 +277,8 @@ mod test { /// ┤ H ├ /// └───┘ fn make_dfg_hugr2() -> Result { - let mut dfg_builder = DFGBuilder::new(AbstractSignature::new_df( - type_row![QB, QB], - type_row![QB, QB], - ))?; + let mut dfg_builder = + DFGBuilder::new(AbstractSignature::new(type_row![QB, QB], type_row![QB, QB]))?; let [wire0, wire1] = dfg_builder.input_wires_arr(); let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; let wire2out = wire2.outputs().exactly_one().unwrap(); @@ -469,7 +453,7 @@ mod test { #[test] fn test_replace_cx_cross() { let q_row: Vec = vec![QB, QB]; - let mut builder = DFGBuilder::new(AbstractSignature::new_df(q_row.clone(), q_row)).unwrap(); + let mut builder = DFGBuilder::new(AbstractSignature::new(q_row.clone(), q_row)).unwrap(); let mut circ = builder.as_circuit(builder.input_wires().collect()); circ.append(cx_gate(), [0, 1]).unwrap(); circ.append(cx_gate(), [1, 0]).unwrap(); @@ -516,7 +500,7 @@ mod test { let two_bit = type_row![BOOL_T, BOOL_T]; let mut builder = - DFGBuilder::new(AbstractSignature::new_df(one_bit.clone(), one_bit.clone())).unwrap(); + DFGBuilder::new(AbstractSignature::new(one_bit.clone(), one_bit.clone())).unwrap(); let inw = builder.input_wires().exactly_one().unwrap(); let outw = builder .add_dataflow_op(and_op(), [inw, inw]) @@ -525,7 +509,7 @@ mod test { let [input, _] = builder.io(); let mut h = builder.finish_hugr_with_outputs(outw).unwrap(); - let mut builder = DFGBuilder::new(AbstractSignature::new_df(two_bit, one_bit)).unwrap(); + let mut builder = DFGBuilder::new(AbstractSignature::new(two_bit, one_bit)).unwrap(); let inw = builder.input_wires(); let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs(); let [repl_input, repl_output] = builder.io(); diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index f2d7fb7c7..96a1f1a67 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -303,7 +303,7 @@ pub mod test { let outputs = g.num_outputs(node); match (inputs == 0, outputs == 0) { (false, false) => DFG { - signature: AbstractSignature::new_df(vec![NAT; inputs - 1], vec![NAT; outputs - 1]), + signature: AbstractSignature::new(vec![NAT; inputs - 1], vec![NAT; outputs - 1]), } .into(), (true, false) => Input::new(vec![NAT; outputs - 1]).into(), @@ -360,10 +360,7 @@ pub mod test { let t_row = vec![Type::new_sum(vec![NAT, QB])]; let mut f_build = module_builder - .define_function( - "main", - AbstractSignature::new_df(t_row.clone(), t_row).pure(), - ) + .define_function("main", AbstractSignature::new(t_row.clone(), t_row).pure()) .unwrap(); let outputs = f_build @@ -398,10 +395,7 @@ pub mod test { let mut module_builder = ModuleBuilder::new(); let t_row = vec![Type::new_sum(vec![NAT, QB])]; let mut f_build = module_builder - .define_function( - "main", - AbstractSignature::new_df(t_row.clone(), t_row).pure(), - ) + .define_function("main", AbstractSignature::new(t_row.clone(), t_row).pure()) .unwrap(); let outputs = f_build @@ -432,7 +426,7 @@ pub mod test { #[test] fn dfg_roundtrip() -> Result<(), Box> { let tp: Vec = vec![BOOL_T; 2]; - let mut dfg = DFGBuilder::new(AbstractSignature::new_df(tp.clone(), tp))?; + let mut dfg = DFGBuilder::new(AbstractSignature::new(tp.clone(), tp))?; let mut params: [_; 2] = dfg.input_wires_arr(); for p in params.iter_mut() { *p = dfg @@ -459,7 +453,7 @@ pub mod test { #[test] fn hierarchy_order() { - let dfg = DFGBuilder::new(AbstractSignature::new_df(vec![QB], vec![QB])).unwrap(); + let dfg = DFGBuilder::new(AbstractSignature::new(vec![QB], vec![QB])).unwrap(); let [old_in, out] = dfg.io(); let w = dfg.input_wires(); let mut hugr = dfg.finish_hugr_with_outputs(w).unwrap(); diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 4ab4da0ef..8c9406c4d 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -692,7 +692,7 @@ mod test { fn make_simple_hugr(copies: usize) -> (Hugr, Node) { let def_op: OpType = ops::FuncDefn { name: "main".into(), - signature: AbstractSignature::new_df(type_row![BOOL_T], vec![BOOL_T; copies]), + signature: AbstractSignature::new(type_row![BOOL_T], vec![BOOL_T; copies]), } .into(); @@ -836,7 +836,7 @@ mod test { .unwrap(); // Add a definition without children - let def_sig = AbstractSignature::new_df(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]); + let def_sig = AbstractSignature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]); let new_def = b .add_op_with_parent( root, @@ -1008,7 +1008,7 @@ mod test { #[test] fn test_ext_edge() -> Result<(), HugrError> { let mut h = Hugr::new(NodeType::pure(ops::DFG { - signature: AbstractSignature::new_df(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]), + signature: AbstractSignature::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]), })); let input = h.add_op_with_parent(h.root(), ops::Input::new(type_row![BOOL_T, BOOL_T]))?; let output = h.add_op_with_parent(h.root(), ops::Output::new(type_row![BOOL_T]))?; @@ -1050,7 +1050,7 @@ mod test { #[test] fn test_local_const() -> Result<(), HugrError> { let mut h = Hugr::new(NodeType::pure(ops::DFG { - signature: AbstractSignature::new_df(type_row![BOOL_T], type_row![BOOL_T]), + signature: AbstractSignature::new(type_row![BOOL_T], type_row![BOOL_T]), })); let input = h.add_op_with_parent(h.root(), ops::Input::new(type_row![BOOL_T]))?; let output = h.add_op_with_parent(h.root(), ops::Output::new(type_row![BOOL_T]))?; @@ -1091,11 +1091,11 @@ mod test { let mut module_builder = ModuleBuilder::new(); let mut main = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(), + AbstractSignature::new(type_row![NAT], type_row![NAT]).pure(), )?; let [main_input] = main.input_wires_arr(); - let inner_sig = AbstractSignature::new_df(type_row![NAT], type_row![NAT]) + let inner_sig = AbstractSignature::new(type_row![NAT], type_row![NAT]) // Inner DFG has resource requirements that the wire wont satisfy .with_input_extensions(ExtensionSet::from_iter(["A".into(), "BOOL_T".into()])); @@ -1127,12 +1127,12 @@ mod test { fn too_many_resources() -> Result<(), BuildError> { let mut module_builder = ModuleBuilder::new(); - let main_sig = AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(); + let main_sig = AbstractSignature::new(type_row![NAT], type_row![NAT]).pure(); let mut main = module_builder.define_function("main", main_sig)?; let [main_input] = main.input_wires_arr(); - let inner_sig = AbstractSignature::new_df(type_row![NAT], type_row![NAT]) + let inner_sig = AbstractSignature::new(type_row![NAT], type_row![NAT]) .with_extension_delta(&ExtensionSet::singleton(&"A".into())) .with_input_extensions(ExtensionSet::new()); @@ -1165,19 +1165,19 @@ mod test { let all_rs = ExtensionSet::from_iter(["A".into(), "BOOL_T".into()]); - let main_sig = AbstractSignature::new_df(type_row![], type_row![NAT]) + let main_sig = AbstractSignature::new(type_row![], type_row![NAT]) .with_extension_delta(&all_rs) .with_input_extensions(ExtensionSet::new()); let mut main = module_builder.define_function("main", main_sig)?; - let inner_left_sig = AbstractSignature::new_df(type_row![], type_row![NAT]) + let inner_left_sig = AbstractSignature::new(type_row![], type_row![NAT]) .with_input_extensions(ExtensionSet::singleton(&"A".into())); - let inner_right_sig = AbstractSignature::new_df(type_row![], type_row![NAT]) + let inner_right_sig = AbstractSignature::new(type_row![], type_row![NAT]) .with_input_extensions(ExtensionSet::singleton(&"BOOL_T".into())); - let inner_mult_sig = AbstractSignature::new_df(type_row![NAT, NAT], type_row![NAT]) + let inner_mult_sig = AbstractSignature::new(type_row![NAT, NAT], type_row![NAT]) .with_input_extensions(all_rs); let [left_wire] = main @@ -1219,7 +1219,7 @@ mod test { #[test] fn parent_signature_mismatch() -> Result<(), BuildError> { - let main_signature = AbstractSignature::new_df(type_row![NAT], type_row![NAT]) + let main_signature = AbstractSignature::new(type_row![NAT], type_row![NAT]) .with_extension_delta(&ExtensionSet::singleton(&"R".into())); let builder = DFGBuilder::new(main_signature)?; diff --git a/src/hugr/views/hierarchy.rs b/src/hugr/views/hierarchy.rs index 271726116..684df4a2b 100644 --- a/src/hugr/views/hierarchy.rs +++ b/src/hugr/views/hierarchy.rs @@ -493,7 +493,7 @@ mod test { let (f_id, inner_id) = { let mut func_builder = module_builder.define_function( "main", - AbstractSignature::new_df(type_row![NAT, QB], type_row![NAT, QB]).pure(), + AbstractSignature::new(type_row![NAT, QB], type_row![NAT, QB]).pure(), )?; let [int, qb] = func_builder.input_wires_arr(); @@ -502,7 +502,7 @@ mod test { let inner_id = { let inner_builder = func_builder.dfg_builder( - AbstractSignature::new_df(type_row![NAT], type_row![NAT]), + AbstractSignature::new(type_row![NAT], type_row![NAT]), None, [int], )?; diff --git a/src/macros.rs b/src/macros.rs index 8cca4b26c..e9787f34f 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -47,7 +47,7 @@ pub(crate) use impl_box_clone; /// const U: Type = Type::UNIT; /// let static_row: TypeRow = type_row![U, U]; /// let dynamic_row: TypeRow = vec![U, U, U].into(); -/// let sig = AbstractSignature::new_df(static_row, dynamic_row).pure(); +/// let sig = AbstractSignature::new(static_row, dynamic_row).pure(); /// /// let repeated_row: TypeRow = type_row![U; 3]; /// assert_eq!(repeated_row, *sig.output()); diff --git a/src/ops.rs b/src/ops.rs index d9c79d1f9..61ed3a285 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -9,7 +9,7 @@ pub mod leaf; pub mod module; pub mod tag; pub mod validate; -use crate::types::{AbstractSignature, EdgeKind, SignatureDescription}; +use crate::types::{AbstractSignature, EdgeKind, SignatureDescription, Type}; use crate::{Direction, Port}; use portgraph::NodeIndex; @@ -76,9 +76,17 @@ impl OpType { let signature = self.signature(); let port = port.into(); let dir = port.direction(); - match port.index() < signature.port_count(dir) { - true => signature.get(port), - false => self.other_port(dir), + + let port_count = signature.port_count(dir); + if port.index() < port_count { + signature.get(port).cloned().map(EdgeKind::Value) + } else if port.index() == port_count + && dir == Direction::Incoming + && self.static_input().is_some() + { + self.static_input().map(EdgeKind::Static) + } else { + self.other_port(dir) } } @@ -89,7 +97,14 @@ impl OpType { pub fn other_port_index(&self, dir: Direction) -> Option { let non_df_count = self.validity_flags().non_df_port_count(dir).unwrap_or(1); if self.other_port(dir).is_some() && non_df_count == 1 { - Some(Port::new(dir, self.signature().port_count(dir))) + // if there is a static input it comes before the non_df_ports + let static_input = + (dir == Direction::Incoming && self.static_input().is_some()) as usize; + + Some(Port::new( + dir, + self.signature().port_count(dir) + static_input, + )) } else { None } @@ -103,7 +118,9 @@ impl OpType { .validity_flags() .non_df_port_count(dir) .unwrap_or(has_other_ports as usize); - signature.port_count(dir) + non_df_count + // if there is a static input it comes before the non_df_ports + let static_input = (dir == Direction::Incoming && self.static_input().is_some()) as usize; + signature.port_count(dir) + non_df_count + static_input } /// Returns the number of inputs ports for the operation. @@ -170,6 +187,13 @@ pub trait OpTrait { Default::default() } + /// Get the static input type of this operation if it has one (only Some for + /// [`LoadConstant`] and [`Call`]) + #[inline] + fn static_input(&self) -> Option { + None + } + /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 1fbaf8602..015ae2e56 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -143,7 +143,7 @@ mod test { let pred_rows = vec![type_row![EQ_T, COPYABLE_T], type_row![]]; let pred_ty = Type::new_predicate(pred_rows.clone()); - let mut b = DFGBuilder::new(AbstractSignature::new_df( + let mut b = DFGBuilder::new(AbstractSignature::new( type_row![], TypeRow::from(vec![pred_ty.clone()]), ))?; @@ -155,7 +155,7 @@ mod test { let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w]).unwrap(); - let mut b = DFGBuilder::new(AbstractSignature::new_df( + let mut b = DFGBuilder::new(AbstractSignature::new( type_row![], TypeRow::from(vec![pred_ty]), ))?; diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 8b20ce5cd..9acc76c07 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -31,7 +31,7 @@ impl DataflowOpTrait for TailLoop { fn signature(&self) -> AbstractSignature { let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| predicate_first(row, &self.rest)); - AbstractSignature::new_df(inputs, outputs) + AbstractSignature::new(inputs, outputs) } } @@ -75,7 +75,7 @@ impl DataflowOpTrait for Conditional { 0, Type::new_predicate(self.predicate_inputs.clone().into_iter()), ); - AbstractSignature::new_df(inputs, self.outputs.clone()) + AbstractSignature::new(inputs, self.outputs.clone()) } } @@ -107,7 +107,7 @@ impl DataflowOpTrait for CFG { } fn signature(&self) -> AbstractSignature { - AbstractSignature::new_df(self.inputs.clone(), self.outputs.clone()) + AbstractSignature::new(self.inputs.clone(), self.outputs.clone()) } } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 9faca281a..e541dfeba 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -10,6 +10,11 @@ pub(super) trait DataflowOpTrait { const TAG: OpTag; fn description(&self) -> &str; fn signature(&self) -> AbstractSignature; + + /// Get the static input type of this operation if it has one. + fn static_input(&self) -> Option { + None + } /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// @@ -81,7 +86,7 @@ impl DataflowOpTrait for Input { } fn signature(&self) -> AbstractSignature { - AbstractSignature::new_df(TypeRow::new(), self.types.clone()) + AbstractSignature::new(TypeRow::new(), self.types.clone()) .with_extension_delta(&ExtensionSet::new()) } } @@ -95,7 +100,7 @@ impl DataflowOpTrait for Output { // Note: We know what the input extensions should be, so we *could* give an // instantiated Signature instead fn signature(&self) -> AbstractSignature { - AbstractSignature::new_df(self.types.clone(), TypeRow::new()) + AbstractSignature::new(self.types.clone(), TypeRow::new()) } fn other_output(&self) -> Option { @@ -120,6 +125,10 @@ impl OpTrait for T { fn other_output(&self) -> Option { DataflowOpTrait::other_output(self) } + + fn static_input(&self) -> Option { + DataflowOpTrait::static_input(self) + } } impl StaticTag for T { const TAG: OpTag = T::TAG; @@ -145,10 +154,12 @@ impl DataflowOpTrait for Call { } fn signature(&self) -> AbstractSignature { - AbstractSignature { - static_input: vec![Type::new_graph(self.signature.clone())].into(), - ..self.signature.clone() - } + self.signature.clone() + } + + #[inline] + fn static_input(&self) -> Option { + Some(Type::new_graph(self.signature.clone())) } } @@ -191,11 +202,12 @@ impl DataflowOpTrait for LoadConstant { } fn signature(&self) -> AbstractSignature { - AbstractSignature::new( - TypeRow::new(), - vec![self.datatype.clone()], - vec![self.datatype.clone()], - ) + AbstractSignature::new(TypeRow::new(), vec![self.datatype.clone()]) + } + + #[inline] + fn static_input(&self) -> Option { + Some(self.datatype.clone()) } } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 8050d3418..a618e1da3 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -99,23 +99,23 @@ impl OpTrait for LeafOp { match self { LeafOp::Noop { ty: typ } => { - AbstractSignature::new_df(vec![typ.clone()], vec![typ.clone()]) + AbstractSignature::new(vec![typ.clone()], vec![typ.clone()]) } LeafOp::CustomOp(ext) => ext.signature(), LeafOp::MakeTuple { tys: types } => { - AbstractSignature::new_df(types.clone(), vec![Type::new_tuple(types.clone())]) + AbstractSignature::new(types.clone(), vec![Type::new_tuple(types.clone())]) } LeafOp::UnpackTuple { tys: types } => { - AbstractSignature::new_df(vec![Type::new_tuple(types.clone())], types.clone()) + AbstractSignature::new(vec![Type::new_tuple(types.clone())], types.clone()) } - LeafOp::Tag { tag, variants } => AbstractSignature::new_df( + LeafOp::Tag { tag, variants } => AbstractSignature::new( vec![variants.get(*tag).expect("Not a valid tag").clone()], vec![Type::new_sum(variants.clone())], ), LeafOp::Lift { type_row, new_extension, - } => AbstractSignature::new_df(type_row.clone(), type_row.clone()) + } => AbstractSignature::new(type_row.clone(), type_row.clone()) .with_extension_delta(&ExtensionSet::singleton(new_extension)), } } diff --git a/src/ops/validate.rs b/src/ops/validate.rs index e2cb16df9..04b192ba5 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -518,6 +518,6 @@ impl_validate_op!(Input); impl_validate_op!(Output); impl_validate_op!(Const); impl_validate_op!(Call); -impl_validate_op!(CallIndirect); impl_validate_op!(LoadConstant); +impl_validate_op!(CallIndirect); impl_validate_op!(LeafOp); diff --git a/src/types/signature.rs b/src/types/signature.rs index 489872262..108f405d1 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -7,18 +7,14 @@ use std::ops::Index; use smol_str::SmolStr; -use crate::utils::display_list; - use std::fmt::{self, Display, Write}; use crate::hugr::Direction; -use super::{EdgeKind, Type, TypeRow}; +use super::{Type, TypeRow}; use crate::hugr::Port; -use crate::type_row; - use crate::extension::ExtensionSet; use delegate::delegate; @@ -31,8 +27,6 @@ pub struct AbstractSignature { pub input: TypeRow, /// Value outputs of the function. pub output: TypeRow, - /// Possible static input (for call / load-constant). - pub static_input: TypeRow, /// The extension requirements which are added by the operation pub extension_reqs: ExtensionSet, } @@ -47,20 +41,6 @@ pub struct Signature { } impl AbstractSignature { - /// Create a new signature. - pub fn new( - input: impl Into, - output: impl Into, - static_input: impl Into, - ) -> Self { - Self { - input: input.into(), - output: output.into(), - static_input: static_input.into(), - extension_reqs: ExtensionSet::new(), - } - } - /// Builder method, add extension_reqs to an AbstractSignature pub fn with_extension_delta(mut self, rs: &ExtensionSet) -> Self { self.extension_reqs = self.extension_reqs.union(rs); @@ -101,38 +81,30 @@ impl AbstractSignature { /// The number of wires in the signature. #[inline(always)] pub fn is_empty(&self) -> bool { - self.static_input.is_empty() && self.input.is_empty() && self.output.is_empty() + self.input.is_empty() && self.output.is_empty() } } impl AbstractSignature { - /// Create a new signature with only dataflow inputs and outputs. - pub fn new_df(input: impl Into, output: impl Into) -> Self { - Self::new(input, output, type_row![]) + /// Create a new signature with specified inputs and outputs. + pub fn new(input: impl Into, output: impl Into) -> Self { + // TODO rename to just "new" + Self { + input: input.into(), + output: output.into(), + extension_reqs: ExtensionSet::new(), + } } /// Create a new signature with the same input and output types. pub fn new_linear(linear: impl Into) -> Self { let linear = linear.into(); - Self::new_df(linear.clone(), linear) - } - - /// Returns the type of a [`Port`]. Returns `None` if the port is out of bounds. - pub fn get(&self, port: Port) -> Option { - if port.direction() == Direction::Incoming && port.index() >= self.input.len() { - Some(EdgeKind::Static( - self.static_input - .get(port.index() - self.input.len())? - .clone(), - )) - } else { - self.get_df(port).cloned().map(EdgeKind::Value) - } + Self::new(linear.clone(), linear) } /// Returns the type of a value [`Port`]. Returns `None` if the port is out - /// of bounds or if it is not a value. + /// of bounds. #[inline] - pub fn get_df(&self, port: Port) -> Option<&Type> { + pub fn get(&self, port: Port) -> Option<&Type> { match port.direction() { Direction::Incoming => self.input.get(port.index()), Direction::Outgoing => self.output.get(port.index()), @@ -140,64 +112,55 @@ impl AbstractSignature { } /// Returns the type of a value [`Port`]. Returns `None` if the port is out - /// of bounds or if it is not a value. + /// of bounds. #[inline] - pub fn get_df_mut(&mut self, port: Port) -> Option<&mut Type> { + pub fn get_mut(&mut self, port: Port) -> Option<&mut Type> { match port.direction() { Direction::Incoming => self.input.get_mut(port.index()), Direction::Outgoing => self.output.get_mut(port.index()), } } - /// Returns the number of value and static ports in the signature. + /// Returns the number of ports in the signature. #[inline] pub fn port_count(&self, dir: Direction) -> usize { match dir { - Direction::Incoming => self.input.len() + self.static_input.len(), + Direction::Incoming => self.input.len(), Direction::Outgoing => self.output.len(), } } - /// Returns the number of input value and static ports in the signature. + /// Returns the number of input ports in the signature. #[inline] pub fn input_count(&self) -> usize { self.port_count(Direction::Incoming) } - /// Returns the number of output value and static ports in the signature. + /// Returns the number of output ports in the signature. #[inline] pub fn output_count(&self) -> usize { self.port_count(Direction::Outgoing) } - /// Returns the number of value ports in the signature. - #[inline] - pub fn df_port_count(&self, dir: Direction) -> usize { - match dir { - Direction::Incoming => self.input.len(), - Direction::Outgoing => self.output.len(), - } - } - - /// Returns a slice of the value types for the given direction. + /// Returns a slice of the types for the given direction. #[inline] - pub fn df_types(&self, dir: Direction) -> &[Type] { + pub fn types(&self, dir: Direction) -> &[Type] { match dir { Direction::Incoming => &self.input, Direction::Outgoing => &self.output, } } - /// Returns a slice of the input value types. + /// Returns a slice of the input types. #[inline] - pub fn input_df_types(&self) -> &[Type] { - self.df_types(Direction::Incoming) + pub fn input_types(&self) -> &[Type] { + self.types(Direction::Incoming) } - /// Returns a slice of the output value types. + /// Returns a slice of the output types. #[inline] - pub fn output_df_types(&self) -> &[Type] { - self.df_types(Direction::Outgoing) + pub fn output_types(&self) -> &[Type] { + self.types(Direction::Outgoing) } #[inline] @@ -211,12 +174,6 @@ impl AbstractSignature { pub fn output(&self) -> &TypeRow { &self.output } - - #[inline] - /// Returns the row of static inputs - pub fn static_input(&self) -> &TypeRow { - &self.static_input - } } impl AbstractSignature { @@ -237,24 +194,6 @@ impl AbstractSignature { self.input.iter().filter(|t| !t.copyable()) } - /// Returns the value `Port`s in the signature for a given direction. - #[inline] - pub fn ports_df(&self, dir: Direction) -> impl Iterator { - (0..self.df_port_count(dir)).map(move |i| Port::new(dir, i)) - } - - /// Returns the incoming value `Port`s in the signature. - #[inline] - pub fn input_ports_df(&self) -> impl Iterator { - self.ports_df(Direction::Incoming) - } - - /// Returns the outgoing value `Port`s in the signature. - #[inline] - pub fn output_ports_df(&self) -> impl Iterator { - self.ports_df(Direction::Outgoing) - } - /// Returns the `Port`s in the signature for a given direction. #[inline] pub fn ports(&self, dir: Direction) -> impl Iterator { @@ -290,22 +229,14 @@ impl Signature { pub fn input(&self) -> &TypeRow; /// Outputs of the abstract signature pub fn output(&self) -> &TypeRow; - /// Static inputs of the abstract signature - pub fn static_input(&self) -> &TypeRow; } } } impl Display for AbstractSignature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let has_inputs = !(self.static_input.is_empty() && self.input.is_empty()); - if has_inputs { + if !self.input.is_empty() { self.input.fmt(f)?; - if !self.static_input.is_empty() { - f.write_char('<')?; - display_list(&self.static_input, f)?; - f.write_char('>')?; - } f.write_str(" -> ")?; } f.write_char('[')?; @@ -333,8 +264,6 @@ pub struct SignatureDescription { pub input: Vec, /// Output of the function. pub output: Vec, - /// Static data references used by the function. - pub static_input: Vec, } #[cfg_attr(feature = "pyo3", pymethods)] @@ -342,37 +271,23 @@ impl SignatureDescription { /// The number of wires in the signature. #[inline(always)] pub fn is_empty(&self) -> bool { - self.static_input.is_empty() && self.input.is_empty() && self.output.is_empty() + self.input.is_empty() && self.output.is_empty() } } impl SignatureDescription { /// Create a new signature. - pub fn new( - input: impl Into>, - output: impl Into>, - static_input: impl Into>, - ) -> Self { + pub fn new(input: impl Into>, output: impl Into>) -> Self { Self { input: input.into(), output: output.into(), - static_input: static_input.into(), } } - /// Create a new signature with only linear dataflow inputs and outputs. + /// Create a new signature with only linear inputs and outputs. pub fn new_linear(linear: impl Into>) -> Self { let linear = linear.into(); - SignatureDescription::new_df(linear.clone(), linear) - } - - /// Create a new signature with only dataflow inputs and outputs. - pub fn new_df(input: impl Into>, output: impl Into>) -> Self { - Self { - input: input.into(), - output: output.into(), - ..Default::default() - } + SignatureDescription::new(linear.clone(), linear) } pub(crate) fn row_zip<'a>( @@ -406,14 +321,6 @@ impl SignatureDescription { ) -> impl Iterator { Self::row_zip(signature.output(), &self.output) } - - /// Iterate over the static input wires of the signature and their names. - pub fn static_input_zip<'a>( - &'a self, - signature: &'a Signature, - ) -> impl Iterator { - Self::row_zip(signature.static_input(), &self.static_input) - } } impl Index for SignatureDescription {