Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #215: Handle non-contiguous numbering in rustworkx backend #225

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphix/graphsim/rxgraphstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def neighbors(self, node) -> Iterator:
See :meth:`BaseGraphState.neighbors`.
"""
nidx = self.nodes.get_node_index(node)
return iter(self._graph.neighbors(nidx))
return iter(self.nodes.idx_to_num[idx] for idx in self._graph.neighbors(nidx))
thierry-martinez marked this conversation as resolved.
Show resolved Hide resolved

def subgraph(self, nodes: list) -> rx.PyGraph:
"""Return a subgraph of the graph.
Expand Down
5 changes: 4 additions & 1 deletion graphix/graphsim/rxgraphviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self.nodes = set(node_nums)
self.num_to_data = {nnum: node_datas[nidx] for nidx, nnum in zip(node_indices, node_nums)}
self.num_to_idx = {nnum: nidx for nidx, nnum in zip(node_indices, node_nums)}
self.idx_to_num = {nidx: nnum for nidx, nnum in zip(node_indices, node_nums)}

def __contains__(self, nnum: int) -> bool:
"""Return `True` if the node `nnum` belongs to the list, `False` otherwise."""
Expand Down Expand Up @@ -64,6 +65,7 @@ def add_node(self, nnum: int, ndata: dict, nidx: int) -> None:
self.nodes.add(nnum)
self.num_to_data[nnum] = ndata
self.num_to_idx[nnum] = nidx
self.idx_to_num[nidx] = nnum

def add_nodes_from(self, node_nums: list[int], node_datas: list[dict], node_indices: list[int]) -> None:
"""Add nodes to the list."""
Expand All @@ -80,7 +82,8 @@ def remove_node(self, nnum: int) -> None:
raise ValueError(f"Node {nnum} does not exist")
self.nodes.remove(nnum)
del self.num_to_data[nnum]
del self.num_to_idx[nnum]
idx = self.num_to_idx.pop(nnum)
del self.idx_to_num[idx]

def remove_nodes_from(self, node_nums: list[int]) -> None:
"""Remove nodes from the list."""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ def test_minimize_space(self, fx_rng: Generator) -> None:
state_mbqc = pattern.simulate_pattern(rng=fx_rng)
assert np.abs(np.dot(state_mbqc.flatten().conjugate(), state.flatten())) == pytest.approx(1)

@pytest.mark.parametrize("use_rustworkx", [False, True])
def test_pauli_non_contiguous(self, use_rustworkx) -> None:
thierry-martinez marked this conversation as resolved.
Show resolved Hide resolved
pattern = Pattern(input_nodes=[0])
pattern.extend(
[
N(node=2, state=PlanarState(plane=Plane.XY, angle=0.0)),
E(nodes=(0, 2)),
M(node=0, plane=Plane.XY, angle=0.0, s_domain=set(), t_domain=set()),
]
)
pattern.perform_pauli_measurements(use_rustworkx=use_rustworkx)

@pytest.mark.parametrize("jumps", range(1, 11))
def test_minimize_space_with_gflow(self, fx_bg: PCG64, jumps: int, use_rustworkx: bool = True) -> None:
rng = Generator(fx_bg.jumped(jumps))
Expand Down
Loading