Skip to content

Commit

Permalink
move Opaque to Container (#297)
Browse files Browse the repository at this point in the history
* move `Opaque` to `Container`

mentioned in #286

* review comments: docstring & simpler map_vals
  • Loading branch information
ss2165 authored Jul 27, 2023
1 parent 617e589 commit 40caf68
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 66 deletions.
48 changes: 37 additions & 11 deletions src/hugr/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::hugr::*;

// For static typechecking
use crate::ops::ConstValue;
use crate::types::{ClassicRow, ClassicType, Container, HashableType};
use crate::types::{ClassicRow, ClassicType, Container, HashableType, PrimType, TypeRow};

use crate::ops::constant::{HugrIntValueStore, HugrIntWidthStore, HUGR_MAX_INT_WIDTH};

Expand Down Expand Up @@ -69,6 +69,32 @@ fn check_valid_width(width: HugrIntWidthStore) -> Result<(), ConstTypeError> {
}
}

fn map_vals<T: PrimType, T2: PrimType>(
container: Container<T>,
f: &impl Fn(T) -> T2,
) -> Container<T2> {
fn map_row<T: PrimType, T2: PrimType>(
row: TypeRow<T>,
f: &impl Fn(T) -> T2,
) -> Box<TypeRow<T2>> {
Box::new(TypeRow::from(
row.into_owned().into_iter().map(f).collect::<Vec<T2>>(),
))
}
match container {
Container::List(elem) => Container::List(Box::new(f(*elem))),
Container::Map(kv) => {
let (k, v) = *kv;
Container::Map(Box::new((k, f(v))))
}
Container::Tuple(elems) => Container::Tuple(map_row(*elems, f)),
Container::Sum(variants) => Container::Sum(map_row(*variants, f)),
Container::Array(elem, sz) => Container::Array(Box::new(f(*elem)), sz),
Container::Alias(s) => Container::Alias(s),
Container::Opaque(custom) => Container::Opaque(custom),
}
}

