Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom types defined by static references #418

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,23 @@ lazy_static! {
};
}

pub(crate) const USIZE_CUSTOM_T: CustomType = CustomType::new_simple(
const USIZE_CUSTOM_T: CustomType = CustomType::new_simple(
SmolStr::new_inline("usize"),
SmolStr::new_inline("prelude"),
TypeBound::Eq,
);

pub(crate) const QB_CUSTOM_T: CustomType = CustomType::new_simple(
const QB_CUSTOM_T: CustomType = CustomType::new_simple(
SmolStr::new_inline("qubit"),
SmolStr::new_inline("prelude"),
TypeBound::Any,
);

pub(crate) const QB_T: Type = Type::new_extension(QB_CUSTOM_T);
pub(crate) const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T);
const USIZE_CUSTOM_REF: &CustomType = &USIZE_CUSTOM_T;
const QB_CUSTOM_REF: &CustomType = &QB_CUSTOM_T;

pub(crate) const QB_T: Type = Type::new_static_extension(QB_CUSTOM_REF);
pub(crate) const USIZE_T: Type = Type::new_static_extension(USIZE_CUSTOM_REF);
pub(crate) const BOOL_T: Type = Type::new_simple_predicate(2);

/// Initialize a new array of type `typ` of length `size`
Expand Down
24 changes: 19 additions & 5 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use serde::{Deserialize, Serialize};

use crate::ops::AliasDecl;
use crate::type_row;
use crate::utils::MaybeRef;
use std::fmt::Debug;

use self::primitive::PrimType;
Expand Down Expand Up @@ -212,7 +213,20 @@ impl Type {
// TODO remove? Extensions/TypeDefs should just provide `Type` directly
pub const fn new_extension(opaque: CustomType) -> Self {
let bound = opaque.bound();
Type(TypeEnum::Prim(PrimType::Extension(opaque)), bound)
Type(
TypeEnum::Prim(PrimType::Extension(MaybeRef::new_value(opaque))),
bound,
)
}

/// Initialize a new custom type with a static definition - allowing for
/// pointer equality comparisons.
pub const fn new_static_extension(opaque: &'static CustomType) -> Self {
let bound = opaque.bound();
Type(
TypeEnum::Prim(PrimType::Extension(MaybeRef::new_static(opaque))),
bound,
)
}

/// Initialize a new alias.
Expand Down Expand Up @@ -270,14 +284,14 @@ where
pub(crate) mod test {

use super::{
custom::test::{ANY_CUST, COPYABLE_CUST, EQ_CUST},
custom::test::{ANY_CUST_REF, COPYABLE_CUST_REF, EQ_CUST_REF},
*,
};
use crate::{extension::prelude::USIZE_T, ops::AliasDecl};

pub(crate) const EQ_T: Type = Type::new_extension(EQ_CUST);
pub(crate) const COPYABLE_T: Type = Type::new_extension(COPYABLE_CUST);
pub(crate) const ANY_T: Type = Type::new_extension(ANY_CUST);
pub(crate) const EQ_T: Type = Type::new_static_extension(EQ_CUST_REF);
pub(crate) const COPYABLE_T: Type = Type::new_static_extension(COPYABLE_CUST_REF);
pub(crate) const ANY_T: Type = Type::new_static_extension(ANY_CUST_REF);

#[test]
fn construct() {
Expand Down
2 changes: 1 addition & 1 deletion src/types/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl PrimType {

match (self, val) {
(PrimType::Extension(e), PrimValue::Extension(e_val)) => {
e_val.0.check_custom_type(e)?;
e_val.0.check_custom_type(e.as_ref())?;
Ok(())
}
(PrimType::Graph(_), PrimValue::Graph) => todo!(),
Expand Down
4 changes: 4 additions & 0 deletions src/types/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,8 @@ pub(crate) mod test {
SmolStr::new_inline("MyRsrc"),
TypeBound::Any,
);

pub(crate) const EQ_CUST_REF: &CustomType = &EQ_CUST;
pub(crate) const COPYABLE_CUST_REF: &CustomType = &COPYABLE_CUST;
pub(crate) const ANY_CUST_REF: &CustomType = &ANY_CUST;
}
11 changes: 5 additions & 6 deletions src/types/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
//! Primitive types which are leaves of the type tree

use crate::ops::AliasDecl;
use crate::{ops::AliasDecl, utils::MaybeRef};

use super::{AbstractSignature, CustomType, TypeBound};

#[derive(
Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize,
)]
#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)]
pub(super) enum PrimType {
// TODO optimise with Box<CustomType> ?
// or some static version of this?
Extension(CustomType),
#[display(fmt = "{}", "_0.as_ref()")]
Extension(MaybeRef<'static, CustomType>),
#[display(fmt = "Alias({})", "_0.name()")]
Alias(AliasDecl),
#[display(fmt = "Graph({})", "_0")]
Expand All @@ -20,7 +19,7 @@ pub(super) enum PrimType {
impl PrimType {
pub(super) fn bound(&self) -> TypeBound {
match self {
PrimType::Extension(c) => c.bound(),
PrimType::Extension(c) => c.as_ref().bound(),
PrimType::Alias(a) => a.bound,
PrimType::Graph(_) => TypeBound::Copyable,
}
Expand Down
2 changes: 1 addition & 1 deletion src/types/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl From<Type> for SerSimpleType {
let Type(value, _) = value;
match value {
TypeEnum::Prim(t) => match t {
PrimType::Extension(c) => SerSimpleType::Opaque(c),
PrimType::Extension(c) => SerSimpleType::Opaque(c.into_inner()),
PrimType::Alias(a) => SerSimpleType::Alias(a),
PrimType::Graph(sig) => SerSimpleType::G(Box::new(*sig)),
},
Expand Down
48 changes: 48 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,54 @@ pub fn collect_array<const N: usize, T: Debug>(arr: &[T]) -> [&T; N] {
arr.iter().collect_vec().try_into().unwrap()
}

#[derive(Clone, Debug, Eq)]
/// Utility struct that can be either owned value or reference, used to short
/// circuit PartialEq with pointer equality when possible.
pub(crate) enum MaybeRef<'a, T> {
Value(T),
Ref(&'a T),
}

impl<'a, T> MaybeRef<'a, T> {
pub(super) const fn new_static(v_ref: &'a T) -> Self {
Self::Ref(v_ref)
}

pub(super) const fn new_value(v: T) -> Self {
Self::Value(v)
}
}

impl<'a, T: Clone> MaybeRef<'a, T> {
pub(crate) fn into_inner(self) -> T {
match self {
MaybeRef::Value(v) => v,
MaybeRef::Ref(v_ref) => v_ref.clone(),
}
}
}

impl<'a, T> AsRef<T> for MaybeRef<'a, T> {
fn as_ref(&self) -> &T {
match self {
MaybeRef::Value(v) => v,
MaybeRef::Ref(v_ref) => v_ref,
}
}
}

// can use pointer equality to compare static instances
impl<'a, T: PartialEq + Eq> PartialEq for MaybeRef<'a, T> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
// pointer equality can give false-negative
(Self::Ref(l0), Self::Ref(r0)) => std::ptr::eq(*l0, *r0) || l0 == r0,
(Self::Value(l0), Self::Value(r0)) => l0 == r0,
(Self::Value(v), Self::Ref(v_ref)) | (Self::Ref(v_ref), Self::Value(v)) => v == *v_ref,
}
}
}

#[allow(dead_code)]
// Test only utils
#[cfg(test)]
Expand Down