Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable static simple predicates #417

Merged
merged 13 commits into from
Aug 18, 2023
1 change: 1 addition & 0 deletions src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub(crate) const QB_CUSTOM_T: CustomType = CustomType::new_simple(

pub(crate) const QB_T: Type = Type::new_extension(QB_CUSTOM_T);
pub(crate) const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T);
pub(crate) const BOOL_T: Type = Type::new_simple_predicate(2);

/// Initialize a new array of type `typ` of length `size`
pub fn new_array(typ: Type, size: u64) -> Type {
Expand Down
33 changes: 18 additions & 15 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct ValidationContext<'a> {
hugr: &'a Hugr,
/// Dominator tree for each CFG region, using the container node as index.
dominators: HashMap<Node, Dominators<Node>>,
/// Context for the extension validation.
/// Context for the resource validation.
extension_validator: ExtensionValidator,
}

Expand Down Expand Up @@ -137,8 +137,8 @@ impl<'a> ValidationContext<'a> {
// Check operation-specific constraints
self.validate_operation(node, node_type)?;

// If this is a container with I/O nodes, check that the extensions they
// define match the extensions of the container.
// If this is a container with I/O nodes, check that the resources they
// define match the resources of the container.
if let Some([input, output]) = self.hugr.get_io(node) {
self.extension_validator
.validate_io_extensions(node, input, output)?;
Expand Down Expand Up @@ -591,7 +591,7 @@ pub enum ValidationError {
/// There are invalid inter-graph edges.
#[error(transparent)]
InterGraphEdgeError(#[from] InterGraphEdgeError),
/// There are errors in the extension declarations.
/// There are errors in the resource declarations.
#[error(transparent)]
ExtensionError(#[from] ExtensionError),
}
Expand Down Expand Up @@ -736,8 +736,8 @@ mod test {
parent: Node,
predicate_size: usize,
) -> (Node, Node, Node, Node) {
let const_op = ops::Const::simple_predicate(0, predicate_size);
let tag_type = Type::new_simple_predicate(predicate_size);
let const_op = ops::Const::simple_predicate(0, predicate_size as u8);
let tag_type = Type::new_simple_predicate(predicate_size as u8);

let input = b
.add_op_with_parent(parent, ops::Input::new(type_row![B]))
Expand Down Expand Up @@ -986,7 +986,10 @@ mod test {
b.replace_op(block_input, NodeType::pure(ops::Input::new(type_row![Q])));
b.replace_op(
block_output,
NodeType::pure(ops::Output::new(vec![Type::new_simple_predicate(1), Q])),
NodeType::pure(ops::Output::new(type_row![
Type::new_simple_predicate(1),
Q
])),
);
assert_matches!(
b.validate(),
Expand Down Expand Up @@ -1070,8 +1073,8 @@ mod test {
}

#[test]
/// A wire with no extension requirements is wired into a node which has
/// [A,B] extensions required on its inputs and outputs. This could be fixed
/// A wire with no resource requirements is wired into a node which has
/// [A,B] resources required on its inputs and outputs. This could be fixed
/// by adding a lift node, but for validation this is an error.
fn missing_lift_node() -> Result<(), BuildError> {
let mut module_builder = ModuleBuilder::new();
Expand All @@ -1082,7 +1085,7 @@ mod test {
let [main_input] = main.input_wires_arr();

let inner_sig = AbstractSignature::new_df(type_row![NAT], type_row![NAT])
// Inner DFG has extension requirements that the wire wont satisfy
// Inner DFG has resource requirements that the wire wont satisfy
.with_input_extensions(ExtensionSet::from_iter(["A".into(), "B".into()]));

let f_builder = main.dfg_builder(
Expand All @@ -1106,11 +1109,11 @@ mod test {
}

#[test]
/// A wire with extension requirement `[A]` is wired into a an output with no
/// extension req. In the validation extension typechecking, we don't do any
/// unification, so don't allow open extension variables on the function
/// A wire with resource requirement `[A]` is wired into a an output with no
/// resource req. In the validation resource typechecking, we don't do any
/// unification, so don't allow open resource variables on the function
/// signature, so this fails.
fn too_many_extensions() -> Result<(), BuildError> {
fn too_many_resources() -> Result<(), BuildError> {
let mut module_builder = ModuleBuilder::new();

let main_sig = AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure();
Expand Down Expand Up @@ -1142,7 +1145,7 @@ mod test {
}

#[test]
/// A wire with extension requirements `[A]` and another with requirements
/// A wire with resource requirements `[A]` and another with requirements
/// `[B]` are both wired into a node which requires its inputs to have
/// requirements `[A,B]`. A slightly more complex test of the error from
/// `missing_lift_node`.
Expand Down
2 changes: 1 addition & 1 deletion src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub(crate) use impl_box_clone;
/// ```
/// # use hugr::macros::type_row;
/// # use hugr::types::{AbstractSignature, Type, TypeRow};
/// const U: Type = Type::new_unit();
/// const U: Type = Type::UNIT;
/// let static_row: TypeRow = type_row![U, U];
/// let dynamic_row: TypeRow = vec![U, U, U].into();
/// let sig = AbstractSignature::new_df(static_row, dynamic_row).pure();
Expand Down
2 changes: 1 addition & 1 deletion src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl Const {
}

/// Constant Sum over units, used as predicates.
pub fn simple_predicate(tag: usize, size: usize) -> Self {
pub fn simple_predicate(tag: usize, size: u8) -> Self {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why restrict to u8?

Self {
value: Value::simple_predicate(tag),
typ: Type::new_simple_predicate(size),
Expand Down
4 changes: 1 addition & 3 deletions src/ops/leaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ pub enum LeafOp {

impl Default for LeafOp {
fn default() -> Self {
Self::Noop {
ty: Type::new_unit(),
}
Self::Noop { ty: Type::UNIT }
}
}
impl OpName for LeafOp {
Expand Down
7 changes: 3 additions & 4 deletions src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@ use crate::{
Extension,
};

use super::super::logic::bool_type;
use super::float_types::FLOAT64_TYPE;

/// The extension identifier.
pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.float");

fn fcmp_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
vec![FLOAT64_TYPE; 2].into(),
vec![bool_type()].into(),
type_row![FLOAT64_TYPE; 2],
type_row![crate::extension::prelude::BOOL_T],
ExtensionSet::default(),
))
}

fn fbinop_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
vec![FLOAT64_TYPE; 2].into(),
type_row![FLOAT64_TYPE; 2],
type_row![FLOAT64_TYPE],
ExtensionSet::default(),
))
Expand Down
10 changes: 5 additions & 5 deletions src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

use smol_str::SmolStr;

use super::super::logic::bool_type;
use super::int_types::{get_width, int_type};
use crate::extension::prelude::ERROR_TYPE;
use crate::extension::prelude::{BOOL_T, ERROR_TYPE};
use crate::type_row;
use crate::types::type_param::TypeParam;
use crate::utils::collect_array;
use crate::{
Expand Down Expand Up @@ -47,14 +47,14 @@ fn inarrow_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet
fn itob_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
vec![int_type(1)].into(),
vec![bool_type()].into(),
type_row![BOOL_T],
ExtensionSet::default(),
))
}

fn btoi_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
vec![bool_type()].into(),
type_row![BOOL_T],
vec![int_type(1)].into(),
ExtensionSet::default(),
))
Expand All @@ -65,7 +65,7 @@ fn icmp_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet),
let n: u8 = get_width(arg)?;
Ok((
vec![int_type(n); 2].into(),
vec![bool_type()].into(),
type_row![BOOL_T],
ExtensionSet::default(),
))
}
Expand Down
32 changes: 12 additions & 20 deletions src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ use itertools::Itertools;
use smol_str::SmolStr;

use crate::{
extension::ExtensionSet,
ops,
types::{
type_param::{TypeArg, TypeArgError, TypeParam},
Type,
},
extension::{prelude::BOOL_T, ExtensionSet},
ops, type_row,
types::type_param::{TypeArg, TypeArgError, TypeParam},
Extension,
};

Expand All @@ -21,11 +18,6 @@ pub const TRUE_NAME: &str = "TRUE";
/// The extension identifier.
pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("logic");

/// Construct a boolean type.
pub fn bool_type() -> Type {
Type::new_simple_predicate(2)
}

/// Extension for basic logical operations.
pub fn extension() -> Extension {
const H_INT: TypeParam = TypeParam::USize;
Expand All @@ -38,8 +30,8 @@ pub fn extension() -> Extension {
vec![],
|_arg_values: &[TypeArg]| {
Ok((
vec![bool_type()].into(),
vec![bool_type()].into(),
type_row![BOOL_T],
type_row![BOOL_T],
ExtensionSet::default(),
))
},
Expand All @@ -64,8 +56,8 @@ pub fn extension() -> Extension {
}
};
Ok((
vec![bool_type(); n as usize].into(),
vec![bool_type()].into(),
vec![BOOL_T; n as usize].into(),
type_row![BOOL_T],
ExtensionSet::default(),
))
},
Expand All @@ -90,8 +82,8 @@ pub fn extension() -> Extension {
}
};
Ok((
vec![bool_type(); n as usize].into(),
vec![bool_type()].into(),
vec![BOOL_T; n as usize].into(),
type_row![BOOL_T],
ExtensionSet::default(),
))
},
Expand All @@ -109,9 +101,9 @@ pub fn extension() -> Extension {

#[cfg(test)]
mod test {
use crate::Extension;
use crate::{extension::prelude::BOOL_T, Extension};

use super::{bool_type, extension, FALSE_NAME, TRUE_NAME};
use super::{extension, FALSE_NAME, TRUE_NAME};

#[test]
fn test_logic_extension() {
Expand All @@ -128,7 +120,7 @@ mod test {

for v in [false_val, true_val] {
let simpl = v.typed_value().const_type();
assert_eq!(simpl, &bool_type());
assert_eq!(simpl, &BOOL_T);
}
}
}
Loading