Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add and update type annotations for BuildingBlocks. #4066

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 61 additions & 59 deletions tensorflow_federated/python/core/impl/compiler/building_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import abc
from collections.abc import Iterable, Iterator
import enum
from typing import Any, Optional
import typing
from typing import Optional, Union
import zlib

from tensorflow_federated.proto.v0 import computation_pb2 as pb
Expand Down Expand Up @@ -242,17 +243,14 @@ class Reference(ComputationBuildingBlock):
"""

@classmethod
def from_proto(
cls: type['Reference'],
computation_proto: pb.Computation,
) -> 'Reference':
def from_proto(cls, computation_proto: pb.Computation) -> 'Reference':
_check_computation_oneof(computation_proto, 'reference')
return cls(
str(computation_proto.reference.name),
type_serialization.deserialize_type(computation_proto.type),
)

def __init__(self, name, type_spec, context=None):
def __init__(self, name: str, type_spec: object, context=None):
"""Creates a reference to 'name' of type 'type_spec' in context 'context'.

Args:
Expand Down Expand Up @@ -282,7 +280,7 @@ def children(self) -> Iterator[ComputationBuildingBlock]:
return iter(())

@property
def name(self):
def name(self) -> str:
return self._name

@property
Expand Down Expand Up @@ -319,7 +317,12 @@ def from_proto(
)
return cls(selection, index=computation_proto.selection.index)

def __init__(self, source, name=None, index=None):
def __init__(
self,
source: ComputationBuildingBlock,
name: Optional[str] = None,
index: Optional[int] = None,
):
"""A selection from 'source' by a string or numeric 'name_or_index'.

Exactly one of 'name' or 'index' must be specified (not None).
Expand All @@ -337,20 +340,18 @@ def __init__(self, source, name=None, index=None):
both are defined (not None).
"""
py_typecheck.check_type(source, ComputationBuildingBlock)
if name is None and index is None:
raise ValueError(
'Must define either a name or index, and neither was specified.'
)
if name is not None and index is not None:
raise ValueError(
'Cannot simultaneously specify a name and an index, choose one.'
)
source_type = source.type_signature
# TODO: b/224484886 - Downcasting to all handled types.
source_type = typing.cast(Union[computation_types.StructType], source_type)
if not isinstance(source_type, computation_types.StructType):
raise TypeError(
'Expected the source of selection to be a TFF struct, '
'instead found it to be of type {}.'.format(source_type)
)
if name is not None and index is not None:
raise ValueError(
'Cannot simultaneously specify a name and an index, choose one.'
)
if name is not None:
py_typecheck.check_type(name, str)
if not name:
Expand All @@ -363,7 +364,7 @@ def __init__(self, source, name=None, index=None):
f'whose only named fields are {structure.name_list(source_type)}.'
)
type_signature = source_type[name]
else:
elif index is not None:
py_typecheck.check_type(index, int)
length = len(source_type)
if index < 0 or index >= length:
Expand All @@ -372,6 +373,10 @@ def __init__(self, source, name=None, index=None):
f'struct type: 0..{length}'
)
type_signature = source_type[index]
else:
raise ValueError(
'Must define either a name or index, and neither was specified.'
)
super().__init__(type_signature)
self._source = source
self._name = name
Expand Down Expand Up @@ -435,15 +440,19 @@ def from_proto(
computation_proto: pb.Computation,
) -> 'Struct':
_check_computation_oneof(computation_proto, 'struct')
return cls(
[
(
str(e.name) if e.name else None,
ComputationBuildingBlock.from_proto(e.value),
)
for e in computation_proto.struct.element
]
)

def _element(
proto: pb.Struct.Element,
) -> tuple[Optional[str], ComputationBuildingBlock]:
if proto.name:
name = str(proto.name)
else:
name = None
element = ComputationBuildingBlock.from_proto(proto.value)
return (name, element)

elements = [_element(x) for x in computation_proto.struct.element]
return cls(elements)

def __init__(self, elements, container_type=None):
"""Constructs a struct from the given list of elements.
Expand Down Expand Up @@ -552,7 +561,11 @@ def from_proto(
arg = None
return cls(fn, arg)

def __init__(self, fn, arg=None):
def __init__(
self,
fn: ComputationBuildingBlock,
arg: Optional[ComputationBuildingBlock] = None,
):
"""Creates a call to 'fn' with argument 'arg'.

