From 21ac480297104f06f9957df6f8706ea3c8263d17 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 31 Aug 2023 11:19:09 +0100 Subject: [PATCH] add check port is output --- src/hugr/rewrite/insert_identity.rs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs index 5816fdcf1..c2d177b29 100644 --- a/src/hugr/rewrite/insert_identity.rs +++ b/src/hugr/rewrite/insert_identity.rs @@ -3,7 +3,7 @@ use crate::hugr::{HugrMut, Node}; use crate::ops::LeafOp; use crate::types::EdgeKind; -use crate::{Hugr, HugrView, Port}; +use crate::{Direction, Hugr, HugrView, Port}; use super::Rewrite; @@ -38,6 +38,10 @@ pub enum IdentityInsertionError { /// Invalid port kind. #[error("post_port has invalid kind {0:?}. Must be Value.")] InvalidPortKind(Option), + + /// Must be input port. + #[error("post_port is an output port, must be input.")] + PortIsOutput, } impl Rewrite for IdentityInsertion { @@ -53,11 +57,15 @@ impl Rewrite for IdentityInsertion { Conditions: 1. post_port is Value kind. 2. post_port is connected to a sibling of post_node. + 3. post_port is input. */ unimplemented!() } fn apply(self, h: &mut Hugr) -> Result { + if self.post_port.direction() != Direction::Incoming { + return Err(IdentityInsertionError::PortIsOutput); + } let (pre_node, pre_port) = h .linked_ports(self.post_node, self.post_port) .exactly_one() @@ -73,7 +81,6 @@ impl Rewrite for IdentityInsertion { h.connect(pre_node, pre_port.index(), new_node, 0) .expect("Should only fail if ports don't exist."); - // TODO Check type, insert Noop... h.connect(new_node, 0, self.post_node, self.post_port.index()) .expect("Should only fail if ports don't exist."); Ok(new_node) @@ -121,9 +128,14 @@ mod tests { let final_node = tail.node(); - let final_node_port = h.node_inputs(final_node).next().unwrap(); + let final_node_output = h.node_outputs(final_node).next().unwrap(); + let rw = IdentityInsertion::new(final_node, final_node_output); + let apply_result = h.apply_rewrite(rw); + assert_eq!(apply_result, Err(IdentityInsertionError::PortIsOutput)); - let rw = IdentityInsertion::new(final_node, final_node_port); + let final_node_input = h.node_inputs(final_node).next().unwrap(); + + let rw = IdentityInsertion::new(final_node, final_node_input); let apply_result = h.apply_rewrite(rw); assert_eq!(