Skip to content

Commit

Permalink
Make federated_min and federated_max TFF intrinsics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557695447
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed Aug 17, 2023
1 parent a13d9bb commit 6f31d6f
Show file tree
Hide file tree
Showing 14 changed files with 421 additions and 20 deletions.
2 changes: 2 additions & 0 deletions tensorflow_federated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_broadcast
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_eval
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_map
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_max
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_mean
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_min
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_modular_sum
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_select
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_sum
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/aggregators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ py_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
"//tensorflow_federated/python/common_libs:deprecation",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
Expand Down
23 changes: 11 additions & 12 deletions tensorflow_federated/python/aggregators/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import tensorflow as tf

from tensorflow_federated.python.common_libs import deprecation
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.federated_context import intrinsics
Expand Down Expand Up @@ -103,6 +104,10 @@ def zeros_fn():
return zeros_fn()


@deprecation.deprecated(
'`tff.aggregators.federated_min` is deprecated, use `tff.federated_min`'
' instead.'
)
def federated_min(value):
"""Computes the minimum at `tff.SERVER` of a `value` placed at `tff.CLIENTS`.
Expand All @@ -120,14 +125,13 @@ def federated_min(value):
A representation of the min of the member constituents of `value` placed at
`tff.SERVER`.
"""
_validate_value_on_clients(value)
member_type = value.type_signature.member
# Explicit cast because v.dtype.max returns a Python constant, which could be
# implicitly converted to a tensor of different dtype by TensorFlow.
zeros = _initial_values(lambda v: tf.cast(v.dtype.max, v.dtype), member_type)
return _federated_reduce_with_func(value, tf.minimum, zeros)
return intrinsics.federated_min(value)


@deprecation.deprecated(
'`tff.aggregators.federated_max` is deprecated, use `tff.federated_max`'
' instead.'
)
def federated_max(value):
"""Computes the maximum at `tff.SERVER` of a `value` placed at `tff.CLIENTS`.
Expand All @@ -145,12 +149,7 @@ def federated_max(value):
A representation of the max of the member constituents of `value` placed at
`tff.SERVER`.
"""
_validate_value_on_clients(value)
member_type = value.type_signature.member
# Explicit cast because v.dtype.min returns a Python constant, which could be
# implicitly converted to a tensor of different dtype by TensorFlow.
zeros = _initial_values(lambda v: tf.cast(v.dtype.min, v.dtype), member_type)
return _federated_reduce_with_func(value, tf.maximum, zeros)
return intrinsics.federated_max(value)


class _Samples(NamedTuple):
Expand Down
20 changes: 12 additions & 8 deletions tensorflow_federated/python/aggregators/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,11 @@ def call_federated_min(value):
self.assertAllClose(value, [[1.0, -1.0], -12.0])

def test_federated_min_wrong_type(self):
with self.assertRaisesRegex(TypeError, 'Unsupported dtype.'):

with self.assertRaisesRegex(
TypeError,
'The value type {bool}@CLIENTS is not compatible with the min'
' operator.',
):
@federated_computation.federated_computation(
computation_types.at_clients(tf.bool)
)
Expand All @@ -138,9 +141,8 @@ def call_federated_min(value):

def test_federated_min_wrong_placement(self):
with self.assertRaisesRegex(
TypeError, r'.* argument must be a tff.Value placed at CLIENTS'
TypeError, r'.*value to take min of should be placed at CLIENTS'
):

@federated_computation.federated_computation(
computation_types.at_server(tf.int32)
)
Expand Down Expand Up @@ -239,8 +241,11 @@ def call_federated_max(value):
self.assertAllClose(value, [[5.0, 6.0], 8.0])

def test_federated_max_wrong_type(self):
with self.assertRaisesRegex(TypeError, 'Unsupported dtype.'):

with self.assertRaisesRegex(
TypeError,
'The value type {bool}@CLIENTS is not compatible with the max'
' operator.',
):
@federated_computation.federated_computation(
computation_types.at_clients(tf.bool)
)
Expand All @@ -251,9 +256,8 @@ def call_federated_max(value):

def test_federated_max_wrong_placement(self):
with self.assertRaisesRegex(
TypeError, r'.*argument must be a tff.Value placed at CLIENTS.*'
TypeError, r'.*value to take max of should be placed at CLIENTS.*'
):

