Skip to content

Commit

Permalink
Parameterize TypeRow by element type (#256)
Browse files Browse the repository at this point in the history
* PrimType trait requires Debug, Clone, 'static; remove const CLASSIC: bool
* Use PrimType constraint for all TypeRow's; turn off pyo3 generation
* Add SimpleRow + ClassicRow type aliases
* Add SerializableType trait/constraint within types/simple/serialize.rs (with CLASSIC)
* Add TypeRow::{try_convert_elems, into_owned, map_into}
* Constant Sums must be all-classic
* Remove ConstTypeError::LinearTypeDisallowed (now statically impossible)
* Serialization: SerSimpleType now contains only SerSimpleTypes
    --> recurse (building whole SerSimpleType structure) rather than serializing layers iteratively
* Drop some now-unnecessary into's
* Remove now-unnecessary hoop-jumping in Signature::get for 'other' ports
* Add classic_row! macro paralleling type_row!
* Test new_tuple and new_sum
  • Loading branch information
acl-cqc authored Jul 25, 2023
1 parent e9b1ae3 commit 2045c97
Show file tree
Hide file tree
Showing 17 changed files with 432 additions and 270 deletions.
24 changes: 12 additions & 12 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
types::EdgeKind,
};

use crate::types::{Signature, SimpleType, TypeRow};
use crate::types::{ClassicRow, ClassicType, Signature, SimpleRow, SimpleType};

use itertools::Itertools;

