Skip to content

Commit

Permalink
New test (breaking out to base_hugr/hugr_mut when necessary).
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Sep 12, 2023
1 parent 2930a3b commit b4b4e5e
Showing 1 changed file with 89 additions and 38 deletions.
127 changes: 89 additions & 38 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ use thiserror::Error;

use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
use crate::extension::ExtensionSet;
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::rewrite::Rewrite;
use crate::hugr::views::{HierarchyView, SiblingGraph};
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
use crate::ops::handle::{BasicBlockID, CfgID};
use crate::ops::{BasicBlock, OpTag, OpTrait, OpType};
use crate::{type_row, Node};

Expand Down Expand Up @@ -139,13 +142,17 @@ impl Rewrite for OutlineCfg {
};

// 3. Extract Cfg node created above (it moved when we called insert_hugr)
let cfg_node = h
// Support filtered Sibling-only views by explicitly descending into new_block
let in_bb_view: SiblingGraph<'_, BasicBlockID> = SiblingGraph::new(h, new_block).unwrap();
let cfg_node = in_bb_view
.children(new_block)
.filter(|n| h.get_optype(*n).tag() == OpTag::Cfg)
.filter(|n| in_bb_view.get_optype(*n).tag() == OpTag::Cfg)
.exactly_one()
.ok() // HugrMut::Children is not Debug
.unwrap();
let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
let in_cfg_view: SiblingGraph<'_, CfgID> =
SiblingGraph::new(&in_bb_view, cfg_node).unwrap();
let inner_exit = in_cfg_view.children(cfg_node).exactly_one().ok().unwrap();

// 4. Entry edges. Change any edges into entry_block from outside, to target new_block
let preds: Vec<_> = h
Expand All @@ -163,31 +170,37 @@ impl Rewrite for OutlineCfg {
h.move_before_sibling(new_block, outer_entry).unwrap();
}

