Skip to content

Commit

Permalink
Statically query the tag of an operation (#251)
Browse files Browse the repository at this point in the history
* feat: Statically query the tag of an operation

Implemented `OpTagged` for all operation structs.
`OpTrait` still has a `tag(&self)` method for dynamic computation (e.g. for `OpType`)

* Just use an associated const

* Make tag comparison constant

* s/OpTagged/StaticTag/

* `contains` -> `is_superset`, `parent_tags` -> `immediate_supersets`
  • Loading branch information
aborgna-q authored Jul 11, 2023
1 parent e57500e commit 3ca9fed
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 96 deletions.
2 changes: 1 addition & 1 deletion src/algorithm/half_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::hash::Hash;

use super::nest_cfgs::CfgView;
use crate::hugr::view::HugrView;
use crate::ops::tag::OpTag;
use crate::ops::OpTag;
use crate::ops::OpTrait;
use crate::{Direction, Node};

Expand Down
2 changes: 1 addition & 1 deletion src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use std::hash::Hash;
use itertools::Itertools;

use crate::hugr::view::HugrView;
use crate::ops::tag::OpTag;
use crate::ops::OpTag;
use crate::ops::OpTrait;
use crate::{Direction, Node};

Expand Down
2 changes: 1 addition & 1 deletion src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ mod test {

use crate::builder::build_traits::DataflowHugr;
use crate::builder::{DataflowSubContainer, ModuleBuilder};
use crate::ops::tag::OpTag;
use crate::ops::OpTag;
use crate::ops::OpTrait;
use crate::{
builder::{
Expand Down
2 changes: 1 addition & 1 deletion src/builder/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::{
ops::{
handle::{BasicBlockID, CaseID, DfgID, FuncID, NodeHandle, TailLoopID},
tag::OpTag,
OpTag,
},
Port,
};
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use portgraph::{LinkMut, LinkView, MultiMut, NodeIndex, PortView};
use crate::hugr::{HugrMut, HugrView, NodeMetadata};
use crate::{
hugr::{Node, Rewrite},
ops::{tag::OpTag, OpTrait, OpType},
ops::{OpTag, OpTrait, OpType},
Hugr, Port,
};
use thiserror::Error;
Expand Down Expand Up @@ -264,7 +264,7 @@ mod test {
};
use crate::hugr::view::HugrView;
use crate::hugr::{Hugr, Node};
use crate::ops::tag::OpTag;
use crate::ops::OpTag;
use crate::ops::{LeafOp, OpTrait, OpType};
use crate::types::{ClassicType, LinearType, Signature, SimpleType};
use crate::{type_row, Port};
Expand Down
12 changes: 6 additions & 6 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use portgraph::{LinkView, PortView};
use thiserror::Error;

use crate::hugr::typecheck::{typecheck_const, ConstTypeError};
use crate::ops::tag::OpTag;
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::OpTag;
use crate::ops::{self, OpTrait, OpType, ValidateOp};
use crate::resource::ResourceSet;
use crate::types::ClassicType;
Expand Down Expand Up @@ -123,7 +123,7 @@ impl<'a> ValidationContext<'a> {

let parent_optype = self.hugr.get_optype(parent);
let allowed_children = parent_optype.validity_flags().allowed_children;
if !allowed_children.contains(optype.tag()) {
if !allowed_children.is_superset(optype.tag()) {
return Err(ValidationError::InvalidParentOp {
child: node,
child_optype: optype.clone(),
Expand Down Expand Up @@ -296,7 +296,7 @@ impl<'a> ValidationContext<'a> {
let all_children = self.hugr.children(node);
let mut first_two_children = all_children.clone().take(2);
let first_child = self.hugr.get_optype(first_two_children.next().unwrap());
if !flags.allowed_first_child.contains(first_child.tag()) {
if !flags.allowed_first_child.is_superset(first_child.tag()) {
return Err(ValidationError::InvalidInitialChild {
parent: node,
parent_optype: optype.clone(),
Expand All @@ -310,7 +310,7 @@ impl<'a> ValidationContext<'a> {
.next()
.map(|child| self.hugr.get_optype(child))
{
if !flags.allowed_second_child.contains(second_child.tag()) {
if !flags.allowed_second_child.is_superset(second_child.tag()) {
return Err(ValidationError::InvalidInitialChild {
parent: node,
parent_optype: optype.clone(),
Expand Down Expand Up @@ -396,7 +396,7 @@ impl<'a> ValidationContext<'a> {
let non_defn_count = self
.hugr
.children(parent)
.filter(|n| !OpTag::ScopedDefn.contains(self.hugr.get_optype(*n).tag()))
.filter(|n| !OpTag::ScopedDefn.is_superset(self.hugr.get_optype(*n).tag()))
.count();
if nodes_visited != non_defn_count {
return Err(ValidationError::NotABoundedDag {
Expand Down Expand Up @@ -443,7 +443,7 @@ impl<'a> ValidationContext<'a> {
} else {
// If const edges aren't coming from const nodes, they're graph
// edges coming from FuncDecl or FuncDefn
return if OpTag::Function.contains(from_optype.tag()) {
return if OpTag::Function.is_superset(from_optype.tag()) {
Ok(())
} else {
Err(InterGraphEdgeError::InvalidConstSrc {
Expand Down
13 changes: 12 additions & 1 deletion src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ use crate::{Direction, Port};
use portgraph::NodeIndex;
use smol_str::SmolStr;

use self::tag::OpTag;
use enum_dispatch::enum_dispatch;

pub use constant::{Const, ConstValue};
pub use controlflow::{BasicBlock, Case, Conditional, TailLoop, CFG};
pub use dataflow::{Call, CallIndirect, Input, LoadConstant, Output, DFG};
pub use leaf::LeafOp;
pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module};
pub use tag::OpTag;

#[enum_dispatch(OpTrait, OpName, ValidateOp)]
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -139,13 +139,24 @@ pub trait OpName {
fn name(&self) -> SmolStr;
}

/// Trait statically querying the tag of an operation.
///
/// This is implemented by all OpType variants, and always contains the dynamic
/// tag returned by `OpType::tag(&self)`.
pub trait StaticTag {
/// The name of the operation.
const TAG: OpTag;
}

#[enum_dispatch]
/// Trait implemented by all OpType variants.
pub trait OpTrait {
/// A human-readable description of the operation.
fn description(&self) -> &str;

/// Tag identifying the operation.
fn tag(&self) -> OpTag;

/// The signature of the operation.
///
/// Only dataflow operations have a non-empty signature.
Expand Down
9 changes: 6 additions & 3 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::{
use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use super::tag::OpTag;
use super::{OpName, OpTrait};
use super::OpTag;
use super::{OpName, OpTrait, StaticTag};

/// A constant value definition.
#[derive(Debug, Clone, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
Expand All @@ -22,13 +22,16 @@ impl OpName for Const {
self.0.name()
}
}
impl StaticTag for Const {
const TAG: OpTag = OpTag::Const;
}
impl OpTrait for Const {
fn description(&self) -> &str {
self.0.description()
}

fn tag(&self) -> OpTag {
OpTag::Const
<Self as StaticTag>::TAG
}

fn other_output(&self) -> Option<EdgeKind> {
Expand Down
32 changes: 17 additions & 15 deletions src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use smol_str::SmolStr;
use crate::types::{EdgeKind, Signature, SimpleType, TypeRow};

use super::dataflow::DataflowOpTrait;
use super::tag::OpTag;
use super::{impl_op_name, OpName, OpTrait};
use super::OpTag;
use super::{impl_op_name, OpName, OpTrait, StaticTag};

/// Tail-controlled loop.
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
Expand All @@ -22,14 +22,12 @@ pub struct TailLoop {
impl_op_name!(TailLoop);

impl DataflowOpTrait for TailLoop {
const TAG: OpTag = OpTag::TailLoop;

fn description(&self) -> &str {
"A tail-controlled loop"
}

fn tag(&self) -> OpTag {
OpTag::TailLoop
}

fn signature(&self) -> Signature {
let [inputs, outputs] =
[self.just_inputs.clone(), self.just_outputs.clone()].map(|mut row| {
Expand Down Expand Up @@ -71,14 +69,12 @@ pub struct Conditional {
impl_op_name!(Conditional);

impl DataflowOpTrait for Conditional {
const TAG: OpTag = OpTag::Conditional;

fn description(&self) -> &str {
"HUGR conditional operation"
}

fn tag(&self) -> OpTag {
OpTag::Conditional
}

fn signature(&self) -> Signature {
let mut inputs = self.other_inputs.clone();
inputs.to_mut().insert(
Expand Down Expand Up @@ -110,14 +106,12 @@ pub struct CFG {
impl_op_name!(CFG);

impl DataflowOpTrait for CFG {
const TAG: OpTag = OpTag::Cfg;

fn description(&self) -> &str {
"A dataflow node defined by a child CFG"
}

fn tag(&self) -> OpTag {
OpTag::Cfg
}

fn signature(&self) -> Signature {
Signature::new_df(self.inputs.clone(), self.outputs.clone())
}
Expand Down Expand Up @@ -149,6 +143,10 @@ impl OpName for BasicBlock {
}
}

impl StaticTag for BasicBlock {
const TAG: OpTag = OpTag::BasicBlock;
}

impl OpTrait for BasicBlock {
/// The description of the operation.
fn description(&self) -> &str {
Expand Down Expand Up @@ -210,13 +208,17 @@ pub struct Case {

impl_op_name!(Case);

impl StaticTag for Case {
const TAG: OpTag = OpTag::Case;
}

impl OpTrait for Case {
fn description(&self) -> &str {
"A case node inside a conditional"
}

fn tag(&self) -> OpTag {
OpTag::Case
<Self as StaticTag>::TAG
}
}

Expand Down
Loading

0 comments on commit 3ca9fed

Please sign in to comment.