Skip to content

Commit

Permalink
Add Open graph class (#191)
Browse files Browse the repository at this point in the history
* Add OpenGraph class for graphical compilation

This is based on the Open Graph definition from the "Picturing Quantum
Software" textbook.
Comes with methods to convert to and from pyzx diagrams.

Provides properties inputs(), outputs(), and measurements() that are
compatible with the rest of graphix.

* Add OpenGraph to docs

* Add Open Graph tests

* Implement suggested changes on PR

* Add PyZX requirement into dev requirements

* Use enum for measurement planes

* Skip tests if pyzx isn't installed

* Add conversion between patterns and Open graphs

* Extract PyZX code into separate file

This is because we want to treat PyZX as an optional dependency.
Since OpenGraph is meant to be a standardised interface for many
projects to use, it shouldn't include any optional dependencies

* Update graphix/open_graph.py

Co-authored-by: Shinichi Sunami <[email protected]>

* Improve docstrings on pattern methods

* Rename open_graph.py to opengraph.py

* Change docs for measurement angle

It was always between [0, 2). The docstring was wrong.

* Improve Measurement class's comparisons

Removed is_z_measurement() method since there is no need for it here.

Since we compare floats in the equality operation, we have converted it
into an `isclose` method so we can include the relative and absolute
error tolerances.

* Simplify graph equality operation

* Simplify internal datastructure for OpenGraph

* Improve code quality

* Use Mapping for the iterface for extensibility

* Add warning for PyZX version

Co-authored-by: Shinichi Sunami <[email protected]>

* Remove classmethod for OpenGraph

* Add comments for clarity

Co-authored-by: Shinichi Sunami <[email protected]>

* Fix lints

* Implement changes for random circuit testing

Co-authored-by: Shinichi Sunami <[email protected]>

* Implement changes for random circuit testing

Co-authored-by: Shinichi Sunami <[email protected]>

* Implement changes for random circuit testing

Co-authored-by: Shinichi Sunami <[email protected]>

* Implement changes for random circuit testing

Co-authored-by: Shinichi Sunami <[email protected]>

* Implement changes for random circuit testing

Co-authored-by: Shinichi Sunami <[email protected]>

* Implement changes for random circuit testing

Co-authored-by: Shinichi Sunami <[email protected]>

* Implement changes for random circuit testing

Co-authored-by: Shinichi Sunami <[email protected]>

* No need to reset the random seed

* Fix formatting

* Remove qasm files

* Add type annotations

Co-authored-by: S.S. <[email protected]>

* Add type annotations

* Highlight that inputs/outputs are ordered

* Simplify

* Add type annotations

* Avoid consuming the iterator

* Switch from NamedTuple to dataclass for validation

Authored by thierry-martinez

* Check open graphs are close, not equal

---------

Co-authored-by: Shinichi Sunami <[email protected]>
Co-authored-by: S.S. <[email protected]>
  • Loading branch information
3 people authored Aug 23, 2024
1 parent 364be19 commit aec54c0
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 0 deletions.
13 changes: 13 additions & 0 deletions docs/source/open_graph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Open Graph
======================

:mod:`graphix.opengraph` module
+++++++++++++++++++++++++++++

This module defines classes for defining MBQC patterns as Open Graphs.

.. currentmodule:: graphix.opengraph

.. autoclass:: OpenGraph

.. autoclass:: Measurement
142 changes: 142 additions & 0 deletions graphix/opengraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Provides a class for open graphs."""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import TYPE_CHECKING

import networkx as nx

from graphix.generator import generate_from_graph

if TYPE_CHECKING:
from graphix.pattern import Pattern
from graphix.pauli import Plane


@dataclass(frozen=True)
class Measurement:
"""An MBQC measurement.
:param angle: the angle of the measurement. Should be between [0, 2)
:param plane: the measurement plane
"""

angle: float
plane: Plane

def isclose(self, other: Measurement, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Compares if two measurements have the same plane and their angles
are close.
Example
-------
>>> from graphix.opengraph import Measurement
>>> from graphix.pauli import Plane
>>> Measurement(0.0, Plane.XY).isclose(Measurement(0.0, Plane.XY))
True
>>> Measurement(0.0, Plane.XY).isclose(Measurement(0.0, Plane.YZ))
False
>>> Measurement(0.1, Plane.XY).isclose(Measurement(0.0, Plane.XY))
False
"""
return math.isclose(self.angle, other.angle, rel_tol=rel_tol, abs_tol=abs_tol) and self.plane == other.plane


@dataclass(frozen=True)
class OpenGraph:
"""Open graph contains the graph, measurement, and input and output
nodes. This is the graph we wish to implement deterministically
:param inside: the underlying graph state
:param measurements: a dictionary whose key is the ID of a node and the
value is the measurement at that node
:param inputs: an ordered list of node IDs that are inputs to the graph
:param outputs: an ordered list of node IDs that are outputs of the graph
Example
-------
>>> import networkx as nx
>>> from graphix.opengraph import OpenGraph, Measurement
>>>
>>> inside_graph = nx.Graph([(0, 1), (1, 2), (2, 0)])
>>>
>>> measurements = {i: Measurement(0.5 * i, Plane.XY) for i in range(2)}
>>> inputs = [0]
>>> outputs = [2]
>>> og = OpenGraph(inside_graph, measurements, inputs, outputs)
"""

inside: nx.Graph
measurements: dict[int, Measurement]
inputs: list[int] # Inputs are ordered
outputs: list[int] # Outputs are ordered

def __post_init__(self) -> None:
if not all(node in self.inside.nodes for node in self.measurements):
raise ValueError("All measured nodes must be part of the graph's nodes.")
if not all(node in self.inside.nodes for node in self.inputs):
raise ValueError("All input nodes must be part of the graph's nodes.")
if not all(node in self.inside.nodes for node in self.outputs):
raise ValueError("All output nodes must be part of the graph's nodes.")
if any(node in self.outputs for node in self.measurements):
raise ValueError("Output node cannot be measured.")
if len(set(self.inputs)) != len(self.inputs):
raise ValueError("Input nodes contain duplicates.")
if len(set(self.outputs)) != len(self.outputs):
raise ValueError("Output nodes contain duplicates.")

def isclose(self, other: OpenGraph, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Compared two open graphs implement approximately the same unitary
operator by ensuring the structure of the graphs are the same and all
measurement angles are sufficiently close.
This doesn't check they are equal up to an isomorphism"""

if not nx.utils.graphs_equal(self.inside, other.inside):
return False

if self.inputs != other.inputs or self.outputs != other.outputs:
return False

if set(self.measurements.keys()) != set(other.measurements.keys()):
return False

return all(m.isclose(other.measurements[node]) for node, m in self.measurements.items())

@classmethod
def from_pattern(cls, pattern: Pattern) -> OpenGraph:
"""Initialises an `OpenGraph` object based on the resource-state graph
associated with the measurement pattern."""
g = nx.Graph()
nodes, edges = pattern.get_graph()
g.add_nodes_from(nodes)
g.add_edges_from(edges)

inputs = pattern.input_nodes
outputs = pattern.output_nodes

meas_planes = pattern.get_meas_plane()
meas_angles = pattern.get_angles()
meas = {node: Measurement(meas_angles[node], meas_planes[node]) for node in meas_angles}

return cls(g, meas, inputs, outputs)

def to_pattern(self) -> Pattern:
"""Converts the `OpenGraph` into a `Pattern`.
Will raise an exception if the open graph does not have flow, gflow, or
Pauli flow.
The pattern will be generated using maximally-delayed flow.
"""

g = self.inside.copy()
inputs = self.inputs
outputs = self.outputs
meas = self.measurements

angles = {node: m.angle for node, m in meas.items()}
planes = {node: m.plane for node, m in meas.items()}

return generate_from_graph(g, angles, inputs, outputs, planes)
178 changes: 178 additions & 0 deletions graphix/pyzx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Functionality for converting between OpenGraphs and PyZX
These functions are held in their own file rather than including them in the
OpenGraph class because we want PyZX to be an optional dependency.
"""

from __future__ import annotations

import warnings

import networkx as nx
import pyzx as zx

from graphix.opengraph import Measurement, OpenGraph
from graphix.pauli import Plane


def to_pyzx_graph(og: OpenGraph) -> zx.graph.base.BaseGraph:
"""Return a PyZX graph corresponding to the the open graph.
Example
-------
>>> import networkx as nx
>>> g = nx.Graph([(0, 1), (1, 2)])
>>> inputs = [0]
>>> outputs = [2]
>>> measurements = {0: Measurement(0, Plane.XY), 1: Measurement(1, Plane.YZ)}
>>> og = OpenGraph(g, measurements, inputs, outputs)
>>> reconstructed_pyzx_graph = og.to_pyzx_graph()
"""
# check pyzx availability and version
try:
import pyzx as zx
except ModuleNotFoundError as e:
msg = "Cannot find pyzx (optional dependency)."
raise RuntimeError(msg) from e
if zx.__version__ != "0.8.0":
warnings.warn(
"`to_pyzx_graph` is guaranteed to work only with pyzx==0.8.0 due to possible breaking changes in `pyzx`.",
stacklevel=1,
)
g = zx.Graph()

# Add vertices into the graph and set their type
def add_vertices(n: int, ty: zx.VertexType) -> list[zx.VertexType]:
verts = g.add_vertices(n)
for vert in verts:
g.set_type(vert, ty)

return verts

# Add input boundary nodes
in_verts = add_vertices(len(og.inputs), zx.VertexType.BOUNDARY)
g.set_inputs(in_verts)

# Add nodes for internal Z spiders - not including the phase gadgets
body_verts = add_vertices(len(og.inside), zx.VertexType.Z)

# Add nodes for the phase gadgets. In OpenGraph we don't store the
# effect as a seperate node, it is instead just stored in the
# "measurement" attribute of the node it measures.
x_meas = [i for i, m in og.measurements.items() if m.plane == Plane.YZ]
x_meas_verts = add_vertices(len(x_meas), zx.VertexType.Z)

out_verts = add_vertices(len(og.outputs), zx.VertexType.BOUNDARY)
g.set_outputs(out_verts)

# Maps a node's ID in the Open Graph to it's corresponding node ID in
# the PyZX graph and vice versa.
map_to_og = dict(zip(body_verts, og.inside.nodes()))
map_to_pyzx = {v: i for i, v in map_to_og.items()}

# Open Graph's don't have boundary nodes, so we need to connect the
# input and output Z spiders to their corresponding boundary nodes in
# pyzx.
for pyzx_index, og_index in zip(in_verts, og.inputs):
g.add_edge((pyzx_index, map_to_pyzx[og_index]))
for pyzx_index, og_index in zip(out_verts, og.outputs):
g.add_edge((pyzx_index, map_to_pyzx[og_index]))

og_edges = og.inside.edges()
pyzx_edges = ((map_to_pyzx[a], map_to_pyzx[b]) for a, b in og_edges)
g.add_edges(pyzx_edges, zx.EdgeType.HADAMARD)

# Add the edges between the Z spiders in the graph body
for og_index, meas in og.measurements.items():
# If it's an X measured node, then we handle it in the next loop
if meas.plane == Plane.XY:
g.set_phase(map_to_pyzx[og_index], meas.angle)

# Connect the X measured vertices
for og_index, pyzx_index in zip(x_meas, x_meas_verts):
g.add_edge((map_to_pyzx[og_index], pyzx_index), zx.EdgeType.HADAMARD)
g.set_phase(pyzx_index, og.measurements[og_index].angle)

return g


def from_pyzx_graph(g: zx.graph.base.BaseGraph) -> OpenGraph:
"""Constructs an Optyx Open Graph from a PyZX graph.
This method may add additional nodes to the graph so that it adheres
with the definition of an OpenGraph. For instance, if the final node on
a qubit is measured, it will add two nodes behind it so that no output
nodes are measured to satisfy the requirements of an open graph.
.. warning::
works with `pyzx==0.8.0` (see `requirements-dev.txt`). Other versions may not be compatible due to breaking changes in `pyzx`
Example
-------
>>> import pyzx as zx
>>> from graphix.opengraph import OpenGraph
>>> circ = zx.qasm("qreg q[2]; h q[1]; cx q[0], q[1]; h q[1];")
>>> g = circ.to_graph()
>>> og = OpenGraph.from_pyzx_graph(g)
"""
zx.simplify.to_graph_like(g)

measurements = {}
inputs = g.inputs()
outputs = g.outputs()

g_nx = nx.Graph(g.edges())

# We need to do this since the full reduce simplification can
# leave either hadamard or plain wires on the inputs and outputs
for inp in g.inputs():
first_nbr = next(iter(g.neighbors(inp)))
et = g.edge_type((first_nbr, inp))

if et == zx.EdgeType.SIMPLE:
g_nx.remove_node(inp)
inputs = [i if i != inp else first_nbr for i in inputs]

for out in g.outputs():
first_nbr = next(iter(g.neighbors(out)))
et = g.edge_type((first_nbr, out))

if et == zx.EdgeType.SIMPLE:
g_nx.remove_node(out)
outputs = [o if o != out else first_nbr for o in outputs]

# Turn all phase gadgets into measurements
# Since we did a full reduce, any node that isn't an input or output
# node and has only one neighbour is definitely a phase gadget.
nodes = list(g_nx.nodes())
for v in nodes:
if v in inputs or v in outputs:
continue

nbrs = list(g.neighbors(v))
if len(nbrs) == 1:
measurements[nbrs[0]] = Measurement(float(g.phase(v)), Plane.YZ)
g_nx.remove_node(v)

next_id = max(g_nx.nodes) + 1

# Since outputs can't be measured, we need to add an extra two nodes
# in to counter it
for out in outputs:
if g.phase(out) == 0:
continue

g_nx.add_edges_from([(out, next_id), (next_id, next_id + 1)])
measurements[next_id] = Measurement(0, Plane.XY)

outputs = [o if o != out else next_id + 1 for o in outputs]
next_id += 2

# Add the phase to all XY measured nodes
for v in g_nx.nodes:
if v in outputs or v in measurements:
continue

# g.phase() may be a fractions.Fraction object, but Measurement
# expects a float
measurements[v] = Measurement(float(g.phase(v)), Plane.XY)

return OpenGraph(g_nx, measurements, inputs, outputs)
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ tox
qiskit>=1.0
qiskit-aer
rustworkx

# Optional dependency. Pinned due to version changes often being incompatible
pyzx==0.8.0
41 changes: 41 additions & 0 deletions tests/test_opengraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import networkx as nx

from graphix.opengraph import Measurement, OpenGraph
from graphix.pauli import Plane


# Tests whether an open graph can be converted to and from a pattern and be
# successfully reconstructed.
def test_open_graph_to_pattern() -> None:
g = nx.Graph([(0, 1), (1, 2)])
inputs = [0]
outputs = [2]
meas = {0: Measurement(0, Plane.XY), 1: Measurement(0, Plane.XY)}
og = OpenGraph(g, meas, inputs, outputs)

pattern = og.to_pattern()
og_reconstructed = OpenGraph.from_pattern(pattern)

assert og.isclose(og_reconstructed)

# 0 -- 1 -- 2
# |
# 3 -- 4 -- 5
g = nx.Graph([(0, 1), (1, 2), (1, 4), (3, 4), (4, 5)])
inputs = [0, 3]
outputs = [2, 5]
meas = {
0: Measurement(0, Plane.XY),
1: Measurement(1.0, Plane.XY),
3: Measurement(1.0, Plane.YZ),
4: Measurement(1.0, Plane.XY),
}

og = OpenGraph(g, meas, inputs, outputs)

pattern = og.to_pattern()
og_reconstructed = OpenGraph.from_pattern(pattern)

assert og.isclose(og_reconstructed)
Loading

0 comments on commit aec54c0

Please sign in to comment.