Skip to content

Commit

Permalink
Remove usage of tff.async_utils.SharedAwaitable use asyncio.Task
Browse files Browse the repository at this point in the history
…intead.

PiperOrigin-RevId: 643382405
  • Loading branch information
michaelreneer authored and copybara-github committed Jun 14, 2024
1 parent 93990d1 commit b1459df
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 27 deletions.
7 changes: 4 additions & 3 deletions fcp/demo/checkpoint_tensor_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""MaterializableValueReference that reads from a TensorFlow checkpoint."""

import asyncio
import typing
from typing import Any, Optional
import uuid
Expand All @@ -29,7 +30,7 @@ def __init__(
tensor_name: str,
dtype: tf.dtypes.DType,
shape: Any,
checkpoint_future: tff.async_utils.SharedAwaitable,
checkpoint_future: asyncio.Task,
):
"""Constructs a new CheckpointTensorReference object.
Expand All @@ -38,8 +39,8 @@ def __init__(
dtype: The type of the tensor.
shape: The shape of the tensor, expressed as a value convertible to
`tf.TensorShape`.
checkpoint_future: A `tff.async_utils.SharedAwaitable` that resolves to
the TF checkpoint bytes once they're available.
checkpoint_future: A `asyncio.Task` that resolves to the TF checkpoint
bytes once they're available.
"""
self._tensor_name = tensor_name
type_signature = tff.types.tensorflow_to_type((dtype, shape))
Expand Down
33 changes: 17 additions & 16 deletions fcp/demo/checkpoint_tensor_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import unittest

from absl.testing import absltest
Expand All @@ -35,10 +36,10 @@ async def get_test_checkpoint():
class CheckpointTensorReferenceTest(absltest.TestCase,
unittest.IsolatedAsyncioTestCase):

def test_type_signature(self):
ref = ctr.CheckpointTensorReference(
TENSOR_NAME, DTYPE, SHAPE,
tff.async_utils.SharedAwaitable(get_test_checkpoint()))
async def test_type_signature(self):
coro = get_test_checkpoint()
task = asyncio.create_task(coro)
ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task)
self.assertEqual(
ref.type_signature,
tff.types.tensorflow_to_type((DTYPE, SHAPE)),
Expand All @@ -49,16 +50,16 @@ async def test_get_value(self):
async def get_checkpoint():
return test_utils.create_checkpoint({TENSOR_NAME: TEST_VALUE})

ref = ctr.CheckpointTensorReference(
TENSOR_NAME, DTYPE, SHAPE,
tff.async_utils.SharedAwaitable(get_checkpoint()))
coro = get_checkpoint()
task = asyncio.create_task(coro)
ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task)
self.assertTrue(numpy.array_equiv(await ref.get_value(), TEST_VALUE))

async def test_get_value_in_graph_mode(self):
with tf.compat.v1.Graph().as_default():
ref = ctr.CheckpointTensorReference(
TENSOR_NAME, DTYPE, SHAPE,
tff.async_utils.SharedAwaitable(get_test_checkpoint()))
coro = get_test_checkpoint()
task = asyncio.create_task(coro)
ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task)
with self.assertRaisesRegex(ValueError,
'get_value is only supported in eager mode'):
await ref.get_value()
Expand All @@ -68,9 +69,9 @@ async def test_get_value_not_found(self):
async def get_not_found_checkpoint():
return test_utils.create_checkpoint({'other': TEST_VALUE})

ref = ctr.CheckpointTensorReference(
TENSOR_NAME, DTYPE, SHAPE,
tff.async_utils.SharedAwaitable(get_not_found_checkpoint()))
coro = get_not_found_checkpoint()
task = asyncio.create_task(coro)
ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task)
with self.assertRaises(tf.errors.NotFoundError):
await ref.get_value()

Expand All @@ -79,9 +80,9 @@ async def test_get_value_with_invalid_checkpoint(self):
async def get_invalid_checkpoint():
return b'invalid'

ref = ctr.CheckpointTensorReference(
TENSOR_NAME, DTYPE, SHAPE,
tff.async_utils.SharedAwaitable(get_invalid_checkpoint()))
coro = get_invalid_checkpoint()
task = asyncio.create_task(coro)
ref = ctr.CheckpointTensorReference(TENSOR_NAME, DTYPE, SHAPE, task)
with self.assertRaises(tf.errors.DataLossError):
await ref.get_value()

Expand Down
13 changes: 7 additions & 6 deletions fcp/demo/federated_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
"""TFF FederatedContext subclass for the demo Federated Computation platform."""

from collections.abc import Awaitable
import asyncio
from collections.abc import Coroutine
import socket
import ssl
import threading
Expand Down Expand Up @@ -292,11 +293,10 @@ def _state_to_checkpoint(
tf.io.gfile.remove(tmpfile)

def _create_tensor_reference_struct(
self, result_type: tff.Type,
checkpoint_future: Awaitable[bytes]) -> tff.structure.Struct:
self, result_type: tff.Type, checkpoint_future: Coroutine[Any, Any, bytes]
) -> tff.structure.Struct:
"""Creates the CheckpointTensorReference struct for a result type."""
shared_checkpoint_future = tff.async_utils.SharedAwaitable(
checkpoint_future)
task = asyncio.create_task(checkpoint_future)
tensor_specs = checkpoint_utils.tff_type_to_tensor_spec_list(result_type) # pytype: disable=wrong-arg-types
var_names = variable_helpers.variable_names_from_type(
result_type[0], # pytype: disable=unsupported-operands
Expand All @@ -307,7 +307,8 @@ def _create_tensor_reference_struct(
)
tensor_refs = [
checkpoint_tensor_reference.CheckpointTensorReference(
var_name, spec.dtype, spec.shape, shared_checkpoint_future)
var_name, spec.dtype, spec.shape, task
)
for var_name, spec in zip(var_names, tensor_specs)
]
return checkpoint_utils.pack_tff_value(
Expand Down
4 changes: 2 additions & 2 deletions fcp/demo/federated_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_invoke_with_invalid_data_source_type(self):
r'FederatedDataSource.iterator\(\).select\(\)'):
comp(0, plan_pb2.Plan())

def test_invoke_succeeds_with_structure_state_type(self):
async def test_invoke_succeeds_with_structure_state_type(self):
comp = federated_computation.FederatedComputation(
irregular_arrays, name='x'
)
Expand All @@ -195,7 +195,7 @@ def test_invoke_succeeds_with_structure_state_type(self):
state = {'foo': (3, 1), 'bar': (4, 5, 6)}
comp(state, DATA_SOURCE.iterator().select(1))

def test_invoke_succeeds_with_attrs_state_type(self):
async def test_invoke_succeeds_with_attrs_state_type(self):
comp = federated_computation.FederatedComputation(
attrs_computation, name='x'
)
Expand Down

0 comments on commit b1459df

Please sign in to comment.