// 5. Children of new CFG.
// Entry node must be first
h.move_before_sibling(entry, inner_exit).unwrap();
// And remaining nodes
for n in self.blocks {
// Do not move the entry node, as we have already
if n != entry {
h.set_parent(n, cfg_node).unwrap();
{
// These operations do not fit into any SiblingView
// so we need to access the Hugr directly.
let h = h.hugr_mut();

// 5. Children of new CFG.
// Entry node must be first
h.move_before_sibling(entry, inner_exit).unwrap();
// And remaining nodes
for n in self.blocks {
// Do not move the entry node, as we have already
if n != entry {
h.set_parent(n, cfg_node).unwrap();
}
}
}

// 6. Exit edges.
// Retarget edge from exit_node (that used to target outside) to inner_exit
let exit_port = h
.node_outputs(exit)
.filter(|p| {
let (t, p2) = h.linked_ports(exit, *p).exactly_one().ok().unwrap();
assert!(p2.index() == 0);
t == outside
})
.exactly_one()
.ok() // NodePorts does not implement Debug
.unwrap();
h.disconnect(exit, exit_port).unwrap();
h.connect(exit, exit_port.index(), inner_exit, 0).unwrap();
// 6. Exit edges.
// Retarget edge from exit_node (that used to target outside) to inner_exit
let exit_port = h
.node_outputs(exit)
.filter(|p| {
let (t, p2) = h.linked_ports(exit, *p).exactly_one().ok().unwrap();
assert!(p2.index() == 0);
t == outside
})
.exactly_one()
.ok() // NodePorts does not implement Debug
.unwrap();
h.disconnect(exit, exit_port).unwrap();
h.connect(exit, exit_port.index(), inner_exit, 0).unwrap();
}
// And connect new_block to outside instead
h.connect(new_block, 0, outside, 0).unwrap();

Expand Down Expand Up @@ -226,19 +239,25 @@ mod test {
use std::collections::HashSet;

use crate::algorithm::nest_cfgs::test::{
build_cond_then_loop_cfg, build_conditional_in_loop_cfg,
build_cond_then_loop_cfg, build_conditional_in_loop, build_conditional_in_loop_cfg,
};
use crate::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer,
};
use crate::extension::prelude::USIZE_T;
use crate::extension::PRELUDE_REGISTRY;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::HugrMut;
use crate::ops::handle::NodeHandle;
use crate::{HugrView, Node};
use crate::ops::handle::{BasicBlockID, NodeHandle};
use crate::types::FunctionType;
use crate::{type_row, HugrView, Node};
use cool_asserts::assert_matches;
use itertools::Itertools;

use super::{OutlineCfg, OutlineCfgError};

fn depth(h: &impl HugrView, n: Node) -> u32 {
match h.get_parent(n) {
match h.base_hugr().get_parent(n) {
Some(p) => 1 + depth(h, p),
None => 0,
}
Expand Down Expand Up @@ -274,6 +293,17 @@ mod test {
#[test]
fn test_outline_cfg() {
let (mut h, head, tail) = build_conditional_in_loop_cfg(false).unwrap();
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
do_outline_cfg_test(&mut h, head, tail, 1);
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
}

fn do_outline_cfg_test(
h: &mut impl HugrMut,
head: BasicBlockID,
tail: BasicBlockID,
expected_depth: u32,
) {
let head = head.node();
let tail = tail.node();
let parent = h.get_parent(head).unwrap();
Expand All @@ -283,29 +313,50 @@ mod test {
// | \-> right -/ |
// \---<---<---<---<---<--<---/
// merge is unique predecessor of tail
let merge = h.input_neighbours(tail).exactly_one().unwrap();
let merge = h.input_neighbours(tail).exactly_one().ok().unwrap();
let [left, right]: [Node; 2] = h.output_neighbours(head).collect_vec().try_into().unwrap();
for n in [head, tail, merge] {
assert_eq!(depth(&h, n), 1);
assert_eq!(depth(h, n), expected_depth);
}
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
let blocks = [head, left, right, merge];
h.apply_rewrite(OutlineCfg::new(blocks)).unwrap();
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
for n in blocks {
assert_eq!(depth(&h, n), 3);
assert_eq!(depth(h, n), expected_depth + 2);
}
let new_block = h.output_neighbours(entry).exactly_one().unwrap();
let new_block = h.output_neighbours(entry).exactly_one().ok().unwrap();
for n in [entry, exit, tail, new_block] {
assert_eq!(depth(&h, n), 1);
assert_eq!(depth(h, n), expected_depth);
}
assert_eq!(h.input_neighbours(tail).exactly_one().unwrap(), new_block);
assert_eq!(
h.input_neighbours(tail).exactly_one().ok().unwrap(),
new_block
);
assert_eq!(
h.output_neighbours(tail).take(2).collect::<HashSet<Node>>(),
HashSet::from([exit, new_block])
);
}

#[test]
fn test_outline_cfg_subregion() {
let mut module_builder = ModuleBuilder::new();
let mut fbuild = module_builder
.define_function(
"main",
FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]).pure(),
)
.unwrap();
let [i1] = fbuild.input_wires_arr();
let mut cfg_builder = fbuild
.cfg_builder([(USIZE_T, i1)], type_row![USIZE_T], Default::default())
.unwrap();
let (head, tail) = build_conditional_in_loop(&mut cfg_builder, false).unwrap();
let cfg = cfg_builder.finish_sub_container().unwrap();
fbuild.finish_with_outputs(cfg.outputs()).unwrap();
let mut h = module_builder.finish_prelude_hugr().unwrap();
do_outline_cfg_test(&mut SiblingMut::new(&mut h, cfg.node()), head, tail, 3);
}

#[test]
fn test_outline_cfg_move_entry() {
// /-> left --\
Expand Down

0 comments on commit b4b4e5e

Please sign in to comment.