@federated_computation.federated_computation(
computation_types.at_server(tf.float32)
)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/core/impl/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ py_library(
":building_block_factory",
":building_blocks",
":intrinsic_defs",
":tensorflow_computation_factory",
":transformation_utils",
":tree_analysis",
"//tensorflow_federated/python/common_libs:py_typecheck",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,70 @@ def create_federated_mean(
return building_blocks.Call(intrinsic, value)


def create_federated_min(
value: building_blocks.ComputationBuildingBlock,
) -> building_blocks.Call:
r"""Creates a called federated min.
Call
/ \
Intrinsic Comp
Args:
value: A `building_blocks.ComputationBuildingBlock` to use as the value.
Returns:
A `building_blocks.Call`.
Raises:
TypeError: If any of the types do not match.
"""
py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock)
result_type = computation_types.FederatedType(
value.type_signature.member, # pytype: disable=attribute-error
placements.SERVER,
)
intrinsic_type = computation_types.FunctionType(
type_conversions.type_to_non_all_equal(value.type_signature), result_type
)
intrinsic = building_blocks.Intrinsic(
intrinsic_defs.FEDERATED_MIN.uri, intrinsic_type
)
return building_blocks.Call(intrinsic, value)


def create_federated_max(
value: building_blocks.ComputationBuildingBlock,
) -> building_blocks.Call:
r"""Creates a called federated max.
Call
/ \
Intrinsic Comp
Args:
value: A `building_blocks.ComputationBuildingBlock` to use as the value.
Returns:
A `building_blocks.Call`.
Raises:
TypeError: If any of the types do not match.
"""
py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock)
result_type = computation_types.FederatedType(
value.type_signature.member, # pytype: disable=attribute-error
placements.SERVER,
)
intrinsic_type = computation_types.FunctionType(
type_conversions.type_to_non_all_equal(value.type_signature), result_type
)
intrinsic = building_blocks.Intrinsic(
intrinsic_defs.FEDERATED_MAX.uri, intrinsic_type
)
return building_blocks.Call(intrinsic, value)


def create_null_federated_secure_modular_sum():
return create_federated_secure_modular_sum(
create_federated_value(building_blocks.Struct([]), placements.CLIENTS),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,34 @@ def test_returns_federated_weighted_mean(self):
self.assertEqual(str(comp.type_signature), 'int32@SERVER')


class CreateFederatedMinTest(absltest.TestCase):

def test_raises_type_error_with_none_value(self):
with self.assertRaises(TypeError):
building_block_factory.create_federated_min(None)

def test_returns_federated_min(self):
value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
value = building_blocks.Data('v', value_type)
comp = building_block_factory.create_federated_min(value)
self.assertEqual(comp.compact_representation(), 'federated_min(v)')
self.assertEqual(str(comp.type_signature), 'int32@SERVER')


class CreateFederatedMaxTest(absltest.TestCase):

def test_raises_type_error_with_none_value(self):
with self.assertRaises(TypeError):
building_block_factory.create_federated_max(None)

def test_returns_federated_max(self):
value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
value = building_blocks.Data('v', value_type)
comp = building_block_factory.create_federated_max(value)
self.assertEqual(comp.compact_representation(), 'federated_max(v)')
self.assertEqual(str(comp.type_signature), 'int32@SERVER')


class CreateFederatedSecureModularSumTest(absltest.TestCase):

def test_raises_type_error_with_none_value(self):
Expand Down
32 changes: 32 additions & 0 deletions tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,38 @@ def __repr__(self):
aggregation_kind=AggregationKind.DEFAULT,
)

# Computes the min of client values on the server. Only supported for numeric
# types, or nested structures made up of numeric computation_types.
#
# Type signature: {T}@CLIENTS -> T@SERVER
FEDERATED_MIN = IntrinsicDef(
'FEDERATED_MIN',
'federated_min',
computation_types.FunctionType(
parameter=computation_types.at_clients(
computation_types.AbstractType('T')
),
result=computation_types.at_server(computation_types.AbstractType('T')),
),
aggregation_kind=AggregationKind.DEFAULT,
)

