diff --git a/fcp/demo/checkpoint_tensor_reference.py b/fcp/demo/checkpoint_tensor_reference.py index 23bed53..5c6943d 100644 --- a/fcp/demo/checkpoint_tensor_reference.py +++ b/fcp/demo/checkpoint_tensor_reference.py @@ -13,6 +13,7 @@ # limitations under the License. """MaterializableValueReference that reads from a TensorFlow checkpoint.""" +import asyncio import typing from typing import Any, Optional import uuid @@ -29,7 +30,7 @@ def __init__( tensor_name: str, dtype: tf.dtypes.DType, shape: Any, - checkpoint_future: tff.async_utils.SharedAwaitable, + checkpoint_future: asyncio.Task, ): """Constructs a new CheckpointTensorReference object. @@ -38,8 +39,8 @@ def __init__( dtype: The type of the tensor. shape: The shape of the tensor, expressed as a value convertible to `tf.TensorShape`. - checkpoint_future: A `tff.async_utils.SharedAwaitable` that resolves to - the TF checkpoint bytes once they're available. + checkpoint_future: A `asyncio.Task` that resolves to the TF checkpoint + bytes once they're available. """ self._tensor_name = tensor_name type_signature = tff.types.tensorflow_to_type((dtype, shape)) diff --git a/fcp/demo/checkpoint_tensor_reference_test.py b/fcp/demo/checkpoint_tensor_reference_test.py index bd087cf..e72dd0e 100644 --- a/fcp/demo/checkpoint_tensor_reference_test.py +++ b/fcp/demo/checkpoint_tensor_reference_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import unittest from absl.testing import absltest @@ -35,10 +36,10 @@ async def get_test_checkpoint(): class CheckpointTensorReferenceTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): - def test_type_signature(self): - ref = ctr.CheckpointTensorReference( - TENSOR_NAME, DTYPE, SHAPE, - tff.async_utils.SharedAwaitable(get_test_checkpoint())) + async def test_type_signature(self): + coro = get_test_checkpoint() + task = asyncio.create_task(coro) + ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task) self.assertEqual( ref.type_signature, tff.types.tensorflow_to_type((DTYPE, SHAPE)), @@ -49,16 +50,16 @@ async def test_get_value(self): async def get_checkpoint(): return test_utils.create_checkpoint({TENSOR_NAME: TEST_VALUE}) - ref = ctr.CheckpointTensorReference( - TENSOR_NAME, DTYPE, SHAPE, - tff.async_utils.SharedAwaitable(get_checkpoint())) + coro = get_checkpoint() + task = asyncio.create_task(coro) + ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task) self.assertTrue(numpy.array_equiv(await ref.get_value(), TEST_VALUE)) async def test_get_value_in_graph_mode(self): with tf.compat.v1.Graph().as_default(): - ref = ctr.CheckpointTensorReference( - TENSOR_NAME, DTYPE, SHAPE, - tff.async_utils.SharedAwaitable(get_test_checkpoint())) + coro = get_test_checkpoint() + task = asyncio.create_task(coro) + ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task) with self.assertRaisesRegex(ValueError, 'get_value is only supported in eager mode'): await ref.get_value() @@ -68,9 +69,9 @@ async def test_get_value_not_found(self): async def get_not_found_checkpoint(): return test_utils.create_checkpoint({'other': TEST_VALUE}) - ref = ctr.CheckpointTensorReference( - TENSOR_NAME, DTYPE, SHAPE, - tff.async_utils.SharedAwaitable(get_not_found_checkpoint())) + coro = get_not_found_checkpoint() + task = asyncio.create_task(coro) + ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task) with self.assertRaises(tf.errors.NotFoundError): await ref.get_value() @@ -79,9 +80,9 @@ async def test_get_value_with_invalid_checkpoint(self): async def get_invalid_checkpoint(): return b'invalid' - ref = ctr.CheckpointTensorReference( - TENSOR_NAME, DTYPE, SHAPE, - tff.async_utils.SharedAwaitable(get_invalid_checkpoint())) + coro = get_invalid_checkpoint() + task = asyncio.create_task(coro) + ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task) with self.assertRaises(tf.errors.DataLossError): await ref.get_value() diff --git a/fcp/demo/federated_context.py b/fcp/demo/federated_context.py index ed976f9..81d8978 100644 --- a/fcp/demo/federated_context.py +++ b/fcp/demo/federated_context.py @@ -13,7 +13,8 @@ # limitations under the License. """TFF FederatedContext subclass for the demo Federated Computation platform.""" -from collections.abc import Awaitable +import asyncio +from collections.abc import Coroutine import socket import ssl import threading @@ -292,11 +293,10 @@ def _state_to_checkpoint( tf.io.gfile.remove(tmpfile) def _create_tensor_reference_struct( - self, result_type: tff.Type, - checkpoint_future: Awaitable[bytes]) -> tff.structure.Struct: + self, result_type: tff.Type, checkpoint_future: Coroutine[Any, Any, bytes] + ) -> tff.structure.Struct: """Creates the CheckpointTensorReference struct for a result type.""" - shared_checkpoint_future = tff.async_utils.SharedAwaitable( - checkpoint_future) + task = asyncio.create_task(checkpoint_future) tensor_specs = checkpoint_utils.tff_type_to_tensor_spec_list(result_type) # pytype: disable=wrong-arg-types var_names = variable_helpers.variable_names_from_type( result_type[0], # pytype: disable=unsupported-operands @@ -307,7 +307,8 @@ def _create_tensor_reference_struct( ) tensor_refs = [ checkpoint_tensor_reference.CheckpointTensorReference( - var_name, spec.dtype, spec.shape, shared_checkpoint_future) + var_name, spec.dtype, spec.shape, task + ) for var_name, spec in zip(var_names, tensor_specs) ] return checkpoint_utils.pack_tff_value( diff --git a/fcp/demo/federated_context_test.py b/fcp/demo/federated_context_test.py index eee67fa..be29167 100644 --- a/fcp/demo/federated_context_test.py +++ b/fcp/demo/federated_context_test.py @@ -184,7 +184,7 @@ def test_invoke_with_invalid_data_source_type(self): r'FederatedDataSource.iterator\(\).select\(\)'): comp(0, plan_pb2.Plan()) - def test_invoke_succeeds_with_structure_state_type(self): + async def test_invoke_succeeds_with_structure_state_type(self): comp = federated_computation.FederatedComputation( irregular_arrays, name='x' ) @@ -195,7 +195,7 @@ def test_invoke_succeeds_with_structure_state_type(self): state = {'foo': (3, 1), 'bar': (4, 5, 6)} comp(state, DATA_SOURCE.iterator().select(1)) - def test_invoke_succeeds_with_attrs_state_type(self): + async def test_invoke_succeeds_with_attrs_state_type(self): comp = federated_computation.FederatedComputation( attrs_computation, name='x' )