Skip to content

Commit

Permalink
feat: Split & reassemble circuit chunks (#130)
Browse files Browse the repository at this point in the history
This is a pass utility to split circuits into chunks that can be
independently optimized.

Closes #129
  • Loading branch information
aborgna-q authored Sep 27, 2023
1 parent 31ecafa commit d8fce77
Show file tree
Hide file tree
Showing 7 changed files with 396 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "19ed0fc" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "af664e3" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
36 changes: 36 additions & 0 deletions pyrs/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::prelude::*;
use hugr::{Hugr, HugrView};
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket2::passes::CircuitChunks;
use tket_json_rs::circuit_json::SerialCircuit;

/// Apply a fallible function expecting a hugr on a pytket circuit.
Expand Down Expand Up @@ -52,3 +53,38 @@ pub fn to_hugr_dot(c: Py<PyAny>) -> PyResult<String> {
pub fn to_hugr(c: Py<PyAny>) -> PyResult<Hugr> {
with_hugr(c, |hugr| hugr)
}

#[pyfunction]
pub fn chunks(c: Py<PyAny>, max_chunk_size: usize) -> PyResult<CircuitChunks> {
with_hugr(c, |hugr| CircuitChunks::split(&hugr, max_chunk_size))
}

/// circuit module
pub fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;
m.add_class::<tket2::passes::CircuitChunks>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr, m)?)?;
m.add_function(wrap_pyfunction!(chunks, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}
26 changes: 1 addition & 25 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]
use circuit::try_with_hugr;
use circuit::{add_circuit_module, try_with_hugr};
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;
Expand All @@ -25,30 +25,6 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
Ok(())
}

/// circuit module
fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}

/// portmatching module
fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "pattern")?;
Expand Down
13 changes: 12 additions & 1 deletion pyrs/test/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from pyrs.pyrs import passes
from pyrs.pyrs import passes, circuit
from pytket.circuit import Circuit


Expand All @@ -19,6 +19,17 @@ def test_depth_optimise():

assert c.depth() == 2

def test_chunks():
c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3)

assert c.depth() == 3

chunks = circuit.chunks(c, 2)
circuits = chunks.circuits()
chunks.update_circuit(0, circuits[0])
c2 = chunks.reassemble()

assert c2.depth() == 3

# from dataclasses import dataclass
# from typing import Callable, Iterable
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ pub(crate) mod ops;
pub mod optimiser;
pub mod passes;
pub mod rewrite;
pub use ops::{symbolic_constant_op, Pauli, T2Op};

#[cfg(feature = "portmatching")]
pub mod portmatching;

mod utils;

pub use circuit::Circuit;
pub use ops::{symbolic_constant_op, Pauli, T2Op};
5 changes: 4 additions & 1 deletion src/passes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Optimisation passes for circuits.
//! Optimisation passes and related utilities for circuits.

mod commutation;
pub use commutation::apply_greedy_commutation;
#[cfg(feature = "pyo3")]
pub use commutation::PyPullForwardError;

pub mod chunks;
pub use chunks::CircuitChunks;
Loading

0 comments on commit d8fce77

Please sign in to comment.