From 40caf6860b7303a9bdbf4507845c875d8be98af9 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 27 Jul 2023 11:17:37 +0100 Subject: [PATCH] move `Opaque` to `Container` (#297) * move `Opaque` to `Container` mentioned in #286 * review comments: docstring & simpler map_vals --- src/hugr/typecheck.rs | 48 +++++++++++++++++++++++++++-------- src/ops/constant.rs | 4 +-- src/types/custom.rs | 4 +-- src/types/simple.rs | 38 +++------------------------ src/types/simple/serialize.rs | 21 +++------------ 5 files changed, 49 insertions(+), 66 deletions(-) diff --git a/src/hugr/typecheck.rs b/src/hugr/typecheck.rs index d3821bcfc..2ec9fd243 100644 --- a/src/hugr/typecheck.rs +++ b/src/hugr/typecheck.rs @@ -9,7 +9,7 @@ use crate::hugr::*; // For static typechecking use crate::ops::ConstValue; -use crate::types::{ClassicRow, ClassicType, Container, HashableType}; +use crate::types::{ClassicRow, ClassicType, Container, HashableType, PrimType, TypeRow}; use crate::ops::constant::{HugrIntValueStore, HugrIntWidthStore, HUGR_MAX_INT_WIDTH}; @@ -69,6 +69,32 @@ fn check_valid_width(width: HugrIntWidthStore) -> Result<(), ConstTypeError> { } } +fn map_vals( + container: Container, + f: &impl Fn(T) -> T2, +) -> Container { + fn map_row( + row: TypeRow, + f: &impl Fn(T) -> T2, + ) -> Box> { + Box::new(TypeRow::from( + row.into_owned().into_iter().map(f).collect::>(), + )) + } + match container { + Container::List(elem) => Container::List(Box::new(f(*elem))), + Container::Map(kv) => { + let (k, v) = *kv; + Container::Map(Box::new((k, f(v)))) + } + Container::Tuple(elems) => Container::Tuple(map_row(*elems, f)), + Container::Sum(variants) => Container::Sum(map_row(*variants, f)), + Container::Array(elem, sz) => Container::Array(Box::new(f(*elem)), sz), + Container::Alias(s) => Container::Alias(s), + Container::Opaque(custom) => Container::Opaque(custom), + } +} + /// Typecheck a constant value pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstTypeError> { match (typ, val) { @@ -122,13 +148,22 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT (Container::Sum(_), _) => { Err(ConstTypeError::TypeMismatch(ty.clone(), tm.const_type())) } + (Container::Opaque(ty), ConstValue::Opaque(ty_act, _val)) => { + if ty_act != ty { + return Err(ConstTypeError::TypeMismatch( + ty.clone().into(), + ty_act.clone().into(), + )); + } + Ok(()) + } _ => Err(ConstTypeError::Unimplemented(ty.clone())), }, (ClassicType::Hashable(HashableType::Container(c)), tm) => { // Here we deliberately build malformed Container-of-Hashable types // (rather than Hashable-of-Container) in order to reuse logic above typecheck_const( - &ClassicType::Container(c.clone().map_vals(&ClassicType::Hashable)), + &ClassicType::Container(map_vals(c.clone(), &ClassicType::Hashable)), tm, ) } @@ -139,15 +174,6 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT (ClassicType::Hashable(HashableType::Variable(_)), _) => { Err(ConstTypeError::ConstCantBeVar) } - (ClassicType::Opaque(ty), ConstValue::Opaque(ty_act, _val)) => { - if ty_act != ty { - return Err(ConstTypeError::TypeMismatch( - ty.clone().into(), - ty_act.clone().into(), - )); - } - Ok(()) - } (ty, _) => Err(ConstTypeError::TypeMismatch(ty.clone(), val.const_type())), } } diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 8bb364f98..b6f6ddbae 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -5,7 +5,7 @@ use std::any::Any; use crate::{ classic_row, macros::impl_box_clone, - types::{ClassicRow, ClassicType, CustomType, EdgeKind, HashableType}, + types::{ClassicRow, ClassicType, Container, CustomType, EdgeKind, HashableType}, }; use downcast_rs::{impl_downcast, Downcast}; @@ -120,7 +120,7 @@ impl ConstValue { pub fn const_type(&self) -> ClassicType { match self { Self::Int { value: _, width } => HashableType::Int(*width).into(), - Self::Opaque(_, b) => ClassicType::Opaque((*b).custom_type()), + Self::Opaque(_, b) => Container::Opaque((*b).custom_type()).into(), Self::Sum { variants, .. } => ClassicType::new_sum(variants.clone()), Self::Tuple(vals) => { let row: Vec<_> = vals.iter().map(|val| val.const_type()).collect(); diff --git a/src/types/custom.rs b/src/types/custom.rs index fee759d94..e64527ec8 100644 --- a/src/types/custom.rs +++ b/src/types/custom.rs @@ -4,7 +4,7 @@ use smol_str::SmolStr; use std::fmt::{self, Display}; -use super::{type_param::TypeArg, ClassicType}; +use super::{type_param::TypeArg, ClassicType, Container}; /// An opaque type element. Contains the unique identifier of its definition. #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] @@ -46,7 +46,7 @@ impl CustomType { /// Returns a [`ClassicType`] containing this opaque type. pub const fn classic_type(self) -> ClassicType { - ClassicType::Opaque(self) + ClassicType::Container(Container::Opaque(self)) } } diff --git a/src/types/simple.rs b/src/types/simple.rs index a8b77b870..ac4952651 100644 --- a/src/types/simple.rs +++ b/src/types/simple.rs @@ -29,8 +29,6 @@ pub enum SimpleType { Classic(ClassicType), /// A qubit. Qubit, - /// A linear opaque type that can be downcasted by the extensions that define it. - Qpaque(CustomType), /// A nested definition containing other linear types (possibly as well as classical ones) Qontainer(Container), } @@ -42,7 +40,6 @@ impl Display for SimpleType { match self { SimpleType::Classic(ty) => ty.fmt(f), SimpleType::Qubit => f.write_str("Qubit"), - SimpleType::Qpaque(custom) => custom.fmt(f), SimpleType::Qontainer(c) => c.fmt(f), } } @@ -125,31 +122,9 @@ pub enum Container { Array(Box, usize), /// Alias defined in AliasDefn or AliasDecl nodes. Alias(SmolStr), -} - -impl Container { - /// Applies the specified function to the value types of this Container - pub fn map_vals(self, f: &impl Fn(T) -> T2) -> Container { - fn map_row( - row: TypeRow, - f: &impl Fn(T) -> T2, - ) -> Box> { - Box::new(TypeRow::from( - row.into_owned().into_iter().map(f).collect::>(), - )) - } - match self { - Self::List(elem) => Container::List(Box::new(f(*elem))), - Self::Map(kv) => { - let (k, v) = *kv; - Container::Map(Box::new((k, f(v)))) - } - Self::Tuple(elems) => Container::Tuple(map_row(*elems, f)), - Self::Sum(variants) => Container::Sum(map_row(*variants, f)), - Self::Array(elem, sz) => Container::Array(Box::new(f(*elem)), sz), - Self::Alias(s) => Container::Alias(s), - } - } + /// An opaque type that can be downcasted by the extensions that define it. + /// This will always have a [`TypeTag`] contained within that of the Container + Opaque(CustomType), } impl Display for Container { @@ -161,6 +136,7 @@ impl Display for Container { Container::Sum(row) => write!(f, "Sum({})", row.as_ref()), Container::Array(t, size) => write!(f, "Array({}, {})", t, size), Container::Alias(str) => f.write_str(str), + Container::Opaque(c) => write!(f, "Opaque({})", c), } } } @@ -208,8 +184,6 @@ pub enum ClassicType { Graph(Box<(ResourceSet, Signature)>), /// A nested definition containing other classic types. Container(Container), - /// An opaque operation that can be downcasted by the extensions that define it. - Opaque(CustomType), /// A type which can be hashed Hashable(HashableType), } @@ -228,8 +202,6 @@ pub enum HashableType { Int(HugrIntWidthStore), /// An arbitrary length string. String, - /// An opaque type defined by an extension as being hashable - Opaque(CustomType), /// A container (all of whose elements can be hashed) Container(Container), } @@ -316,7 +288,6 @@ impl Display for ClassicType { sig.fmt(f) } ClassicType::Container(c) => c.fmt(f), - ClassicType::Opaque(custom) => custom.fmt(f), ClassicType::Hashable(h) => h.fmt(f), } } @@ -330,7 +301,6 @@ impl Display for HashableType { f.write_char('I')?; f.write_str(&i.to_string()) } - HashableType::Opaque(custom) => custom.fmt(f), HashableType::String => f.write_str("String"), HashableType::Container(c) => c.fmt(f), } diff --git a/src/types/simple/serialize.rs b/src/types/simple/serialize.rs index a101fdd77..0bfe8c27d 100644 --- a/src/types/simple/serialize.rs +++ b/src/types/simple/serialize.rs @@ -120,6 +120,7 @@ where c: T::TAG, }, Container::Alias(name) => SerSimpleType::Alias { name, c: T::TAG }, + Container::Opaque(custom) => SerSimpleType::Opaque { custom, c: T::TAG }, } } } @@ -129,10 +130,6 @@ impl From for SerSimpleType { match value { HashableType::Variable(s) => SerSimpleType::Var { name: s }, HashableType::Int(w) => SerSimpleType::I { width: w }, - HashableType::Opaque(c) => SerSimpleType::Opaque { - custom: c, - c: TypeTag::Hashable, - }, HashableType::String => SerSimpleType::S, HashableType::Container(c) => c.into(), } @@ -148,10 +145,6 @@ impl From for SerSimpleType { signature: Box::new(inner.1), }, ClassicType::Container(c) => c.into(), - ClassicType::Opaque(inner) => SerSimpleType::Opaque { - custom: inner, - c: TypeTag::Classic, - }, ClassicType::Hashable(h) => h.into(), } } @@ -163,10 +156,6 @@ impl From for SerSimpleType { SimpleType::Classic(c) => c.into(), SimpleType::Qubit => SerSimpleType::Q, SimpleType::Qontainer(c) => c.into(), - SimpleType::Qpaque(inner) => SerSimpleType::Opaque { - custom: inner, - c: TypeTag::Simple, - }, } } } @@ -225,11 +214,9 @@ impl From for SimpleType { handle_container!(c, Array(box_convert_try(*inner), len)) } SerSimpleType::Alias { name: s, c } => handle_container!(c, Alias(s)), - SerSimpleType::Opaque { custom, c } => match c { - TypeTag::Simple => SimpleType::Qpaque(custom), - TypeTag::Classic => ClassicType::Opaque(custom).into(), - TypeTag::Hashable => HashableType::Opaque(custom).into(), - }, + SerSimpleType::Opaque { custom, c } => { + handle_container!(c, Opaque(custom)) + } SerSimpleType::Var { name: s } => { ClassicType::Hashable(HashableType::Variable(s)).into() }