Skip to content

Commit

Permalink
Implement suggested changes on PR
Browse files Browse the repository at this point in the history
  • Loading branch information
wlcsm committed Aug 3, 2024
1 parent 878173b commit 7476965
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 108 deletions.
118 changes: 28 additions & 90 deletions graphix/open_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Measurement:
angle: float
plane: str

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""Compares if two measurements are equal
Example
Expand All @@ -34,6 +34,9 @@ def __eq__(self, other):
>>> Measurement(0.1, "XY") == Measurement(0.0, "XY")
False
"""
if not isinstance(other, Measurement):
return NotImplemented

return np.allclose(self.angle, other.angle) and self.plane == other.plane

def is_z_measurement(self) -> bool:
Expand All @@ -52,48 +55,6 @@ def is_z_measurement(self) -> bool:
return np.allclose(self.angle, 0.0) and self.plane == "XY"


@dataclass
class Fusion:
"""A fusion between two nodes
:param node1: ID of one of the nodes in the fusion
:param node2: ID of the other node in the fusion
:param fusion_type: The type of fusion. Currently either: "X", "Y"
Example
-------
>>> from graphix.open_graph import Fusion
>>> Fusion(0, 1, "X") == Fusion(0, 1, "X")
True
>>> Fusion(0, 1, "X") == Fusion(0, 1, "Y")
False
>>> Fusion(0, 1, "X") == Fusion(1, 0, "X")
True
>>> Fusion(0, 1, "X") == Fusion(0, 2, "X")
False
"""

node1: int
node2: int
fusion_type: str

def __eq__(self, other) -> bool:
if self.fusion_type != other.fusion_type:
return False

if self.node1 == other.node1 and self.node2 == other.node2:
return True

if self.node1 == other.node2 and self.node2 == other.node1:
return True

return False

def contains(self, node_id: int) -> bool:
"""Indicates whether the node is part of the fusion"""
return node_id in (self.node1, self.node2)


class OpenGraph:
"""Open graph contains the graph, measurement, and input and output
nodes. This is the graph we wish to implement deterministically
Expand Down Expand Up @@ -137,21 +98,18 @@ class OpenGraph:
# * "is_output" - The value is always True
# * "output_order" - A zero-indexed integer used to preserve the ordering of
# the outputs
inside: nx.Graph
_inside: nx.Graph

def __eq__(self, other):
"""Checks the two open graphs are equal
This doesn't check they are equal up to an isomorphism"""

g1 = self.perform_z_deletions()
g2 = other.perform_z_deletions()

return (
g1.inputs == g2.inputs
and g1.outputs == g2.outputs
and nx.utils.graphs_equal(g1.inside, g2.inside)
and g1.measurements == g2.measurements
self.inputs == other.inputs
and self.outputs == other.outputs
and nx.utils.graphs_equal(self._inside, other._inside)
and self.measurements == other.measurements
)

