Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HashableValue as part of ConstValue, and ContainerValue #322

Merged
merged 1 commit into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use thiserror::Error;
use pyo3::prelude::*;

use crate::hugr::{HugrError, Node, ValidationError, Wire};
use crate::ops::constant::typecheck::ConstTypeError;
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::types::SimpleType;
use crate::values::ConstTypeError;

pub mod handle;
pub use handle::BuildHandle;
Expand Down
2 changes: 1 addition & 1 deletion src/extensions/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ use std::collections::HashMap;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

use crate::ops::constant::typecheck::CustomCheckFail;
use crate::ops::constant::CustomConst;
use crate::resource::{OpDef, ResourceSet, TypeDef};
use crate::types::type_param::TypeArg;
use crate::types::{CustomType, SimpleRow, TypeTag};
use crate::values::CustomCheckFail;
use crate::Resource;

pub const fn resource_id() -> SmolStr {
Expand Down
10 changes: 3 additions & 7 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ use thiserror::Error;
use pyo3::prelude::*;

use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::OpTag;
use crate::ops::{OpTrait, OpType, ValidateOp};
use crate::ops::{OpTag, OpTrait, OpType, ValidateOp};
use crate::resource::ResourceSet;
use crate::types::{ClassicType, EdgeKind, SimpleType};
use crate::{Direction, Hugr, Node, Port};
Expand Down Expand Up @@ -748,7 +747,7 @@ mod test {
use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder};
use crate::hugr::{HugrError, HugrMut, NodeType};
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, ConstValue, LeafOp, OpType};
use crate::ops::{self, LeafOp, OpType};
use crate::types::{AbstractSignature, ClassicType};
use crate::Direction;
use crate::{type_row, Node};
Expand Down Expand Up @@ -1146,10 +1145,7 @@ mod test {
})
);
// Second input of Xor from a constant
let cst = h.add_op_with_parent(
h.root(),
ops::Const::new(ConstValue::Int(1), ClassicType::int::<1>()).unwrap(),
)?;
let cst = h.add_op_with_parent(h.root(), ops::Const::int::<1>(1).unwrap())?;
let lcst = h.add_op_with_parent(
h.root(),
ops::LoadConstant {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod ops;
pub mod resource;
pub mod types;
mod utils;
pub mod values;

pub use crate::hugr::{Direction, Hugr, HugrView, Node, Port, SimpleReplacement, Wire};
pub use crate::resource::Resource;
214 changes: 162 additions & 52 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ use std::any::Any;

use crate::{
macros::impl_box_clone,
types::{ClassicRow, ClassicType, CustomType, EdgeKind},
types::{simple::Container, ClassicRow, ClassicType, CustomType, EdgeKind, HashableType},
values::{
map_container_type, ConstTypeError, ContainerValue, CustomCheckFail, HashableValue,
ValueOfType,
},
};

use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use self::typecheck::{typecheck_const, ConstTypeError, CustomCheckFail};

use super::OpTag;
use super::{OpName, OpTrait, StaticTag};

Expand Down Expand Up @@ -48,8 +50,7 @@ impl Const {
variant_rows: impl IntoIterator<Item = ClassicRow>,
) -> Result<Self, ConstTypeError> {
let typ = ClassicType::new_predicate(variant_rows);

Self::new(ConstValue::Sum(tag, Box::new(value)), typ)
Self::new(ConstValue::sum(tag, value), typ)
}

/// Constant Sum over units, used as predicates.
Expand Down Expand Up @@ -80,18 +81,30 @@ impl Const {

/// Fixed width integer
pub fn int<const N: u8>(value: HugrIntValueStore) -> Result<Self, ConstTypeError> {
Self::new(ConstValue::Int(value), ClassicType::int::<N>())
Self::new(
ConstValue::Hashable(HashableValue::Int(value)),
ClassicType::int::<N>(),
)
}

/// 64-bit integer
pub fn i64(value: i64) -> Result<Self, ConstTypeError> {
Self::int::<64>(value as HugrIntValueStore)
}

/// Tuple of values
pub fn new_tuple(items: impl IntoIterator<Item = Const>) -> Self {
let (values, types): (Vec<ConstValue>, Vec<ClassicType>) = items
.into_iter()
.map(|Const { value, typ }| (value, typ))
.unzip();
Self::new(ConstValue::sequence(&values), ClassicType::new_tuple(types)).unwrap()
}
}

impl OpName for Const {
fn name(&self) -> SmolStr {
self.value.name()
self.value.name().into()
}
}
impl StaticTag for Const {
Expand All @@ -116,23 +129,17 @@ pub(crate) type HugrIntWidthStore = u8;
pub(crate) const HUGR_MAX_INT_WIDTH: HugrIntWidthStore =
HugrIntValueStore::BITS as HugrIntWidthStore;

/// Value constants
///
/// TODO: Add more constants
/// TODO: bigger/smaller integers.
/// Value constants. (This could be "ClassicValue" to parallel [HashableValue])
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
#[allow(missing_docs)]
pub enum ConstValue {
/// An arbitrary length integer constant.
Int(HugrIntValueStore),
Hashable(HashableValue),
/// A collection of constant values (at least some of which are not [ConstValue::Hashable])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"not known to be", perhaps

Container(ContainerValue<ConstValue>),
/// Double precision float
F64(f64),
/// A constant specifying a variant of a Sum type.
Sum(usize, Box<ConstValue>),
/// A tuple of constant values.
Tuple(Vec<ConstValue>),
/// An opaque constant value, with cached type
/// An opaque constant value, with cached type. TODO put this into ContainerValue.
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Opaque((Box<dyn CustomConst>,)),
}
Expand All @@ -143,59 +150,129 @@ impl PartialEq for dyn CustomConst {
}
}

impl Default for ConstValue {
fn default() -> Self {
Self::Int(0)
}
}
impl ValueOfType for ConstValue {
type T = ClassicType;

impl ConstValue {
/// Returns the datatype of the constant.
pub fn check_type(&self, typ: &ClassicType) -> Result<(), ConstTypeError> {
typecheck_const(typ, self)
fn name(&self) -> String {
match self {
ConstValue::F64(f) => format!("const:float:{}", f),
ConstValue::Hashable(hv) => hv.name(),
ConstValue::Container(ctr) => ctr.desc(),
ConstValue::Opaque((v,)) => format!("const:custom:{}", v.name()),
}
}
/// Unique name of the constant.
pub fn name(&self) -> SmolStr {

fn check_type(&self, ty: &ClassicType) -> Result<(), ConstTypeError> {
match self {
Self::Int(value) => format!("const:int{value}"),
Self::F64(f) => format!("const:float:{f}"),
Self::Opaque((v,)) => format!("const:{}", v.name()),
Self::Sum(tag, val) => {
format!("const:sum:{{tag:{tag}, val:{}}}", val.name())
ConstValue::F64(_) => {
if let ClassicType::F64 = ty {
return Ok(());
}
}
Self::Tuple(vals) => {
let valstr: Vec<_> = vals.iter().map(|v| v.name()).collect();
let valstr = valstr.join(", ");
format!("const:tuple:{{{valstr}}}")
ConstValue::Hashable(hv) => {
match ty {
ClassicType::Hashable(exp) => return hv.check_type(exp),
ClassicType::Container(cty) => {
// A "hashable" value might be an instance of a non-hashable type:
// e.g. an empty list is hashable, yet can be checked against a classic element type!
if let HashableValue::Container(ctr) = hv {
// Note if ctr is a ContainerValue<HashableValue>::Opaque, this means we can check that
// against a Container<ClassicType>::Opaque, which is perhaps unnecessary, but harmless.
return ctr.map_vals(&ConstValue::Hashable).check_container(cty);
}
}
_ => (),
}
}
}
.into()
ConstValue::Container(vals) => {
match ty {
ClassicType::Container(cty) => return vals.check_container(cty),
// We might also fail to deduce a container *value* was hashable,
// because it contains opaque values whose tag is unknown.
ClassicType::Hashable(HashableType::Container(cty)) => {
return vals
.check_container(&map_container_type(cty, &ClassicType::Hashable))
}
_ => (),
};
}
ConstValue::Opaque((val,)) => {
let maybe_cty = match ty {
ClassicType::Container(Container::Opaque(t)) => Some(t),
ClassicType::Hashable(HashableType::Container(Container::Opaque(t))) => Some(t),
_ => None,
};
if let Some(cu_ty) = maybe_cty {
return val.check_custom_type(cu_ty).map_err(ConstTypeError::from);
}
}
};
Err(ConstTypeError::ValueCheckFail(ty.clone(), self.clone()))
}

fn container_error(
typ: Container<ClassicType>,
vals: ContainerValue<ConstValue>,
) -> ConstTypeError {
ConstTypeError::ValueCheckFail(ClassicType::Container(typ), ConstValue::Container(vals))
}
}

impl ConstValue {
/// Description of the constant.
pub fn description(&self) -> &str {
"Constant value"
}

/// Constant unit type (empty Tuple).
pub const fn unit() -> ConstValue {
ConstValue::Tuple(vec![])
pub const fn unit() -> Self {
Self::Hashable(HashableValue::Container(ContainerValue::Sequence(vec![])))
}

/// Constant Sum over units, used as predicates.
pub fn simple_predicate(tag: usize) -> Self {
Self::predicate(tag, Self::unit())
}

/// Constant Sum over Tuples, used as predicates.
pub fn predicate(tag: usize, val: ConstValue) -> Self {
ConstValue::Sum(tag, Box::new(val))
Self::sum(tag, Self::unit())
}

/// Constant Sum over Tuples with just one variant of unit type
pub fn simple_unary_predicate() -> Self {
Self::simple_predicate(0)
}

/// Sequence of values (could be a tuple, list or array)
pub fn sequence(items: &[ConstValue]) -> Self {
// Keep Hashable at the outside (if all values are)
match items
.iter()
.map(|item| match item {
ConstValue::Hashable(h) => Some(h),
_ => None,
})
.collect::<Option<Vec<&HashableValue>>>()
{
Some(hashables) => ConstValue::Hashable(HashableValue::Container(
ContainerValue::Sequence(hashables.into_iter().cloned().collect()),
)),
None => ConstValue::Container(ContainerValue::Sequence(items.to_vec())),
}
}

/// Sum value (could be of any compatible type, e.g. a predicate)
pub fn sum(tag: usize, value: ConstValue) -> Self {
// Keep Hashable as outermost constructor
match value {
ConstValue::Hashable(hv) => {
HashableValue::Container(ContainerValue::Sum(tag, Box::new(hv))).into()
}
_ => ConstValue::Container(ContainerValue::Sum(tag, Box::new(value))),
}
}
}

impl From<HashableValue> for ConstValue {
fn from(hv: HashableValue) -> Self {
Self::Hashable(hv)
}
}

impl<T: CustomConst> From<T> for ConstValue {
Expand Down Expand Up @@ -232,12 +309,12 @@ impl_box_clone!(CustomConst, CustomConstBoxClone);
mod test {
use cool_asserts::assert_matches;

use super::ConstValue;
use super::{typecheck::ConstTypeError, Const};
use super::{typecheck::ConstIntError, Const, ConstValue};
use crate::{
builder::{BuildError, Container, DFGBuilder, Dataflow, DataflowHugr},
classic_row, type_row,
types::{ClassicType, SimpleRow, SimpleType},
values::{ConstTypeError, HashableValue, ValueOfType},
};

#[test]
Expand All @@ -251,7 +328,10 @@ mod test {
let mut b = DFGBuilder::new(type_row![], SimpleRow::from(vec![pred_ty.clone()]))?;
let c = b.add_constant(Const::predicate(
0,
ConstValue::Tuple(vec![ConstValue::Int(3), ConstValue::F64(3.15)]),
ConstValue::sequence(&[
ConstValue::Hashable(HashableValue::Int(3)),
ConstValue::F64(3.15),
]),
pred_rows.clone(),
)?)?;
let w = b.load_const(&c)?;
Expand All @@ -272,7 +352,37 @@ mod test {
type_row![],
];

let res = Const::predicate(0, ConstValue::Tuple(vec![]), pred_rows);
let res = Const::predicate(0, ConstValue::sequence(&[]), pred_rows);
assert_matches!(res, Err(ConstTypeError::TupleWrongLength));
}

#[test]
fn test_constant_values() {
const T_INT: ClassicType = ClassicType::int::<64>();
const V_INT: ConstValue = ConstValue::Hashable(HashableValue::Int(257));
V_INT.check_type(&T_INT).unwrap();
assert_eq!(
V_INT.check_type(&ClassicType::int::<8>()),
Err(ConstTypeError::Int(ConstIntError::IntTooLarge(8, 257)))
);
ConstValue::F64(17.4).check_type(&ClassicType::F64).unwrap();
assert_matches!(
V_INT.check_type(&ClassicType::F64),
Err(ConstTypeError::ValueCheckFail(ClassicType::F64, v)) => v == V_INT
);
let tuple_ty = ClassicType::new_tuple(classic_row![T_INT, ClassicType::F64]);
let tuple_val = ConstValue::sequence(&[V_INT, ConstValue::F64(5.1)]);
tuple_val.check_type(&tuple_ty).unwrap();
let tuple_val2 = ConstValue::sequence(&[ConstValue::F64(5.1), V_INT]);
assert_matches!(
tuple_val2.check_type(&tuple_ty),
Err(ConstTypeError::ValueCheckFail(ty, tv2)) => ty == tuple_ty && tv2 == tuple_val2
);
let tuple_val3 =
ConstValue::sequence(&vec![V_INT, ConstValue::F64(3.3), ConstValue::F64(2.0)]);
assert_eq!(
tuple_val3.check_type(&tuple_ty),
Err(ConstTypeError::TupleWrongLength)
);
}
}
Loading