Skip to content

Commit

Permalink
add check port is output
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Aug 31, 2023
1 parent 344d983 commit 21ac480
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::hugr::{HugrMut, Node};
use crate::ops::LeafOp;
use crate::types::EdgeKind;
use crate::{Hugr, HugrView, Port};
use crate::{Direction, Hugr, HugrView, Port};

use super::Rewrite;

Expand Down Expand Up @@ -38,6 +38,10 @@ pub enum IdentityInsertionError {
/// Invalid port kind.
#[error("post_port has invalid kind {0:?}. Must be Value.")]
InvalidPortKind(Option<EdgeKind>),

/// Must be input port.
#[error("post_port is an output port, must be input.")]
PortIsOutput,
}

impl Rewrite for IdentityInsertion {
Expand All @@ -53,11 +57,15 @@ impl Rewrite for IdentityInsertion {
Conditions:
1. post_port is Value kind.
2. post_port is connected to a sibling of post_node.
3. post_port is input.
*/

unimplemented!()
}
fn apply(self, h: &mut Hugr) -> Result<Self::ApplyResult, IdentityInsertionError> {
if self.post_port.direction() != Direction::Incoming {
return Err(IdentityInsertionError::PortIsOutput);
}
let (pre_node, pre_port) = h
.linked_ports(self.post_node, self.post_port)
.exactly_one()
Expand All @@ -73,7 +81,6 @@ impl Rewrite for IdentityInsertion {
h.connect(pre_node, pre_port.index(), new_node, 0)
.expect("Should only fail if ports don't exist.");

// TODO Check type, insert Noop...
h.connect(new_node, 0, self.post_node, self.post_port.index())
.expect("Should only fail if ports don't exist.");
Ok(new_node)
Expand Down Expand Up @@ -121,9 +128,14 @@ mod tests {

let final_node = tail.node();

let final_node_port = h.node_inputs(final_node).next().unwrap();
let final_node_output = h.node_outputs(final_node).next().unwrap();
let rw = IdentityInsertion::new(final_node, final_node_output);
let apply_result = h.apply_rewrite(rw);
assert_eq!(apply_result, Err(IdentityInsertionError::PortIsOutput));

let rw = IdentityInsertion::new(final_node, final_node_port);
let final_node_input = h.node_inputs(final_node).next().unwrap();

let rw = IdentityInsertion::new(final_node, final_node_input);

let apply_result = h.apply_rewrite(rw);
assert_eq!(
Expand Down

0 comments on commit 21ac480

Please sign in to comment.