Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

require type checking on CustomConst #325

Merged
merged 3 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions src/extensions/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::collections::HashMap;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

use crate::ops::constant::typecheck::CustomCheckFail;
use crate::ops::constant::CustomConst;
use crate::resource::{OpDef, ResourceSet, TypeDef};
use crate::types::type_param::TypeArg;
Expand Down Expand Up @@ -93,6 +94,15 @@ pub enum Constant {
Quaternion(cgmath::Quaternion<f64>),
}

impl Constant {
fn rotation_type(&self) -> Type {
match self {
Constant::Angle(_) => Type::Angle,
Constant::Quaternion(_) => Type::Quaternion,
}
}
}

#[typetag::serde]
impl CustomConst for Constant {
fn name(&self) -> SmolStr {
Expand All @@ -103,12 +113,24 @@ impl CustomConst for Constant {
.into()
}

fn custom_type(&self) -> CustomType {
let t: Type = match self {
Constant::Angle(_) => Type::Angle,
Constant::Quaternion(_) => Type::Quaternion,
};
t.custom_type()
fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFail> {
let self_typ = self.rotation_type();

if &self_typ.custom_type() == typ {
Ok(())
} else {
Err(CustomCheckFail::new(
"Rotation constant type mismatch.".into(),
))
}
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
if let Some(other) = other.as_any().downcast_ref::<Constant>() {
self == other
} else {
false
}
}
}

Expand Down Expand Up @@ -320,4 +342,30 @@ mod test {
))
);
}

#[test]
fn test_type_check() {
let resource = resource();

let custom_type = resource
.types()
.get("angle")
.unwrap()
.instantiate_concrete([])
.unwrap();

let custom_value = Constant::Angle(AngleValue::F64(0.0));

// correct type
custom_value.check_custom_type(&custom_type).unwrap();

let wrong_custom_type = resource
.types()
.get("quat")
.unwrap()
.instantiate_concrete([])
.unwrap();
let res = custom_value.check_custom_type(&wrong_custom_type);
assert!(res.is_err());
}
}
13 changes: 6 additions & 7 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use self::typecheck::{typecheck_const, ConstTypeError};
use self::typecheck::{typecheck_const, ConstTypeError, CustomCheckFail};

use super::OpTag;
use super::{OpName, OpTrait, StaticTag};
Expand Down Expand Up @@ -134,7 +134,7 @@ pub enum ConstValue {
Tuple(Vec<ConstValue>),
/// An opaque constant value, with cached type
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Opaque((CustomType, Box<dyn CustomConst>)),
Opaque((Box<dyn CustomConst>,)),
}

impl PartialEq for dyn CustomConst {
Expand All @@ -159,7 +159,7 @@ impl ConstValue {
match self {
Self::Int(value) => format!("const:int{value}"),
Self::F64(f) => format!("const:float:{f}"),
Self::Opaque((_, v)) => format!("const:{}", v.name()),
Self::Opaque((v,)) => format!("const:{}", v.name()),
Self::Sum(tag, val) => {
format!("const:sum:{{tag:{tag}, val:{}}}", val.name())
}
Expand Down Expand Up @@ -200,7 +200,7 @@ impl ConstValue {

impl<T: CustomConst> From<T> for ConstValue {
fn from(v: T) -> Self {
Self::Opaque((v.custom_type(), Box::new(v)))
Self::Opaque((Box::new(v),))
}
}

Expand All @@ -215,9 +215,8 @@ pub trait CustomConst:
/// An identifier for the constant.
fn name(&self) -> SmolStr;

/// Returns the type of the constant.
// TODO it would be good to ensure that this is a *classic* CustomType not a linear one!
fn custom_type(&self) -> CustomType;
/// Check the value is a valid instance of the provided type.
fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFail>;

/// Compare two constants for equality, using downcasting and comparing the definitions.
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
Expand Down
25 changes: 17 additions & 8 deletions src/ops/constant/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ pub enum ConstIntError {
IntWidthInvalid(HugrIntWidthStore),
}

/// Struct for custom type check fails.
#[derive(Clone, Debug, PartialEq, Error)]
#[error("Error when checking custom type.")]
pub struct CustomCheckFail(String);

impl CustomCheckFail {
/// Creates a new [`CustomCheckFail`].
pub fn new(message: String) -> Self {
Self(message)
}
}

/// Errors that arise from typechecking constants
#[derive(Clone, Debug, PartialEq, Error)]
pub enum ConstTypeError {
Expand All @@ -50,12 +62,12 @@ pub enum ConstTypeError {
/// Tag for a sum value exceeded the number of variants
#[error("Tag of Sum value is invalid")]
InvalidSumTag,
/// A mismatch between the type expected and the actual type of the constant
#[error("Type mismatch for const - expected {0}, found {1:?}")]
TypeMismatch(ClassicType, ClassicType),
/// A mismatch between the type expected and the value.
#[error("Value {1:?} does not match expected type {0}")]
ValueCheckFail(ClassicType, ConstValue),
/// Error when checking a custom value.
#[error("Custom value type check error: {0:?}")]
CustomCheckFail(#[from] CustomCheckFail),
}

lazy_static! {
Expand Down Expand Up @@ -140,11 +152,8 @@ pub(super) fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(),
}
}
(Container::Sum(_), _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), tm.clone())),
(Container::Opaque(ty), ConstValue::Opaque((ty_act, _val))) => {
if ty_act != ty {
return Err(ConstTypeError::ValueCheckFail(typ.clone(), val.clone()));
}
Ok(())
(Container::Opaque(ty), ConstValue::Opaque((val,))) => {
val.check_custom_type(ty).map_err(ConstTypeError::from)
}
_ => Err(ConstTypeError::Unimplemented(ty.clone())),
},
Expand Down