Skip to content

Commit

Permalink
Cleanup ConstValue::predicate, improve ConstTypeErrors, add tests (#257)
Browse files Browse the repository at this point in the history
* Fix ConstValue::predicate, must provide the value
* Temporarily assert value matches the Sum type until #231 fixed
* Remove SimpleType::new_unit
* Fix some ConstTypeErrors for F64 and containers
* Add tests of typecheck_const and ConstValue::predicate...
    * including one [should_panic] documenting failure wrt #231
  • Loading branch information
acl-cqc authored Jul 11, 2023
1 parent 2819425 commit 69881fd
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 18 deletions.
57 changes: 55 additions & 2 deletions src/hugr/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::ops::constant::{HugrIntValueStore, HugrIntWidthStore, HUGR_MAX_INT_WI
pub enum ConstTypeError {
/// This case hasn't been implemented. Possibly because we don't have value
/// constructors to check against it
#[error("Const type checking unimplemented for {0}")]
#[error("Unimplemented: there are no constants of type {0}")]
Unimplemented(ClassicType),
/// The given type and term are incompatible
#[error("Invalid const value for type {0}")]
Expand Down Expand Up @@ -100,7 +100,7 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT
Err(ConstTypeError::IntWidthMismatch(*exp_width, *width))
}
}
(ty @ ClassicType::F64, _) => Err(ConstTypeError::Unimplemented(ty.clone())),
(ClassicType::F64, ConstValue::F64(_)) => Ok(()),
(ty @ ClassicType::Container(c), tm) => match (c, tm) {
(Container::Tuple(row), ConstValue::Tuple(xs)) => {
if row.len() != xs.len() {
Expand All @@ -114,6 +114,7 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT
}
Ok(())
}
(Container::Tuple(_), _) => Err(ConstTypeError::Failed(ty.clone())),
(Container::Sum(row), ConstValue::Sum { tag, variants, val }) => {
if tag > &row.len() {
return Err(ConstTypeError::InvalidSumTag);
Expand All @@ -130,6 +131,7 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT
_ => Err(ConstTypeError::LinearTypeDisallowed),
}
}
(Container::Sum(_), _) => Err(ConstTypeError::Failed(ty.clone())),
_ => Err(ConstTypeError::Unimplemented(ty.clone())),
},
(ty @ ClassicType::Graph(_), _) => Err(ConstTypeError::Unimplemented(ty.clone())),
Expand All @@ -147,3 +149,54 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT
(ty, _) => Err(ConstTypeError::Failed(ty.clone())),
}
}

