Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Split & reassemble circuit chunks #130

Merged
merged 12 commits into from
Sep 27, 2023
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "19ed0fc" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "af664e3" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
36 changes: 36 additions & 0 deletions pyrs/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::prelude::*;
use hugr::{Hugr, HugrView};
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket2::passes::CircuitChunks;
use tket_json_rs::circuit_json::SerialCircuit;

/// Apply a fallible function expecting a hugr on a pytket circuit.
Expand Down Expand Up @@ -52,3 +53,38 @@ pub fn to_hugr_dot(c: Py<PyAny>) -> PyResult<String> {
pub fn to_hugr(c: Py<PyAny>) -> PyResult<Hugr> {
with_hugr(c, |hugr| hugr)
}

#[pyfunction]
pub fn chunks(c: Py<PyAny>, max_chunk_size: usize) -> PyResult<CircuitChunks> {
with_hugr(c, |hugr| CircuitChunks::split(&hugr, max_chunk_size))
}

/// circuit module
pub fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;
m.add_class::<tket2::passes::CircuitChunks>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr, m)?)?;
m.add_function(wrap_pyfunction!(chunks, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}
26 changes: 1 addition & 25 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]
use circuit::try_with_hugr;
use circuit::{add_circuit_module, try_with_hugr};
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;
Expand All @@ -25,30 +25,6 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
Ok(())
}

/// circuit module
fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}

/// portmatching module
fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "pattern")?;
Expand Down
13 changes: 12 additions & 1 deletion pyrs/test/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from pyrs.pyrs import passes
from pyrs.pyrs import passes, circuit
from pytket.circuit import Circuit


Expand All @@ -19,6 +19,17 @@ def test_depth_optimise():

assert c.depth() == 2

def test_chunks():
c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3)

assert c.depth() == 3

chunks = circuit.chunks(c, 2)
circuits = chunks.circuits()
chunks.update_circuit(0, circuits[0])
c2 = chunks.reassemble()

assert c2.depth() == 3

# from dataclasses import dataclass
# from typing import Callable, Iterable
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ pub(crate) mod ops;
pub mod optimiser;
pub mod passes;
pub mod rewrite;
pub use ops::{symbolic_constant_op, Pauli, T2Op};

#[cfg(feature = "portmatching")]
pub mod portmatching;

mod utils;

pub use circuit::Circuit;
pub use ops::{symbolic_constant_op, Pauli, T2Op};
5 changes: 4 additions & 1 deletion src/passes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Optimisation passes for circuits.
//! Optimisation passes and related utilities for circuits.

mod commutation;
pub use commutation::apply_greedy_commutation;
#[cfg(feature = "pyo3")]
pub use commutation::PyPullForwardError;

pub mod chunks;
pub use chunks::CircuitChunks;
Loading