# Computes the max of client values on the server. Only supported for numeric
# types, or nested structures made up of numeric computation_types.
#
# Type signature: {T}@CLIENTS -> T@SERVER
FEDERATED_MAX = IntrinsicDef(
'FEDERATED_MAX',
'federated_max',
computation_types.FunctionType(
parameter=computation_types.at_clients(
computation_types.AbstractType('T')
),
result=computation_types.at_server(computation_types.AbstractType('T')),
),
aggregation_kind=AggregationKind.DEFAULT,
)

# Computes the modular sum of client values on the server, securely. Only
# supported for integers or nested structures of integers.
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_federated.python.core.impl.compiler import building_block_factory
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
from tensorflow_federated.python.core.impl.compiler import tensorflow_computation_factory
from tensorflow_federated.python.core.impl.compiler import transformation_utils
from tensorflow_federated.python.core.impl.compiler import tree_analysis
from tensorflow_federated.python.core.impl.types import computation_types
Expand Down Expand Up @@ -712,6 +713,43 @@ def _apply_generic_op(op, arg):
return building_block_factory.apply_binary_operator_with_upcast(arg, op)


def _initial_values(initial_value_fn, member_type):
"""Create a nested structure of initial values.
Args:
initial_value_fn: A function that maps a tff.TensorType to a specific value
constant for initialization.
member_type: A `tff.Type` representing the member components of the
federated type.
Returns:
A building_blocks.ComputationBuildingBlock representing the initial values.
"""

def _fill(tensor_type):
proto, function_type = tensorflow_computation_factory.create_constant(
initial_value_fn(tensor_type), tensor_type
)
compiled = building_blocks.CompiledComputation(
proto, type_signature=function_type
)
return building_blocks.Call(compiled, None)

def _structify_bb(inner_value):
if isinstance(inner_value, dict):
return building_blocks.Struct(
[(k, _structify_bb(v)) for k, v in inner_value.items()]
)
if isinstance(inner_value, tuple) or isinstance(inner_value, list):
return building_blocks.Struct([_structify_bb(v) for v in inner_value])
assert isinstance(inner_value, building_blocks.ComputationBuildingBlock)
return inner_value

return _structify_bb(
type_conversions.structure_from_tensor_type_tree(_fill, member_type)
)


def get_intrinsic_reductions() -> (
dict[
str,
Expand Down Expand Up @@ -786,6 +824,36 @@ def federated_mean(arg):
mean_arg = building_blocks.Struct([(None, arg), (None, one)])
return federated_weighted_mean(mean_arg)

def federated_min(x):
py_typecheck.check_type(x, building_blocks.ComputationBuildingBlock)
operand_type = x.type_signature.member # pytype: disable=attribute-error
zero = _initial_values(lambda v: v.dtype.max, operand_type)
min_op = (
building_block_factory.create_tensorflow_binary_operator_with_upcast(
tf.minimum,
computation_types.StructType([operand_type, operand_type]),
)
)
identity = building_block_factory.create_identity(operand_type)
return building_block_factory.create_federated_aggregate(
x, zero, min_op, min_op, identity
)

def federated_max(x):
py_typecheck.check_type(x, building_blocks.ComputationBuildingBlock)
operand_type = x.type_signature.member # pytype: disable=attribute-error
zero = _initial_values(lambda v: v.dtype.min, operand_type)
max_op = (
building_block_factory.create_tensorflow_binary_operator_with_upcast(
tf.maximum,
computation_types.StructType([operand_type, operand_type]),
)
)
identity = building_block_factory.create_identity(operand_type)
return building_block_factory.create_federated_aggregate(
x, zero, max_op, max_op, identity
)

def federated_sum(x):
py_typecheck.check_type(x, building_blocks.ComputationBuildingBlock)
operand_type = x.type_signature.member # pytype: disable=attribute-error
Expand Down Expand Up @@ -860,6 +928,8 @@ def federated_sum(x):
intrinsic_bodies_by_uri = collections.OrderedDict([
(intrinsic_defs.FEDERATED_MEAN.uri, federated_mean),
(intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, federated_weighted_mean),
(intrinsic_defs.FEDERATED_MIN.uri, federated_min),
(intrinsic_defs.FEDERATED_MAX.uri, federated_max),
(intrinsic_defs.FEDERATED_SUM.uri, federated_sum),
(intrinsic_defs.GENERIC_DIVIDE.uri, generic_divide),
(intrinsic_defs.GENERIC_MULTIPLY.uri, generic_multiply),
Expand Down
Loading

0 comments on commit 6f31d6f

Please sign in to comment.