Skip to content

Commit

Permalink
Add HashableValue as part of ConstValue, and ContainerValue (#322)
Browse files Browse the repository at this point in the history
Add src/values.rs w/new {Hashable,Container}Value, also move ConstTypeError etc.
  +trait ValueOfType defines 'check_type' method breaking up old typecheck_const
src/ops/constant.rs contains ConstValue (which uses HashableValue)
src/ops/constant/typecheck.rs contains bits still shared with TypeArg/TypeParam
  • Loading branch information
acl-cqc authored Aug 2, 2023
1 parent b9ffa88 commit 7b927eb
Show file tree
Hide file tree
Showing 7 changed files with 425 additions and 232 deletions.
2 changes: 1 addition & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use thiserror::Error;
use pyo3::prelude::*;

use crate::hugr::{HugrError, Node, ValidationError, Wire};
use crate::ops::constant::typecheck::ConstTypeError;
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::types::SimpleType;
use crate::values::ConstTypeError;

pub mod handle;
pub use handle::BuildHandle;
Expand Down
2 changes: 1 addition & 1 deletion src/extensions/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ use std::collections::HashMap;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

use crate::ops::constant::typecheck::CustomCheckFail;
use crate::ops::constant::CustomConst;
use crate::resource::{OpDef, ResourceSet, TypeDef};
use crate::types::type_param::TypeArg;
use crate::types::{CustomType, SimpleRow, TypeTag};
use crate::values::CustomCheckFail;
use crate::Resource;

pub const fn resource_id() -> SmolStr {
Expand Down
10 changes: 3 additions & 7 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ use thiserror::Error;
use pyo3::prelude::*;

use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::OpTag;
use crate::ops::{OpTrait, OpType, ValidateOp};
use crate::ops::{OpTag, OpTrait, OpType, ValidateOp};
use crate::resource::ResourceSet;
use crate::types::{ClassicType, EdgeKind, SimpleType};
use crate::{Direction, Hugr, Node, Port};
Expand Down Expand Up @@ -748,7 +747,7 @@ mod test {
use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder};
use crate::hugr::{HugrError, HugrMut, NodeType};
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, ConstValue, LeafOp, OpType};
use crate::ops::{self, LeafOp, OpType};
use crate::types::{AbstractSignature, ClassicType};
use crate::Direction;
use crate::{type_row, Node};
Expand Down Expand Up @@ -1146,10 +1145,7 @@ mod test {
})
);
// Second input of Xor from a constant
let cst = h.add_op_with_parent(
h.root(),
ops::Const::new(ConstValue::Int(1), ClassicType::int::<1>()).unwrap(),
)?;
let cst = h.add_op_with_parent(h.root(), ops::Const::int::<1>(1).unwrap())?;
let lcst = h.add_op_with_parent(
h.root(),
ops::LoadConstant {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod ops;
pub mod resource;
pub mod types;
mod utils;
pub mod values;

pub use crate::hugr::{Direction, Hugr, HugrView, Node, Port, SimpleReplacement, Wire};
pub use crate::resource::Resource;
214 changes: 162 additions & 52 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ use std::any::Any;

use crate::{
macros::impl_box_clone,
types::{ClassicRow, ClassicType, CustomType, EdgeKind},
types::{simple::Container, ClassicRow, ClassicType, CustomType, EdgeKind, HashableType},
values::{
map_container_type, ConstTypeError, ContainerValue, CustomCheckFail, HashableValue,
ValueOfType,
},
};

use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use self::typecheck::{typecheck_const, ConstTypeError, CustomCheckFail};

use super::OpTag;
use super::{OpName, OpTrait, StaticTag};

Expand Down Expand Up @@ -48,8 +50,7 @@ impl Const {
variant_rows: impl IntoIterator<Item = ClassicRow>,
) -> Result<Self, ConstTypeError> {
let typ = ClassicType::new_predicate(variant_rows);

Self::new(ConstValue::Sum(tag, Box::new(value)), typ)
Self::new(ConstValue::sum(tag, value), typ)
}

/// Constant Sum over units, used as predicates.
Expand Down Expand Up @@ -80,18 +81,30 @@ impl Const {

/// Fixed width integer
pub fn int<const N: u8>(value: HugrIntValueStore) -> Result<Self, ConstTypeError> {
Self::new(ConstValue::Int(value), ClassicType::int::<N>())
Self::new(
ConstValue::Hashable(HashableValue::Int(value)),
ClassicType::int::<N>(),
)
}

/// 64-bit integer
pub fn i64(value: i64) -> Result<Self, ConstTypeError> {
Self::int::<64>(value as HugrIntValueStore)
}

/// Tuple of values
pub fn new_tuple(items: impl IntoIterator<Item = Const>) -> Self {
let (values, types): (Vec<ConstValue>, Vec<ClassicType>) = items
.into_iter()
.map(|Const { value, typ }| (value, typ))
.unzip();
Self::new(ConstValue::sequence(&values), ClassicType::new_tuple(types)).unwrap()
}
}

impl OpName for Const {
fn name(&self) -> SmolStr {
self.value.name()
self.value.name().into()
}
}
impl StaticTag for Const {
Expand All @@ -116,23 +129,17 @@ pub(crate) type HugrIntWidthStore = u8;
pub(crate) const HUGR_MAX_INT_WIDTH: HugrIntWidthStore =
HugrIntValueStore::BITS as HugrIntWidthStore;

/// Value constants
///
/// TODO: Add more constants
/// TODO: bigger/smaller integers.
/// Value constants. (This could be "ClassicValue" to parallel [HashableValue])
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
#[allow(missing_docs)]
pub enum ConstValue {
/// An arbitrary length integer constant.
Int(HugrIntValueStore),
Hashable(HashableValue),
/// A collection of constant values (at least some of which are not [ConstValue::Hashable])
Container(ContainerValue<ConstValue>),
/// Double precision float
F64(f64),
/// A constant specifying a variant of a Sum type.
Sum(usize, Box<ConstValue>),
/// A tuple of constant values.
Tuple(Vec<ConstValue>),
/// An opaque constant value, with cached type
/// An opaque constant value, with cached type. TODO put this into ContainerValue.
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Opaque((Box<dyn CustomConst>,)),
}
Expand All @@ -143,59 +150,129 @@ impl PartialEq for dyn CustomConst {
}
}

impl Default for ConstValue {
fn default() -> Self {
Self::Int(0)
}
}
impl ValueOfType for ConstValue {
type T = ClassicType;

impl ConstValue {
/// Returns the datatype of the constant.
pub fn check_type(&self, typ: &ClassicType) -> Result<(), ConstTypeError> {
typecheck_const(typ, self)
fn name(&self) -> String {
match self {
ConstValue::F64(f) => format!("const:float:{}", f),
ConstValue::Hashable(hv) => hv.name(),
ConstValue::Container(ctr) => ctr.desc(),
ConstValue::Opaque((v,)) => format!("const:custom:{}", v.name()),
}
}
/// Unique name of the constant.
pub fn name(&self) -> SmolStr {

fn check_type(&self, ty: &ClassicType) -> Result<(), ConstTypeError> {
match self {
Self::Int(value) => format!("const:int{value}"),
Self::F64(f) => format!("const:float:{f}"),
Self::Opaque((v,)) => format!("const:{}", v.name()),
Self::Sum(tag, val) => {
format!("const:sum:{{tag:{tag}, val:{}}}", val.name())
ConstValue::F64(_) => {
if let ClassicType::F64 = ty {
return Ok(());
}
}
Self::Tuple(vals) => {
let valstr: Vec<_> = vals.iter().map(|v| v.name()).collect();
let valstr = valstr.join(", ");
format!("const:tuple:{{{valstr}}}")
ConstValue::Hashable(hv) => {
match ty {
ClassicType::Hashable(exp) => return hv.check_type(exp),
ClassicType::Container(cty) => {
// A "hashable" value might be an instance of a non-hashable type:
// e.g. an empty list is hashable, yet can be checked against a classic element type!
if let HashableValue::Container(ctr) = hv {
// Note if ctr is a ContainerValue<HashableValue>::Opaque, this means we can check that
// against a Container<ClassicType>::Opaque, which is perhaps unnecessary, but harmless.
return ctr.map_vals(&ConstValue::Hashable).check_container(cty);
}
}
_ => (),
}
}
}
.into()
ConstValue::Container(vals) => {
match ty {
ClassicType::Container(cty) => return vals.check_container(cty),
// We might also fail to deduce a container *value* was hashable,
// because it contains opaque values whose tag is unknown.
ClassicType::Hashable(HashableType::Container(cty)) => {
return vals
.check_container(&map_container_type(cty, &ClassicType::Hashable))
}
_ => (),
};
}
ConstValue::Opaque((val,)) => {
let maybe_cty = match ty {
ClassicType::Container(Container::Opaque(t)) => Some(t),
ClassicType::Hashable(HashableType::Container(Container::Opaque(t))) => Some(t),
_ => None,
};
if let Some(cu_ty) = maybe_cty {
return val.check_custom_type(cu_ty).map_err(ConstTypeError::from);
}
}
};
Err(ConstTypeError::ValueCheckFail(ty.clone(), self.clone()))
}

fn container_error(
typ: Container<ClassicType>,
vals: ContainerValue<ConstValue>,
) -> ConstTypeError {
ConstTypeError::ValueCheckFail(ClassicType::Container(typ), ConstValue::Container(vals))
}
}

impl ConstValue {
/// Description of the constant.
pub fn description(&self) -> &str {
"Constant value"
}

/// Constant unit type (empty Tuple).
pub const fn unit() -> ConstValue {
ConstValue::Tuple(vec![])
pub const fn unit() -> Self {
Self::Hashable(HashableValue::Container(ContainerValue::Sequence(vec![])))
}

/// Constant Sum over units, used as predicates.
pub fn simple_predicate(tag: usize) -> Self {
Self::predicate(tag, Self::unit())
}

/// Constant Sum over Tuples, used as predicates.
pub fn predicate(tag: usize, val: ConstValue) -> Self {
ConstValue::Sum(tag, Box::new(val))
Self::sum(tag, Self::unit())
}

/// Constant Sum over Tuples with just one variant of unit type
pub fn simple_unary_predicate() -> Self {
Self::simple_predicate(0)
}

/// Sequence of values (could be a tuple, list or array)
pub fn sequence(items: &[ConstValue]) -> Self {
// Keep Hashable at the outside (if all values are)
match items
.iter()
.map(|item| match item {
ConstValue::Hashable(h) => Some(h),
_ => None,
})
.collect::<Option<Vec<&HashableValue>>>()
{
Some(hashables) => ConstValue::Hashable(HashableValue::Container(
ContainerValue::Sequence(hashables.into_iter().cloned().collect()),
)),
None => ConstValue::Container(ContainerValue::Sequence(items.to_vec())),
}
}

/// Sum value (could be of any compatible type, e.g. a predicate)
pub fn sum(tag: usize, value: ConstValue) -> Self {
// Keep Hashable as outermost constructor
match value {
ConstValue::Hashable(hv) => {
HashableValue::Container(ContainerValue::Sum(tag, Box::new(hv))).into()
}
_ => ConstValue::Container(ContainerValue::Sum(tag, Box::new(value))),
}
}
}

impl From<HashableValue> for ConstValue {
fn from(hv: HashableValue) -> Self {
Self::Hashable(hv)
}
}

impl<T: CustomConst> From<T> for ConstValue {
Expand Down Expand Up @@ -232,12 +309,12 @@ impl_box_clone!(CustomConst, CustomConstBoxClone);
mod test {
use cool_asserts::assert_matches;

use super::ConstValue;
use super::{typecheck::ConstTypeError, Const};
use super::{typecheck::ConstIntError, Const, ConstValue};
use crate::{
builder::{BuildError, Container, DFGBuilder, Dataflow, DataflowHugr},
classic_row, type_row,
types::{ClassicType, SimpleRow, SimpleType},
values::{ConstTypeError, HashableValue, ValueOfType},
};

#[test]
Expand All @@ -251,7 +328,10 @@ mod test {
let mut b = DFGBuilder::new(type_row![], SimpleRow::from(vec![pred_ty.clone()]))?;
let c = b.add_constant(Const::predicate(
0,
ConstValue::Tuple(vec![ConstValue::Int(3), ConstValue::F64(3.15)]),
ConstValue::sequence(&[
ConstValue::Hashable(HashableValue::Int(3)),
ConstValue::F64(3.15),
]),
pred_rows.clone(),
)?)?;
let w = b.load_const(&c)?;
Expand All @@ -272,7 +352,37 @@ mod test {
type_row![],
];

let res = Const::predicate(0, ConstValue::Tuple(vec![]), pred_rows);
let res = Const::predicate(0, ConstValue::sequence(&[]), pred_rows);
assert_matches!(res, Err(ConstTypeError::TupleWrongLength));
}

#[test]
fn test_constant_values() {
const T_INT: ClassicType = ClassicType::int::<64>();
const V_INT: ConstValue = ConstValue::Hashable(HashableValue::Int(257));
V_INT.check_type(&T_INT).unwrap();
assert_eq!(
V_INT.check_type(&ClassicType::int::<8>()),
Err(ConstTypeError::Int(ConstIntError::IntTooLarge(8, 257)))
);
ConstValue::F64(17.4).check_type(&ClassicType::F64).unwrap();
assert_matches!(
V_INT.check_type(&ClassicType::F64),
Err(ConstTypeError::ValueCheckFail(ClassicType::F64, v)) => v == V_INT
);
let tuple_ty = ClassicType::new_tuple(classic_row![T_INT, ClassicType::F64]);
let tuple_val = ConstValue::sequence(&[V_INT, ConstValue::F64(5.1)]);
tuple_val.check_type(&tuple_ty).unwrap();
let tuple_val2 = ConstValue::sequence(&[ConstValue::F64(5.1), V_INT]);
assert_matches!(
tuple_val2.check_type(&tuple_ty),
Err(ConstTypeError::ValueCheckFail(ty, tv2)) => ty == tuple_ty && tv2 == tuple_val2
);
let tuple_val3 =
ConstValue::sequence(&vec![V_INT, ConstValue::F64(3.3), ConstValue::F64(2.0)]);
assert_eq!(
tuple_val3.check_type(&tuple_ty),
Err(ConstTypeError::TupleWrongLength)
);
}
}
Loading

0 comments on commit 7b927eb

Please sign in to comment.