Skip to content

Commit

Permalink
Make federated_min and federated_max TFF intrinsics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557695447
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed Aug 18, 2023
1 parent e9e4cdc commit 6097747
Show file tree
Hide file tree
Showing 14 changed files with 420 additions and 82 deletions.
2 changes: 2 additions & 0 deletions tensorflow_federated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_broadcast
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_eval
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_map
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_max
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_mean
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_min
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_modular_sum
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_select
from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_sum
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/aggregators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ py_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
"//tensorflow_federated/python/common_libs:deprecation",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
Expand Down
81 changes: 11 additions & 70 deletions tensorflow_federated/python/aggregators/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import tensorflow as tf

from tensorflow_federated.python.common_libs import deprecation
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.federated_context import intrinsics
Expand Down Expand Up @@ -45,64 +46,10 @@ def _validate_dtype_is_min_max_compatible(dtype):
)


def _federated_reduce_with_func(value, tf_func, zeros):
"""Applies to `tf_func` to accumulated `value`s.
This utility provides a generic aggregation for accumulating a value and
applying a simple aggregation (like minimum or maximum aggregations).
Args:
value: A `tff.Value` placed on the `tff.CLIENTS`.
tf_func: A function to be applied to the accumulated values. Must be a
binary operation where both parameters are of type `U` and the return type
is also `U`.
zeros: The zero of the same type as `value` in the algebra of reduction
operators.
Returns:
A representation on the `tff.SERVER` of the result of aggregating `value`.
"""
value_type = value.type_signature.member

@tensorflow_computation.tf_computation(value_type, value_type)
def accumulate(current, value):
return tf.nest.map_structure(tf_func, current, value)

@tensorflow_computation.tf_computation(value_type)
def report(value):
return value

return intrinsics.federated_aggregate(
value, zeros, accumulate, accumulate, report
)


def _initial_values(initial_value_fn, member_type):
"""Create a nested structure of initial values for the reduction.
Args:
initial_value_fn: A function that maps a tff.TensorType to a specific value
constant for initialization.
member_type: A `tff.Type` representing the member components of the
federated type.
Returns:
A function of the result of reducing a value with no constituents.
"""

def validate_and_fill(type_spec: computation_types.TensorType) -> tf.Tensor:
_validate_dtype_is_min_max_compatible(type_spec.dtype)
return tf.fill(dims=type_spec.shape, value=initial_value_fn(type_spec))

@tensorflow_computation.tf_computation
def zeros_fn():
return type_conversions.structure_from_tensor_type_tree(
validate_and_fill, member_type
)

return zeros_fn()


@deprecation.deprecated(
'`tff.aggregators.federated_min` is deprecated, use `tff.federated_min`'
' instead.'
)
def federated_min(value):
"""Computes the minimum at `tff.SERVER` of a `value` placed at `tff.CLIENTS`.
Expand All @@ -120,14 +67,13 @@ def federated_min(value):
A representation of the min of the member constituents of `value` placed at
`tff.SERVER`.
"""
_validate_value_on_clients(value)
member_type = value.type_signature.member
# Explicit cast because v.dtype.max returns a Python constant, which could be
# implicitly converted to a tensor of different dtype by TensorFlow.
zeros = _initial_values(lambda v: tf.cast(v.dtype.max, v.dtype), member_type)
return _federated_reduce_with_func(value, tf.minimum, zeros)
return intrinsics.federated_min(value_impl.to_value(value, type_spec=None))


@deprecation.deprecated(
'`tff.aggregators.federated_max` is deprecated, use `tff.federated_max`'
' instead.'
)
def federated_max(value):
"""Computes the maximum at `tff.SERVER` of a `value` placed at `tff.CLIENTS`.
Expand All @@ -145,12 +91,7 @@ def federated_max(value):
A representation of the max of the member constituents of `value` placed at
`tff.SERVER`.
"""
_validate_value_on_clients(value)
member_type = value.type_signature.member
# Explicit cast because v.dtype.min returns a Python constant, which could be
# implicitly converted to a tensor of different dtype by TensorFlow.
zeros = _initial_values(lambda v: tf.cast(v.dtype.min, v.dtype), member_type)
return _federated_reduce_with_func(value, tf.maximum, zeros)
return intrinsics.federated_max(value_impl.to_value(value, type_spec=None))


class _Samples(NamedTuple):
Expand Down
16 changes: 4 additions & 12 deletions tensorflow_federated/python/aggregators/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def call_federated_min(value):
self.assertAllClose(value, [[1.0, -1.0], -12.0])

def test_federated_min_wrong_type(self):
with self.assertRaisesRegex(TypeError, 'Unsupported dtype.'):

with self.assertRaises(ValueError):
@federated_computation.federated_computation(
computation_types.at_clients(tf.bool)
)
Expand All @@ -137,10 +136,7 @@ def call_federated_min(value):
call_federated_min([False])

def test_federated_min_wrong_placement(self):
with self.assertRaisesRegex(
TypeError, r'.* argument must be a tff.Value placed at CLIENTS'
):

with self.assertRaises(TypeError):
@federated_computation.federated_computation(
computation_types.at_server(tf.int32)
)
Expand Down Expand Up @@ -239,8 +235,7 @@ def call_federated_max(value):
self.assertAllClose(value, [[5.0, 6.0], 8.0])

