diff --git a/tensorflow_federated/__init__.py b/tensorflow_federated/__init__.py index cc1fa9327c..e5ca80a625 100644 --- a/tensorflow_federated/__init__.py +++ b/tensorflow_federated/__init__.py @@ -36,7 +36,9 @@ from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_broadcast from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_eval from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_map +from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_max from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_mean +from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_min from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_modular_sum from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_select from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_sum diff --git a/tensorflow_federated/python/aggregators/BUILD b/tensorflow_federated/python/aggregators/BUILD index 8769d0b282..7a56f54a98 100644 --- a/tensorflow_federated/python/aggregators/BUILD +++ b/tensorflow_federated/python/aggregators/BUILD @@ -465,6 +465,7 @@ py_library( name = "primitives", srcs = ["primitives.py"], deps = [ + "//tensorflow_federated/python/common_libs:deprecation", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/impl/federated_context:intrinsics", diff --git a/tensorflow_federated/python/aggregators/primitives.py b/tensorflow_federated/python/aggregators/primitives.py index 316cfec210..d4095d4806 100644 --- a/tensorflow_federated/python/aggregators/primitives.py +++ b/tensorflow_federated/python/aggregators/primitives.py @@ -17,6 +17,7 @@ import tensorflow as tf +from tensorflow_federated.python.common_libs import deprecation from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.impl.federated_context import intrinsics @@ -45,64 +46,10 @@ def _validate_dtype_is_min_max_compatible(dtype): ) -def _federated_reduce_with_func(value, tf_func, zeros): - """Applies to `tf_func` to accumulated `value`s. - - This utility provides a generic aggregation for accumulating a value and - applying a simple aggregation (like minimum or maximum aggregations). - - Args: - value: A `tff.Value` placed on the `tff.CLIENTS`. - tf_func: A function to be applied to the accumulated values. Must be a - binary operation where both parameters are of type `U` and the return type - is also `U`. - zeros: The zero of the same type as `value` in the algebra of reduction - operators. - - Returns: - A representation on the `tff.SERVER` of the result of aggregating `value`. - """ - value_type = value.type_signature.member - - @tensorflow_computation.tf_computation(value_type, value_type) - def accumulate(current, value): - return tf.nest.map_structure(tf_func, current, value) - - @tensorflow_computation.tf_computation(value_type) - def report(value): - return value - - return intrinsics.federated_aggregate( - value, zeros, accumulate, accumulate, report - ) - - -def _initial_values(initial_value_fn, member_type): - """Create a nested structure of initial values for the reduction. - - Args: - initial_value_fn: A function that maps a tff.TensorType to a specific value - constant for initialization. - member_type: A `tff.Type` representing the member components of the - federated type. - - Returns: - A function of the result of reducing a value with no constituents. - """ - - def validate_and_fill(type_spec: computation_types.TensorType) -> tf.Tensor: - _validate_dtype_is_min_max_compatible(type_spec.dtype) - return tf.fill(dims=type_spec.shape, value=initial_value_fn(type_spec)) - - @tensorflow_computation.tf_computation - def zeros_fn(): - return type_conversions.structure_from_tensor_type_tree( - validate_and_fill, member_type - ) - - return zeros_fn() - - +@deprecation.deprecated( + '`tff.aggregators.federated_min` is deprecated, use `tff.federated_min`' + ' instead.' +) def federated_min(value): """Computes the minimum at `tff.SERVER` of a `value` placed at `tff.CLIENTS`. @@ -120,14 +67,13 @@ def federated_min(value): A representation of the min of the member constituents of `value` placed at `tff.SERVER`. """ - _validate_value_on_clients(value) - member_type = value.type_signature.member - # Explicit cast because v.dtype.max returns a Python constant, which could be - # implicitly converted to a tensor of different dtype by TensorFlow. - zeros = _initial_values(lambda v: tf.cast(v.dtype.max, v.dtype), member_type) - return _federated_reduce_with_func(value, tf.minimum, zeros) + return intrinsics.federated_min(value_impl.to_value(value, type_spec=None)) +@deprecation.deprecated( + '`tff.aggregators.federated_max` is deprecated, use `tff.federated_max`' + ' instead.' +) def federated_max(value): """Computes the maximum at `tff.SERVER` of a `value` placed at `tff.CLIENTS`. @@ -145,12 +91,7 @@ def federated_max(value): A representation of the max of the member constituents of `value` placed at `tff.SERVER`. """ - _validate_value_on_clients(value) - member_type = value.type_signature.member - # Explicit cast because v.dtype.min returns a Python constant, which could be - # implicitly converted to a tensor of different dtype by TensorFlow. - zeros = _initial_values(lambda v: tf.cast(v.dtype.min, v.dtype), member_type) - return _federated_reduce_with_func(value, tf.maximum, zeros) + return intrinsics.federated_max(value_impl.to_value(value, type_spec=None)) class _Samples(NamedTuple): diff --git a/tensorflow_federated/python/aggregators/primitives_test.py b/tensorflow_federated/python/aggregators/primitives_test.py index b5996202d2..93c8cf829f 100644 --- a/tensorflow_federated/python/aggregators/primitives_test.py +++ b/tensorflow_federated/python/aggregators/primitives_test.py @@ -126,8 +126,7 @@ def call_federated_min(value): self.assertAllClose(value, [[1.0, -1.0], -12.0]) def test_federated_min_wrong_type(self): - with self.assertRaisesRegex(TypeError, 'Unsupported dtype.'): - + with self.assertRaises(ValueError): @federated_computation.federated_computation( computation_types.at_clients(tf.bool) ) @@ -137,10 +136,7 @@ def call_federated_min(value): call_federated_min([False]) def test_federated_min_wrong_placement(self): - with self.assertRaisesRegex( - TypeError, r'.* argument must be a tff.Value placed at CLIENTS' - ): - + with self.assertRaises(TypeError): @federated_computation.federated_computation( computation_types.at_server(tf.int32) ) @@ -239,8 +235,7 @@ def call_federated_max(value): self.assertAllClose(value, [[5.0, 6.0], 8.0]) def test_federated_max_wrong_type(self): - with self.assertRaisesRegex(TypeError, 'Unsupported dtype.'): - + with self.assertRaises(ValueError): @federated_computation.federated_computation( computation_types.at_clients(tf.bool) ) @@ -250,10 +245,7 @@ def call_federated_max(value): call_federated_max([True, False]) def test_federated_max_wrong_placement(self): - with self.assertRaisesRegex( - TypeError, r'.*argument must be a tff.Value placed at CLIENTS.*' - ): - + with self.assertRaises(TypeError): @federated_computation.federated_computation( computation_types.at_server(tf.float32) ) diff --git a/tensorflow_federated/python/core/impl/compiler/BUILD b/tensorflow_federated/python/core/impl/compiler/BUILD index 6ad7b67c49..e657778c82 100644 --- a/tensorflow_federated/python/core/impl/compiler/BUILD +++ b/tensorflow_federated/python/core/impl/compiler/BUILD @@ -437,6 +437,7 @@ py_library( ":building_block_factory", ":building_blocks", ":intrinsic_defs", + ":tensorflow_computation_factory", ":transformation_utils", ":tree_analysis", "//tensorflow_federated/python/common_libs:py_typecheck", diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory.py b/tensorflow_federated/python/core/impl/compiler/building_block_factory.py index fb10cf30b3..2603c79b33 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory.py +++ b/tensorflow_federated/python/core/impl/compiler/building_block_factory.py @@ -1079,6 +1079,72 @@ def create_federated_mean( return building_blocks.Call(intrinsic, value) +def create_federated_min( + value: building_blocks.ComputationBuildingBlock, +) -> building_blocks.Call: + r"""Creates a called federated min. + + Call + / \ + Intrinsic Comp + + Args: + value: A `building_blocks.ComputationBuildingBlock` to use as the value. + + Returns: + A `building_blocks.Call`. + + Raises: + ValueError: If any of the types do not match. + """ + if not isinstance(value.type_signature, computation_types.FederatedType): + raise ValueError('Expected a federated value.') + result_type = computation_types.FederatedType( + value.type_signature.member, + placements.SERVER, + ) + intrinsic_type = computation_types.FunctionType( + type_conversions.type_to_non_all_equal(value.type_signature), result_type + ) + intrinsic = building_blocks.Intrinsic( + intrinsic_defs.FEDERATED_MIN.uri, intrinsic_type + ) + return building_blocks.Call(intrinsic, value) + + +def create_federated_max( + value: building_blocks.ComputationBuildingBlock, +) -> building_blocks.Call: + r"""Creates a called federated max. + + Call + / \ + Intrinsic Comp + + Args: + value: A `building_blocks.ComputationBuildingBlock` to use as the value. + + Returns: + A `building_blocks.Call`. + + Raises: + ValueError: If any of the types do not match. + """ + if not isinstance(value.type_signature, computation_types.FederatedType): + raise ValueError('Expected a federated value.') + result_type = computation_types.FederatedType( + value.type_signature.member, + placements.SERVER, + ) + intrinsic_type = computation_types.FunctionType( + type_conversions.type_to_non_all_equal(value.type_signature), result_type + ) + intrinsic = building_blocks.Intrinsic( + intrinsic_defs.FEDERATED_MAX.uri, intrinsic_type + ) + return building_blocks.Call(intrinsic, value) + + def create_null_federated_secure_modular_sum(): return create_federated_secure_modular_sum( create_federated_value(building_blocks.Struct([]), placements.CLIENTS), diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py index c29b8429d1..e5903e5bcd 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py +++ b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py @@ -678,6 +678,30 @@ def test_returns_federated_weighted_mean(self): self.assertEqual(str(comp.type_signature), 'int32@SERVER') +class CreateFederatedMinTest(absltest.TestCase): + + def test_returns_federated_min(self): + value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) + value = building_blocks.Data('v', value_type) + comp = building_block_factory.create_federated_min(value) + self.assertEqual(comp.compact_representation(), 'federated_min(v)') + self.assertEqual( + comp.type_signature.compact_representation(), 'int32@SERVER' + ) + + +class CreateFederatedMaxTest(absltest.TestCase): + + def test_returns_federated_max(self): + value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) + value = building_blocks.Data('v', value_type) + comp = building_block_factory.create_federated_max(value) + self.assertEqual(comp.compact_representation(), 'federated_max(v)') + self.assertEqual( + comp.type_signature.compact_representation(), 'int32@SERVER' + ) + + class CreateFederatedSecureModularSumTest(absltest.TestCase): def test_raises_type_error_with_none_value(self): diff --git a/tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py b/tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py index 273b1dd0cc..d728b85e29 100644 --- a/tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py +++ b/tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py @@ -360,6 +360,38 @@ def __repr__(self): aggregation_kind=AggregationKind.DEFAULT, ) +# Computes the min of client values on the server. Only supported for numeric +# types, or nested structures made up of numeric computation_types. +# +# Type signature: {T}@CLIENTS -> T@SERVER +FEDERATED_MIN = IntrinsicDef( + 'FEDERATED_MIN', + 'federated_min', + computation_types.FunctionType( + parameter=computation_types.at_clients( + computation_types.AbstractType('T') + ), + result=computation_types.at_server(computation_types.AbstractType('T')), + ), + aggregation_kind=AggregationKind.DEFAULT, +) + +# Computes the max of client values on the server. Only supported for numeric +# types, or nested structures made up of numeric computation_types. +# +# Type signature: {T}@CLIENTS -> T@SERVER +FEDERATED_MAX = IntrinsicDef( + 'FEDERATED_MAX', + 'federated_max', + computation_types.FunctionType( + parameter=computation_types.at_clients( + computation_types.AbstractType('T') + ), + result=computation_types.at_server(computation_types.AbstractType('T')), + ), + aggregation_kind=AggregationKind.DEFAULT, +) + # Computes the modular sum of client values on the server, securely. Only # supported for integers or nested structures of integers. # diff --git a/tensorflow_federated/python/core/impl/compiler/tree_transformations.py b/tensorflow_federated/python/core/impl/compiler/tree_transformations.py index de6de78926..c197bc88d6 100644 --- a/tensorflow_federated/python/core/impl/compiler/tree_transformations.py +++ b/tensorflow_federated/python/core/impl/compiler/tree_transformations.py @@ -15,6 +15,7 @@ import collections from collections.abc import Callable, Sequence +from typing import Any import tensorflow as tf @@ -24,6 +25,7 @@ from tensorflow_federated.python.core.impl.compiler import building_block_factory from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import intrinsic_defs +from tensorflow_federated.python.core.impl.compiler import tensorflow_computation_factory from tensorflow_federated.python.core.impl.compiler import transformation_utils from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.types import computation_types @@ -712,6 +714,51 @@ def _apply_generic_op(op, arg): return building_block_factory.apply_binary_operator_with_upcast(arg, op) +def _initial_values( + initial_value_fn: Callable[[computation_types.TensorType], Any], + member_type: computation_types.Type, +) -> building_blocks.ComputationBuildingBlock: + """Create a nested structure of initial values. + + Args: + initial_value_fn: A function that maps a tff.TensorType to a specific value + constant for initialization. + member_type: A `tff.Type` representing the member components of the + federated type. + + Returns: + A building_blocks.ComputationBuildingBlock representing the initial values. + """ + + def _fill(tensor_type: computation_types.TensorType) -> building_blocks.Call: + computation_proto, function_type = ( + tensorflow_computation_factory.create_constant( + initial_value_fn(tensor_type), tensor_type + ) + ) + compiled = building_blocks.CompiledComputation( + computation_proto, type_signature=function_type + ) + return building_blocks.Call(compiled) + + def _structify_bb( + inner_value: Any, + ) -> building_blocks.ComputationBuildingBlock: + if isinstance(inner_value, dict): + return building_blocks.Struct( + [(k, _structify_bb(v)) for k, v in inner_value.items()] + ) + if isinstance(inner_value, (tuple, list)): + return building_blocks.Struct([_structify_bb(v) for v in inner_value]) + if not isinstance(inner_value, building_blocks.ComputationBuildingBlock): + raise ValueError('Encountered unexpected value: ' + str(inner_value)) + return inner_value + + return _structify_bb( + type_conversions.structure_from_tensor_type_tree(_fill, member_type) + ) + + def get_intrinsic_reductions() -> ( dict[ str, @@ -786,6 +833,38 @@ def federated_mean(arg): mean_arg = building_blocks.Struct([(None, arg), (None, one)]) return federated_weighted_mean(mean_arg) + def federated_min(x: building_blocks.ComputationBuildingBlock): + if not isinstance(x.type_signature, computation_types.FederatedType): + raise TypeError('Expected a federated value.') + operand_type = x.type_signature.member + zero = _initial_values(lambda v: v.dtype.max, operand_type) + min_op = ( + building_block_factory.create_tensorflow_binary_operator_with_upcast( + tf.minimum, + computation_types.StructType([operand_type, operand_type]), + ) + ) + identity = building_block_factory.create_identity(operand_type) + return building_block_factory.create_federated_aggregate( + x, zero, min_op, min_op, identity + ) + + def federated_max(x: building_blocks.ComputationBuildingBlock): + if not isinstance(x.type_signature, computation_types.FederatedType): + raise TypeError('Expected a federated value.') + operand_type = x.type_signature.member + zero = _initial_values(lambda v: v.dtype.min, operand_type) + max_op = ( + building_block_factory.create_tensorflow_binary_operator_with_upcast( + tf.maximum, + computation_types.StructType([operand_type, operand_type]), + ) + ) + identity = building_block_factory.create_identity(operand_type) + return building_block_factory.create_federated_aggregate( + x, zero, max_op, max_op, identity + ) + def federated_sum(x): py_typecheck.check_type(x, building_blocks.ComputationBuildingBlock) operand_type = x.type_signature.member # pytype: disable=attribute-error @@ -860,6 +939,8 @@ def federated_sum(x): intrinsic_bodies_by_uri = collections.OrderedDict([ (intrinsic_defs.FEDERATED_MEAN.uri, federated_mean), (intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, federated_weighted_mean), + (intrinsic_defs.FEDERATED_MIN.uri, federated_min), + (intrinsic_defs.FEDERATED_MAX.uri, federated_max), (intrinsic_defs.FEDERATED_SUM.uri, federated_sum), (intrinsic_defs.GENERIC_DIVIDE.uri, generic_divide), (intrinsic_defs.GENERIC_MULTIPLY.uri, generic_multiply), diff --git a/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py b/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py index 78e978426f..1a46ca8170 100644 --- a/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py +++ b/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py @@ -1300,6 +1300,60 @@ def test_federated_weighted_mean_reduces_to_aggregate(self): self.assertEqual(count_means_after_reduction, 0) self.assertGreater(count_aggregations, 0) + def test_federated_min_reduces_to_aggregate(self): + uri = intrinsic_defs.FEDERATED_MIN.uri + + comp = building_blocks.Intrinsic( + uri, + computation_types.FunctionType( + computation_types.at_clients(tf.float32), + computation_types.at_server(tf.float32), + ), + ) + + count_min_before_reduction = _count_intrinsics(comp, uri) + reduced, modified = tree_transformations.replace_intrinsics_with_bodies( + comp + ) + count_min_after_reduction = _count_intrinsics(reduced, uri) + count_aggregations = _count_intrinsics( + reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri + ) + self.assertTrue(modified) + type_test_utils.assert_types_identical( + comp.type_signature, reduced.type_signature + ) + self.assertGreater(count_min_before_reduction, 0) + self.assertEqual(count_min_after_reduction, 0) + self.assertGreater(count_aggregations, 0) + + def test_federated_max_reduces_to_aggregate(self): + uri = intrinsic_defs.FEDERATED_MAX.uri + + comp = building_blocks.Intrinsic( + uri, + computation_types.FunctionType( + computation_types.at_clients(tf.float32), + computation_types.at_server(tf.float32), + ), + ) + + count_max_before_reduction = _count_intrinsics(comp, uri) + reduced, modified = tree_transformations.replace_intrinsics_with_bodies( + comp + ) + count_max_after_reduction = _count_intrinsics(reduced, uri) + count_aggregations = _count_intrinsics( + reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri + ) + self.assertTrue(modified) + type_test_utils.assert_types_identical( + comp.type_signature, reduced.type_signature + ) + self.assertGreater(count_max_before_reduction, 0) + self.assertEqual(count_max_after_reduction, 0) + self.assertGreater(count_aggregations, 0) + def test_federated_sum_reduces_to_aggregate(self): uri = intrinsic_defs.FEDERATED_SUM.uri diff --git a/tensorflow_federated/python/core/impl/federated_context/intrinsics.py b/tensorflow_federated/python/core/impl/federated_context/intrinsics.py index f90cd87a23..668cce62a1 100644 --- a/tensorflow_federated/python/core/impl/federated_context/intrinsics.py +++ b/tensorflow_federated/python/core/impl/federated_context/intrinsics.py @@ -13,6 +13,7 @@ # limitations under the License. """A factory of intrinsics for use in composing federated computations.""" +from typing import Any import warnings import tensorflow as tf @@ -415,6 +416,64 @@ def federated_mean(value, weight=None): return value_impl.Value(comp) +def federated_min(value: Any) -> value_impl.Value: + """Computes a min at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. + + Args: + value: A value of a TFF federated type placed at the `tff.CLIENTS`. + + Returns: + A representation of the min of the member constituents of `value` placed on + the `tff.SERVER`. + + Raises: + ValueError: If the argument is not a federated TFF value placed at + `tff.CLIENTS` compatible with min. + """ + value = value_impl.to_value(value, type_spec=None) + value = value_utils.ensure_federated_value( + value, placements.CLIENTS, 'value to take min of' + ) + if not type_analysis.is_min_max_compatible(value.type_signature): + raise ValueError( + 'The value type {} is not compatible with the min operator.'.format( + value.type_signature + ) + ) + comp = building_block_factory.create_federated_min(value.comp) + comp = _bind_comp_as_reference(comp) + return value_impl.Value(comp) + + +def federated_max(value: Any) -> value_impl.Value: + """Computes a max at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. + + Args: + value: A value of a TFF federated type placed at the `tff.CLIENTS`. + + Returns: + A representation of the max of the member constituents of `value` placed on + the `tff.SERVER`. + + Raises: + ValueError: If the argument is not a federated TFF value placed at + `tff.CLIENTS` compatible with max. + """ + value = value_impl.to_value(value, type_spec=None) + value = value_utils.ensure_federated_value( + value, placements.CLIENTS, 'value to take max of' + ) + if not type_analysis.is_min_max_compatible(value.type_signature): + raise ValueError( + 'The value type {} is not compatible with the max operator.'.format( + value.type_signature + ) + ) + comp = building_block_factory.create_federated_max(value.comp) + comp = _bind_comp_as_reference(comp) + return value_impl.Value(comp) + + def federated_sum(value): """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. diff --git a/tensorflow_federated/python/core/impl/federated_context/intrinsics_test.py b/tensorflow_federated/python/core/impl/federated_context/intrinsics_test.py index 1f04526bbc..c7939588cf 100644 --- a/tensorflow_federated/python/core/impl/federated_context/intrinsics_test.py +++ b/tensorflow_federated/python/core/impl/federated_context/intrinsics_test.py @@ -878,6 +878,40 @@ def test_federated_mean_with_string_weight_fails(self): intrinsics.federated_mean(values, weights) +class FederatedMinTest(parameterized.TestCase, IntrinsicTestBase): + + def test_federated_min_with_client_int(self): + x = _mock_data_of_type(computation_types.at_clients(tf.int32)) + val = intrinsics.federated_min(x) + self.assert_value(val, 'int32@SERVER') + + @parameterized.named_parameters( + ('client_string', computation_types.at_clients(tf.string)), + ('server_int', computation_types.at_server(tf.int32)), + ) + def test_federated_min_with_invalid_type(self, data_type): + x = _mock_data_of_type(data_type) + with self.assertRaises(Exception): + intrinsics.federated_min(x) + + +class FederatedMaxTest(parameterized.TestCase, IntrinsicTestBase): + + def test_federated_max_with_client_int(self): + x = _mock_data_of_type(computation_types.at_clients(tf.int32)) + val = intrinsics.federated_max(x) + self.assert_value(val, 'int32@SERVER') + + @parameterized.named_parameters( + ('client_string', computation_types.at_clients(tf.string)), + ('server_int', computation_types.at_server(tf.int32)), + ) + def test_federated_max_with_invalid_type(self, data_type): + x = _mock_data_of_type(data_type) + with self.assertRaises(Exception): + intrinsics.federated_max(x) + + class FederatedAggregateTest(IntrinsicTestBase): def test_federated_aggregate_with_client_int(self): diff --git a/tensorflow_federated/python/core/impl/types/type_analysis.py b/tensorflow_federated/python/core/impl/types/type_analysis.py index d89f220ee1..761ed2bad3 100644 --- a/tensorflow_federated/python/core/impl/types/type_analysis.py +++ b/tensorflow_federated/python/core/impl/types/type_analysis.py @@ -572,6 +572,30 @@ def is_average_compatible(type_spec: computation_types.Type) -> bool: return False +def is_min_max_compatible(type_spec: computation_types.Type) -> bool: + """Determines if `type_spec` is min/max compatible. + + Types that are min/max-compatible are composed of integer or floating tensor + types, possibly packaged into nested tuples and possibly federated. + + Args: + type_spec: a `computation_types.Type`. + + Returns: + `True` iff `type_spec` is min/max compatible, `False` otherwise. + """ + if isinstance(type_spec, computation_types.TensorType): + return type_spec.dtype.is_integer or type_spec.dtype.is_floating + elif isinstance(type_spec, computation_types.StructType): + return all( + is_min_max_compatible(v) for _, v in structure.iter_elements(type_spec) + ) + elif isinstance(type_spec, computation_types.FederatedType): + return is_min_max_compatible(type_spec.member) + else: + return False + + def is_struct_with_py_container(value, type_spec): return isinstance(value, structure.Struct) and isinstance( type_spec, computation_types.StructWithPythonType diff --git a/tensorflow_federated/python/core/impl/types/type_analysis_test.py b/tensorflow_federated/python/core/impl/types/type_analysis_test.py index b781a1c4d1..f101eb8cc7 100644 --- a/tensorflow_federated/python/core/impl/types/type_analysis_test.py +++ b/tensorflow_federated/python/core/impl/types/type_analysis_test.py @@ -267,6 +267,33 @@ def test_returns_false(self, type_spec): self.assertFalse(type_analysis.is_average_compatible(type_spec)) +class IsMinMaxCompatibleTest(parameterized.TestCase): + + @parameterized.named_parameters([ + ('tensor_type_int32', computation_types.TensorType(tf.int32)), + ('tensor_type_int64', computation_types.TensorType(tf.int64)), + ('tensor_type_float32', computation_types.TensorType(tf.float32)), + ('tensor_type_float64', computation_types.TensorType(tf.float64)), + ( + 'tuple_type', + computation_types.StructType([('x', tf.float32), ('y', tf.int64)]), + ), + ( + 'federated_type', + computation_types.FederatedType(tf.float32, placements.CLIENTS), + ), + ]) + def test_returns_true(self, type_spec): + self.assertTrue(type_analysis.is_min_max_compatible(type_spec)) + + @parameterized.named_parameters([ + ('tensor_type_complex', computation_types.TensorType(tf.complex128)), + ('sequence_type', computation_types.SequenceType(tf.float32)), + ]) + def test_returns_false(self, type_spec): + self.assertFalse(type_analysis.is_min_max_compatible(type_spec)) + + class CheckTypeTest(absltest.TestCase): def test_raises_type_error(self):