def __init__(
Expand All @@ -166,21 +124,21 @@ def __init__(
The inputs() and outputs() methods will preserve the order that was
original given in to this methods.
"""
self.inside = inside
self._inside = inside

if any(node in outputs for node in measurements):
raise ValueError("output node can not be measured")

for node_id, measurement in measurements.items():
self.inside.nodes[node_id]["measurement"] = measurement
self._inside.nodes[node_id]["measurement"] = measurement

for i, node_id in enumerate(inputs):
self.inside.nodes[node_id]["is_input"] = True
self.inside.nodes[node_id]["input_order"] = i
self._inside.nodes[node_id]["is_input"] = True
self._inside.nodes[node_id]["input_order"] = i

for i, node_id in enumerate(outputs):
self.inside.nodes[node_id]["is_output"] = True
self.inside.nodes[node_id]["output_order"] = i
self._inside.nodes[node_id]["is_output"] = True
self._inside.nodes[node_id]["output_order"] = i

def to_pyzx_graph(self) -> zx.graph.base.BaseGraph:
"""Return a PyZX graph corresponding to the the open graph.
Expand Down Expand Up @@ -210,7 +168,7 @@ def add_vertices(n: int, ty: zx.VertexType) -> list[zx.VertexType]:
g.set_inputs(in_verts)

# Add nodes for internal Z spiders - not including the phase gadgets
body_verts = add_vertices(len(self.inside), zx.VertexType.Z)
body_verts = add_vertices(len(self._inside), zx.VertexType.Z)

# Add nodes for the phase gadgets. In OpenGraph we don't store the
# effect as a seperate node, it is instead just stored in the
Expand All @@ -223,7 +181,7 @@ def add_vertices(n: int, ty: zx.VertexType) -> list[zx.VertexType]:

# Maps a node's ID in the Open Graph to it's corresponding node ID in
# the PyZX graph and vice versa.
map_to_og = dict(zip(body_verts, self.inside.nodes()))
map_to_og = dict(zip(body_verts, self._inside.nodes()))
map_to_pyzx = {v: i for i, v in map_to_og.items()}

# Open Graph's don't have boundary nodes, so we need to connect the
Expand All @@ -234,7 +192,7 @@ def add_vertices(n: int, ty: zx.VertexType) -> list[zx.VertexType]:
for pyzx_index, og_index in zip(out_verts, self.outputs):
g.add_edge((pyzx_index, map_to_pyzx[og_index]))

og_edges = self.inside.edges()
og_edges = self._inside.edges()
pyzx_edges = [(map_to_pyzx[a], map_to_pyzx[b]) for a, b in og_edges]
g.add_edges(pyzx_edges, zx.EdgeType.HADAMARD)

Expand Down Expand Up @@ -269,7 +227,6 @@ def from_pyzx_graph(cls, g: zx.graph.base.BaseGraph) -> OpenGraph:
>>> og = OpenGraph.from_pyzx_graph(g)
"""
zx.simplify.to_graph_like(g)
zx.simplify.full_reduce(g)

measurements = {}
inputs = g.inputs()
Expand Down Expand Up @@ -305,7 +262,7 @@ def from_pyzx_graph(cls, g: zx.graph.base.BaseGraph) -> OpenGraph:

nbrs = list(g.neighbors(v))
if len(nbrs) == 1:
measurements[nbrs[0]] = Measurement(g.phase(v), "YZ")
measurements[nbrs[0]] = Measurement(float(g.phase(v)), "YZ")
g_nx.remove_node(v)

next_id = max(g_nx.nodes) + 1
Expand All @@ -327,7 +284,9 @@ def from_pyzx_graph(cls, g: zx.graph.base.BaseGraph) -> OpenGraph:
if v in outputs or v in measurements:
continue

measurements[v] = Measurement(g.phase(v), "XY")
# g.phase() may be a fractions.Fraction object, but Measurement
# expects a float
measurements[v] = Measurement(float(g.phase(v)), "XY")

return cls(g_nx, measurements, inputs, outputs)

Expand All @@ -349,10 +308,11 @@ def inputs(self) -> list[int]:
>>> og.inputs == inputs
True
"""
inputs = [i for i in self.inside.nodes(data=True) if "is_input" in i[1]]
inputs = [i for i in self._inside.nodes(data=True) if "is_input" in i[1]]
sorted_inputs = sorted(inputs, key=lambda x: x[1]["input_order"])
input_node_ids = [i[0] for i in sorted_inputs]
return input_node_ids

# Returns only the input ids
return [i[0] for i in sorted_inputs]

@property
def outputs(self) -> list[int]:
Expand All @@ -372,7 +332,7 @@ def outputs(self) -> list[int]:
>>> og.outputs == outputs
True
"""
outputs = [i for i in self.inside.nodes(data=True) if "is_output" in i[1]]
outputs = [i for i in self._inside.nodes(data=True) if "is_output" in i[1]]
sorted_outputs = sorted(outputs, key=lambda x: x[1]["output_order"])
output_node_ids = [i[0] for i in sorted_outputs]
return output_node_ids
Expand All @@ -399,26 +359,4 @@ def measurements(self) -> dict[int, Measurement]:
... }
True
"""
return {
n[0]: n[1]["measurement"]
for n in self.inside.nodes(data=True)
if "measurement" in n[1]
}

def perform_z_deletions_in_place(self):
"""Removes the Z-deleted nodes from the graph in place"""
z_measured_nodes = [
node
for node in self.inside.nodes
if "measurement" in self.inside.nodes[node]
and self.inside.nodes[node]["measurement"].is_z_measurement()
]

for node in z_measured_nodes:
self.inside.remove_node(node)

def perform_z_deletions(self) -> OpenGraph:
"""Removes the Z-deleted nodes from the graph"""
g = deepcopy(self)
g.perform_z_deletions_in_place()
return g
return {n[0]: n[1]["measurement"] for n in self._inside.nodes(data=True) if "measurement" in n[1]}
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ pydantic
quimb>=1.4.0
scipy
sympy>=1.9
pyzx==0.8.0
typing_extensions
51 changes: 34 additions & 17 deletions tests/test_open_graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
from __future__ import annotations

import os

import networkx as nx
import numpy as np

Check failure on line 6 in tests/test_open_graph.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

tests/test_open_graph.py:6:17: F401 `numpy` imported but unused
import pytest
import pyzx as zx

from graphix.open_graph import OpenGraph
from graphix.open_graph import Measurement, OpenGraph


def test_graph_no_output_measurements() -> None:
g = nx.Graph([(0, 1)])
meas = {0: Measurement(0, "XY"), 1: Measurement(0, "XY")}
inputs = [0]
outputs = [1]

# Output node can not be measurement
with pytest.raises(ValueError):
OpenGraph(g, meas, inputs, outputs)


def test_graph_equality():
def test_graph_equality() -> None:
file = "./tests/circuits/adder_n4.qasm"
circ = zx.Circuit.load(file)

Expand All @@ -21,10 +36,9 @@ def test_graph_equality():

# Converts a graph to and from an Open graph and then checks the resulting
# pyzx graph is equal to the original.
def assert_reconstructed_pyzx_graph_equal(circ: zx.Circuit):
def assert_reconstructed_pyzx_graph_equal(circ: zx.Circuit) -> None:
g = circ.to_graph()
zx.simplify.to_graph_like(g)
zx.simplify.full_reduce(g)

g_copy = circ.to_graph()
og = OpenGraph.from_pyzx_graph(g_copy)
Expand All @@ -34,28 +48,31 @@ def assert_reconstructed_pyzx_graph_equal(circ: zx.Circuit):
for v in reconstructed_pyzx_graph.vertices():
reconstructed_pyzx_graph.set_row(v, 2)

ten = zx.tensorfy(g).flatten()
ten_graph = zx.tensorfy(reconstructed_pyzx_graph).flatten()
ten = zx.tensorfy(g)
ten_graph = zx.tensorfy(reconstructed_pyzx_graph)
assert zx.compare_tensors(ten, ten_graph)

# Here we check their tensor representations instead of composing g with
# the adjoint of reconstructed_pyzx_graph and checking it reduces to the
# identity since there seems to be a bug where equal graphs don't produce
# the identity
i = np.argmax(ten)
assert np.allclose(ten / ten[i], ten_graph / ten_graph[i])


# Tests that compiling from a pyzx graph to an OpenGraph returns the same
# graph. Only works with small circuits up to 4 qubits since PyZX's `tensorfy`
# function seems to consume huge amount of memory for larger qubit
def test_all_small_circuits():
@pytest.fixture
def all_small_circuits() -> list[zx.Circuit]:
direc = "./tests/circuits/"
directory = os.fsencode(direc)

circuits = []
for file in os.listdir(directory):
filename = os.fsdecode(file)
if not filename.endswith(".qasm"):
raise Exception(f"only '.qasm' files allowed: not {filename}")

circ = zx.Circuit.load(direc + filename)
circuits.append(circ)

return circuits


# Tests that compiling from a pyzx graph to an OpenGraph returns the same
# graph. Only works with small circuits up to 4 qubits since PyZX's `tensorfy`
# function seems to consume huge amount of memory for larger qubit
def test_all_small_circuits(all_small_circuits) -> None:
for circ in all_small_circuits:
assert_reconstructed_pyzx_graph_equal(circ)

0 comments on commit 7476965

Please sign in to comment.