diff --git a/src/builder.rs b/src/builder.rs index 17b6f81b0..0444b1932 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -6,9 +6,9 @@ use thiserror::Error; use pyo3::prelude::*; use crate::hugr::{HugrError, Node, ValidationError, Wire}; -use crate::ops::constant::typecheck::ConstTypeError; use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID}; use crate::types::SimpleType; +use crate::values::ConstTypeError; pub mod handle; pub use handle::BuildHandle; diff --git a/src/extensions/rotation.rs b/src/extensions/rotation.rs index c2baa1a73..8d18e2840 100644 --- a/src/extensions/rotation.rs +++ b/src/extensions/rotation.rs @@ -11,11 +11,11 @@ use std::collections::HashMap; #[cfg(feature = "pyo3")] use pyo3::prelude::*; -use crate::ops::constant::typecheck::CustomCheckFail; use crate::ops::constant::CustomConst; use crate::resource::{OpDef, ResourceSet, TypeDef}; use crate::types::type_param::TypeArg; use crate::types::{CustomType, SimpleRow, TypeTag}; +use crate::values::CustomCheckFail; use crate::Resource; pub const fn resource_id() -> SmolStr { diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 11d50145c..4c70c1879 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -13,8 +13,7 @@ use thiserror::Error; use pyo3::prelude::*; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::OpTag; -use crate::ops::{OpTrait, OpType, ValidateOp}; +use crate::ops::{OpTag, OpTrait, OpType, ValidateOp}; use crate::resource::ResourceSet; use crate::types::{ClassicType, EdgeKind, SimpleType}; use crate::{Direction, Hugr, Node, Port}; @@ -748,7 +747,7 @@ mod test { use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder}; use crate::hugr::{HugrError, HugrMut, NodeType}; use crate::ops::dataflow::IOTrait; - use crate::ops::{self, ConstValue, LeafOp, OpType}; + use crate::ops::{self, LeafOp, OpType}; use crate::types::{AbstractSignature, ClassicType}; use crate::Direction; use crate::{type_row, Node}; @@ -1146,10 +1145,7 @@ mod test { }) ); // Second input of Xor from a constant - let cst = h.add_op_with_parent( - h.root(), - ops::Const::new(ConstValue::Int(1), ClassicType::int::<1>()).unwrap(), - )?; + let cst = h.add_op_with_parent(h.root(), ops::Const::int::<1>(1).unwrap())?; let lcst = h.add_op_with_parent( h.root(), ops::LoadConstant { diff --git a/src/lib.rs b/src/lib.rs index af6bcde7b..a4bf0c57b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ pub mod ops; pub mod resource; pub mod types; mod utils; +pub mod values; pub use crate::hugr::{Direction, Hugr, HugrView, Node, Port, SimpleReplacement, Wire}; pub use crate::resource::Resource; diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 5b9fb0e4b..b051ac7d2 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -4,14 +4,16 @@ use std::any::Any; use crate::{ macros::impl_box_clone, - types::{ClassicRow, ClassicType, CustomType, EdgeKind}, + types::{simple::Container, ClassicRow, ClassicType, CustomType, EdgeKind, HashableType}, + values::{ + map_container_type, ConstTypeError, ContainerValue, CustomCheckFail, HashableValue, + ValueOfType, + }, }; use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; -use self::typecheck::{typecheck_const, ConstTypeError, CustomCheckFail}; - use super::OpTag; use super::{OpName, OpTrait, StaticTag}; @@ -48,8 +50,7 @@ impl Const { variant_rows: impl IntoIterator, ) -> Result { let typ = ClassicType::new_predicate(variant_rows); - - Self::new(ConstValue::Sum(tag, Box::new(value)), typ) + Self::new(ConstValue::sum(tag, value), typ) } /// Constant Sum over units, used as predicates. @@ -80,18 +81,30 @@ impl Const { /// Fixed width integer pub fn int(value: HugrIntValueStore) -> Result { - Self::new(ConstValue::Int(value), ClassicType::int::()) + Self::new( + ConstValue::Hashable(HashableValue::Int(value)), + ClassicType::int::(), + ) } /// 64-bit integer pub fn i64(value: i64) -> Result { Self::int::<64>(value as HugrIntValueStore) } + + /// Tuple of values + pub fn new_tuple(items: impl IntoIterator) -> Self { + let (values, types): (Vec, Vec) = items + .into_iter() + .map(|Const { value, typ }| (value, typ)) + .unzip(); + Self::new(ConstValue::sequence(&values), ClassicType::new_tuple(types)).unwrap() + } } impl OpName for Const { fn name(&self) -> SmolStr { - self.value.name() + self.value.name().into() } } impl StaticTag for Const { @@ -116,23 +129,17 @@ pub(crate) type HugrIntWidthStore = u8; pub(crate) const HUGR_MAX_INT_WIDTH: HugrIntWidthStore = HugrIntValueStore::BITS as HugrIntWidthStore; -/// Value constants -/// -/// TODO: Add more constants -/// TODO: bigger/smaller integers. +/// Value constants. (This could be "ClassicValue" to parallel [HashableValue]) #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] #[non_exhaustive] #[allow(missing_docs)] pub enum ConstValue { - /// An arbitrary length integer constant. - Int(HugrIntValueStore), + Hashable(HashableValue), + /// A collection of constant values (at least some of which are not [ConstValue::Hashable]) + Container(ContainerValue), /// Double precision float F64(f64), - /// A constant specifying a variant of a Sum type. - Sum(usize, Box), - /// A tuple of constant values. - Tuple(Vec), - /// An opaque constant value, with cached type + /// An opaque constant value, with cached type. TODO put this into ContainerValue. // Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808 Opaque((Box,)), } @@ -143,59 +150,129 @@ impl PartialEq for dyn CustomConst { } } -impl Default for ConstValue { - fn default() -> Self { - Self::Int(0) - } -} +impl ValueOfType for ConstValue { + type T = ClassicType; -impl ConstValue { - /// Returns the datatype of the constant. - pub fn check_type(&self, typ: &ClassicType) -> Result<(), ConstTypeError> { - typecheck_const(typ, self) + fn name(&self) -> String { + match self { + ConstValue::F64(f) => format!("const:float:{}", f), + ConstValue::Hashable(hv) => hv.name(), + ConstValue::Container(ctr) => ctr.desc(), + ConstValue::Opaque((v,)) => format!("const:custom:{}", v.name()), + } } - /// Unique name of the constant. - pub fn name(&self) -> SmolStr { + + fn check_type(&self, ty: &ClassicType) -> Result<(), ConstTypeError> { match self { - Self::Int(value) => format!("const:int{value}"), - Self::F64(f) => format!("const:float:{f}"), - Self::Opaque((v,)) => format!("const:{}", v.name()), - Self::Sum(tag, val) => { - format!("const:sum:{{tag:{tag}, val:{}}}", val.name()) + ConstValue::F64(_) => { + if let ClassicType::F64 = ty { + return Ok(()); + } } - Self::Tuple(vals) => { - let valstr: Vec<_> = vals.iter().map(|v| v.name()).collect(); - let valstr = valstr.join(", "); - format!("const:tuple:{{{valstr}}}") + ConstValue::Hashable(hv) => { + match ty { + ClassicType::Hashable(exp) => return hv.check_type(exp), + ClassicType::Container(cty) => { + // A "hashable" value might be an instance of a non-hashable type: + // e.g. an empty list is hashable, yet can be checked against a classic element type! + if let HashableValue::Container(ctr) = hv { + // Note if ctr is a ContainerValue::Opaque, this means we can check that + // against a Container::Opaque, which is perhaps unnecessary, but harmless. + return ctr.map_vals(&ConstValue::Hashable).check_container(cty); + } + } + _ => (), + } } - } - .into() + ConstValue::Container(vals) => { + match ty { + ClassicType::Container(cty) => return vals.check_container(cty), + // We might also fail to deduce a container *value* was hashable, + // because it contains opaque values whose tag is unknown. + ClassicType::Hashable(HashableType::Container(cty)) => { + return vals + .check_container(&map_container_type(cty, &ClassicType::Hashable)) + } + _ => (), + }; + } + ConstValue::Opaque((val,)) => { + let maybe_cty = match ty { + ClassicType::Container(Container::Opaque(t)) => Some(t), + ClassicType::Hashable(HashableType::Container(Container::Opaque(t))) => Some(t), + _ => None, + }; + if let Some(cu_ty) = maybe_cty { + return val.check_custom_type(cu_ty).map_err(ConstTypeError::from); + } + } + }; + Err(ConstTypeError::ValueCheckFail(ty.clone(), self.clone())) + } + + fn container_error( + typ: Container, + vals: ContainerValue, + ) -> ConstTypeError { + ConstTypeError::ValueCheckFail(ClassicType::Container(typ), ConstValue::Container(vals)) } +} +impl ConstValue { /// Description of the constant. pub fn description(&self) -> &str { "Constant value" } /// Constant unit type (empty Tuple). - pub const fn unit() -> ConstValue { - ConstValue::Tuple(vec![]) + pub const fn unit() -> Self { + Self::Hashable(HashableValue::Container(ContainerValue::Sequence(vec![]))) } /// Constant Sum over units, used as predicates. pub fn simple_predicate(tag: usize) -> Self { - Self::predicate(tag, Self::unit()) - } - - /// Constant Sum over Tuples, used as predicates. - pub fn predicate(tag: usize, val: ConstValue) -> Self { - ConstValue::Sum(tag, Box::new(val)) + Self::sum(tag, Self::unit()) } /// Constant Sum over Tuples with just one variant of unit type pub fn simple_unary_predicate() -> Self { Self::simple_predicate(0) } + + /// Sequence of values (could be a tuple, list or array) + pub fn sequence(items: &[ConstValue]) -> Self { + // Keep Hashable at the outside (if all values are) + match items + .iter() + .map(|item| match item { + ConstValue::Hashable(h) => Some(h), + _ => None, + }) + .collect::>>() + { + Some(hashables) => ConstValue::Hashable(HashableValue::Container( + ContainerValue::Sequence(hashables.into_iter().cloned().collect()), + )), + None => ConstValue::Container(ContainerValue::Sequence(items.to_vec())), + } + } + + /// Sum value (could be of any compatible type, e.g. a predicate) + pub fn sum(tag: usize, value: ConstValue) -> Self { + // Keep Hashable as outermost constructor + match value { + ConstValue::Hashable(hv) => { + HashableValue::Container(ContainerValue::Sum(tag, Box::new(hv))).into() + } + _ => ConstValue::Container(ContainerValue::Sum(tag, Box::new(value))), + } + } +} + +impl From for ConstValue { + fn from(hv: HashableValue) -> Self { + Self::Hashable(hv) + } } impl From for ConstValue { @@ -232,12 +309,12 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); mod test { use cool_asserts::assert_matches; - use super::ConstValue; - use super::{typecheck::ConstTypeError, Const}; + use super::{typecheck::ConstIntError, Const, ConstValue}; use crate::{ builder::{BuildError, Container, DFGBuilder, Dataflow, DataflowHugr}, classic_row, type_row, types::{ClassicType, SimpleRow, SimpleType}, + values::{ConstTypeError, HashableValue, ValueOfType}, }; #[test] @@ -251,7 +328,10 @@ mod test { let mut b = DFGBuilder::new(type_row![], SimpleRow::from(vec![pred_ty.clone()]))?; let c = b.add_constant(Const::predicate( 0, - ConstValue::Tuple(vec![ConstValue::Int(3), ConstValue::F64(3.15)]), + ConstValue::sequence(&[ + ConstValue::Hashable(HashableValue::Int(3)), + ConstValue::F64(3.15), + ]), pred_rows.clone(), )?)?; let w = b.load_const(&c)?; @@ -272,7 +352,37 @@ mod test { type_row![], ]; - let res = Const::predicate(0, ConstValue::Tuple(vec![]), pred_rows); + let res = Const::predicate(0, ConstValue::sequence(&[]), pred_rows); assert_matches!(res, Err(ConstTypeError::TupleWrongLength)); } + + #[test] + fn test_constant_values() { + const T_INT: ClassicType = ClassicType::int::<64>(); + const V_INT: ConstValue = ConstValue::Hashable(HashableValue::Int(257)); + V_INT.check_type(&T_INT).unwrap(); + assert_eq!( + V_INT.check_type(&ClassicType::int::<8>()), + Err(ConstTypeError::Int(ConstIntError::IntTooLarge(8, 257))) + ); + ConstValue::F64(17.4).check_type(&ClassicType::F64).unwrap(); + assert_matches!( + V_INT.check_type(&ClassicType::F64), + Err(ConstTypeError::ValueCheckFail(ClassicType::F64, v)) => v == V_INT + ); + let tuple_ty = ClassicType::new_tuple(classic_row![T_INT, ClassicType::F64]); + let tuple_val = ConstValue::sequence(&[V_INT, ConstValue::F64(5.1)]); + tuple_val.check_type(&tuple_ty).unwrap(); + let tuple_val2 = ConstValue::sequence(&[ConstValue::F64(5.1), V_INT]); + assert_matches!( + tuple_val2.check_type(&tuple_ty), + Err(ConstTypeError::ValueCheckFail(ty, tv2)) => ty == tuple_ty && tv2 == tuple_val2 + ); + let tuple_val3 = + ConstValue::sequence(&vec![V_INT, ConstValue::F64(3.3), ConstValue::F64(2.0)]); + assert_eq!( + tuple_val3.check_type(&tuple_ty), + Err(ConstTypeError::TupleWrongLength) + ); + } } diff --git a/src/ops/constant/typecheck.rs b/src/ops/constant/typecheck.rs index ded6a644c..fae7d2bdb 100644 --- a/src/ops/constant/typecheck.rs +++ b/src/ops/constant/typecheck.rs @@ -1,6 +1,9 @@ -//! Simple type checking - takes a hugr and some extra info and checks whether -//! the types at the sources of each wire match those of the targets - +//! Simple type checking for int constants - currently this is just the bits that are +//! shared between the old [TypeArg] and the new [ConstValue]/[HashableValue]. +//! +//! [TypeArg]: crate::types::type_param::TypeArg +//! [ConstValue]: crate::ops::constant::ConstValue +//! [HashableValue]: crate::values::HashableValue use lazy_static::lazy_static; use std::collections::HashSet; @@ -8,9 +11,6 @@ use std::collections::HashSet; use thiserror::Error; // For static typechecking -use crate::ops::ConstValue; -use crate::types::{ClassicType, Container, HashableType, PrimType, TypeRow}; - use crate::ops::constant::{HugrIntValueStore, HugrIntWidthStore, HUGR_MAX_INT_WIDTH}; /// An error in fitting an integer constant into its size @@ -28,48 +28,6 @@ pub enum ConstIntError { IntWidthInvalid(HugrIntWidthStore), } -/// Struct for custom type check fails. -#[derive(Clone, Debug, PartialEq, Error)] -#[error("Error when checking custom type.")] -pub struct CustomCheckFail(String); - -impl CustomCheckFail { - /// Creates a new [`CustomCheckFail`]. - pub fn new(message: String) -> Self { - Self(message) - } -} - -/// Errors that arise from typechecking constants -#[derive(Clone, Debug, PartialEq, Error)] -pub enum ConstTypeError { - /// This case hasn't been implemented. Possibly because we don't have value - /// constructors to check against it - #[error("Unimplemented: there are no constants of type {0}")] - Unimplemented(ClassicType), - /// There was some problem fitting a const int into its declared size - #[error("Error with int constant")] - Int(#[from] ConstIntError), - /// Expected width (packed with const int) doesn't match type - #[error("Type mismatch for int: expected I{0}, but found I{1}")] - IntWidthMismatch(HugrIntWidthStore, HugrIntWidthStore), - /// Found a Var type constructor when we're checking a const val - #[error("Type of a const value can't be Var")] - ConstCantBeVar, - /// The length of the tuple value doesn't match the length of the tuple type - #[error("Tuple of wrong length")] - TupleWrongLength, - /// Tag for a sum value exceeded the number of variants - #[error("Tag of Sum value is invalid")] - InvalidSumTag, - /// A mismatch between the type expected and the value. - #[error("Value {1:?} does not match expected type {0}")] - ValueCheckFail(ClassicType, ConstValue), - /// Error when checking a custom value. - #[error("Custom value type check error: {0:?}")] - CustomCheckFail(#[from] CustomCheckFail), -} - lazy_static! { static ref VALID_WIDTHS: HashSet = HashSet::from_iter((0..8).map(|a| HugrIntWidthStore::pow(2, a))); @@ -99,126 +57,3 @@ pub(crate) fn check_int_fits_in_width( Err(ConstIntError::IntWidthInvalid(width)) } } - -fn map_vals( - container: Container, - f: &impl Fn(T) -> T2, -) -> Container { - fn map_row( - row: TypeRow, - f: &impl Fn(T) -> T2, - ) -> Box> { - Box::new(TypeRow::from( - row.into_owned().into_iter().map(f).collect::>(), - )) - } - match container { - Container::List(elem) => Container::List(Box::new(f(*elem))), - Container::Map(kv) => { - let (k, v) = *kv; - Container::Map(Box::new((k, f(v)))) - } - Container::Tuple(elems) => Container::Tuple(map_row(*elems, f)), - Container::Sum(variants) => Container::Sum(map_row(*variants, f)), - Container::Array(elem, sz) => Container::Array(Box::new(f(*elem)), sz), - Container::Alias(s) => Container::Alias(s), - Container::Opaque(custom) => Container::Opaque(custom), - } -} - -/// Typecheck a constant value -pub(super) fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstTypeError> { - match (typ, val) { - (ClassicType::Hashable(HashableType::Int(exp_width)), ConstValue::Int(value)) => { - check_int_fits_in_width(*value, *exp_width).map_err(ConstTypeError::Int) - } - (ClassicType::F64, ConstValue::F64(_)) => Ok(()), - (ty @ ClassicType::Container(c), tm) => match (c, tm) { - (Container::Tuple(row), ConstValue::Tuple(xs)) => { - if row.len() != xs.len() { - return Err(ConstTypeError::TupleWrongLength); - } - for (ty, tm) in row.iter().zip(xs.iter()) { - typecheck_const(ty, tm)? - } - Ok(()) - } - (Container::Tuple(_), _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), tm.clone())), - (Container::Sum(row), ConstValue::Sum(tag, val)) => { - if let Some(ty) = row.get(*tag) { - typecheck_const(ty, val.as_ref()) - } else { - Err(ConstTypeError::InvalidSumTag) - } - } - (Container::Sum(_), _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), tm.clone())), - (Container::Opaque(ty), ConstValue::Opaque((val,))) => { - val.check_custom_type(ty).map_err(ConstTypeError::from) - } - _ => Err(ConstTypeError::Unimplemented(ty.clone())), - }, - (ClassicType::Hashable(HashableType::Container(c)), tm) => { - // Here we deliberately build malformed Container-of-Hashable types - // (rather than Hashable-of-Container) in order to reuse logic above - typecheck_const( - &ClassicType::Container(map_vals(c.clone(), &ClassicType::Hashable)), - tm, - ) - } - (ty @ ClassicType::Graph(_), _) => Err(ConstTypeError::Unimplemented(ty.clone())), - (ty @ ClassicType::Hashable(HashableType::String), _) => { - Err(ConstTypeError::Unimplemented(ty.clone())) - } - (ClassicType::Hashable(HashableType::Variable(_)), _) => { - Err(ConstTypeError::ConstCantBeVar) - } - (ty, _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), val.clone())), - } -} - -#[cfg(test)] -mod test { - use cool_asserts::assert_matches; - - use crate::{classic_row, types::ClassicType}; - - use super::*; - - #[test] - fn test_typecheck_const() { - const INT: ClassicType = ClassicType::int::<64>(); - typecheck_const(&INT, &ConstValue::Int(3)).unwrap(); - typecheck_const(&ClassicType::F64, &ConstValue::F64(17.4)).unwrap(); - assert_eq!( - typecheck_const(&ClassicType::F64, &ConstValue::Int(5)), - Err(ConstTypeError::ValueCheckFail( - ClassicType::F64, - ConstValue::Int(5) - )) - ); - let tuple_ty = ClassicType::new_tuple(classic_row![INT, ClassicType::F64,]); - typecheck_const( - &tuple_ty, - &ConstValue::Tuple(vec![ConstValue::Int(7), ConstValue::F64(5.1)]), - ) - .unwrap(); - assert_matches!( - typecheck_const( - &tuple_ty, - &ConstValue::Tuple(vec![ConstValue::F64(4.8), ConstValue::Int(2)]) - ), - Err(ConstTypeError::ValueCheckFail(_, _)) - ); - assert_eq!( - typecheck_const( - &tuple_ty, - &ConstValue::Tuple(vec![ - ConstValue::Int(5), - ConstValue::F64(3.3), - ConstValue::Int(2) - ]) - ), - Err(ConstTypeError::TupleWrongLength) - ); - } -} diff --git a/src/values.rs b/src/values.rs new file mode 100644 index 000000000..6f0c28565 --- /dev/null +++ b/src/values.rs @@ -0,0 +1,251 @@ +//! Representation of values (shared between [Const] and in future [TypeArg]) +//! +//! [Const]: crate::ops::Const +//! [TypeArg]: crate::types::type_param::TypeArg + +use thiserror::Error; + +use crate::types::{ClassicType, Container, HashableType, PrimType}; +use crate::{ + ops::constant::{ + typecheck::{check_int_fits_in_width, ConstIntError}, + ConstValue, HugrIntValueStore, + }, + types::TypeRow, +}; + +/// A constant value/instance of a [HashableType]. Note there is no +/// equivalent of [HashableType::Variable]; we can't have instances of that. +#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum HashableValue { + /// A string, i.e. corresponding to [HashableType::String] + String(String), + /// An integer, i.e. an instance of all [HashableType::Int]s of sufficient width + Int(HugrIntValueStore), + /// A container of other hashable values + Container(ContainerValue), +} + +/// Trait for classes which represent values of some kind of [PrimType] +pub trait ValueOfType: Clone { + /// The exact type whose values the type implementing [ValueOfType] represents + type T: PrimType; + + /// Checks that a value can be an instance of the specified type. + fn check_type(&self, ty: &Self::T) -> Result<(), ConstTypeError>; + + /// Unique name of the constant/value. + fn name(&self) -> String; + + /// When there is an error fitting a [ContainerValue] of these values + /// into a [Container] (type), produce a [ConstTypeError::ValueCheckFail] for that. + fn container_error(typ: Container, vals: ContainerValue) -> ConstTypeError; +} + +impl ValueOfType for HashableValue { + type T = HashableType; + + fn name(&self) -> String { + match self { + HashableValue::String(s) => format!("const:string:\"{}\"", s), + HashableValue::Int(v) => format!("const:int:{}", v), + HashableValue::Container(c) => c.desc(), + } + } + + fn check_type(&self, ty: &HashableType) -> Result<(), ConstTypeError> { + if let HashableType::Container(Container::Alias(s)) = ty { + return Err(ConstTypeError::NoAliases(s.to_string())); + }; + match self { + HashableValue::String(_) => { + if let HashableType::String = ty { + return Ok(()); + }; + } + HashableValue::Int(value) => { + if let HashableType::Int(width) = ty { + return check_int_fits_in_width(*value, *width).map_err(ConstTypeError::Int); + }; + } + HashableValue::Container(vals) => { + if let HashableType::Container(c_ty) = ty { + return vals.check_container(c_ty); + }; + } + } + Err(ConstTypeError::ValueCheckFail( + ClassicType::Hashable(ty.clone()), + ConstValue::Hashable(self.clone()), + )) + } + + fn container_error( + typ: Container, + vals: ContainerValue, + ) -> ConstTypeError { + ConstTypeError::ValueCheckFail( + ClassicType::Hashable(HashableType::Container(typ)), + ConstValue::Hashable(HashableValue::Container(vals)), + ) + } +} + +/// A value that is a container of other values, e.g. a tuple or sum; +/// thus, corresponding to [Container]. Note there is no member +/// corresponding to [Container::Alias]; such types must have been +/// resolved to concrete types in order to create instances (values). +#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum ContainerValue { + /// A [Container::Array] or [Container::Tuple] or [Container::List] + Sequence(Vec), + /// A [Container::Map] + Map(Vec<(HashableValue, T)>), // TODO try to make this an actual map? + /// A [Container::Sum] - for any Sum type where this value meets + /// the type of the variant indicated by the tag + Sum(usize, Box), // Tag and value + /// A value of a custom type defined by an extension/[Resource]. + /// + /// [Resource]: crate::resource::Resource + // TODO replace this with CustomConst + Opaque(serde_yaml::Value), +} + +impl ContainerValue { + pub(crate) fn desc(&self) -> String { + match self { + ContainerValue::Sequence(vals) => { + let names: Vec<_> = vals.iter().map(ValueOfType::name).collect(); + format!("const:seq:{{{}}}", names.join(", ")) + } + ContainerValue::Map(_) => "a map".to_string(), + ContainerValue::Sum(tag, val) => format!("const:sum:{{tag:{tag}, val:{}}}", val.name()), + ContainerValue::Opaque(c) => format!("const:yaml:{:?}", c), + } + } + pub(crate) fn check_container(&self, ty: &Container) -> Result<(), ConstTypeError> { + match (self, ty) { + (ContainerValue::Sequence(elems), Container::List(elem_ty)) => { + for elem in elems { + elem.check_type(elem_ty)?; + } + Ok(()) + } + (ContainerValue::Sequence(elems), Container::Tuple(tup_tys)) => { + if elems.len() != tup_tys.len() { + return Err(ConstTypeError::TupleWrongLength); + } + for (elem, ty) in elems.iter().zip(tup_tys.iter()) { + elem.check_type(ty)?; + } + Ok(()) + } + (ContainerValue::Sequence(elems), Container::Array(elem_ty, sz)) => { + if elems.len() != *sz { + return Err(ConstTypeError::TupleWrongLength); + } + for elem in elems { + elem.check_type(elem_ty)?; + } + Ok(()) + } + (ContainerValue::Map(mappings), Container::Map(kv)) => { + let (key_ty, val_ty) = &**kv; + for (key, val) in mappings { + key.check_type(key_ty)?; + val.check_type(val_ty)?; + } + Ok(()) + } + (ContainerValue::Sum(tag, value), Container::Sum(variants)) => { + value.check_type(variants.get(*tag).ok_or(ConstTypeError::InvalidSumTag)?) + } + (ContainerValue::Opaque(_), Container::Opaque(_)) => Ok(()), // TODO + (_, Container::Alias(s)) => Err(ConstTypeError::NoAliases(s.to_string())), + (_, _) => Err(ValueOfType::container_error(ty.clone(), self.clone())), + } + } + + pub(crate) fn map_vals(&self, f: &impl Fn(Elem) -> T2) -> ContainerValue { + match self { + ContainerValue::Sequence(vals) => { + ContainerValue::Sequence(vals.iter().cloned().map(f).collect()) + } + ContainerValue::Map(_) => todo!(), + ContainerValue::Sum(tag, value) => { + ContainerValue::Sum(*tag, Box::new(f((**value).clone()))) + } + ContainerValue::Opaque(v) => ContainerValue::Opaque(v.clone()), + } + } +} + +pub(crate) fn map_container_type( + container: &Container, + f: &impl Fn(T) -> T2, +) -> Container { + fn map_row( + row: &TypeRow, + f: &impl Fn(T) -> T2, + ) -> Box> { + Box::new(TypeRow::from( + (*row) + .to_owned() + .into_owned() + .into_iter() + .map(f) + .collect::>(), + )) + } + match container { + Container::List(elem) => Container::List(Box::new(f(*(elem).clone()))), + Container::Map(kv) => { + let (k, v) = (**kv).clone(); + Container::Map(Box::new((k, f(v)))) + } + Container::Tuple(elems) => Container::Tuple(map_row(elems, f)), + Container::Sum(variants) => Container::Sum(map_row(variants, f)), + Container::Array(elem, sz) => Container::Array(Box::new(f((**elem).clone())), *sz), + Container::Alias(s) => Container::Alias(s.clone()), + Container::Opaque(custom) => Container::Opaque(custom.clone()), + } +} + +/// Struct for custom type check fails. +#[derive(Clone, Debug, PartialEq, Error)] +#[error("Error when checking custom type.")] +pub struct CustomCheckFail(String); + +impl CustomCheckFail { + /// Creates a new [`CustomCheckFail`]. + pub fn new(message: String) -> Self { + Self(message) + } +} + +/// Errors that arise from typechecking constants +#[derive(Clone, Debug, PartialEq, Error)] +pub enum ConstTypeError { + /// There was some problem fitting a const int into its declared size + #[error("Error with int constant")] + Int(#[from] ConstIntError), + /// Found a Var type constructor when we're checking a const val + #[error("Type of a const value can't be Var")] + ConstCantBeVar, + /// Type we were checking against was an Alias. + /// This should have been resolved to an actual type. + #[error("Type of a const value can't be an Alias {0}")] + NoAliases(String), + /// The length of the tuple value doesn't match the length of the tuple type + #[error("Tuple of wrong length")] + TupleWrongLength, + /// Tag for a sum value exceeded the number of variants + #[error("Tag of Sum value is invalid")] + InvalidSumTag, + /// A mismatch between the type expected and the value. + #[error("Value {1:?} does not match expected type {0}")] + ValueCheckFail(ClassicType, ConstValue), + /// Error when checking a custom value. + #[error("Custom value type check error: {0:?}")] + CustomCheckFail(#[from] CustomCheckFail), +}