/// Typecheck a constant value
pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstTypeError> {
match (typ, val) {
Expand Down Expand Up @@ -122,13 +148,22 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT
(Container::Sum(_), _) => {
Err(ConstTypeError::TypeMismatch(ty.clone(), tm.const_type()))
}
(Container::Opaque(ty), ConstValue::Opaque(ty_act, _val)) => {
if ty_act != ty {
return Err(ConstTypeError::TypeMismatch(
ty.clone().into(),
ty_act.clone().into(),
));
}
Ok(())
}
_ => Err(ConstTypeError::Unimplemented(ty.clone())),
},
(ClassicType::Hashable(HashableType::Container(c)), tm) => {
// Here we deliberately build malformed Container-of-Hashable types
// (rather than Hashable-of-Container) in order to reuse logic above
typecheck_const(
&ClassicType::Container(c.clone().map_vals(&ClassicType::Hashable)),
&ClassicType::Container(map_vals(c.clone(), &ClassicType::Hashable)),
tm,
)
}
Expand All @@ -139,15 +174,6 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT
(ClassicType::Hashable(HashableType::Variable(_)), _) => {
Err(ConstTypeError::ConstCantBeVar)
}
(ClassicType::Opaque(ty), ConstValue::Opaque(ty_act, _val)) => {
if ty_act != ty {
return Err(ConstTypeError::TypeMismatch(
ty.clone().into(),
ty_act.clone().into(),
));
}
Ok(())
}
(ty, _) => Err(ConstTypeError::TypeMismatch(ty.clone(), val.const_type())),
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::any::Any;
use crate::{
classic_row,
macros::impl_box_clone,
types::{ClassicRow, ClassicType, CustomType, EdgeKind, HashableType},
types::{ClassicRow, ClassicType, Container, CustomType, EdgeKind, HashableType},
};

use downcast_rs::{impl_downcast, Downcast};
Expand Down Expand Up @@ -120,7 +120,7 @@ impl ConstValue {
pub fn const_type(&self) -> ClassicType {
match self {
Self::Int { value: _, width } => HashableType::Int(*width).into(),
Self::Opaque(_, b) => ClassicType::Opaque((*b).custom_type()),
Self::Opaque(_, b) => Container::Opaque((*b).custom_type()).into(),
Self::Sum { variants, .. } => ClassicType::new_sum(variants.clone()),
Self::Tuple(vals) => {
let row: Vec<_> = vals.iter().map(|val| val.const_type()).collect();
Expand Down
4 changes: 2 additions & 2 deletions src/types/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use smol_str::SmolStr;
use std::fmt::{self, Display};

use super::{type_param::TypeArg, ClassicType};
use super::{type_param::TypeArg, ClassicType, Container};

/// An opaque type element. Contains the unique identifier of its definition.
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -46,7 +46,7 @@ impl CustomType {

/// Returns a [`ClassicType`] containing this opaque type.
pub const fn classic_type(self) -> ClassicType {
ClassicType::Opaque(self)
ClassicType::Container(Container::Opaque(self))
}
}

Expand Down
38 changes: 4 additions & 34 deletions src/types/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ pub enum SimpleType {
Classic(ClassicType),
/// A qubit.
Qubit,
/// A linear opaque type that can be downcasted by the extensions that define it.
Qpaque(CustomType),
/// A nested definition containing other linear types (possibly as well as classical ones)
Qontainer(Container<SimpleType>),
}
Expand All @@ -42,7 +40,6 @@ impl Display for SimpleType {
match self {
SimpleType::Classic(ty) => ty.fmt(f),
SimpleType::Qubit => f.write_str("Qubit"),
SimpleType::Qpaque(custom) => custom.fmt(f),
SimpleType::Qontainer(c) => c.fmt(f),
}
}
Expand Down Expand Up @@ -125,31 +122,9 @@ pub enum Container<T: PrimType> {
Array(Box<T>, usize),
/// Alias defined in AliasDefn or AliasDecl nodes.
Alias(SmolStr),
}

impl<T: PrimType> Container<T> {
/// Applies the specified function to the value types of this Container
pub fn map_vals<T2: PrimType>(self, f: &impl Fn(T) -> T2) -> Container<T2> {
fn map_row<T: PrimType, T2: PrimType>(
row: TypeRow<T>,
f: &impl Fn(T) -> T2,
) -> Box<TypeRow<T2>> {
Box::new(TypeRow::from(
row.into_owned().into_iter().map(f).collect::<Vec<T2>>(),
))
}
match self {
Self::List(elem) => Container::List(Box::new(f(*elem))),
Self::Map(kv) => {
let (k, v) = *kv;
Container::Map(Box::new((k, f(v))))
}
Self::Tuple(elems) => Container::Tuple(map_row(*elems, f)),
Self::Sum(variants) => Container::Sum(map_row(*variants, f)),
Self::Array(elem, sz) => Container::Array(Box::new(f(*elem)), sz),
Self::Alias(s) => Container::Alias(s),
}
}
/// An opaque type that can be downcasted by the extensions that define it.
/// This will always have a [`TypeTag`] contained within that of the Container
Opaque(CustomType),
}

impl<T: Display + PrimType> Display for Container<T> {
Expand All @@ -161,6 +136,7 @@ impl<T: Display + PrimType> Display for Container<T> {
Container::Sum(row) => write!(f, "Sum({})", row.as_ref()),
Container::Array(t, size) => write!(f, "Array({}, {})", t, size),
Container::Alias(str) => f.write_str(str),
Container::Opaque(c) => write!(f, "Opaque({})", c),
}
}
}
Expand Down Expand Up @@ -208,8 +184,6 @@ pub enum ClassicType {
Graph(Box<(ResourceSet, Signature)>),
/// A nested definition containing other classic types.
Container(Container<ClassicType>),
/// An opaque operation that can be downcasted by the extensions that define it.
Opaque(CustomType),
/// A type which can be hashed
Hashable(HashableType),
}
Expand All @@ -228,8 +202,6 @@ pub enum HashableType {
Int(HugrIntWidthStore),
/// An arbitrary length string.
String,
/// An opaque type defined by an extension as being hashable
Opaque(CustomType),
/// A container (all of whose elements can be hashed)
Container(Container<HashableType>),
}
Expand Down Expand Up @@ -316,7 +288,6 @@ impl Display for ClassicType {
sig.fmt(f)
}
ClassicType::Container(c) => c.fmt(f),
ClassicType::Opaque(custom) => custom.fmt(f),
ClassicType::Hashable(h) => h.fmt(f),
}
}
Expand All @@ -330,7 +301,6 @@ impl Display for HashableType {
f.write_char('I')?;
f.write_str(&i.to_string())
}
HashableType::Opaque(custom) => custom.fmt(f),
HashableType::String => f.write_str("String"),
HashableType::Container(c) => c.fmt(f),
}
Expand Down
21 changes: 4 additions & 17 deletions src/types/simple/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ where
c: T::TAG,
},
Container::Alias(name) => SerSimpleType::Alias { name, c: T::TAG },
Container::Opaque(custom) => SerSimpleType::Opaque { custom, c: T::TAG },
}
}
}
Expand All @@ -129,10 +130,6 @@ impl From<HashableType> for SerSimpleType {
match value {
HashableType::Variable(s) => SerSimpleType::Var { name: s },
HashableType::Int(w) => SerSimpleType::I { width: w },
HashableType::Opaque(c) => SerSimpleType::Opaque {
custom: c,
c: TypeTag::Hashable,
},
HashableType::String => SerSimpleType::S,
HashableType::Container(c) => c.into(),
}
Expand All @@ -148,10 +145,6 @@ impl From<ClassicType> for SerSimpleType {
signature: Box::new(inner.1),
},
ClassicType::Container(c) => c.into(),
ClassicType::Opaque(inner) => SerSimpleType::Opaque {
custom: inner,
c: TypeTag::Classic,
},
ClassicType::Hashable(h) => h.into(),
}
}
Expand All @@ -163,10 +156,6 @@ impl From<SimpleType> for SerSimpleType {
SimpleType::Classic(c) => c.into(),
SimpleType::Qubit => SerSimpleType::Q,
SimpleType::Qontainer(c) => c.into(),
SimpleType::Qpaque(inner) => SerSimpleType::Opaque {
custom: inner,
c: TypeTag::Simple,
},
}
}
}
Expand Down Expand Up @@ -225,11 +214,9 @@ impl From<SerSimpleType> for SimpleType {
handle_container!(c, Array(box_convert_try(*inner), len))
}
SerSimpleType::Alias { name: s, c } => handle_container!(c, Alias(s)),
SerSimpleType::Opaque { custom, c } => match c {
TypeTag::Simple => SimpleType::Qpaque(custom),
TypeTag::Classic => ClassicType::Opaque(custom).into(),
TypeTag::Hashable => HashableType::Opaque(custom).into(),
},
SerSimpleType::Opaque { custom, c } => {
handle_container!(c, Opaque(custom))
}
SerSimpleType::Var { name: s } => {
ClassicType::Hashable(HashableType::Variable(s)).into()
}
Expand Down

0 comments on commit 40caf68

Please sign in to comment.