#[cfg(test)]
mod test {
use cool_asserts::assert_matches;

use crate::{type_row, types::ClassicType};

use super::*;

#[test]
fn test_typecheck_const() {
const INT: ClassicType = ClassicType::Int(64);
typecheck_const(&INT, &ConstValue::i64(3)).unwrap();
assert_eq!(
typecheck_const(&ClassicType::Int(32), &ConstValue::i64(3)),
Err(ConstTypeError::IntWidthMismatch(32, 64))
);
typecheck_const(&ClassicType::F64, &ConstValue::F64(17.4)).unwrap();
assert_eq!(
typecheck_const(&ClassicType::F64, &ConstValue::i64(5)),
Err(ConstTypeError::Failed(ClassicType::F64))
);
let tuple_ty = ClassicType::Container(Container::Tuple(Box::new(type_row![
SimpleType::Classic(INT),
SimpleType::Classic(ClassicType::F64)
])));
typecheck_const(
&tuple_ty,
&ConstValue::Tuple(vec![ConstValue::i64(7), ConstValue::F64(5.1)]),
)
.unwrap();
assert_matches!(
typecheck_const(
&tuple_ty,
&ConstValue::Tuple(vec![ConstValue::F64(4.8), ConstValue::i64(2)])
),
Err(ConstTypeError::Failed(_))
);
assert_eq!(
typecheck_const(
&tuple_ty,
&ConstValue::Tuple(vec![
ConstValue::i64(5),
ConstValue::F64(3.3),
ConstValue::i64(2)
])
),
Err(ConstTypeError::TupleWrongLength)
);
}
}
97 changes: 88 additions & 9 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,29 @@ impl ConstValue {

/// Constant Sum over units, used as predicates.
pub fn simple_predicate(tag: usize, size: usize) -> Self {
Self::predicate(tag, std::iter::repeat(type_row![]).take(size))
Self::predicate(tag, Self::unit(), std::iter::repeat(type_row![]).take(size))
}

/// Constant Sum over Tuples, used as predicates.
pub fn predicate(tag: usize, variant_rows: impl IntoIterator<Item = TypeRow>) -> Self {
pub fn predicate(
tag: usize,
val: ConstValue,
variant_rows: impl IntoIterator<Item = TypeRow>,
) -> Self {
let variants = TypeRow::predicate_variants_row(variant_rows);
let const_type = SimpleType::Classic(val.const_type());
// TODO This assert is not appropriate for a public API and if the specified `val`
// is not of tuple type matching the `tag`th element of `variant_rows` then
// really the Hugr will fail in validate. However it doesn't at the moment
// (https://github.com/CQCL-DEV/hugr/issues/231).
assert!(Some(&const_type) == variants.get(tag));
ConstValue::Sum {
tag,
variants: TypeRow::predicate_variants_row(variant_rows),
val: Box::new(Self::unit()),
variants,
val: Box::new(val),
}
}

/// Constant Sum over Tuples with just one variant
pub fn unary_predicate(row: impl Into<TypeRow>) -> Self {
Self::predicate(0, [row.into()])
}

/// Constant Sum over Tuples with just one variant of unit type
pub fn simple_unary_predicate() -> Self {
Self::simple_predicate(0, 1)
Expand Down Expand Up @@ -232,3 +238,76 @@ pub trait CustomConst:

impl_downcast!(CustomConst);
impl_box_clone!(CustomConst, CustomConstBoxClone);

#[cfg(test)]
mod test {
use super::ConstValue;
use crate::{
builder::{BuildError, Container, DFGBuilder, Dataflow, DataflowHugr},
hugr::{typecheck::ConstTypeError, ValidationError},
type_row,
types::{ClassicType, SimpleType, TypeRow},
};

#[test]
fn test_predicate() -> Result<(), BuildError> {
let pred_rows = vec![
type_row![
SimpleType::Classic(ClassicType::i64()),
SimpleType::Classic(ClassicType::F64)
],
type_row![],
];
let pred_ty = SimpleType::new_predicate(pred_rows.clone());

let mut b = DFGBuilder::new(type_row![], TypeRow::from(vec![pred_ty.clone()]))?;
let c = b.add_constant(ConstValue::predicate(
0,
ConstValue::Tuple(vec![ConstValue::i64(3), ConstValue::F64(3.15)]),
pred_rows.clone(),
))?;
let w = b.load_const(&c)?;
b.finish_hugr_with_outputs([w]).unwrap();

let mut b = DFGBuilder::new(type_row![], TypeRow::from(vec![pred_ty]))?;
let c = b.add_constant(ConstValue::predicate(
1,
ConstValue::Tuple(vec![]),
pred_rows,
))?;
let w = b.load_const(&c)?;
b.finish_hugr_with_outputs([w]).unwrap();

Ok(())
}

#[test]
#[should_panic] // Pending resolution of https://github.com/CQCL-DEV/hugr/issues/231
fn test_bad_predicate() {
let pred_rows = vec![
type_row![
SimpleType::Classic(ClassicType::i64()),
SimpleType::Classic(ClassicType::F64)
],
type_row![],
];
let pred_ty = SimpleType::new_predicate(pred_rows.clone());

let mut b = DFGBuilder::new(type_row![], TypeRow::from(vec![pred_ty])).unwrap();
// Until #231 is fixed, this is made to fail by an assert in ConstValue::predicate
let c = b
.add_constant(ConstValue::predicate(
0,
ConstValue::Tuple(vec![]),
pred_rows,
))
.unwrap();
let w = b.load_const(&c).unwrap();
assert_eq!(
b.finish_hugr_with_outputs([w]),
Err(BuildError::InvalidHUGR(ValidationError::ConstTypeError(
ConstTypeError::TupleWrongLength
)))
);
}
}
7 changes: 0 additions & 7 deletions src/types/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,6 @@ impl SimpleType {
}
}

/// New unit type, defined as an empty Tuple.
pub fn new_unit() -> Self {
Self::Classic(ClassicType::Container(Container::Tuple(Box::new(
TypeRow::new(),
))))
}

/// New Sum of Tuple types, used as predicates in branching.
/// Tuple rows are defined in order by input rows.
pub fn new_predicate(variant_rows: impl IntoIterator<Item = TypeRow>) -> Self {
Expand Down

0 comments on commit 69881fd

Please sign in to comment.