Skip to content

Commit

Permalink
Add a new gauge for SqrtCZ and support SqrtCZ† and fix and improve sp…
Browse files Browse the repository at this point in the history
…in inversion gauge (#6571)
  • Loading branch information
NoureldinYosri authored May 3, 2024
1 parent 614c78a commit 3080d93
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 15 deletions.
38 changes: 37 additions & 1 deletion cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ConstantGauge(Gauge):
post_q1: Tuple[ops.Gate, ...] = field(
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
)
swap_qubits: bool = False

def sample(self, gate: ops.Gate, prng: np.random.Generator) -> "ConstantGauge":
return self
Expand All @@ -85,6 +86,41 @@ def post(self) -> Tuple[Tuple[ops.Gate, ...], Tuple[ops.Gate, ...]]:
"""A tuple (ops to apply to q0, ops to apply to q1)."""
return self.post_q0, self.post_q1

def on(self, q0: ops.Qid, q1: ops.Qid) -> ops.Operation:
"""Returns the operation that replaces the two qubit gate."""
if self.swap_qubits:
return self.two_qubit_gate(q1, q0)
return self.two_qubit_gate(q0, q1)


@frozen
class SameGateGauge(Gauge):
"""Same as ConstantGauge but the new two-qubit gate equals the old gate."""

pre_q0: Tuple[ops.Gate, ...] = field(
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
)
pre_q1: Tuple[ops.Gate, ...] = field(
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
)
post_q0: Tuple[ops.Gate, ...] = field(
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
)
post_q1: Tuple[ops.Gate, ...] = field(
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
)
swap_qubits: bool = False

def sample(self, gate: ops.Gate, prng: np.random.Generator) -> ConstantGauge:
return ConstantGauge(
two_qubit_gate=gate,
pre_q0=self.pre_q0,
pre_q1=self.pre_q1,
post_q0=self.post_q0,
post_q1=self.post_q1,
swap_qubits=self.swap_qubits,
)


def _select(choices: Sequence[Gauge], probabilites: np.ndarray, prng: np.random.Generator) -> Gauge:
return choices[prng.choice(len(choices), p=probabilites)]
Expand Down Expand Up @@ -154,7 +190,7 @@ def __call__(
gauge = self.gauge_selector(rng).sample(op.gate, rng)
q0, q1 = op.qubits
left.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.pre))
center.append(gauge.two_qubit_gate(q0, q1))
center.append(gauge.on(q0, q1))
right.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.post))
else:
center.append(op)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
from cirq.transformers.gauge_compiling.gauge_compiling import (
GaugeTransformer,
GaugeSelector,
ConstantGauge,
SameGateGauge,
)
from cirq import ops

SpinInversionGaugeSelector = GaugeSelector(
gauges=[
ConstantGauge(two_qubit_gate=ops.ZZ, pre_q0=ops.X, post_q0=ops.X),
ConstantGauge(two_qubit_gate=ops.ZZ, pre_q1=ops.X, post_q1=ops.X),
SameGateGauge(pre_q0=ops.X, post_q0=ops.X, pre_q1=ops.X, post_q1=ops.X),
SameGateGauge(),
]
)

SpinInversionGaugeTransformer = GaugeTransformer(
target=ops.ZZ, gauge_selector=SpinInversionGaugeSelector
target=ops.GateFamily(ops.ZZPowGate), gauge_selector=SpinInversionGaugeSelector
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import cirq
from cirq.transformers.gauge_compiling import SpinInversionGaugeTransformer
from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester


class TestSpinInversionGauge(GaugeTester):
class TestSpinInversionGauge_0(GaugeTester):
two_qubit_gate = cirq.ZZ
gauge_transformer = SpinInversionGaugeTransformer


class TestSpinInversionGauge_1(GaugeTester):
two_qubit_gate = cirq.ZZ**0.1
gauge_transformer = SpinInversionGaugeTransformer


class TestSpinInversionGauge_2(GaugeTester):
two_qubit_gate = cirq.ZZ**-1
gauge_transformer = SpinInversionGaugeTransformer


class TestSpinInversionGauge_3(GaugeTester):
two_qubit_gate = cirq.ZZ**0.3
gauge_transformer = SpinInversionGaugeTransformer
49 changes: 42 additions & 7 deletions cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""A Gauge transformer for CZ**0.5 gate."""
"""A Gauge transformer for CZ**0.5 and CZ**-0.5 gates."""


from typing import TYPE_CHECKING
import numpy as np

from cirq.transformers.gauge_compiling.gauge_compiling import (
GaugeTransformer,
GaugeSelector,
ConstantGauge,
Gauge,
)
from cirq.ops.common_gates import CZ
from cirq import ops
from cirq.ops import CZ, S, X, Gateset

if TYPE_CHECKING:
import cirq

_SQRT_CZ = CZ**0.5
_ADJ_S = S**-1

SqrtCZGaugeSelector = GaugeSelector(
gauges=[ConstantGauge(pre_q0=ops.X, post_q0=ops.X, post_q1=ops.Z**0.5, two_qubit_gate=CZ**-0.5)]
)

SqrtCZGaugeTransformer = GaugeTransformer(target=CZ**0.5, gauge_selector=SqrtCZGaugeSelector)
class SqrtCZGauge(Gauge):

def weight(self) -> float:
return 3.0

def sample(self, gate: 'cirq.Gate', prng: np.random.Generator) -> ConstantGauge:
if prng.choice([True, False]):
return ConstantGauge(two_qubit_gate=gate)
swap_qubits = prng.choice([True, False])
if swap_qubits:
return ConstantGauge(
pre_q1=X,
post_q1=X,
post_q0=S if gate == _SQRT_CZ else _ADJ_S,
two_qubit_gate=gate**-1,
swap_qubits=True,
)
else:
return ConstantGauge(
pre_q0=X,
post_q0=X,
post_q1=S if gate == _SQRT_CZ else _ADJ_S,
two_qubit_gate=gate**-1,
)


SqrtCZGaugeTransformer = GaugeTransformer(
target=Gateset(_SQRT_CZ, _SQRT_CZ**-1), gauge_selector=GaugeSelector(gauges=[SqrtCZGauge()])
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import cirq
from cirq.transformers.gauge_compiling import SqrtCZGaugeTransformer
from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester
Expand All @@ -21,3 +20,8 @@
class TestSqrtCZGauge(GaugeTester):
two_qubit_gate = cirq.CZ**0.5
gauge_transformer = SqrtCZGaugeTransformer


class TestAdjointSqrtCZGauge(GaugeTester):
two_qubit_gate = cirq.CZ**-0.5
gauge_transformer = SqrtCZGaugeTransformer

0 comments on commit 3080d93

Please sign in to comment.