diff --git a/graphix/command.py b/graphix/command.py index 03c46f11..0d16f653 100644 --- a/graphix/command.py +++ b/graphix/command.py @@ -10,7 +10,8 @@ from pydantic import BaseModel import graphix.clifford -from graphix.pauli import Plane +from graphix.clifford import Clifford +from graphix.pauli import Axis, Plane, Sign if TYPE_CHECKING: from graphix.clifford import Clifford @@ -58,6 +59,19 @@ class M(Command): s_domain: set[Node] = set() t_domain: set[Node] = set() + def is_pauli(self, precision: float = 1e-6) -> tuple[Axis, Sign] | None: + angle_double = 2 * self.angle + angle_double_int = int(angle_double) + if abs(angle_double - angle_double_int) > precision: + return None + angle_double_mod_4 = angle_double_int % 4 + if angle_double_mod_4 % 2 == 0: + axis = self.plane.cos + else: + axis = self.plane.sin + sign = Sign.minus_if(angle_double_mod_4 >= 2) + return (axis, sign) + def clifford(self, clifford: Clifford) -> M: s_domain = self.s_domain t_domain = self.t_domain diff --git a/graphix/pattern.py b/graphix/pattern.py index 67c4c5da..b22db5c4 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -8,13 +8,12 @@ from dataclasses import dataclass import networkx as nx -import numpy as np import typing_extensions import graphix.clifford import graphix.pauli from graphix import command -from graphix.clifford import CLIFFORD_CONJ, CLIFFORD_MEASURE, CLIFFORD_TO_QASM3 +from graphix.clifford import CLIFFORD_CONJ, CLIFFORD_TO_QASM3 from graphix.device_interface import PatternRunner from graphix.gflow import find_flow, find_gflow, get_layers from graphix.graphsim.graphstate import GraphState @@ -2024,54 +2023,11 @@ def is_pauli_measurement(cmd: command.Command, ignore_vop=True): if the measurement is not in Pauli basis, returns None. """ assert cmd.kind == command.CommandKind.M - basis_str = [("+X", "-X"), ("+Y", "-Y"), ("+Z", "-Z")] - # first item: 0, 1 or 2. correspond to choice of X, Y and Z - # second item: 0 or 1. correspond to sign (+, -) - basis_index = (0, 0) - if np.mod(cmd.angle, 2) == 0: - if cmd.plane == graphix.pauli.Plane.XY: - basis_index = (0, 0) - elif cmd.plane == graphix.pauli.Plane.YZ: - basis_index = (1, 0) - elif cmd.plane == graphix.pauli.Plane.XZ: - basis_index = (0, 0) - else: - raise ValueError("Unknown measurement plane") - elif np.mod(cmd.angle, 2) == 1: - if cmd.plane == graphix.pauli.Plane.XY: - basis_index = (0, 1) - elif cmd.plane == graphix.pauli.Plane.YZ: - basis_index = (1, 1) - elif cmd.plane == graphix.pauli.Plane.XZ: - basis_index = (0, 1) - else: - raise ValueError("Unknown measurement plane") - elif np.mod(cmd.angle, 2) == 0.5: - if cmd.plane == graphix.pauli.Plane.XY: - basis_index = (1, 0) - elif cmd.plane == graphix.pauli.Plane.YZ: - basis_index = (2, 0) - elif cmd.plane == graphix.pauli.Plane.XZ: - basis_index = (2, 0) - else: - raise ValueError("Unknown measurement plane") - elif np.mod(cmd.angle, 2) == 1.5: - if cmd.plane == graphix.pauli.Plane.XY: - basis_index = (1, 1) - elif cmd.plane == graphix.pauli.Plane.YZ: - basis_index = (2, 1) - elif cmd.plane == graphix.pauli.Plane.XZ: - basis_index = (2, 1) - else: - raise ValueError("Unknown measurement plane") - else: + pauli = cmd.is_pauli() + if pauli is None: return None - if not ignore_vop: - basis_index = ( - CLIFFORD_MEASURE[cmd.vop][basis_index[0]][0], - int(np.abs(basis_index[1] - CLIFFORD_MEASURE[cmd.vop][basis_index[0]][1])), - ) - return basis_str[basis_index[0]][basis_index[1]] + axis, sign = pauli + return f"{sign}{axis.name}" def cmd_to_qasm3(cmd): diff --git a/graphix/pauli.py b/graphix/pauli.py index 91d4634f..7aa5a4c9 100644 --- a/graphix/pauli.py +++ b/graphix/pauli.py @@ -5,6 +5,8 @@ from __future__ import annotations import enum +import typing +from numbers import Number import numpy as np import pydantic @@ -18,6 +20,59 @@ class IXYZ(enum.Enum): Y = 1 Z = 2 +class Sign(enum.Enum): + Plus = 1 + Minus = -1 + + def __str__(self) -> str: + if self == Sign.Plus: + return "+" + return "-" + + @staticmethod + def plus_if(b: bool) -> Sign: + if b: + return Sign.Plus + return Sign.Minus + + @staticmethod + def minus_if(b: bool) -> Sign: + if b: + return Sign.Minus + return Sign.Plus + + def __neg__(self) -> Sign: + return Sign.minus_if(self == Sign.Plus) + + @typing.overload + def __mul__(self, other: Sign) -> Sign: + ... + + @typing.overload + def __mul__(self, other: Number) -> Number: + ... + + def __mul__(self, other): + if isinstance(other, Sign): + return Sign.plus_if(self == other) + if isinstance(other, Number): + return self.value * other + return NotImplemented + + def __rmul__(self, other) -> Number | type(NotImplemented): + if isinstance(other, Number): + return self.value * other + return NotImplemented + + def __int__(self) -> int: + return self.value + + def __float__(self) -> float: + return float(self.value) + + def __complex__(self) -> complex: + return complex(self.value) + class ComplexUnit: """ @@ -27,7 +82,7 @@ class ComplexUnit: with Python constants 1, -1, 1j, -1j, and can be negated. """ - def __init__(self, sign: bool, im: bool): + def __init__(self, sign: Sign, im: bool): self.__sign = sign self.__im = im @@ -36,27 +91,24 @@ def sign(self): return self.__sign @property - def im(self): + def im(self) -> bool: return self.__im - @property - def complex(self) -> complex: + def __complex__(self) -> complex: """ Return the unit as complex number """ - result: complex = 1 - if self.__sign: - result *= -1 + result: complex = complex(self.__sign) if self.__im: result *= 1j return result - def __repr__(self): + def __str__(self) -> str: if self.__im: result = "1j" else: result = "1" - if self.__sign: + if self.__sign == Sign.Minus: result = "-" + result return result @@ -69,32 +121,32 @@ def prefix(self, s: str) -> str: result = "1j*" + s else: result = s - if self.__sign: + if self.__sign == Sign.Minus: result = "-" + result return result def __mul__(self, other): if isinstance(other, ComplexUnit): im = self.__im != other.__im - sign = (self.__sign != other.__sign) != (self.__im and other.__im) - return COMPLEX_UNITS[sign][im] + sign = self.__sign * other.__sign * Sign.minus_if(self.__im and other.__im) + return COMPLEX_UNITS[sign == Sign.Minus][im] return NotImplemented def __rmul__(self, other): if other == 1: return self elif other == -1: - return COMPLEX_UNITS[not self.__sign][self.__im] + return COMPLEX_UNITS[self.__sign == Sign.Plus][self.__im] elif other == 1j: - return COMPLEX_UNITS[self.__sign != self.__im][not self.__im] + return COMPLEX_UNITS[self.__sign == Sign.plus_if(self.__im)][not self.__im] elif other == -1j: - return COMPLEX_UNITS[self.__sign == self.__im][not self.__im] + return COMPLEX_UNITS[self.__sign == Sign.minus_if(self.__im)][not self.__im] def __neg__(self): - return COMPLEX_UNITS[not self.__sign][self.__im] + return COMPLEX_UNITS[self.__sign == Sign.Plus][self.__im] -COMPLEX_UNITS = [[ComplexUnit(sign, im) for im in (False, True)] for sign in (False, True)] +COMPLEX_UNITS = [[ComplexUnit(sign, im) for im in (False, True)] for sign in (Sign.Plus, Sign.Minus)] UNIT = COMPLEX_UNITS[False][False] @@ -225,7 +277,7 @@ def matrix(self) -> np.ndarray: """ Return the matrix of the Pauli gate. """ - return self.__unit.complex * graphix.clifford.CLIFFORD[self.__symbol.value + 1] + return complex(self.__unit) * graphix.clifford.CLIFFORD[self.__symbol.value + 1] def __repr__(self): return self.__unit.prefix(self.__symbol.name) @@ -268,7 +320,7 @@ def __neg__(self): def get(symbol: IXYZ, unit: ComplexUnit) -> Pauli: """Return the Pauli gate with given symbol and unit.""" - return TABLE[symbol.value + 1][unit.sign][unit.im] + return TABLE[symbol.value + 1][unit.sign == Sign.Minus][unit.im] I = get(IXYZ.I, UNIT) @@ -306,7 +358,7 @@ def compute(plane: Plane, s: bool, t: bool, clifford: graphix.clifford.Clifford) else: coeff = 1 add_term = 0 - if cos_pauli.unit.sign: + if cos_pauli.unit.sign == Sign.Minus: add_term += np.pi if exchange: add_term = np.pi / 2 - add_term diff --git a/tests/test_pattern.py b/tests/test_pattern.py index 192a2eee..1c79c292 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -171,6 +171,18 @@ def test_pauli_measurement_random_circuit( state_mbqc = pattern.simulate_pattern(backend) assert compare_backend_result_with_statevec(backend, state_mbqc, state) == pytest.approx(1) + @pytest.mark.parametrize("plane", Plane) + @pytest.mark.parametrize("angle", [0., 0.5, 1., 1.5]) + def test_pauli_measurement_single(self, plane: Plane, angle: float, use_rustworkx: bool = True) -> None: + pattern = Pattern(input_nodes=[0, 1]) + pattern.add(E(nodes=[0, 1])) + pattern.add(M(node=0, plane=plane, angle=angle)) + pattern_ref = pattern.copy() + pattern.perform_pauli_measurements(use_rustworkx=use_rustworkx) + state = pattern.simulate_pattern() + state_ref = pattern_ref.simulate_pattern(pr_calc=False, rng=IterGenerator([0])) + assert np.abs(np.dot(state.flatten().conjugate(), state_ref.flatten())) == pytest.approx(1) + @pytest.mark.parametrize("jumps", range(1, 11)) def test_pauli_measurement_leave_input_random_circuit( self, fx_bg: PCG64, jumps: int, use_rustworkx: bool = True diff --git a/tests/test_pauli.py b/tests/test_pauli.py index 1177dad0..c6ab7d27 100644 --- a/tests/test_pauli.py +++ b/tests/test_pauli.py @@ -23,7 +23,7 @@ class TestPauli: ), ) def test_unit_mul(self, u: ComplexUnit, p: Pauli) -> None: - assert np.allclose((u * p).matrix, u.complex * p.matrix) + assert np.allclose((u * p).matrix, complex(u) * p.matrix) @pytest.mark.parametrize( ("a", "b"),