Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Allow linear Constants #369

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use pyo3::prelude::*;
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::{OpTag, OpTrait, OpType, ValidateOp};
use crate::resource::ResourceSet;
use crate::types::{ClassicType, EdgeKind, SimpleType};
use crate::types::{EdgeKind, SimpleType};
use crate::{Direction, Hugr, Node, Port};

use super::hierarchical_views::{HierarchyView, SiblingGraph};
Expand Down Expand Up @@ -734,7 +734,7 @@ pub enum InterGraphEdgeError {
InvalidConstSrc {
from: Node,
from_offset: Port,
typ: ClassicType,
typ: SimpleType,
},
}

Expand Down Expand Up @@ -823,12 +823,7 @@ mod test {
.unwrap();
let tag_def = b.add_op_with_parent(b.root(), const_op).unwrap();
let tag = b
.add_op_with_parent(
parent,
ops::LoadConstant {
datatype: tag_type.try_into().unwrap(),
},
)
.add_op_with_parent(parent, ops::LoadConstant { datatype: tag_type })
.unwrap();

b.connect(tag_def, 0, tag, 0).unwrap();
Expand Down Expand Up @@ -1149,7 +1144,7 @@ mod test {
let lcst = h.add_op_with_parent(
h.root(),
ops::LoadConstant {
datatype: ClassicType::int::<1>(),
datatype: SimpleType::int::<1>(),
},
)?;
h.connect(cst, 0, lcst, 0)?;
Expand Down
100 changes: 57 additions & 43 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use std::any::Any;

use crate::{
macros::impl_box_clone,
types::{simple::Container, ClassicRow, ClassicType, CustomType, EdgeKind, HashableType},
types::{
simple::Container, ClassicRow, ClassicType, CustomType, EdgeKind, HashableType, SimpleType,
},
values::{
map_container_type, ConstTypeError, ContainerValue, CustomCheckFail, HashableValue,
ValueOfType,
Expand All @@ -21,12 +23,12 @@ use super::{OpName, OpTrait, StaticTag};
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Const {
value: ConstValue,
typ: ClassicType,
typ: SimpleType,
}

impl Const {
/// Creates a new Const, type-checking the value.
pub fn new(value: ConstValue, typ: ClassicType) -> Result<Self, ConstTypeError> {
pub fn new(value: ConstValue, typ: SimpleType) -> Result<Self, ConstTypeError> {
value.check_type(&typ)?;
Ok(Self { value, typ })
}
Expand All @@ -37,7 +39,7 @@ impl Const {
}

/// Returns a reference to the type of this [`Const`].
pub fn const_type(&self) -> &ClassicType {
pub fn const_type(&self) -> &SimpleType {
&self.typ
}

Expand All @@ -48,23 +50,23 @@ impl Const {
value: ConstValue,
variant_rows: impl IntoIterator<Item = ClassicRow>,
) -> Result<Self, ConstTypeError> {
let typ = ClassicType::new_predicate(variant_rows);
let typ = SimpleType::new_predicate(variant_rows);
Self::new(ConstValue::sum(tag, value), typ)
}

/// Constant Sum over units, used as predicates.
pub fn simple_predicate(tag: usize, size: usize) -> Self {
Self {
value: ConstValue::simple_predicate(tag),
typ: ClassicType::new_simple_predicate(size),
typ: SimpleType::new_simple_predicate(size),
}
}

/// Constant Sum over units, with only one variant.
pub fn simple_unary_predicate() -> Self {
Self {
value: ConstValue::simple_unary_predicate(),
typ: ClassicType::new_simple_predicate(1),
typ: SimpleType::new_simple_predicate(1),
}
}

Expand All @@ -82,7 +84,7 @@ impl Const {
pub fn int<const N: u8>(value: HugrIntValueStore) -> Result<Self, ConstTypeError> {
Self::new(
ConstValue::Hashable(HashableValue::Int(value)),
ClassicType::int::<N>(),
SimpleType::int::<N>(),
)
}

Expand All @@ -93,11 +95,11 @@ impl Const {

/// Tuple of values
pub fn new_tuple(items: impl IntoIterator<Item = Const>) -> Self {
let (values, types): (Vec<ConstValue>, Vec<ClassicType>) = items
let (values, types): (Vec<ConstValue>, Vec<SimpleType>) = items
.into_iter()
.map(|Const { value, typ }| (value, typ))
.unzip();
Self::new(ConstValue::sequence(&values), ClassicType::new_tuple(types)).unwrap()
Self::new(ConstValue::sequence(&values), SimpleType::new_tuple(types)).unwrap()
}
}

Expand Down Expand Up @@ -153,7 +155,7 @@ impl PartialEq for dyn CustomConst {
}

impl ValueOfType for ConstValue {
type T = ClassicType;
type T = SimpleType;

fn name(&self) -> String {
match self {
Expand All @@ -164,42 +166,54 @@ impl ValueOfType for ConstValue {
}
}

fn check_type(&self, ty: &ClassicType) -> Result<(), ConstTypeError> {
fn check_type(&self, ty: &SimpleType) -> Result<(), ConstTypeError> {
match self {
ConstValue::F64(_) => {
if let ClassicType::F64 = ty {
if ty == &SimpleType::Classic(ClassicType::F64) {
return Ok(());
}
}
ConstValue::Hashable(hv) => {
match ty {
ClassicType::Hashable(exp) => return hv.check_type(exp),
ClassicType::Container(cty) => {
// A "hashable" value might be an instance of a non-hashable type:
// e.g. an empty list is hashable, yet can be checked against a classic element type!
if let HashableValue::Container(ctr) = hv {
return ctr.map_vals(&ConstValue::Hashable).check_container(cty);
if let SimpleType::Classic(ClassicType::Hashable(typ)) = ty {
return hv.check_type(typ);
}
if let HashableValue::Container(ctr) = hv {
// An empty list is a hashable value, but could be an instance of a non-hashable list type
// such as List<F64> or even List<Qubit> !
let mapped_cty = || ctr.map_vals(&ConstValue::Hashable);
match ty {
SimpleType::Qontainer(cty) => return mapped_cty().check_container(cty),
SimpleType::Classic(ClassicType::Container(cty)) => {
return mapped_cty()
.check_container(&map_container_type(cty, &SimpleType::Classic))
}
}
_ => (),
_ => (),
};
}
}
ConstValue::Container(vals) => {
match ty {
ClassicType::Container(cty) => return vals.check_container(cty),
// We might also fail to deduce a container *value* was hashable,
SimpleType::Qontainer(cty) => return vals.check_container(cty),
SimpleType::Classic(ClassicType::Container(cty)) => {
return vals.check_container(&map_container_type(cty, &SimpleType::Classic))
}
// We might also fail to deduce/represent a container *value* was hashable,
// because it contains opaque values whose tag is unknown.
ClassicType::Hashable(HashableType::Container(cty)) => {
return vals
.check_container(&map_container_type(cty, &ClassicType::Hashable))
SimpleType::Classic(ClassicType::Hashable(HashableType::Container(cty))) => {
return vals.check_container(&map_container_type(cty, &|elemty| {
SimpleType::Classic(ClassicType::Hashable(elemty))
}))
}
_ => (),
};
}
ConstValue::Opaque((val,)) => {
let maybe_cty = match ty {
ClassicType::Container(Container::Opaque(t)) => Some(t),
ClassicType::Hashable(HashableType::Container(Container::Opaque(t))) => Some(t),
SimpleType::Qontainer(Container::Opaque(t)) => Some(t),
SimpleType::Classic(ClassicType::Container(Container::Opaque(t))) => Some(t),
SimpleType::Classic(ClassicType::Hashable(HashableType::Container(
Container::Opaque(t),
))) => Some(t),
_ => None,
};
if let Some(cu_ty) = maybe_cty {
Expand All @@ -211,10 +225,10 @@ impl ValueOfType for ConstValue {
}

fn container_error(
typ: Container<ClassicType>,
typ: Container<SimpleType>,
vals: ContainerValue<ConstValue>,
) -> ConstTypeError {
ConstTypeError::ValueCheckFail(ClassicType::Container(typ), ConstValue::Container(vals))
ConstTypeError::ValueCheckFail(SimpleType::Qontainer(typ), ConstValue::Container(vals))
}
}

Expand Down Expand Up @@ -394,19 +408,20 @@ mod test {

#[test]
fn test_constant_values() {
const T_INT: ClassicType = ClassicType::int::<64>();
const T_INT: SimpleType = SimpleType::int::<64>();
const V_INT: ConstValue = ConstValue::Hashable(HashableValue::Int(257));
const T_F64: SimpleType = SimpleType::Classic(ClassicType::F64);
V_INT.check_type(&T_INT).unwrap();
assert_eq!(
V_INT.check_type(&ClassicType::int::<8>()),
V_INT.check_type(&SimpleType::int::<8>()),
Err(ConstTypeError::Int(ConstIntError::IntTooLarge(8, 257)))
);
ConstValue::F64(17.4).check_type(&ClassicType::F64).unwrap();
ConstValue::F64(17.4).check_type(&T_F64).unwrap();
assert_matches!(
V_INT.check_type(&ClassicType::F64),
Err(ConstTypeError::ValueCheckFail(ClassicType::F64, v)) => v == V_INT
V_INT.check_type(&T_F64),
Err(ConstTypeError::ValueCheckFail(T_F64, v)) => v == V_INT
);
let tuple_ty = ClassicType::new_tuple(classic_row![T_INT, ClassicType::F64]);
let tuple_ty = SimpleType::new_tuple(type_row![T_INT, T_F64]);
let tuple_val = ConstValue::sequence(&[V_INT, ConstValue::F64(5.1)]);
tuple_val.check_type(&tuple_ty).unwrap();
let tuple_val2 = ConstValue::sequence(&[ConstValue::F64(5.1), V_INT]);
Expand All @@ -433,14 +448,13 @@ mod test {
typ: typ_int.clone(),
value: Value::Number(6.into()),
}),));
let SimpleType::Classic(classic_t) = typ_int.clone().into()
else {panic!("Hashable CustomType returned as non-Classic");};
assert_matches!(classic_t, ClassicType::Hashable(_));
val.check_type(&classic_t).unwrap();
let simp_t: SimpleType = typ_int.clone().into();
assert_matches!(simp_t, SimpleType::Classic(ClassicType::Hashable(_)));
val.check_type(&simp_t).unwrap();

// This misrepresents the CustomType, so doesn't really "have to work".
// But just as documentation of current behaviour:
val.check_type(&ClassicType::Container(Container::Opaque(typ_int.clone())))
val.check_type(&SimpleType::Qontainer(Container::Opaque(typ_int.clone())))
.unwrap();

let typ_qb = CustomType::new(
Expand All @@ -450,7 +464,7 @@ mod test {
TypeTag::Hashable,
);
let t: SimpleType = typ_qb.clone().into();
assert_matches!(val.check_type(&t.try_into().unwrap()),
assert_matches!(val.check_type(&t),
Err(ConstTypeError::CustomCheckFail(CustomCheckFail::TypeMismatch(a, b))) => a == typ_int && b == typ_qb);

assert_eq!(val, val);
Expand Down
6 changes: 3 additions & 3 deletions src/ops/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl DataflowOpTrait for Call {

fn signature(&self) -> AbstractSignature {
AbstractSignature {
static_input: vec![ClassicType::graph_from_sig(self.signature.clone())].into(),
static_input: vec![ClassicType::graph_from_sig(self.signature.clone()).into()].into(),
..self.signature.clone()
}
}
Expand Down Expand Up @@ -181,7 +181,7 @@ impl DataflowOpTrait for CallIndirect {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct LoadConstant {
/// Constant type
pub datatype: ClassicType,
pub datatype: SimpleType,
}
impl_op_name!(LoadConstant);
impl DataflowOpTrait for LoadConstant {
Expand All @@ -194,7 +194,7 @@ impl DataflowOpTrait for LoadConstant {
fn signature(&self) -> AbstractSignature {
AbstractSignature::new(
SimpleRow::new(),
vec![SimpleType::Classic(self.datatype.clone())],
vec![self.datatype.clone()],
vec![self.datatype.clone()],
)
}
Expand Down
12 changes: 6 additions & 6 deletions src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ impl OpTrait for FuncDefn {
}

fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::Static(ClassicType::graph_from_sig(
self.signature.clone(),
)))
Some(EdgeKind::Static(
ClassicType::graph_from_sig(self.signature.clone()).into(),
))
}
}

Expand All @@ -82,9 +82,9 @@ impl OpTrait for FuncDecl {
}

fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::Static(ClassicType::graph_from_sig(
self.signature.clone(),
)))
Some(EdgeKind::Static(
ClassicType::graph_from_sig(self.signature.clone()).into(),
))
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub enum EdgeKind {
/// Data edges of a DDG region, also known as "wires".
Value(SimpleType),
/// A reference to a static value definition.
Static(ClassicType),
Static(SimpleType),
/// Explicitly enforce an ordering between nodes in a DDG.
StateOrder,
}
Expand All @@ -59,7 +59,7 @@ pub struct AbstractSignature {
/// Value outputs of the function.
pub output: SimpleRow,
/// Possible static input (for call / load-constant).
pub static_input: ClassicRow,
pub static_input: SimpleRow,
/// The resource requirements which are added by the operation
pub resource_reqs: ResourceSet,
}
Expand All @@ -78,7 +78,7 @@ impl AbstractSignature {
pub fn new(
input: impl Into<SimpleRow>,
output: impl Into<SimpleRow>,
static_input: impl Into<ClassicRow>,
static_input: impl Into<SimpleRow>,
) -> Self {
Self {
input: input.into(),
Expand Down Expand Up @@ -241,7 +241,7 @@ impl AbstractSignature {

#[inline]
/// Returns the row of static inputs
pub fn static_input(&self) -> &ClassicRow {
pub fn static_input(&self) -> &SimpleRow {
&self.static_input
}
}
Expand Down Expand Up @@ -318,7 +318,7 @@ impl Signature {
/// Outputs of the abstract signature
pub fn output(&self) -> &SimpleRow;
/// Static inputs of the abstract signature
pub fn static_input(&self) -> &ClassicRow;
pub fn static_input(&self) -> &SimpleRow;
}
}
}
Expand Down Expand Up @@ -438,7 +438,7 @@ impl SignatureDescription {
pub fn static_input_zip<'a>(
&'a self,
signature: &'a Signature,
) -> impl Iterator<Item = (&SmolStr, &ClassicType)> {
) -> impl Iterator<Item = (&SmolStr, &SimpleType)> {
Self::row_zip(signature.static_input(), &self.static_input)
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/types/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ impl SimpleType {
pub fn new_simple_predicate(size: usize) -> Self {
Self::Classic(ClassicType::new_simple_predicate(size))
}

/// Returns a new integer type with the given number of bits.
#[inline]
pub const fn int<const N: HugrIntWidthStore>() -> Self {
Self::Classic(ClassicType::int::<N>())
}
}

impl From<ClassicType> for SimpleType {
Expand Down
Loading