diff --git a/src/hugr/typecheck.rs b/src/hugr/typecheck.rs index a7189bcae..5990d9126 100644 --- a/src/hugr/typecheck.rs +++ b/src/hugr/typecheck.rs @@ -19,7 +19,7 @@ use crate::ops::constant::{HugrIntValueStore, HugrIntWidthStore, HUGR_MAX_INT_WI pub enum ConstTypeError { /// This case hasn't been implemented. Possibly because we don't have value /// constructors to check against it - #[error("Const type checking unimplemented for {0}")] + #[error("Unimplemented: there are no constants of type {0}")] Unimplemented(ClassicType), /// The given type and term are incompatible #[error("Invalid const value for type {0}")] @@ -100,7 +100,7 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT Err(ConstTypeError::IntWidthMismatch(*exp_width, *width)) } } - (ty @ ClassicType::F64, _) => Err(ConstTypeError::Unimplemented(ty.clone())), + (ClassicType::F64, ConstValue::F64(_)) => Ok(()), (ty @ ClassicType::Container(c), tm) => match (c, tm) { (Container::Tuple(row), ConstValue::Tuple(xs)) => { if row.len() != xs.len() { @@ -114,6 +114,7 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT } Ok(()) } + (Container::Tuple(_), _) => Err(ConstTypeError::Failed(ty.clone())), (Container::Sum(row), ConstValue::Sum { tag, variants, val }) => { if tag > &row.len() { return Err(ConstTypeError::InvalidSumTag); @@ -130,6 +131,7 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT _ => Err(ConstTypeError::LinearTypeDisallowed), } } + (Container::Sum(_), _) => Err(ConstTypeError::Failed(ty.clone())), _ => Err(ConstTypeError::Unimplemented(ty.clone())), }, (ty @ ClassicType::Graph(_), _) => Err(ConstTypeError::Unimplemented(ty.clone())), @@ -147,3 +149,54 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT (ty, _) => Err(ConstTypeError::Failed(ty.clone())), } } + +#[cfg(test)] +mod test { + use cool_asserts::assert_matches; + + use crate::{type_row, types::ClassicType}; + + use super::*; + + #[test] + fn test_typecheck_const() { + const INT: ClassicType = ClassicType::Int(64); + typecheck_const(&INT, &ConstValue::i64(3)).unwrap(); + assert_eq!( + typecheck_const(&ClassicType::Int(32), &ConstValue::i64(3)), + Err(ConstTypeError::IntWidthMismatch(32, 64)) + ); + typecheck_const(&ClassicType::F64, &ConstValue::F64(17.4)).unwrap(); + assert_eq!( + typecheck_const(&ClassicType::F64, &ConstValue::i64(5)), + Err(ConstTypeError::Failed(ClassicType::F64)) + ); + let tuple_ty = ClassicType::Container(Container::Tuple(Box::new(type_row![ + SimpleType::Classic(INT), + SimpleType::Classic(ClassicType::F64) + ]))); + typecheck_const( + &tuple_ty, + &ConstValue::Tuple(vec![ConstValue::i64(7), ConstValue::F64(5.1)]), + ) + .unwrap(); + assert_matches!( + typecheck_const( + &tuple_ty, + &ConstValue::Tuple(vec![ConstValue::F64(4.8), ConstValue::i64(2)]) + ), + Err(ConstTypeError::Failed(_)) + ); + assert_eq!( + typecheck_const( + &tuple_ty, + &ConstValue::Tuple(vec![ + ConstValue::i64(5), + ConstValue::F64(3.3), + ConstValue::i64(2) + ]) + ), + Err(ConstTypeError::TupleWrongLength) + ); + } +} diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 1d498b736..952f18400 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -172,23 +172,29 @@ impl ConstValue { /// Constant Sum over units, used as predicates. pub fn simple_predicate(tag: usize, size: usize) -> Self { - Self::predicate(tag, std::iter::repeat(type_row![]).take(size)) + Self::predicate(tag, Self::unit(), std::iter::repeat(type_row![]).take(size)) } /// Constant Sum over Tuples, used as predicates. - pub fn predicate(tag: usize, variant_rows: impl IntoIterator) -> Self { + pub fn predicate( + tag: usize, + val: ConstValue, + variant_rows: impl IntoIterator, + ) -> Self { + let variants = TypeRow::predicate_variants_row(variant_rows); + let const_type = SimpleType::Classic(val.const_type()); + // TODO This assert is not appropriate for a public API and if the specified `val` + // is not of tuple type matching the `tag`th element of `variant_rows` then + // really the Hugr will fail in validate. However it doesn't at the moment + // (https://github.com/CQCL-DEV/hugr/issues/231). + assert!(Some(&const_type) == variants.get(tag)); ConstValue::Sum { tag, - variants: TypeRow::predicate_variants_row(variant_rows), - val: Box::new(Self::unit()), + variants, + val: Box::new(val), } } - /// Constant Sum over Tuples with just one variant - pub fn unary_predicate(row: impl Into) -> Self { - Self::predicate(0, [row.into()]) - } - /// Constant Sum over Tuples with just one variant of unit type pub fn simple_unary_predicate() -> Self { Self::simple_predicate(0, 1) @@ -232,3 +238,76 @@ pub trait CustomConst: impl_downcast!(CustomConst); impl_box_clone!(CustomConst, CustomConstBoxClone); + +#[cfg(test)] +mod test { + use super::ConstValue; + use crate::{ + builder::{BuildError, Container, DFGBuilder, Dataflow, DataflowHugr}, + hugr::{typecheck::ConstTypeError, ValidationError}, + type_row, + types::{ClassicType, SimpleType, TypeRow}, + }; + + #[test] + fn test_predicate() -> Result<(), BuildError> { + let pred_rows = vec![ + type_row![ + SimpleType::Classic(ClassicType::i64()), + SimpleType::Classic(ClassicType::F64) + ], + type_row![], + ]; + let pred_ty = SimpleType::new_predicate(pred_rows.clone()); + + let mut b = DFGBuilder::new(type_row![], TypeRow::from(vec![pred_ty.clone()]))?; + let c = b.add_constant(ConstValue::predicate( + 0, + ConstValue::Tuple(vec![ConstValue::i64(3), ConstValue::F64(3.15)]), + pred_rows.clone(), + ))?; + let w = b.load_const(&c)?; + b.finish_hugr_with_outputs([w]).unwrap(); + + let mut b = DFGBuilder::new(type_row![], TypeRow::from(vec![pred_ty]))?; + let c = b.add_constant(ConstValue::predicate( + 1, + ConstValue::Tuple(vec![]), + pred_rows, + ))?; + let w = b.load_const(&c)?; + b.finish_hugr_with_outputs([w]).unwrap(); + + Ok(()) + } + + #[test] + #[should_panic] // Pending resolution of https://github.com/CQCL-DEV/hugr/issues/231 + fn test_bad_predicate() { + let pred_rows = vec![ + type_row![ + SimpleType::Classic(ClassicType::i64()), + SimpleType::Classic(ClassicType::F64) + ], + type_row![], + ]; + let pred_ty = SimpleType::new_predicate(pred_rows.clone()); + + let mut b = DFGBuilder::new(type_row![], TypeRow::from(vec![pred_ty])).unwrap(); + // Until #231 is fixed, this is made to fail by an assert in ConstValue::predicate + let c = b + .add_constant(ConstValue::predicate( + 0, + ConstValue::Tuple(vec![]), + pred_rows, + )) + .unwrap(); + let w = b.load_const(&c).unwrap(); + assert_eq!( + b.finish_hugr_with_outputs([w]), + Err(BuildError::InvalidHUGR(ValidationError::ConstTypeError( + ConstTypeError::TupleWrongLength + ))) + ); + } +} diff --git a/src/types/simple.rs b/src/types/simple.rs index 04b0f29ed..f06738060 100644 --- a/src/types/simple.rs +++ b/src/types/simple.rs @@ -254,13 +254,6 @@ impl SimpleType { } } - /// New unit type, defined as an empty Tuple. - pub fn new_unit() -> Self { - Self::Classic(ClassicType::Container(Container::Tuple(Box::new( - TypeRow::new(), - )))) - } - /// New Sum of Tuple types, used as predicates in branching. /// Tuple rows are defined in order by input rows. pub fn new_predicate(variant_rows: impl IntoIterator) -> Self {