diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index bd1e34e73ee..71b308c81f4 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -14,11 +14,24 @@ """Abstract base class for things sampling quantum circuits.""" import collections -from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from itertools import islice +from typing import ( + Dict, + FrozenSet, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypeVar, + TYPE_CHECKING, + Union, +) import duet import pandas as pd + from cirq import ops, protocols, study, value from cirq.work.observable_measurement import ( measure_observables, @@ -30,10 +43,17 @@ if TYPE_CHECKING: import cirq +T = TypeVar('T') + class Sampler(metaclass=value.ABCMetaImplementAnyOneOf): """Something capable of sampling quantum circuits. Simulator or hardware.""" + # Users have a rate limit of 1000 QPM for read/write requests to + # the Quantum Engine. 1000/60 ~= 16 QPS. So requests are sent + # in chunks of size 16 per second. + CHUNK_SIZE: int = 16 + def run( self, program: 'cirq.AbstractCircuit', @@ -294,9 +314,26 @@ async def run_batch_async( See docs for `cirq.Sampler.run_batch`. """ params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions) - return await duet.pstarmap_async( - self.run_sweep_async, zip(programs, params_list, repetitions) - ) + if len(programs) <= self.CHUNK_SIZE: + return await duet.pstarmap_async( + self.run_sweep_async, zip(programs, params_list, repetitions) + ) + + results = [] + for program_chunk, params_chunk, reps_chunk in zip( + _chunked(programs, self.CHUNK_SIZE), + _chunked(params_list, self.CHUNK_SIZE), + _chunked(repetitions, self.CHUNK_SIZE), + ): + # Run_sweep_async for the current chunk + await duet.sleep(1) # Delay for 1 second between chunk + results.extend( + await duet.pstarmap_async( + self.run_sweep_async, zip(program_chunk, params_chunk, reps_chunk) + ) + ) + + return results def _normalize_batch_args( self, @@ -449,3 +486,8 @@ def _get_measurement_shapes( ) num_instances[key] += 1 return {k: (num_instances[k], qid_shape) for k, qid_shape in qid_shapes.items()} + + +def _chunked(iterable: Sequence[T], n: int) -> Iterator[tuple[T, ...]]: + it = iter(iterable) + return iter(lambda: tuple(islice(it, n)), ()) diff --git a/cirq-core/cirq/work/sampler_test.py b/cirq-core/cirq/work/sampler_test.py index 195b44c0ff2..2c20ccbf030 100644 --- a/cirq-core/cirq/work/sampler_test.py +++ b/cirq-core/cirq/work/sampler_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for cirq.Sampler.""" from typing import Sequence +from unittest import mock import pytest @@ -266,6 +267,55 @@ def test_sampler_run_batch_bad_input_lengths(): ) +@mock.patch('duet.pstarmap_async') +@pytest.mark.parametrize('call_count', [1, 2, 3]) +@duet.sync +async def test_run_batch_async_sends_circuits_in_chunks(spy, call_count): + class AsyncSampler(cirq.Sampler): + CHUNK_SIZE = 3 + + async def run_sweep_async(self, _, params, __: int = 1): + pass # pragma: no cover + + sampler = AsyncSampler() + a = cirq.LineQubit(0) + circuit_list = [cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))] * ( + sampler.CHUNK_SIZE * call_count + ) + param_list = [cirq.Points('t', [0.3, 0.7])] * (sampler.CHUNK_SIZE * call_count) + + await sampler.run_batch_async(circuit_list, params_list=param_list) + + assert spy.call_count == call_count + + +@pytest.mark.parametrize('call_count', [1, 2, 3]) +@duet.sync +async def test_run_batch_async_runs_runs_sequentially(call_count): + a = cirq.LineQubit(0) + finished = [] + circuit1 = cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m')) + circuit2 = cirq.Circuit(cirq.Y(a) ** sympy.Symbol('t'), cirq.measure(a, key='m')) + params1 = cirq.Points('t', [0.3, 0.7]) + params2 = cirq.Points('t', [0.4, 0.6]) + + class AsyncSampler(cirq.Sampler): + CHUNK_SIZE = 1 + + async def run_sweep_async(self, _, params, __: int = 1): + if params == params1: + await duet.sleep(0.001) + + finished.append(params) + + sampler = AsyncSampler() + circuit_list = [circuit1, circuit2] * call_count + param_list = [params1, params2] * call_count + await sampler.run_batch_async(circuit_list, params_list=param_list) + + assert finished == param_list + + def test_sampler_simple_sample_expectation_values(): a = cirq.LineQubit(0) sampler = cirq.Simulator()