Skip to content

Commit

Permalink
Add rules (TeamGraphix#186)
Browse files Browse the repository at this point in the history
* 🔧 Add more rules

* 🐛 Resolve invalid escapes

* ♻️ Resolve loop variable override

* ♻️ Resolve RUF rules

* 🔧 Add PERF rule

* ⚡ Resolve PREF rules

* ♻️ Use raw string

* 🔧 Add TCH rule

* ⚡ Resolve TCH rule

* 🚨 Apply isort

* ♻️ Update graphix/gflow.py

Co-authored-by: thierry-martinez <[email protected]>

* ♻️ Update graphix/pattern.py

Co-authored-by: thierry-martinez <[email protected]>

* ♻️ Fix endless while-loop in `visualization.py`

The commit 1d23a4f introduced an endless while-loop in

* 🚨 Remove dead import

* 🎨 Format

* 💡 Note on (p,) =

---------

Co-authored-by: thierry-martinez <[email protected]>
  • Loading branch information
EarlMilktea and thierry-martinez authored Jul 28, 2024
1 parent 40985e2 commit 364be19
Show file tree
Hide file tree
Showing 22 changed files with 137 additions and 97 deletions.
2 changes: 1 addition & 1 deletion examples/qft_with_tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def qft(circuit, n):
# To specify TN backend of the simulation, simply provide as a keyword argument.
# here we do a very basic check that one of the statevector amplitudes is what it is expected to be:

import time # noqa: E402
import time

t1 = time.time()
tn = pattern.simulate_pattern(backend="tensornetwork")
Expand Down
8 changes: 4 additions & 4 deletions graphix/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class KrausChannel:
"""quantum channel class in the Kraus representation.
Defined by Kraus operators :math:`K_i` with scalar prefactors :code:`coef`) :math:`c_i`,
where the channel act on density matrix as :math:`\\rho' = \sum_i K_i^\dagger \\rho K_i`.
The data should satisfy :math:`\sum K_i^\dagger K_i = I`
where the channel act on density matrix as :math:`\\rho' = \\sum_i K_i^\\dagger \\rho K_i`.
The data should satisfy :math:`\\sum K_i^\\dagger K_i = I`
Attributes
----------
Expand Down Expand Up @@ -139,7 +139,7 @@ def two_qubit_depolarising_channel(prob: float) -> KrausChannel:
"""two-qubit depolarising channel.
.. math::
\mathcal{E} (\\rho) = (1-p) \\rho + \\frac{p}{15} \sum_{P_i \in \{id, X, Y ,Z\}^{\otimes 2}/(id \otimes id)}P_i \\rho P_i
\\mathcal{E} (\\rho) = (1-p) \\rho + \\frac{p}{15} \\sum_{P_i \\in \\{id, X, Y ,Z\\}^{\\otimes 2}/(id \\otimes id)}P_i \\rho P_i
Parameters
----------
Expand Down Expand Up @@ -179,7 +179,7 @@ def two_qubit_depolarising_tensor_channel(prob: float) -> KrausChannel:
Kraus operators:
.. math::
\Big\{ \sqrt{(1-p)} id, \sqrt{(p/3)} X, \sqrt{(p/3)} Y , \sqrt{(p/3)} Z \Big\} \otimes \Big\{ \sqrt{(1-p)} id, \sqrt{(p/3)} X, \sqrt{(p/3)} Y , \sqrt{(p/3)} Z \Big\}
\\Big\\{ \\sqrt{(1-p)} id, \\sqrt{(p/3)} X, \\sqrt{(p/3)} Y , \\sqrt{(p/3)} Z \\Big\\} \\otimes \\Big\\{ \\sqrt{(1-p)} id, \\sqrt{(p/3)} X, \\sqrt{(p/3)} Y , \\sqrt{(p/3)} Z \\Big\\}
Parameters
----------
Expand Down
5 changes: 4 additions & 1 deletion graphix/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@

import abc
import enum
from typing import TYPE_CHECKING

import numpy as np
from pydantic import BaseModel

import graphix.clifford
from graphix.clifford import Clifford
from graphix.pauli import Plane

if TYPE_CHECKING:
from graphix.clifford import Clifford

Node = int


Expand Down
7 changes: 5 additions & 2 deletions graphix/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

from copy import deepcopy
from enum import Enum
from typing import TYPE_CHECKING

import numpy as np

from graphix.graphsim.basegraphstate import BaseGraphState
from graphix.graphsim.graphstate import GraphState
from graphix.graphsim.rxgraphstate import RXGraphState
from graphix.graphsim.utils import is_graphs_equal

if TYPE_CHECKING:
from graphix.graphsim.basegraphstate import BaseGraphState


class ResourceType(Enum):
GHZ = "GHZ"
Expand Down Expand Up @@ -131,7 +134,7 @@ def get_fusion_network_from_graph(
for v in adjdict.keys():
if len(adjdict[v]) == 2:
neighbors = list(adjdict[v].keys())
nodes = [v] + neighbors
nodes = [v, *neighbors]
del adjdict[neighbors[0]][v]
del adjdict[neighbors[1]][v]
del adjdict[v][neighbors[0]]
Expand Down
21 changes: 11 additions & 10 deletions graphix/gflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@

from __future__ import annotations

import numbers
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from graphix.pattern import Pattern

from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING

import networkx as nx
import numpy as np
Expand All @@ -29,6 +24,11 @@
from graphix.command import CommandKind
from graphix.linalg import MatGF2

if TYPE_CHECKING:
import numbers

from graphix.pattern import Pattern


def check_meas_planes(meas_planes: dict[int, graphix.pauli.Plane]) -> None:
for node, plane in meas_planes.items():
Expand Down Expand Up @@ -332,7 +332,8 @@ def flowaux(
N = search_neighbor(q, edges)
p_set = N & (nodes - oset)
if len(p_set) == 1:
p = list(p_set)[0]
# Iterate over p_set assuming there is only one element p
(p,) = p_set
f[p] = {q}
l_k[p] = k
v_out_prime = v_out_prime | {p}
Expand Down Expand Up @@ -1163,11 +1164,11 @@ def verify_flow(
return valid_flow
# check if v ~ f(v) for each node
edges = set(graph.edges)
for node, correction in flow.items():
if len(correction) > 1:
for node, corrections in flow.items():
if len(corrections) > 1:
valid_flow = False
return valid_flow
correction = list(correction)[0]
correction = next(iter(corrections))
if (node, correction) not in edges and (correction, node) not in edges:
valid_flow = False
return valid_flow
Expand Down
5 changes: 4 additions & 1 deletion graphix/graphsim/basegraphstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING

import networkx as nx
import networkx.classes.reportviews as nx_reportviews
Expand All @@ -13,6 +13,9 @@

from .rxgraphviews import EdgeList, NodeList

if TYPE_CHECKING:
from collections.abc import Iterator

RUSTWORKX_INSTALLED = False
try:
import rustworkx as rx
Expand Down
8 changes: 6 additions & 2 deletions graphix/graphsim/nxgraphstate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import TYPE_CHECKING

import networkx as nx
from networkx.classes.reportviews import EdgeView, NodeView

from .basegraphstate import BaseGraphState

if TYPE_CHECKING:
from collections.abc import Iterator

from networkx.classes.reportviews import EdgeView, NodeView


class NXGraphState(BaseGraphState):
"""Graph state simulator implemented with networkx.
Expand Down
7 changes: 5 additions & 2 deletions graphix/graphsim/rxgraphstate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import TYPE_CHECKING

from .basegraphstate import RUSTWORKX_INSTALLED, BaseGraphState
from .rxgraphviews import EdgeList, NodeList

if TYPE_CHECKING:
from collections.abc import Iterator

if RUSTWORKX_INSTALLED:
import rustworkx as rx
else:
Expand Down Expand Up @@ -87,7 +90,7 @@ def adjacency(self) -> Iterator:
nidx = self.nodes.get_node_index(n)
adjacency_dict = self._graph.adj(nidx)
new_adjacency_dict = {}
for nidx, _ in adjacency_dict.items():
for nidx in adjacency_dict.keys():
new_adjacency_dict[self.nodes.get_node_index(nidx)] = {} # replace None with {}
ret.append((n, new_adjacency_dict))
return iter(ret)
Expand Down
12 changes: 8 additions & 4 deletions graphix/graphsim/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from networkx import Graph
from networkx.utils import graphs_equal

from .graphstate import RUSTWORKX_INSTALLED
from .nxgraphstate import NXGraphState
from .rxgraphstate import RXGraphState

if TYPE_CHECKING:
from .basegraphstate import BaseGraphState


if RUSTWORKX_INSTALLED:
from rustworkx import PyGraph
else:
PyGraph = None

from .basegraphstate import BaseGraphState
from .nxgraphstate import NXGraphState
from .rxgraphstate import RXGraphState


def convert_rustworkx_to_networkx(graph: PyGraph) -> Graph:
"""Convert a rustworkx PyGraph to a networkx graph.
Expand Down
3 changes: 2 additions & 1 deletion graphix/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from pydantic import BaseModel

from graphix.pauli import Plane
# MEMO: Cannot use TYPE_CHECKING here for pydantic
from graphix.pauli import Plane # noqa: TCH001


class InstructionKind(enum.Enum):
Expand Down
2 changes: 1 addition & 1 deletion graphix/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_rank(self):
return len(nonzero_index[0])

def forward_eliminate(self, b=None, copy=False):
"""forward eliminate the matrix
r"""forward eliminate the matrix
|A B| --\ |I X|
|C D| --/ |0 0|
Expand Down
5 changes: 3 additions & 2 deletions graphix/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from functools import reduce
from itertools import product
from typing import ClassVar

import numpy as np

Expand Down Expand Up @@ -33,7 +34,7 @@ class Ops:
[0, 0, 0, 0, 0, 0, 1, 0],
]
)
Pauli_ops = [np.eye(2), x, y, z]
Pauli_ops: ClassVar = [np.eye(2), x, y, z]

@staticmethod
def Rx(theta):
Expand Down Expand Up @@ -102,7 +103,7 @@ def Rzz(theta):

@staticmethod
def build_tensor_Pauli_ops(n_qubits: int):
"""Method to build all the 4^n tensor Pauli operators {I, X, Y, Z}^{\otimes n}
"""Method to build all the 4^n tensor Pauli operators {I, X, Y, Z}^{\\otimes n}
:param n_qubits: number of copies (qubits) to consider
:type n_qubits: int
Expand Down
36 changes: 12 additions & 24 deletions graphix/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,7 @@ def _measurement_order_depth(self):
d, l_k = self.get_layers()
meas_order = []
for i in range(d):
for node in l_k[i]:
meas_order.append(node)
meas_order.extend(l_k[i])
return meas_order

def connected_edges(self, node, edges):
Expand Down Expand Up @@ -913,8 +912,7 @@ def get_measurement_order_from_gflow(self):
k, layers = get_layers(l_k)
meas_order = []
while k > 0:
for node in layers[k]:
meas_order.append(node)
meas_order.extend(layers[k])
k -= 1
return meas_order

Expand Down Expand Up @@ -1142,11 +1140,7 @@ def standardize_and_shift_signals(self, method="local"):
def correction_commands(self):
"""Returns the list of byproduct correction commands"""
assert self.is_standard()
Clist = []
for i in range(len(self.__seq)):
if self.__seq[i].kind in (command.CommandKind.X, command.CommandKind.Z):
Clist.append(self.__seq[i])
return Clist
return [seqi for seqi in self.__seq if seqi.kind in (command.CommandKind.X, command.CommandKind.Z)]

def parallelize_pattern(self):
"""Optimize the pattern to reduce the depth of the computation
Expand Down Expand Up @@ -1939,21 +1933,15 @@ def measure_pauli(pattern, leave_input, copy=False, use_rustworkx=False):
# update command sequence
vops = graph_state.get_vops()
new_seq = []
for index in set(graph_state.nodes) - set(new_inputs):
new_seq.append(command.N(node=index))
for edge in graph_state.edges:
new_seq.append(command.E(nodes=edge))
for cmd in pattern:
if cmd.kind == command.CommandKind.M:
if cmd.node in graph_state.nodes:
new_seq.append(cmd.clifford(graphix.clifford.get(vops[cmd.node])))
for index in pattern.output_nodes:
new_clifford_ = vops[index]
if new_clifford_ != 0:
new_seq.append(command.C(node=index, cliff_index=new_clifford_))
for cmd in pattern:
if cmd.kind == command.CommandKind.X or (cmd.kind == command.CommandKind.Z):
new_seq.append(cmd)
new_seq.extend(command.N(node=index) for index in set(graph_state.nodes) - set(new_inputs))
new_seq.extend(command.E(nodes=edge) for edge in graph_state.edges)
new_seq.extend(
cmd.clifford(graphix.clifford.get(vops[cmd.node]))
for cmd in pattern
if cmd.kind == command.CommandKind.M and cmd.node in graph_state.nodes
)
new_seq.extend(command.C(node=index, cliff_index=vops[index]) for index in pattern.output_nodes if vops[index] != 0)
new_seq.extend(cmd for cmd in pattern if cmd.kind in (command.CommandKind.X, command.CommandKind.Z))

if copy:
pat = Pattern()
Expand Down
2 changes: 1 addition & 1 deletion graphix/pauli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Pauli gates ± {1,j} × {I, X, Y, Z}
"""
""" # noqa: RUF002

from __future__ import annotations

Expand Down
8 changes: 4 additions & 4 deletions graphix/sim/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_open_tensor_from_index(self, index):
index = str(index)
assert isinstance(index, str)
tags = [index, "Open"]
tid = list(self._get_tids_from_tags(tags, which="all"))[0]
tid = next(iter(self._get_tids_from_tags(tags, which="all")))
tensor = self.tensor_map[tid]
return tensor.data

Expand Down Expand Up @@ -491,7 +491,7 @@ def get_basis_coefficient(self, basis, normalize=True, indices=None, **kwagrs):
tensor = Tensor(state_out, [tn._dangling[node]], [node, f"qubit {i}", "Close"])
# retag
old_ind = tn._dangling[node]
tid = list(tn._get_tids_from_inds(old_ind))[0]
tid = next(iter(tn._get_tids_from_inds(old_ind)))
tn.tensor_map[tid].retag({"Open": "Close"})
tn.add_tensor(tensor)

Expand Down Expand Up @@ -598,8 +598,8 @@ def expectation_value(self, op, qubit_indices, output_node_indices=None, **kwagr
# reindex & retag
for node in out_inds:
old_ind = tn_cp_left._dangling[str(node)]
tid_left = list(tn_cp_left._get_tids_from_inds(old_ind))[0]
tid_right = list(tn_cp_right._get_tids_from_inds(old_ind))[0]
tid_left = next(iter(tn_cp_left._get_tids_from_inds(old_ind)))
tid_right = next(iter(tn_cp_right._get_tids_from_inds(old_ind)))
if node in target_nodes:
tn_cp_left.tensor_map[tid_left].reindex({old_ind: new_ind_left[target_nodes.index(node)]}, inplace=True)
tn_cp_right.tensor_map[tid_right].reindex(
Expand Down
3 changes: 2 additions & 1 deletion graphix/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import abc
from typing import ClassVar

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -75,4 +76,4 @@ class BasicStates:
MINUS_I = PlanarState(plane=graphix.pauli.Plane.XY, angle=-np.pi / 2)
# remove that in the end
# need in TN backend
VEC = [PLUS, MINUS, ZERO, ONE, PLUS_I, MINUS_I]
VEC: ClassVar = [PLUS, MINUS, ZERO, ONE, PLUS_I, MINUS_I]
Loading

0 comments on commit 364be19

Please sign in to comment.