diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index d94a0a5c8..83fc07af2 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -15,7 +15,7 @@ use pyo3::prelude::*; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; use crate::ops::{OpTag, OpTrait, OpType, ValidateOp}; use crate::resource::ResourceSet; -use crate::types::{ClassicType, EdgeKind, SimpleType}; +use crate::types::{EdgeKind, SimpleType}; use crate::{Direction, Hugr, Node, Port}; use super::hierarchical_views::{HierarchyView, SiblingGraph}; @@ -734,7 +734,7 @@ pub enum InterGraphEdgeError { InvalidConstSrc { from: Node, from_offset: Port, - typ: ClassicType, + typ: SimpleType, }, } @@ -823,12 +823,7 @@ mod test { .unwrap(); let tag_def = b.add_op_with_parent(b.root(), const_op).unwrap(); let tag = b - .add_op_with_parent( - parent, - ops::LoadConstant { - datatype: tag_type.try_into().unwrap(), - }, - ) + .add_op_with_parent(parent, ops::LoadConstant { datatype: tag_type }) .unwrap(); b.connect(tag_def, 0, tag, 0).unwrap(); @@ -1149,7 +1144,7 @@ mod test { let lcst = h.add_op_with_parent( h.root(), ops::LoadConstant { - datatype: ClassicType::int::<1>(), + datatype: SimpleType::int::<1>(), }, )?; h.connect(cst, 0, lcst, 0)?; diff --git a/src/ops/constant.rs b/src/ops/constant.rs index bc09c4b52..f861fa897 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -4,7 +4,9 @@ use std::any::Any; use crate::{ macros::impl_box_clone, - types::{simple::Container, ClassicRow, ClassicType, CustomType, EdgeKind, HashableType}, + types::{ + simple::Container, ClassicRow, ClassicType, CustomType, EdgeKind, HashableType, SimpleType, + }, values::{ map_container_type, ConstTypeError, ContainerValue, CustomCheckFail, HashableValue, ValueOfType, @@ -21,12 +23,12 @@ use super::{OpName, OpTrait, StaticTag}; #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct Const { value: ConstValue, - typ: ClassicType, + typ: SimpleType, } impl Const { /// Creates a new Const, type-checking the value. - pub fn new(value: ConstValue, typ: ClassicType) -> Result { + pub fn new(value: ConstValue, typ: SimpleType) -> Result { value.check_type(&typ)?; Ok(Self { value, typ }) } @@ -37,7 +39,7 @@ impl Const { } /// Returns a reference to the type of this [`Const`]. - pub fn const_type(&self) -> &ClassicType { + pub fn const_type(&self) -> &SimpleType { &self.typ } @@ -48,7 +50,7 @@ impl Const { value: ConstValue, variant_rows: impl IntoIterator, ) -> Result { - let typ = ClassicType::new_predicate(variant_rows); + let typ = SimpleType::new_predicate(variant_rows); Self::new(ConstValue::sum(tag, value), typ) } @@ -56,7 +58,7 @@ impl Const { pub fn simple_predicate(tag: usize, size: usize) -> Self { Self { value: ConstValue::simple_predicate(tag), - typ: ClassicType::new_simple_predicate(size), + typ: SimpleType::new_simple_predicate(size), } } @@ -64,7 +66,7 @@ impl Const { pub fn simple_unary_predicate() -> Self { Self { value: ConstValue::simple_unary_predicate(), - typ: ClassicType::new_simple_predicate(1), + typ: SimpleType::new_simple_predicate(1), } } @@ -82,7 +84,7 @@ impl Const { pub fn int(value: HugrIntValueStore) -> Result { Self::new( ConstValue::Hashable(HashableValue::Int(value)), - ClassicType::int::(), + SimpleType::int::(), ) } @@ -93,11 +95,11 @@ impl Const { /// Tuple of values pub fn new_tuple(items: impl IntoIterator) -> Self { - let (values, types): (Vec, Vec) = items + let (values, types): (Vec, Vec) = items .into_iter() .map(|Const { value, typ }| (value, typ)) .unzip(); - Self::new(ConstValue::sequence(&values), ClassicType::new_tuple(types)).unwrap() + Self::new(ConstValue::sequence(&values), SimpleType::new_tuple(types)).unwrap() } } @@ -153,7 +155,7 @@ impl PartialEq for dyn CustomConst { } impl ValueOfType for ConstValue { - type T = ClassicType; + type T = SimpleType; fn name(&self) -> String { match self { @@ -164,42 +166,54 @@ impl ValueOfType for ConstValue { } } - fn check_type(&self, ty: &ClassicType) -> Result<(), ConstTypeError> { + fn check_type(&self, ty: &SimpleType) -> Result<(), ConstTypeError> { match self { ConstValue::F64(_) => { - if let ClassicType::F64 = ty { + if ty == &SimpleType::Classic(ClassicType::F64) { return Ok(()); } } ConstValue::Hashable(hv) => { - match ty { - ClassicType::Hashable(exp) => return hv.check_type(exp), - ClassicType::Container(cty) => { - // A "hashable" value might be an instance of a non-hashable type: - // e.g. an empty list is hashable, yet can be checked against a classic element type! - if let HashableValue::Container(ctr) = hv { - return ctr.map_vals(&ConstValue::Hashable).check_container(cty); + if let SimpleType::Classic(ClassicType::Hashable(typ)) = ty { + return hv.check_type(typ); + } + if let HashableValue::Container(ctr) = hv { + // An empty list is a hashable value, but could be an instance of a non-hashable list type + // such as List or even List ! + let mapped_cty = || ctr.map_vals(&ConstValue::Hashable); + match ty { + SimpleType::Qontainer(cty) => return mapped_cty().check_container(cty), + SimpleType::Classic(ClassicType::Container(cty)) => { + return mapped_cty() + .check_container(&map_container_type(cty, &SimpleType::Classic)) } - } - _ => (), + _ => (), + }; } } ConstValue::Container(vals) => { match ty { - ClassicType::Container(cty) => return vals.check_container(cty), - // We might also fail to deduce a container *value* was hashable, + SimpleType::Qontainer(cty) => return vals.check_container(cty), + SimpleType::Classic(ClassicType::Container(cty)) => { + return vals.check_container(&map_container_type(cty, &SimpleType::Classic)) + } + // We might also fail to deduce/represent a container *value* was hashable, // because it contains opaque values whose tag is unknown. - ClassicType::Hashable(HashableType::Container(cty)) => { - return vals - .check_container(&map_container_type(cty, &ClassicType::Hashable)) + SimpleType::Classic(ClassicType::Hashable(HashableType::Container(cty))) => { + return vals.check_container(&map_container_type(cty, &|elemty| { + SimpleType::Classic(ClassicType::Hashable(elemty)) + })) } _ => (), }; } ConstValue::Opaque((val,)) => { let maybe_cty = match ty { - ClassicType::Container(Container::Opaque(t)) => Some(t), - ClassicType::Hashable(HashableType::Container(Container::Opaque(t))) => Some(t), + SimpleType::Qontainer(Container::Opaque(t)) => Some(t), + SimpleType::Classic(ClassicType::Container(Container::Opaque(t))) => Some(t), + SimpleType::Classic(ClassicType::Hashable(HashableType::Container( + Container::Opaque(t), + ))) => Some(t), _ => None, }; if let Some(cu_ty) = maybe_cty { @@ -211,10 +225,10 @@ impl ValueOfType for ConstValue { } fn container_error( - typ: Container, + typ: Container, vals: ContainerValue, ) -> ConstTypeError { - ConstTypeError::ValueCheckFail(ClassicType::Container(typ), ConstValue::Container(vals)) + ConstTypeError::ValueCheckFail(SimpleType::Qontainer(typ), ConstValue::Container(vals)) } } @@ -394,19 +408,20 @@ mod test { #[test] fn test_constant_values() { - const T_INT: ClassicType = ClassicType::int::<64>(); + const T_INT: SimpleType = SimpleType::int::<64>(); const V_INT: ConstValue = ConstValue::Hashable(HashableValue::Int(257)); + const T_F64: SimpleType = SimpleType::Classic(ClassicType::F64); V_INT.check_type(&T_INT).unwrap(); assert_eq!( - V_INT.check_type(&ClassicType::int::<8>()), + V_INT.check_type(&SimpleType::int::<8>()), Err(ConstTypeError::Int(ConstIntError::IntTooLarge(8, 257))) ); - ConstValue::F64(17.4).check_type(&ClassicType::F64).unwrap(); + ConstValue::F64(17.4).check_type(&T_F64).unwrap(); assert_matches!( - V_INT.check_type(&ClassicType::F64), - Err(ConstTypeError::ValueCheckFail(ClassicType::F64, v)) => v == V_INT + V_INT.check_type(&T_F64), + Err(ConstTypeError::ValueCheckFail(T_F64, v)) => v == V_INT ); - let tuple_ty = ClassicType::new_tuple(classic_row![T_INT, ClassicType::F64]); + let tuple_ty = SimpleType::new_tuple(type_row![T_INT, T_F64]); let tuple_val = ConstValue::sequence(&[V_INT, ConstValue::F64(5.1)]); tuple_val.check_type(&tuple_ty).unwrap(); let tuple_val2 = ConstValue::sequence(&[ConstValue::F64(5.1), V_INT]); @@ -433,14 +448,13 @@ mod test { typ: typ_int.clone(), value: Value::Number(6.into()), }),)); - let SimpleType::Classic(classic_t) = typ_int.clone().into() - else {panic!("Hashable CustomType returned as non-Classic");}; - assert_matches!(classic_t, ClassicType::Hashable(_)); - val.check_type(&classic_t).unwrap(); + let simp_t: SimpleType = typ_int.clone().into(); + assert_matches!(simp_t, SimpleType::Classic(ClassicType::Hashable(_))); + val.check_type(&simp_t).unwrap(); // This misrepresents the CustomType, so doesn't really "have to work". // But just as documentation of current behaviour: - val.check_type(&ClassicType::Container(Container::Opaque(typ_int.clone()))) + val.check_type(&SimpleType::Qontainer(Container::Opaque(typ_int.clone()))) .unwrap(); let typ_qb = CustomType::new( @@ -450,7 +464,7 @@ mod test { TypeTag::Hashable, ); let t: SimpleType = typ_qb.clone().into(); - assert_matches!(val.check_type(&t.try_into().unwrap()), + assert_matches!(val.check_type(&t), Err(ConstTypeError::CustomCheckFail(CustomCheckFail::TypeMismatch(a, b))) => a == typ_int && b == typ_qb); assert_eq!(val, val); diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 5b68b7e51..f945bafa5 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -146,7 +146,7 @@ impl DataflowOpTrait for Call { fn signature(&self) -> AbstractSignature { AbstractSignature { - static_input: vec![ClassicType::graph_from_sig(self.signature.clone())].into(), + static_input: vec![ClassicType::graph_from_sig(self.signature.clone()).into()].into(), ..self.signature.clone() } } @@ -181,7 +181,7 @@ impl DataflowOpTrait for CallIndirect { #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct LoadConstant { /// Constant type - pub datatype: ClassicType, + pub datatype: SimpleType, } impl_op_name!(LoadConstant); impl DataflowOpTrait for LoadConstant { @@ -194,7 +194,7 @@ impl DataflowOpTrait for LoadConstant { fn signature(&self) -> AbstractSignature { AbstractSignature::new( SimpleRow::new(), - vec![SimpleType::Classic(self.datatype.clone())], + vec![self.datatype.clone()], vec![self.datatype.clone()], ) } diff --git a/src/ops/module.rs b/src/ops/module.rs index 373e116f2..ebf1b8bcb 100644 --- a/src/ops/module.rs +++ b/src/ops/module.rs @@ -53,9 +53,9 @@ impl OpTrait for FuncDefn { } fn other_output(&self) -> Option { - Some(EdgeKind::Static(ClassicType::graph_from_sig( - self.signature.clone(), - ))) + Some(EdgeKind::Static( + ClassicType::graph_from_sig(self.signature.clone()).into(), + )) } } @@ -82,9 +82,9 @@ impl OpTrait for FuncDecl { } fn other_output(&self) -> Option { - Some(EdgeKind::Static(ClassicType::graph_from_sig( - self.signature.clone(), - ))) + Some(EdgeKind::Static( + ClassicType::graph_from_sig(self.signature.clone()).into(), + )) } } diff --git a/src/types.rs b/src/types.rs index 42429e1d2..00bdcd44c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -34,7 +34,7 @@ pub enum EdgeKind { /// Data edges of a DDG region, also known as "wires". Value(SimpleType), /// A reference to a static value definition. - Static(ClassicType), + Static(SimpleType), /// Explicitly enforce an ordering between nodes in a DDG. StateOrder, } @@ -59,7 +59,7 @@ pub struct AbstractSignature { /// Value outputs of the function. pub output: SimpleRow, /// Possible static input (for call / load-constant). - pub static_input: ClassicRow, + pub static_input: SimpleRow, /// The resource requirements which are added by the operation pub resource_reqs: ResourceSet, } @@ -78,7 +78,7 @@ impl AbstractSignature { pub fn new( input: impl Into, output: impl Into, - static_input: impl Into, + static_input: impl Into, ) -> Self { Self { input: input.into(), @@ -241,7 +241,7 @@ impl AbstractSignature { #[inline] /// Returns the row of static inputs - pub fn static_input(&self) -> &ClassicRow { + pub fn static_input(&self) -> &SimpleRow { &self.static_input } } @@ -318,7 +318,7 @@ impl Signature { /// Outputs of the abstract signature pub fn output(&self) -> &SimpleRow; /// Static inputs of the abstract signature - pub fn static_input(&self) -> &ClassicRow; + pub fn static_input(&self) -> &SimpleRow; } } } @@ -438,7 +438,7 @@ impl SignatureDescription { pub fn static_input_zip<'a>( &'a self, signature: &'a Signature, - ) -> impl Iterator { + ) -> impl Iterator { Self::row_zip(signature.static_input(), &self.static_input) } } diff --git a/src/types/simple.rs b/src/types/simple.rs index 220a8d552..ab6eb4656 100644 --- a/src/types/simple.rs +++ b/src/types/simple.rs @@ -367,6 +367,12 @@ impl SimpleType { pub fn new_simple_predicate(size: usize) -> Self { Self::Classic(ClassicType::new_simple_predicate(size)) } + + /// Returns a new integer type with the given number of bits. + #[inline] + pub const fn int() -> Self { + Self::Classic(ClassicType::int::()) + } } impl From for SimpleType { diff --git a/src/values.rs b/src/values.rs index 74bf4c346..5ce29ec2d 100644 --- a/src/values.rs +++ b/src/values.rs @@ -6,7 +6,7 @@ use thiserror::Error; use crate::ops::constant::{HugrIntWidthStore, HUGR_MAX_INT_WIDTH}; -use crate::types::{ClassicType, Container, CustomType, HashableType, PrimType}; +use crate::types::{Container, CustomType, HashableType, PrimType, SimpleType}; use crate::{ ops::constant::{ConstValue, HugrIntValueStore}, types::TypeRow, @@ -73,7 +73,7 @@ impl ValueOfType for HashableValue { } } Err(ConstTypeError::ValueCheckFail( - ClassicType::Hashable(ty.clone()), + ty.clone().into(), ConstValue::Hashable(self.clone()), )) } @@ -83,7 +83,7 @@ impl ValueOfType for HashableValue { vals: ContainerValue, ) -> ConstTypeError { ConstTypeError::ValueCheckFail( - ClassicType::Hashable(HashableType::Container(typ)), + HashableType::Container(typ).into(), ConstValue::Hashable(HashableValue::Container(vals)), ) } @@ -284,7 +284,7 @@ pub enum ConstTypeError { InvalidSumTag, /// A mismatch between the type expected and the value. #[error("Value {1:?} does not match expected type {0}")] - ValueCheckFail(ClassicType, ConstValue), + ValueCheckFail(SimpleType, ConstValue), /// Error when checking a custom value. #[error("Error when checking custom type: {0:?}")] CustomCheckFail(#[from] CustomCheckFail),