Skip to content

Commit

Permalink
Fix #215: Handle non-contiguous numbering in rustworkx backend (#225)
Browse files Browse the repository at this point in the history
* Fix #215: Handle non-contiguous numbering in `rustworkx` backend

This commit resolves a bug, as tested in the new test case
`test_pauli_non_contiguous`, which previously passed with the
`networkx` backend but failed with the `rustworkx` backend.

The method `BaseGraphState.neighbors` now returns an iteration
over (Graphix) node indices, rather than `rustworkx` vertex
indices.

These indices are then passed to `BaseGraphState.subgraph` in
`BaseGraphState.local_complement`. `BaseGraphState.subgraph` expects
node indices, as it converts them back to `rustworkx` vertex indices.

---------

Co-authored-by: S.S. <[email protected]>
  • Loading branch information
thierry-martinez and EarlMilktea authored Oct 18, 2024
1 parent 90c9433 commit b60ff8b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
4 changes: 2 additions & 2 deletions graphix/graphsim/rxgraphstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def degree(self) -> Iterator[tuple[int, int]]:
ret.append((n, degree))
return iter(ret)

def neighbors(self, node) -> Iterator:
def neighbors(self, node) -> Iterator[int]:
"""Return an iterator over all neighbors of node n.
See :meth:`BaseGraphState.neighbors`.
"""
nidx = self.nodes.get_node_index(node)
return iter(self._graph.neighbors(nidx))
return (self.nodes.idx_to_num[idx] for idx in self._graph.neighbors(nidx))

def subgraph(self, nodes: list) -> rx.PyGraph:
"""Return a subgraph of the graph.
Expand Down
5 changes: 4 additions & 1 deletion graphix/graphsim/rxgraphviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self.nodes = set(node_nums)
self.num_to_data = {nnum: node_datas[nidx] for nidx, nnum in zip(node_indices, node_nums)}
self.num_to_idx = {nnum: nidx for nidx, nnum in zip(node_indices, node_nums)}
self.idx_to_num = {nidx: nnum for nidx, nnum in zip(node_indices, node_nums)}

def __contains__(self, nnum: int) -> bool:
"""Return `True` if the node `nnum` belongs to the list, `False` otherwise."""
Expand Down Expand Up @@ -64,6 +65,7 @@ def add_node(self, nnum: int, ndata: dict, nidx: int) -> None:
self.nodes.add(nnum)
self.num_to_data[nnum] = ndata
self.num_to_idx[nnum] = nidx
self.idx_to_num[nidx] = nnum

def add_nodes_from(self, node_nums: list[int], node_datas: list[dict], node_indices: list[int]) -> None:
"""Add nodes to the list."""
Expand All @@ -80,7 +82,8 @@ def remove_node(self, nnum: int) -> None:
raise ValueError(f"Node {nnum} does not exist")
self.nodes.remove(nnum)
del self.num_to_data[nnum]
del self.num_to_idx[nnum]
idx = self.num_to_idx.pop(nnum)
del self.idx_to_num[idx]

def remove_nodes_from(self, node_nums: list[int]) -> None:
"""Remove nodes from the list."""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ def test_minimize_space(self, fx_rng: Generator) -> None:
state_mbqc = pattern.simulate_pattern(rng=fx_rng)
assert np.abs(np.dot(state_mbqc.flatten().conjugate(), state.flatten())) == pytest.approx(1)

@pytest.mark.parametrize("use_rustworkx", [False, True])
def test_pauli_non_contiguous(self, use_rustworkx: bool) -> None:
pattern = Pattern(input_nodes=[0])
pattern.extend(
[
N(node=2, state=PlanarState(plane=Plane.XY, angle=0.0)),
E(nodes=(0, 2)),
M(node=0, plane=Plane.XY, angle=0.0, s_domain=set(), t_domain=set()),
]
)
pattern.perform_pauli_measurements(use_rustworkx=use_rustworkx)

@pytest.mark.parametrize("jumps", range(1, 11))
def test_minimize_space_with_gflow(self, fx_bg: PCG64, jumps: int, use_rustworkx: bool = True) -> None:
rng = Generator(fx_bg.jumped(jumps))
Expand Down

0 comments on commit b60ff8b

Please sign in to comment.