Skip to content

Commit

Permalink
Cache Qid instances for common types (#6371)
Browse files Browse the repository at this point in the history
Review: @dstrain115
  • Loading branch information
maffoo authored Dec 8, 2023
1 parent 6d437c4 commit 33c2573
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 67 deletions.
75 changes: 51 additions & 24 deletions cirq-core/cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import abc
import functools
import weakref
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TYPE_CHECKING, Union
from typing_extensions import Self

Expand All @@ -34,14 +35,6 @@ class _BaseGridQid(ops.Qid):
_dimension: int
_hash: Optional[int] = None

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
if "_hash" in state:
state = state.copy()
del state["_hash"]
return state

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self._row, self._col, self._dimension))
Expand All @@ -50,7 +43,7 @@ def __hash__(self) -> int:
def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return (
return self is other or (
self._row == other._row
and self._col == other._col
and self._dimension == other._dimension
Expand All @@ -60,7 +53,7 @@ def __eq__(self, other):
def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return (
return self is not other and (
self._row != other._row
or self._col != other._col
or self._dimension != other._dimension
Expand Down Expand Up @@ -178,22 +171,36 @@ class GridQid(_BaseGridQid):
cirq.GridQid(5, 4, dimension=2)
"""

def __init__(self, row: int, col: int, *, dimension: int) -> None:
"""Initializes a grid qid at the given row, col coordinate
# Cache of existing GridQid instances, returned by __new__ if available.
# Holds weak references so instances can still be garbage collected.
_cache = weakref.WeakValueDictionary[Tuple[int, int, int], 'cirq.GridQid']()

def __new__(cls, row: int, col: int, *, dimension: int) -> 'cirq.GridQid':
"""Creates a grid qid at the given row, col coordinate
Args:
row: the row coordinate
col: the column coordinate
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
self.validate_dimension(dimension)
self._row = row
self._col = col
self._dimension = dimension
key = (row, col, dimension)
inst = cls._cache.get(key)
if inst is None:
cls.validate_dimension(dimension)
inst = super().__new__(cls)
inst._row = row
inst._col = col
inst._dimension = dimension
cls._cache[key] = inst
return inst

def __getnewargs_ex__(self):
"""Returns a tuple of (args, kwargs) to pass to __new__ when unpickling."""
return (self._row, self._col), {"dimension": self._dimension}

def _with_row_col(self, row: int, col: int) -> 'GridQid':
return GridQid(row, col, dimension=self.dimension)
return GridQid(row, col, dimension=self._dimension)

@staticmethod
def square(diameter: int, top: int = 0, left: int = 0, *, dimension: int) -> List['GridQid']:
Expand Down Expand Up @@ -290,16 +297,16 @@ def from_diagram(diagram: str, dimension: int) -> List['GridQid']:
return [GridQid(*c, dimension=dimension) for c in coords]

def __repr__(self) -> str:
return f"cirq.GridQid({self._row}, {self._col}, dimension={self.dimension})"
return f"cirq.GridQid({self._row}, {self._col}, dimension={self._dimension})"

def __str__(self) -> str:
return f"q({self._row}, {self._col}) (d={self.dimension})"
return f"q({self._row}, {self._col}) (d={self._dimension})"

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return protocols.CircuitDiagramInfo(
wire_symbols=(f"({self._row}, {self._col}) (d={self.dimension})",)
wire_symbols=(f"({self._row}, {self._col}) (d={self._dimension})",)
)

def _json_dict_(self) -> Dict[str, Any]:
Expand All @@ -325,11 +332,31 @@ class GridQubit(_BaseGridQid):

_dimension = 2

def __init__(self, row: int, col: int) -> None:
self._row = row
self._col = col
# Cache of existing GridQubit instances, returned by __new__ if available.
# Holds weak references so instances can still be garbage collected.
_cache = weakref.WeakValueDictionary[Tuple[int, int], 'cirq.GridQubit']()

def _with_row_col(self, row: int, col: int):
def __new__(cls, row: int, col: int) -> 'cirq.GridQubit':
"""Creates a grid qubit at the given row, col coordinate
Args:
row: the row coordinate
col: the column coordinate
"""
key = (row, col)
inst = cls._cache.get(key)
if inst is None:
inst = super().__new__(cls)
inst._row = row
inst._col = col
cls._cache[key] = inst
return inst

def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._row, self._col)

def _with_row_col(self, row: int, col: int) -> 'GridQubit':
return GridQubit(row, col)

def _cmp_tuple(self):
Expand Down
26 changes: 22 additions & 4 deletions cirq-core/cirq/devices/grid_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,29 @@ def test_eq():
eq.make_equality_group(lambda: cirq.GridQid(0, 0, dimension=3))


def test_pickled_hash():
q = cirq.GridQubit(3, 4)
q_bad = cirq.GridQubit(3, 4)
def test_grid_qubit_pickled_hash():
# Use a large number that is unlikely to be used by any other tests.
row, col = 123456789, 2345678910
q_bad = cirq.GridQubit(row, col)
cirq.GridQubit._cache.pop((row, col))
q = cirq.GridQubit(row, col)
_test_qid_pickled_hash(q, q_bad)


def test_grid_qid_pickled_hash():
# Use a large number that is unlikely to be used by any other tests.
row, col = 123456789, 2345678910
q_bad = cirq.GridQid(row, col, dimension=3)
cirq.GridQid._cache.pop((row, col, 3))
q = cirq.GridQid(row, col, dimension=3)
_test_qid_pickled_hash(q, q_bad)


def _test_qid_pickled_hash(q: 'cirq.Qid', q_bad: 'cirq.Qid') -> None:
"""Test that hashes are not pickled with Qid instances."""
assert q_bad is not q
_ = hash(q_bad) # compute hash to ensure it is cached.
q_bad._hash = q_bad._hash + 1
q_bad._hash = q_bad._hash + 1 # type: ignore[attr-defined]
assert q_bad == q
assert hash(q_bad) != hash(q)
data = pickle.dumps(q_bad)
Expand Down
58 changes: 40 additions & 18 deletions cirq-core/cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import abc
import functools
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TYPE_CHECKING, Union
import weakref
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing_extensions import Self

from cirq import ops, protocols
Expand All @@ -31,14 +32,6 @@ class _BaseLineQid(ops.Qid):
_dimension: int
_hash: Optional[int] = None

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
if "_hash" in state:
state = state.copy()
del state["_hash"]
return state

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self._x, self._dimension))
Expand All @@ -47,13 +40,15 @@ def __hash__(self) -> int:
def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x == other._x and self._dimension == other._dimension
return self is other or (self._x == other._x and self._dimension == other._dimension)
return NotImplemented

def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x != other._x or self._dimension != other._dimension
return self is not other and (
self._x != other._x or self._dimension != other._dimension
)
return NotImplemented

def _comparison_key(self):
Expand Down Expand Up @@ -154,17 +149,31 @@ class LineQid(_BaseLineQid):
"""

def __init__(self, x: int, dimension: int) -> None:
# Cache of existing LineQid instances, returned by __new__ if available.
# Holds weak references so instances can still be garbage collected.
_cache = weakref.WeakValueDictionary[Tuple[int, int], 'cirq.LineQid']()

def __new__(cls, x: int, dimension: int) -> 'cirq.LineQid':
"""Initializes a line qid at the given x coordinate.
Args:
x: The x coordinate.
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
self.validate_dimension(dimension)
self._x = x
self._dimension = dimension
key = (x, dimension)
inst = cls._cache.get(key)
if inst is None:
cls.validate_dimension(dimension)
inst = super().__new__(cls)
inst._x = x
inst._dimension = dimension
cls._cache[key] = inst
return inst

def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._x, self._dimension)

def _with_x(self, x: int) -> 'LineQid':
return LineQid(x, dimension=self._dimension)
Expand Down Expand Up @@ -246,13 +255,26 @@ class LineQubit(_BaseLineQid):

_dimension = 2

def __init__(self, x: int) -> None:
"""Initializes a line qubit at the given x coordinate.
# Cache of existing LineQubit instances, returned by __new__ if available.
# Holds weak references so instances can still be garbage collected.
_cache = weakref.WeakValueDictionary[int, 'cirq.LineQubit']()

def __new__(cls, x: int) -> 'cirq.LineQubit':
"""Initializes a line qid at the given x coordinate.
Args:
x: The x coordinate.
"""
self._x = x
inst = cls._cache.get(x)
if inst is None:
inst = super().__new__(cls)
inst._x = x
cls._cache[x] = inst
return inst

def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._x,)

def _with_x(self, x: int) -> 'LineQubit':
return LineQubit(x)
Expand Down
19 changes: 19 additions & 0 deletions cirq-core/cirq/devices/line_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

import cirq
from cirq.devices.grid_qubit_test import _test_qid_pickled_hash


def test_init():
Expand Down Expand Up @@ -67,6 +68,24 @@ def test_cmp_failure():
_ = cirq.LineQid(1, 3) < 0


def test_line_qubit_pickled_hash():
# Use a large number that is unlikely to be used by any other tests.
x = 1234567891011
q_bad = cirq.LineQubit(x)
cirq.LineQubit._cache.pop(x)
q = cirq.LineQubit(x)
_test_qid_pickled_hash(q, q_bad)


def test_line_qid_pickled_hash():
# Use a large number that is unlikely to be used by any other tests.
x = 1234567891011
q_bad = cirq.LineQid(x, dimension=3)
cirq.LineQid._cache.pop((x, 3))
q = cirq.LineQid(x, dimension=3)
_test_qid_pickled_hash(q, q_bad)


def test_is_adjacent():
assert cirq.LineQubit(1).is_adjacent(cirq.LineQubit(2))
assert cirq.LineQubit(1).is_adjacent(cirq.LineQubit(0))
Expand Down
Loading

0 comments on commit 33c2573

Please sign in to comment.