Expand Down Expand Up @@ -277,11 +277,11 @@ pub trait Dataflow: Container {
fn cfg_builder(
&mut self,
inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
output_types: TypeRow,
output_types: SimpleRow,
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
let (input_types, input_wires): (Vec<SimpleType>, Vec<Wire>) = inputs.into_iter().unzip();

let inputs: TypeRow = input_types.into();
let inputs: SimpleRow = input_types.into();

let (cfg_node, _) = add_op_with_wires(
self,
Expand Down Expand Up @@ -334,11 +334,11 @@ pub trait Dataflow: Container {
/// the [`ops::TailLoop`] node.
fn tail_loop_builder(
&mut self,
just_inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
just_inputs: impl IntoIterator<Item = (ClassicType, Wire)>,
inputs_outputs: impl IntoIterator<Item = (SimpleType, Wire)>,
just_out_types: TypeRow,
just_out_types: ClassicRow,
) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
let (input_types, mut input_wires): (Vec<SimpleType>, Vec<Wire>) =
let (input_types, mut input_wires): (Vec<ClassicType>, Vec<Wire>) =
just_inputs.into_iter().unzip();
let (rest_types, rest_input_wires): (Vec<SimpleType>, Vec<Wire>) =
inputs_outputs.into_iter().unzip();
Expand Down Expand Up @@ -368,16 +368,16 @@ pub trait Dataflow: Container {
/// the Conditional node.
fn conditional_builder(
&mut self,
(predicate_inputs, predicate_wire): (impl IntoIterator<Item = TypeRow>, Wire),
(predicate_inputs, predicate_wire): (impl IntoIterator<Item = ClassicRow>, Wire),
other_inputs: impl IntoIterator<Item = (SimpleType, Wire)>,
output_types: TypeRow,
output_types: SimpleRow,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
let mut input_wires = vec![predicate_wire];
let (input_types, rest_input_wires): (Vec<SimpleType>, Vec<Wire>) =
other_inputs.into_iter().unzip();

input_wires.extend(rest_input_wires);
let inputs: TypeRow = input_types.into();
let inputs: SimpleRow = input_types.into();
let predicate_inputs: Vec<_> = predicate_inputs.into_iter().collect();
let n_cases = predicate_inputs.len();
let n_out_wires = output_types.len();
Expand Down Expand Up @@ -452,7 +452,7 @@ pub trait Dataflow: Container {
fn make_tag(
&mut self,
tag: usize,
variants: impl Into<TypeRow>,
variants: impl Into<SimpleRow>,
value: Wire,
) -> Result<Wire, BuildError> {
let make_op = self.add_dataflow_op(
Expand All @@ -470,11 +470,11 @@ pub trait Dataflow: Container {
fn make_predicate(
&mut self,
tag: usize,
predicate_variants: impl IntoIterator<Item = TypeRow>,
predicate_variants: impl IntoIterator<Item = ClassicRow>,
values: impl IntoIterator<Item = Wire>,
) -> Result<Wire, BuildError> {
let tuple = self.make_tuple(values)?;
let variants = TypeRow::predicate_variants_row(predicate_variants);
let variants = ClassicRow::predicate_variants_row(predicate_variants).map_into();
let make_op = self.add_dataflow_op(LeafOp::Tag { tag, variants }, vec![tuple])?;
Ok(make_op.out_wire(0))
}
Expand Down
67 changes: 36 additions & 31 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,20 @@ use super::{
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
};

use crate::{hugr::view::HugrView, type_row, types::SimpleType};

use crate::hugr::view::HugrView;
use crate::hugr::HugrMut;
use crate::ops::handle::NodeHandle;
use crate::ops::{self, BasicBlock, OpType};
use crate::types::Signature;

use crate::Node;
use crate::{hugr::HugrMut, types::TypeRow, Hugr};
use crate::types::{ClassicRow, Signature, SimpleRow, SimpleType};
use crate::{type_row, Hugr, Node};

/// Builder for a [`crate::ops::CFG`] child control
/// flow graph
#[derive(Debug, PartialEq)]
pub struct CFGBuilder<T> {
pub(super) base: T,
pub(super) cfg_node: Node,
pub(super) inputs: Option<TypeRow>,
pub(super) inputs: Option<SimpleRow>,
pub(super) exit_node: Node,
pub(super) n_out_wires: usize,
}
Expand Down Expand Up @@ -54,7 +52,10 @@ impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for CFGBuilder<H> {

impl CFGBuilder<Hugr> {
/// New CFG rooted HUGR builder
pub fn new(input: impl Into<TypeRow>, output: impl Into<TypeRow>) -> Result<Self, BuildError> {
pub fn new(
input: impl Into<SimpleRow>,
output: impl Into<SimpleRow>,
) -> Result<Self, BuildError> {
let input = input.into();
let output = output.into();
let cfg_op = ops::CFG {
Expand All @@ -79,8 +80,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
pub(super) fn create(
mut base: B,
cfg_node: Node,
input: TypeRow,
output: TypeRow,
input: SimpleRow,
output: SimpleRow,
) -> Result<Self, BuildError> {
let n_out_wires = output.len();
let exit_block_type = OpType::BasicBlock(BasicBlock::Exit {
Expand Down Expand Up @@ -122,18 +123,18 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
/// This function will return an error if there is an error adding the node.
pub fn block_builder(
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
inputs: SimpleRow,
predicate_variants: Vec<ClassicRow>,
other_outputs: SimpleRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(inputs, predicate_variants, other_outputs, false)
}

fn any_block_builder(
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
inputs: SimpleRow,
predicate_variants: Vec<ClassicRow>,
other_outputs: SimpleRow,
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let op = OpType::BasicBlock(BasicBlock::DFB {
Expand Down Expand Up @@ -166,8 +167,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
/// This function will return an error if there is an error adding the node.
pub fn simple_block_builder(
&mut self,
inputs: TypeRow,
outputs: TypeRow,
inputs: SimpleRow,
outputs: SimpleRow,
n_cases: usize,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.block_builder(inputs, vec![type_row![]; n_cases], outputs)
Expand All @@ -182,8 +183,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
/// This function will return an error if an entry block has already been built.
pub fn entry_builder(
&mut self,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
predicate_variants: Vec<ClassicRow>,
other_outputs: SimpleRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let inputs = self
.inputs
Expand All @@ -200,7 +201,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
/// This function will return an error if there is an error adding the node.
pub fn simple_entry_builder(
&mut self,
outputs: TypeRow,
outputs: SimpleRow,
n_cases: usize,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder(vec![type_row![]; n_cases], outputs)
Expand Down Expand Up @@ -244,15 +245,15 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
fn create(
base: B,
block_n: Node,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
inputs: TypeRow,
predicate_variants: Vec<ClassicRow>,
other_outputs: SimpleRow,
inputs: SimpleRow,
) -> Result<Self, BuildError> {
// The node outputs a predicate before the data outputs of the block node
let predicate_type = SimpleType::new_predicate(predicate_variants);
let mut node_outputs = vec![predicate_type];
node_outputs.extend_from_slice(&other_outputs);
let signature = Signature::new_df(inputs, TypeRow::from(node_outputs));
let signature = Signature::new_df(inputs, SimpleRow::from(node_outputs));
let db = DFGBuilder::create_with_io(base, block_n, signature)?;
Ok(BlockBuilder::from_dfg_builder(db))
}
Expand All @@ -275,9 +276,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
impl BlockBuilder<Hugr> {
/// Initialize a [`BasicBlock::DFB`] rooted HUGR builder
pub fn new(
inputs: impl Into<TypeRow>,
predicate_variants: impl IntoIterator<Item = TypeRow>,
other_outputs: impl Into<TypeRow>,
inputs: impl Into<SimpleRow>,
predicate_variants: impl IntoIterator<Item = ClassicRow>,
other_outputs: impl Into<SimpleRow>,
) -> Result<Self, BuildError> {
let inputs = inputs.into();
let predicate_variants: Vec<_> = predicate_variants.into_iter().collect();
Expand All @@ -298,11 +299,12 @@ impl BlockBuilder<Hugr> {
mod test {
use std::collections::HashSet;

use cool_asserts::assert_matches;

use crate::builder::build_traits::HugrBuilder;
use crate::builder::{DataflowSubContainer, ModuleBuilder};
use crate::macros::classic_row;
use crate::types::ClassicType;
use crate::{builder::test::NAT, ops::ConstValue, type_row, types::Signature};
use cool_asserts::assert_matches;

use super::*;
#[test]
Expand Down Expand Up @@ -372,7 +374,10 @@ mod test {
fn build_basic_cfg<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg_builder: &mut CFGBuilder<T>,
) -> Result<(), BuildError> {
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
let sum2_variants = vec![
classic_row![ClassicType::i64()],
classic_row![ClassicType::i64()],
];
let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?;
let entry = {
let [inw] = entry_b.input_wires_arr();
Expand Down
13 changes: 8 additions & 5 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::hugr::view::HugrView;
use crate::types::{Signature, TypeRow};
use crate::types::{ClassicRow, Signature, SimpleRow};

use crate::ops;
use crate::ops::handle::CaseID;
Expand Down Expand Up @@ -146,9 +146,9 @@ impl HugrBuilder for ConditionalBuilder<Hugr> {
impl ConditionalBuilder<Hugr> {
/// Initialize a Conditional rooted HUGR builder
pub fn new(
predicate_inputs: impl IntoIterator<Item = TypeRow>,
other_inputs: impl Into<TypeRow>,
outputs: impl Into<TypeRow>,
predicate_inputs: impl IntoIterator<Item = ClassicRow>,
other_inputs: impl Into<SimpleRow>,
outputs: impl Into<SimpleRow>,
) -> Result<Self, BuildError> {
let predicate_inputs: Vec<_> = predicate_inputs.into_iter().collect();
let other_inputs = other_inputs.into();
Expand Down Expand Up @@ -176,7 +176,10 @@ impl ConditionalBuilder<Hugr> {

impl CaseBuilder<Hugr> {
/// Initialize a Case rooted HUGR
pub fn new(input: impl Into<TypeRow>, output: impl Into<TypeRow>) -> Result<Self, BuildError> {
pub fn new(
input: impl Into<SimpleRow>,
output: impl Into<SimpleRow>,
) -> Result<Self, BuildError> {
let input = input.into();
let output = output.into();
let signature = Signature::new_df(input, output);
Expand Down
6 changes: 3 additions & 3 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::marker::PhantomData;
use crate::hugr::{HugrView, ValidationError};
use crate::ops;

use crate::types::{Signature, TypeRow};
use crate::types::{Signature, SimpleRow};

use crate::Node;
use crate::{hugr::HugrMut, Hugr};
Expand Down Expand Up @@ -60,8 +60,8 @@ impl DFGBuilder<Hugr> {
///
/// Error in adding DFG child nodes.
pub fn new(
input: impl Into<TypeRow>,
output: impl Into<TypeRow>,
input: impl Into<SimpleRow>,
output: impl Into<SimpleRow>,
) -> Result<DFGBuilder<Hugr>, BuildError> {
let input = input.into();
let output = output.into();
Expand Down
22 changes: 13 additions & 9 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::ops::{self, OpType};

use crate::hugr::view::HugrView;
use crate::types::{Signature, TypeRow};
use crate::types::{ClassicRow, Signature, SimpleRow};
use crate::{Hugr, Node};

use super::build_traits::SubContainer;
Expand Down Expand Up @@ -49,7 +49,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
}

/// The output types of the child graph, including the predicate as the first.
pub fn internal_output_row(&self) -> Result<TypeRow, BuildError> {
pub fn internal_output_row(&self) -> Result<SimpleRow, BuildError> {
self.loop_signature().map(ops::TailLoop::body_output_row)
}
}
Expand All @@ -72,9 +72,9 @@ impl<H: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<H> {
impl TailLoopBuilder<Hugr> {
/// Initialize new builder for a [`ops::TailLoop`] rooted HUGR
pub fn new(
just_inputs: impl Into<TypeRow>,
inputs_outputs: impl Into<TypeRow>,
just_outputs: impl Into<TypeRow>,
just_inputs: impl Into<ClassicRow>,
inputs_outputs: impl Into<SimpleRow>,
just_outputs: impl Into<ClassicRow>,
) -> Result<Self, BuildError> {
let tail_loop = ops::TailLoop {
just_inputs: just_inputs.into(),
Expand All @@ -96,18 +96,19 @@ mod test {
test::{BIT, NAT},
DataflowSubContainer, HugrBuilder, ModuleBuilder,
},
classic_row,
hugr::ValidationError,
ops::ConstValue,
type_row,
types::Signature,
types::{ClassicType, Signature},
Hugr,
};

use super::*;
#[test]
fn basic_loop() -> Result<(), BuildError> {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], type_row![NAT])?;
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![ClassicType::i64()])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(ConstValue::i64(1))?;

Expand All @@ -129,8 +130,11 @@ mod test {
let _fdef = {
let [b1] = fbuild.input_wires_arr();
let loop_id = {
let mut loop_b =
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
let mut loop_b = fbuild.tail_loop_builder(
vec![(ClassicType::bit(), b1)],
vec![],
classic_row![ClassicType::i64()],
)?;
let signature = loop_b.loop_signature()?.clone();
let const_wire = loop_b.add_load_const(ConstValue::true_val())?;
let [b1] = loop_b.input_wires_arr();
Expand Down
4 changes: 2 additions & 2 deletions src/extensions/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use pyo3::prelude::*;
use crate::ops::constant::CustomConst;
use crate::resource::{OpDef, ResourceSet, TypeDef};
use crate::types::type_param::TypeArg;
use crate::types::{CustomType, SimpleType, TypeRow};
use crate::types::{CustomType, SimpleRow, SimpleType};
use crate::Resource;

pub const fn resource_id() -> SmolStr {
Expand All @@ -34,7 +34,7 @@ pub fn resource() -> Resource {
vec![],
HashMap::default(),
|_arg_values: &[TypeArg]| {
let t: TypeRow = vec![SimpleType::Classic(Type::Angle.custom_type().into())].into();
let t: SimpleRow = vec![SimpleType::Classic(Type::Angle.custom_type().into())].into();
Ok((t.clone(), t, ResourceSet::default()))
},
);
Expand Down
Loading

0 comments on commit 2045c97

Please sign in to comment.