Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused aggregators from tff.learning. #4942

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ and this project adheres to
### Removed

* `tff.types.tensorflow_to_type`, this function is no longer used.
* `tff.learning.dp_aggregator` removed. Prefer using the class methods on
`tff.aggregators.DifferentiallyPrivateFactory`.
* `tff.learning.ddp_secure_aggregator` and `tff.learning.secure_aggregator`
removed.

## Release 0.88.0

Expand Down
4 changes: 0 additions & 4 deletions tensorflow_federated/python/learning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ py_library(
name = "model_update_aggregator",
srcs = ["model_update_aggregator.py"],
deps = [
"//tensorflow_federated/python/aggregators:differential_privacy",
"//tensorflow_federated/python/aggregators:distributed_dp",
"//tensorflow_federated/python/aggregators:encoded",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/aggregators:mean",
"//tensorflow_federated/python/aggregators:quantile_estimation",
Expand All @@ -88,7 +85,6 @@ py_test(
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/test:static_assert",
],
)

Expand Down
4 changes: 0 additions & 4 deletions tensorflow_federated/python/learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,4 @@
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements_with_mixed_dtype
from tensorflow_federated.python.learning.loop_builder import LoopImplementation
from tensorflow_federated.python.learning.model_update_aggregator import compression_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import ddp_secure_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import dp_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import robust_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import secure_aggregator
9 changes: 0 additions & 9 deletions tensorflow_federated/python/learning/algorithms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,9 @@ py_cpu_gpu_test(
deps = [
":fed_avg",
"//tensorflow_federated/python/aggregators:factory_utils",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_examples",
"//tensorflow_federated/python/learning/models:test_models",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
Expand Down Expand Up @@ -118,7 +115,6 @@ py_cpu_gpu_test(
shard_count = 10,
deps = [
":fed_avg_with_optimizer_schedule",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
Expand Down Expand Up @@ -164,10 +160,8 @@ py_cpu_gpu_test(
":fed_prox",
"//tensorflow_federated/python/aggregators:factory_utils",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_examples",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/models:test_models",
Expand Down Expand Up @@ -330,7 +324,6 @@ py_cpu_gpu_test(
deps = [
":fed_sgd",
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_test_utils",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
Expand Down Expand Up @@ -429,11 +422,9 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/metrics:counters",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:keras_utils",
Expand Down
66 changes: 3 additions & 63 deletions tensorflow_federated/python/learning/algorithms/fed_avg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,25 @@
from absl.testing import parameterized

from tensorflow_federated.python.aggregators import factory_utils
from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_avg
from tensorflow_federated.python.learning.metrics import aggregator
from tensorflow_federated.python.learning.models import model_examples
from tensorflow_federated.python.learning.models import test_models
from tensorflow_federated.python.learning.optimizers import sgdm


class FedAvgTest(parameterized.TestCase):
"""Tests construction of the FedAvg training process."""

@parameterized.product(
optimizer_fn=[
sgdm.build_sgdm(learning_rate=0.1),
],
aggregation_factory=[
model_update_aggregator.robust_aggregator,
model_update_aggregator.compression_aggregator,
model_update_aggregator.secure_aggregator,
],
)
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
def test_construction_calls_model_fn(self):
# Assert that the process building does not call `model_fn` too many times.
# `model_fn` can potentially be expensive (loading weights, processing, etc
# ).
mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
fed_avg.build_weighted_fed_avg(
model_fn=mock_model_fn,
client_optimizer_fn=optimizer_fn,
model_aggregator=aggregation_factory(),
client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
model_aggregator=model_update_aggregator.robust_aggregator(),
)
self.assertEqual(mock_model_fn.call_count, 3)

Expand Down Expand Up @@ -125,34 +112,6 @@ def test_unweighted_fed_avg_raises_on_weighted_aggregator(self):
model_aggregator=model_aggregator,
)

def test_weighted_fed_avg_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_avg.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=True
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)

def test_unweighted_fed_avg_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_avg.build_unweighted_fed_avg(
model_fn,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=False
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


class FunctionalFedAvgTest(parameterized.TestCase):
"""Tests construction of the FedAvg training process."""
Expand All @@ -167,25 +126,6 @@ def test_raises_on_non_callable_or_functional_model(self, constructor):
model_fn=0, client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1)
)

@parameterized.named_parameters(
('weighted', fed_avg.build_weighted_fed_avg),
('unweighted', fed_avg.build_unweighted_fed_avg),
)
def test_weighted_fed_avg_with_only_secure_aggregation(self, constructor):
model = test_models.build_functional_linear_regression()
learning_process = constructor(
model_fn=model,
client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
server_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=constructor is fed_avg.build_weighted_fed_avg
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from absl.testing import parameterized
import tensorflow as tf

from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_avg_with_optimizer_schedule
Expand All @@ -30,17 +29,7 @@

class ClientScheduledFedAvgTest(parameterized.TestCase):

@parameterized.product(
optimizer_fn=[
lambda x: sgdm.build_sgdm(learning_rate=x),
],
aggregation_factory=[
model_update_aggregator.robust_aggregator,
model_update_aggregator.compression_aggregator,
model_update_aggregator.secure_aggregator,
],
)
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
def test_construction_calls_model_fn(self):
# Assert that the process building does not call `model_fn` too many times.
# `model_fn` can potentially be expensive (loading weights, processing, etc
# ).
Expand All @@ -49,8 +38,8 @@ def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule(
model_fn=mock_model_fn,
client_learning_rate_fn=learning_rate_fn,
client_optimizer_fn=optimizer_fn,
model_aggregator=aggregation_factory(),
client_optimizer_fn=lambda lr: sgdm.build_sgdm(learning_rate=lr),
model_aggregator=model_update_aggregator.robust_aggregator(),
)
self.assertEqual(mock_model_fn.call_count, 3)

Expand Down Expand Up @@ -143,21 +132,6 @@ def test_raises_on_non_callable_model_fn(self):
client_optimizer_fn=lambda _: sgdm.build_sgdm(),
)

def test_construction_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule(
model_fn,
client_learning_rate_fn=lambda x: 0.5,
client_optimizer_fn=lambda x: sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=True
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)

def test_measurements_include_client_learning_rate(self):
client_work = fed_avg_with_optimizer_schedule.build_scheduled_client_work(
model_fn=model_examples.LinearRegression,
Expand Down
48 changes: 3 additions & 45 deletions tensorflow_federated/python/learning/algorithms/fed_prox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

from tensorflow_federated.python.aggregators import factory_utils
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_prox
from tensorflow_federated.python.learning.metrics import aggregator
from tensorflow_federated.python.learning.models import model_examples
from tensorflow_federated.python.learning.models import model_weights
from tensorflow_federated.python.learning.models import test_models
Expand All @@ -34,26 +32,16 @@
class FedProxConstructionTest(parameterized.TestCase):
"""Tests construction of the FedProx training process."""

@parameterized.product(
optimizer_fn=[
sgdm.build_sgdm(learning_rate=0.1),
],
aggregation_factory=[
model_update_aggregator.robust_aggregator,
model_update_aggregator.compression_aggregator,
model_update_aggregator.secure_aggregator,
],
)
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
def test_construction_calls_model_fn(self):
# Assert that the process building does not call `model_fn` too many times.
# `model_fn` can potentially be expensive (loading weights, processing, etc
# ).
mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
fed_prox.build_weighted_fed_prox(
model_fn=mock_model_fn,
proximal_strength=1.0,
client_optimizer_fn=optimizer_fn,
model_aggregator=aggregation_factory(),
client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
model_aggregator=model_update_aggregator.robust_aggregator(),
)
self.assertEqual(mock_model_fn.call_count, 3)

Expand Down Expand Up @@ -160,36 +148,6 @@ def test_unweighted_fed_avg_raises_on_weighted_aggregator(self):
model_aggregator=model_aggregator,
)

def test_weighted_fed_prox_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_prox.build_weighted_fed_prox(
model_fn,
proximal_strength=1.0,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=True
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)

def test_unweighted_fed_prox_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_prox.build_unweighted_fed_prox(
model_fn,
proximal_strength=1.0,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=False
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


if __name__ == '__main__':
absltest.main()
28 changes: 0 additions & 28 deletions tensorflow_federated/python/learning/algorithms/fed_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tensorflow as tf

from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils
from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_sgd
Expand Down Expand Up @@ -160,11 +159,6 @@ def test_client_tf_dataset_reduce_fn(self, loop_implementation, mock_method):

@parameterized.named_parameters(
('robust_aggregator', model_update_aggregator.robust_aggregator),
(
'compression_aggregator',
model_update_aggregator.compression_aggregator,
),
('secure_aggreagtor', model_update_aggregator.secure_aggregator),
)
def test_construction_calls_model_fn(self, aggregation_factory):
# Assert that the process building does not call `model_fn` too many times.
Expand All @@ -177,17 +171,6 @@ def test_construction_calls_model_fn(self, aggregation_factory):
# TODO: b/186451541 - reduce the number of calls to model_fn.
self.assertEqual(mock_model_fn.call_count, 3)

def test_no_unsecure_aggregation_with_secure_aggregator(self):
model_fn = model_examples.LinearRegression
learning_process = fed_sgd.build_fed_sgd(
model_fn,
model_aggregator=model_update_aggregator.secure_aggregator(),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


class FunctionalFederatedSgdTest(tf.test.TestCase, parameterized.TestCase):

Expand Down Expand Up @@ -276,17 +259,6 @@ def test_build_functional_fed_sgd_succeeds(self):
model = _build_functional_model()
fed_sgd.build_fed_sgd(model_fn=model)

def test_no_unsecure_aggregation_with_secure_aggregator(self):
model = _build_functional_model()
learning_process = fed_sgd.build_fed_sgd(
model,
model_aggregator=model_update_aggregator.secure_aggregator(),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


if __name__ == '__main__':
tf.test.main()
Loading