From 15bc0c27dc5f3efa598af5b1e7f3b12eabcc2c59 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 1 Aug 2023 11:11:38 +0100 Subject: [PATCH] require type checking on `CustomConst` this now makes it impossible for use of static values of types from resources that do not provide binary `CustomConst`. Is this what we want? --- src/extensions/rotation.rs | 67 +++++++++++++++++++++++++++++++---- src/ops/constant.rs | 11 +++--- src/ops/constant/typecheck.rs | 13 +++---- 3 files changed, 70 insertions(+), 21 deletions(-) diff --git a/src/extensions/rotation.rs b/src/extensions/rotation.rs index d02d27187..cee6018fe 100644 --- a/src/extensions/rotation.rs +++ b/src/extensions/rotation.rs @@ -11,6 +11,7 @@ use std::collections::HashMap; #[cfg(feature = "pyo3")] use pyo3::prelude::*; +use crate::ops::constant::typecheck::ConstTypeError; use crate::ops::constant::CustomConst; use crate::resource::{OpDef, ResourceSet, TypeDef}; use crate::types::type_param::TypeArg; @@ -93,6 +94,15 @@ pub enum Constant { Quaternion(cgmath::Quaternion), } +impl Constant { + fn rotation_type(&self) -> Type { + match self { + Constant::Angle(_) => Type::Angle, + Constant::Quaternion(_) => Type::Quaternion, + } + } +} + #[typetag::serde] impl CustomConst for Constant { fn name(&self) -> SmolStr { @@ -103,12 +113,25 @@ impl CustomConst for Constant { .into() } - fn custom_type(&self) -> CustomType { - let t: Type = match self { - Constant::Angle(_) => Type::Angle, - Constant::Quaternion(_) => Type::Quaternion, - }; - t.custom_type() + fn check_type(&self, typ: &CustomType) -> Result<(), ConstTypeError> { + let self_typ = self.rotation_type(); + + if &self_typ.custom_type() == typ { + Ok(()) + } else { + Err(ConstTypeError::ValueCheckFail( + typ.clone().into(), + self.clone().into(), + )) + } + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self == other + } else { + false + } } } @@ -315,4 +338,36 @@ mod test { )) ); } + + #[test] + fn test_type_check() { + let resource = resource(); + + let custom_type = resource + .types() + .get("angle") + .unwrap() + .instantiate_concrete([]) + .unwrap(); + + let custom_value = Constant::Angle(AngleValue::F64(0.0)); + + // correct type + custom_value.check_type(&custom_type).unwrap(); + + let wrong_custom_type = resource + .types() + .get("quat") + .unwrap() + .instantiate_concrete([]) + .unwrap(); + let res = custom_value.check_type(&wrong_custom_type); + assert_eq!( + res, + Err(ConstTypeError::ValueCheckFail( + wrong_custom_type.into(), + custom_value.into(), + )), + ); + } } diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 247659bd5..93cdbe167 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -134,7 +134,7 @@ pub enum ConstValue { Tuple(Vec), /// An opaque constant value, with cached type // Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808 - Opaque((CustomType, Box)), + Opaque((Box,)), } impl PartialEq for dyn CustomConst { @@ -159,7 +159,7 @@ impl ConstValue { match self { Self::Int(value) => format!("const:int{value}"), Self::F64(f) => format!("const:float:{f}"), - Self::Opaque((_, v)) => format!("const:{}", v.name()), + Self::Opaque((v,)) => format!("const:{}", v.name()), Self::Sum(tag, val) => { format!("const:sum:{{tag:{tag}, val:{}}}", val.name()) } @@ -200,7 +200,7 @@ impl ConstValue { impl From for ConstValue { fn from(v: T) -> Self { - Self::Opaque((v.custom_type(), Box::new(v))) + Self::Opaque((Box::new(v),)) } } @@ -215,9 +215,8 @@ pub trait CustomConst: /// An identifier for the constant. fn name(&self) -> SmolStr; - /// Returns the type of the constant. - // TODO it would be good to ensure that this is a *classic* CustomType not a linear one! - fn custom_type(&self) -> CustomType; + /// Check the value is a valid instance of the provided type. + fn check_type(&self, typ: &CustomType) -> Result<(), ConstTypeError>; /// Compare two constants for equality, using downcasting and comparing the definitions. fn equal_consts(&self, other: &dyn CustomConst) -> bool { diff --git a/src/ops/constant/typecheck.rs b/src/ops/constant/typecheck.rs index 6c2e2c2c1..48881998e 100644 --- a/src/ops/constant/typecheck.rs +++ b/src/ops/constant/typecheck.rs @@ -56,6 +56,9 @@ pub enum ConstTypeError { /// A mismatch between the type expected and the value. #[error("Value {1:?} does not match expected type {0}")] ValueCheckFail(ClassicType, ConstValue), + /// Error when checking a custom value. + #[error("Custom value type check error: {0:?}")] + CustomCheckFail(String), } lazy_static! { @@ -140,15 +143,7 @@ pub(super) fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), } } (Container::Sum(_), _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), tm.clone())), - (Container::Opaque(ty), ConstValue::Opaque((ty_act, _val))) => { - if ty_act != ty { - return Err(ConstTypeError::TypeMismatch( - ty.clone().into(), - ty_act.clone().into(), - )); - } - Ok(()) - } + (Container::Opaque(ty), ConstValue::Opaque((val,))) => val.check_type(ty), _ => Err(ConstTypeError::Unimplemented(ty.clone())), }, (ClassicType::Hashable(HashableType::Container(c)), tm) => {