From 2b29c44ab20094a72c907457aa112ec028dcda0c Mon Sep 17 00:00:00 2001 From: Michael Reneer Date: Mon, 14 Aug 2023 14:36:24 -0700 Subject: [PATCH] Add and update type annotations for `BuildingBlock`s. * Added missing type annotations. * Fixed existing type annotations. * Removed type annotations for `cls`. * Removed `Any` type annotations. * Fixed lint errors. PiperOrigin-RevId: 556914851 --- .../core/impl/compiler/building_blocks.py | 120 +++++++++--------- 1 file changed, 61 insertions(+), 59 deletions(-) diff --git a/tensorflow_federated/python/core/impl/compiler/building_blocks.py b/tensorflow_federated/python/core/impl/compiler/building_blocks.py index dd47c0c54d..4123f7aa3d 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_blocks.py +++ b/tensorflow_federated/python/core/impl/compiler/building_blocks.py @@ -16,7 +16,8 @@ import abc from collections.abc import Iterable, Iterator import enum -from typing import Any, Optional +import typing +from typing import Optional, Union import zlib from tensorflow_federated.proto.v0 import computation_pb2 as pb @@ -242,17 +243,14 @@ class Reference(ComputationBuildingBlock): """ @classmethod - def from_proto( - cls: type['Reference'], - computation_proto: pb.Computation, - ) -> 'Reference': + def from_proto(cls, computation_proto: pb.Computation) -> 'Reference': _check_computation_oneof(computation_proto, 'reference') return cls( str(computation_proto.reference.name), type_serialization.deserialize_type(computation_proto.type), ) - def __init__(self, name, type_spec, context=None): + def __init__(self, name: str, type_spec: object, context=None): """Creates a reference to 'name' of type 'type_spec' in context 'context'. Args: @@ -282,7 +280,7 @@ def children(self) -> Iterator[ComputationBuildingBlock]: return iter(()) @property - def name(self): + def name(self) -> str: return self._name @property @@ -319,7 +317,12 @@ def from_proto( ) return cls(selection, index=computation_proto.selection.index) - def __init__(self, source, name=None, index=None): + def __init__( + self, + source: ComputationBuildingBlock, + name: Optional[str] = None, + index: Optional[int] = None, + ): """A selection from 'source' by a string or numeric 'name_or_index'. Exactly one of 'name' or 'index' must be specified (not None). @@ -337,20 +340,18 @@ def __init__(self, source, name=None, index=None): both are defined (not None). """ py_typecheck.check_type(source, ComputationBuildingBlock) - if name is None and index is None: - raise ValueError( - 'Must define either a name or index, and neither was specified.' - ) - if name is not None and index is not None: - raise ValueError( - 'Cannot simultaneously specify a name and an index, choose one.' - ) source_type = source.type_signature + # TODO: b/224484886 - Downcasting to all handled types. + source_type = typing.cast(Union[computation_types.StructType], source_type) if not isinstance(source_type, computation_types.StructType): raise TypeError( 'Expected the source of selection to be a TFF struct, ' 'instead found it to be of type {}.'.format(source_type) ) + if name is not None and index is not None: + raise ValueError( + 'Cannot simultaneously specify a name and an index, choose one.' + ) if name is not None: py_typecheck.check_type(name, str) if not name: @@ -363,7 +364,7 @@ def __init__(self, source, name=None, index=None): f'whose only named fields are {structure.name_list(source_type)}.' ) type_signature = source_type[name] - else: + elif index is not None: py_typecheck.check_type(index, int) length = len(source_type) if index < 0 or index >= length: @@ -372,6 +373,10 @@ def __init__(self, source, name=None, index=None): f'struct type: 0..{length}' ) type_signature = source_type[index] + else: + raise ValueError( + 'Must define either a name or index, and neither was specified.' + ) super().__init__(type_signature) self._source = source self._name = name @@ -435,15 +440,19 @@ def from_proto( computation_proto: pb.Computation, ) -> 'Struct': _check_computation_oneof(computation_proto, 'struct') - return cls( - [ - ( - str(e.name) if e.name else None, - ComputationBuildingBlock.from_proto(e.value), - ) - for e in computation_proto.struct.element - ] - ) + + def _element( + proto: pb.Struct.Element, + ) -> tuple[Optional[str], ComputationBuildingBlock]: + if proto.name: + name = str(proto.name) + else: + name = None + element = ComputationBuildingBlock.from_proto(proto.value) + return (name, element) + + elements = [_element(x) for x in computation_proto.struct.element] + return cls(elements) def __init__(self, elements, container_type=None): """Constructs a struct from the given list of elements. @@ -552,7 +561,11 @@ def from_proto( arg = None return cls(fn, arg) - def __init__(self, fn, arg=None): + def __init__( + self, + fn: ComputationBuildingBlock, + arg: Optional[ComputationBuildingBlock] = None, + ): """Creates a call to 'fn' with argument 'arg'. Args: @@ -566,29 +579,34 @@ def __init__(self, fn, arg=None): py_typecheck.check_type(fn, ComputationBuildingBlock) if arg is not None: py_typecheck.check_type(arg, ComputationBuildingBlock) - if not isinstance(fn.type_signature, computation_types.FunctionType): + function_type = fn.type_signature + # TODO: b/224484886 - Downcasting to all handled types. + function_type = typing.cast( + Union[computation_types.FunctionType], function_type + ) + if not isinstance(function_type, computation_types.FunctionType): raise TypeError( 'Expected fn to be of a functional type, ' - 'but found that its type is {}.'.format(fn.type_signature) + 'but found that its type is {}.'.format(function_type) ) - if fn.type_signature.parameter is not None: + if function_type.parameter is not None: if arg is None: raise TypeError( 'The invoked function expects an argument of type {}, ' - 'but got None instead.'.format(fn.type_signature.parameter) + 'but got None instead.'.format(function_type.parameter) ) - if not fn.type_signature.parameter.is_assignable_from(arg.type_signature): + if not function_type.parameter.is_assignable_from(arg.type_signature): raise TypeError( 'The parameter of the invoked function is expected to be of ' 'type {}, but the supplied argument is of an incompatible ' - 'type {}.'.format(fn.type_signature.parameter, arg.type_signature) + 'type {}.'.format(function_type.parameter, arg.type_signature) ) elif arg is not None: raise TypeError( 'The invoked function does not expect any parameters, but got ' 'an argument of type {}.'.format(py_typecheck.type_string(type(arg))) ) - super().__init__(fn.type_signature.result) + super().__init__(function_type.result) self._function = fn self._argument = arg @@ -612,11 +630,11 @@ def children(self) -> Iterator[ComputationBuildingBlock]: yield self._argument @property - def function(self): + def function(self) -> ComputationBuildingBlock: return self._function @property - def argument(self): + def argument(self) -> Optional[ComputationBuildingBlock]: return self._argument def __repr__(self): @@ -636,10 +654,7 @@ class Lambda(ComputationBuildingBlock): """ @classmethod - def from_proto( - cls: type['Lambda'], - computation_proto: pb.Computation, - ) -> 'Lambda': + def from_proto(cls, computation_proto: pb.Computation) -> 'Lambda': _check_computation_oneof(computation_proto, 'lambda') the_lambda = getattr(computation_proto, 'lambda') if computation_proto.type.function.HasField('parameter'): @@ -658,7 +673,7 @@ def from_proto( def __init__( self, parameter_name: Optional[str], - parameter_type: Optional[Any], + parameter_type: Optional[object], result: ComputationBuildingBlock, ): """Creates a lambda expression. @@ -780,10 +795,7 @@ class Block(ComputationBuildingBlock): """ @classmethod - def from_proto( - cls: type['Block'], - computation_proto: pb.Computation, - ) -> 'Block': + def from_proto(cls, computation_proto: pb.Computation) -> 'Block': _check_computation_oneof(computation_proto, 'block') return cls( [ @@ -881,10 +893,7 @@ class Intrinsic(ComputationBuildingBlock): """ @classmethod - def from_proto( - cls: type['Intrinsic'], - computation_proto: pb.Computation, - ) -> 'Intrinsic': + def from_proto(cls, computation_proto: pb.Computation) -> 'Intrinsic': _check_computation_oneof(computation_proto, 'intrinsic') return cls( computation_proto.intrinsic.uri, @@ -952,17 +961,14 @@ class Data(ComputationBuildingBlock): """ @classmethod - def from_proto( - cls: type['Data'], - computation_proto: pb.Computation, - ) -> 'Data': + def from_proto(cls, computation_proto: pb.Computation) -> 'Data': _check_computation_oneof(computation_proto, 'data') return cls( computation_proto.data.uri, type_serialization.deserialize_type(computation_proto.type), ) - def __init__(self, uri: str, type_spec: Any): + def __init__(self, uri: str, type_spec: object): """Creates a representation of data. Args: @@ -1077,10 +1083,7 @@ class Placement(ComputationBuildingBlock): """ @classmethod - def from_proto( - cls: type['Placement'], - computation_proto: pb.Computation, - ) -> 'Placement': + def from_proto(cls, computation_proto: pb.Computation) -> 'Placement': _check_computation_oneof(computation_proto, 'placement') return cls( placements.uri_to_placement_literal( @@ -1423,7 +1426,6 @@ def _fit_with_inset(left, right, inset): for left_line, right_line in zip(left, right): if inset > 0: left_inset = 0 - right_inset = 0 trailing_padding = _get_trailing_padding(left_line) if trailing_padding > 0: left_inset = min(trailing_padding, inset)