def test_federated_max_wrong_type(self):
with self.assertRaisesRegex(TypeError, 'Unsupported dtype.'):

with self.assertRaises(ValueError):
@federated_computation.federated_computation(
computation_types.at_clients(tf.bool)
)
Expand All @@ -250,10 +245,7 @@ def call_federated_max(value):
call_federated_max([True, False])

def test_federated_max_wrong_placement(self):
with self.assertRaisesRegex(
TypeError, r'.*argument must be a tff.Value placed at CLIENTS.*'
):

with self.assertRaises(TypeError):
@federated_computation.federated_computation(
computation_types.at_server(tf.float32)
)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/core/impl/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ py_library(
":building_block_factory",
":building_blocks",
":intrinsic_defs",
":tensorflow_computation_factory",
":transformation_utils",
":tree_analysis",
"//tensorflow_federated/python/common_libs:py_typecheck",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,72 @@ def create_federated_mean(
return building_blocks.Call(intrinsic, value)


def create_federated_min(
value: building_blocks.ComputationBuildingBlock,
) -> building_blocks.Call:
r"""Creates a called federated min.
Call
/ \
Intrinsic Comp
Args:
value: A `building_blocks.ComputationBuildingBlock` to use as the value.
Returns:
A `building_blocks.Call`.
Raises:
ValueError: If any of the types do not match.
"""
if not isinstance(value.type_signature, computation_types.FederatedType):
raise ValueError('Expected a federated value.')
result_type = computation_types.FederatedType(
value.type_signature.member,
placements.SERVER,
)
intrinsic_type = computation_types.FunctionType(
type_conversions.type_to_non_all_equal(value.type_signature), result_type
)
intrinsic = building_blocks.Intrinsic(
intrinsic_defs.FEDERATED_MIN.uri, intrinsic_type
)
return building_blocks.Call(intrinsic, value)


def create_federated_max(
value: building_blocks.ComputationBuildingBlock,
) -> building_blocks.Call:
r"""Creates a called federated max.
Call
/ \
Intrinsic Comp
Args:
value: A `building_blocks.ComputationBuildingBlock` to use as the value.
Returns:
A `building_blocks.Call`.
Raises:
ValueError: If any of the types do not match.
"""
if not isinstance(value.type_signature, computation_types.FederatedType):
raise ValueError('Expected a federated value.')
result_type = computation_types.FederatedType(
value.type_signature.member,
placements.SERVER,
)
intrinsic_type = computation_types.FunctionType(
type_conversions.type_to_non_all_equal(value.type_signature), result_type
)
intrinsic = building_blocks.Intrinsic(
intrinsic_defs.FEDERATED_MAX.uri, intrinsic_type
)
return building_blocks.Call(intrinsic, value)


def create_null_federated_secure_modular_sum():
return create_federated_secure_modular_sum(
create_federated_value(building_blocks.Struct([]), placements.CLIENTS),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,30 @@ def test_returns_federated_weighted_mean(self):
self.assertEqual(str(comp.type_signature), 'int32@SERVER')


class CreateFederatedMinTest(absltest.TestCase):

def test_returns_federated_min(self):
value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
value = building_blocks.Data('v', value_type)
comp = building_block_factory.create_federated_min(value)
self.assertEqual(comp.compact_representation(), 'federated_min(v)')
self.assertEqual(
comp.type_signature.compact_representation(), 'int32@SERVER'
)


class CreateFederatedMaxTest(absltest.TestCase):

def test_returns_federated_max(self):
value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
value = building_blocks.Data('v', value_type)
comp = building_block_factory.create_federated_max(value)
self.assertEqual(comp.compact_representation(), 'federated_max(v)')
self.assertEqual(
comp.type_signature.compact_representation(), 'int32@SERVER'
)


class CreateFederatedSecureModularSumTest(absltest.TestCase):

def test_raises_type_error_with_none_value(self):
Expand Down
32 changes: 32 additions & 0 deletions tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,38 @@ def __repr__(self):
aggregation_kind=AggregationKind.DEFAULT,
)

# Computes the min of client values on the server. Only supported for numeric
# types, or nested structures made up of numeric computation_types.
#
# Type signature: {T}@CLIENTS -> T@SERVER
FEDERATED_MIN = IntrinsicDef(
'FEDERATED_MIN',
'federated_min',
computation_types.FunctionType(
parameter=computation_types.at_clients(
computation_types.AbstractType('T')
),
result=computation_types.at_server(computation_types.AbstractType('T')),
),
aggregation_kind=AggregationKind.DEFAULT,
)

# Computes the max of client values on the server. Only supported for numeric
# types, or nested structures made up of numeric computation_types.
#
# Type signature: {T}@CLIENTS -> T@SERVER
FEDERATED_MAX = IntrinsicDef(
'FEDERATED_MAX',
'federated_max',
computation_types.FunctionType(
parameter=computation_types.at_clients(
computation_types.AbstractType('T')
),
result=computation_types.at_server(computation_types.AbstractType('T')),
),
aggregation_kind=AggregationKind.DEFAULT,
)

# Computes the modular sum of client values on the server, securely. Only
# supported for integers or nested structures of integers.
#
Expand Down
Loading

0 comments on commit 6097747

Please sign in to comment.