Args:
Expand All @@ -566,29 +579,34 @@ def __init__(self, fn, arg=None):
py_typecheck.check_type(fn, ComputationBuildingBlock)
if arg is not None:
py_typecheck.check_type(arg, ComputationBuildingBlock)
if not isinstance(fn.type_signature, computation_types.FunctionType):
function_type = fn.type_signature
# TODO: b/224484886 - Downcasting to all handled types.
function_type = typing.cast(
Union[computation_types.FunctionType], function_type
)
if not isinstance(function_type, computation_types.FunctionType):
raise TypeError(
'Expected fn to be of a functional type, '
'but found that its type is {}.'.format(fn.type_signature)
'but found that its type is {}.'.format(function_type)
)
if fn.type_signature.parameter is not None:
if function_type.parameter is not None:
if arg is None:
raise TypeError(
'The invoked function expects an argument of type {}, '
'but got None instead.'.format(fn.type_signature.parameter)
'but got None instead.'.format(function_type.parameter)
)
if not fn.type_signature.parameter.is_assignable_from(arg.type_signature):
if not function_type.parameter.is_assignable_from(arg.type_signature):
raise TypeError(
'The parameter of the invoked function is expected to be of '
'type {}, but the supplied argument is of an incompatible '
'type {}.'.format(fn.type_signature.parameter, arg.type_signature)
'type {}.'.format(function_type.parameter, arg.type_signature)
)
elif arg is not None:
raise TypeError(
'The invoked function does not expect any parameters, but got '
'an argument of type {}.'.format(py_typecheck.type_string(type(arg)))
)
super().__init__(fn.type_signature.result)
super().__init__(function_type.result)
self._function = fn
self._argument = arg

Expand All @@ -612,11 +630,11 @@ def children(self) -> Iterator[ComputationBuildingBlock]:
yield self._argument

@property
def function(self):
def function(self) -> ComputationBuildingBlock:
return self._function

@property
def argument(self):
def argument(self) -> Optional[ComputationBuildingBlock]:
return self._argument

def __repr__(self):
Expand All @@ -636,10 +654,7 @@ class Lambda(ComputationBuildingBlock):
"""

@classmethod
def from_proto(
cls: type['Lambda'],
computation_proto: pb.Computation,
) -> 'Lambda':
def from_proto(cls, computation_proto: pb.Computation) -> 'Lambda':
_check_computation_oneof(computation_proto, 'lambda')
the_lambda = getattr(computation_proto, 'lambda')
if computation_proto.type.function.HasField('parameter'):
Expand All @@ -658,7 +673,7 @@ def from_proto(
def __init__(
self,
parameter_name: Optional[str],
parameter_type: Optional[Any],
parameter_type: Optional[object],
result: ComputationBuildingBlock,
):
"""Creates a lambda expression.
Expand Down Expand Up @@ -780,10 +795,7 @@ class Block(ComputationBuildingBlock):
"""

@classmethod
def from_proto(
cls: type['Block'],
computation_proto: pb.Computation,
) -> 'Block':
def from_proto(cls, computation_proto: pb.Computation) -> 'Block':
_check_computation_oneof(computation_proto, 'block')
return cls(
[
Expand Down Expand Up @@ -881,10 +893,7 @@ class Intrinsic(ComputationBuildingBlock):
"""

@classmethod
def from_proto(
cls: type['Intrinsic'],
computation_proto: pb.Computation,
) -> 'Intrinsic':
def from_proto(cls, computation_proto: pb.Computation) -> 'Intrinsic':
_check_computation_oneof(computation_proto, 'intrinsic')
return cls(
computation_proto.intrinsic.uri,
Expand Down Expand Up @@ -952,17 +961,14 @@ class Data(ComputationBuildingBlock):
"""

@classmethod
def from_proto(
cls: type['Data'],
computation_proto: pb.Computation,
) -> 'Data':
def from_proto(cls, computation_proto: pb.Computation) -> 'Data':
_check_computation_oneof(computation_proto, 'data')
return cls(
computation_proto.data.uri,
type_serialization.deserialize_type(computation_proto.type),
)

def __init__(self, uri: str, type_spec: Any):
def __init__(self, uri: str, type_spec: object):
"""Creates a representation of data.

Args:
Expand Down Expand Up @@ -1077,10 +1083,7 @@ class Placement(ComputationBuildingBlock):
"""

@classmethod
def from_proto(
cls: type['Placement'],
computation_proto: pb.Computation,
) -> 'Placement':
def from_proto(cls, computation_proto: pb.Computation) -> 'Placement':
_check_computation_oneof(computation_proto, 'placement')
return cls(
placements.uri_to_placement_literal(
Expand Down Expand Up @@ -1423,7 +1426,6 @@ def _fit_with_inset(left, right, inset):
for left_line, right_line in zip(left, right):
if inset > 0:
left_inset = 0
right_inset = 0
trailing_padding = _get_trailing_padding(left_line)
if trailing_padding > 0:
left_inset = min(trailing_padding, inset)
Expand Down