Skip to content

Commit

Permalink
Implement Serialize and Deserialize methods for CheckpointAggregator.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631908037
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed May 8, 2024
1 parent 1d2cf43 commit 9c5a51a
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 21 deletions.
11 changes: 11 additions & 0 deletions tensorflow_federated/cc/core/impl/aggregation/protocol/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ proto_library(
deps = ["//tensorflow_federated/cc/core/impl/aggregation/core:tensor_proto"],
)

proto_library(
name = "checkpoint_aggregator_proto",
srcs = ["checkpoint_aggregator.proto"],
)

cc_proto_library(
name = "checkpoint_aggregator_cc_proto",
deps = [":checkpoint_aggregator_proto"],
)

cc_proto_library(
name = "configuration_cc_proto",
visibility = ["//visibility:public"],
Expand Down Expand Up @@ -99,6 +109,7 @@ cc_library(
hdrs = ["checkpoint_aggregator.h"],
visibility = ["//visibility:public"],
deps = [
":checkpoint_aggregator_cc_proto",
":checkpoint_builder",
":checkpoint_parser",
":config_converter",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_aggregator_factory.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_aggregator_registry.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_spec.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.pb.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_builder.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_parser.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/config_converter.h"
Expand Down Expand Up @@ -67,13 +68,50 @@ absl::Status CheckpointAggregator::ValidateConfig(

absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
CheckpointAggregator::Create(const Configuration& configuration) {
return CreateInternal(configuration, nullptr);
}

absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
CheckpointAggregator::Create(const std::vector<Intrinsic>* intrinsics) {
return CreateInternal(intrinsics, nullptr);
}

absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
CheckpointAggregator::Deserialize(const Configuration& configuration,
std::string serialized_state) {
CheckpointAggregatorState aggregator_state;
if (!aggregator_state.ParseFromString(serialized_state)) {
return absl::InvalidArgumentError("Failed to parse serialized state.");
}
return CreateInternal(configuration, &aggregator_state);
}

absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
CheckpointAggregator::Deserialize(const std::vector<Intrinsic>* intrinsics,
std::string serialized_state) {
CheckpointAggregatorState aggregator_state;
if (!aggregator_state.ParseFromString(serialized_state)) {
return absl::InvalidArgumentError("Failed to parse serialized state.");
}
return CreateInternal(intrinsics, &aggregator_state);
}

absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
CheckpointAggregator::CreateInternal(
const Configuration& configuration,
const CheckpointAggregatorState* aggregator_state) {
TFF_ASSIGN_OR_RETURN(std::vector<Intrinsic> intrinsics,
ParseFromConfig(configuration));

std::vector<std::unique_ptr<TensorAggregator>> aggregators;
for (const Intrinsic& intrinsic : intrinsics) {
for (int i = 0; i < intrinsics.size(); ++i) {
const Intrinsic& intrinsic = intrinsics[i];
const std::string* serialized_aggregator = nullptr;
if (aggregator_state != nullptr) {
serialized_aggregator = &aggregator_state->aggregators(i);
}
TFF_ASSIGN_OR_RETURN(std::unique_ptr<TensorAggregator> aggregator,
CreateAggregator(intrinsic));
CreateAggregator(intrinsic, serialized_aggregator));
aggregators.push_back(std::move(aggregator));
}

Expand All @@ -82,11 +120,18 @@ CheckpointAggregator::Create(const Configuration& configuration) {
}

absl::StatusOr<std::unique_ptr<CheckpointAggregator>>
CheckpointAggregator::Create(const std::vector<Intrinsic>* intrinsics) {
CheckpointAggregator::CreateInternal(
const std::vector<Intrinsic>* intrinsics,
const CheckpointAggregatorState* aggregator_state) {
std::vector<std::unique_ptr<TensorAggregator>> aggregators;
for (const Intrinsic& intrinsic : *intrinsics) {
for (int i = 0; i < intrinsics->size(); ++i) {
const Intrinsic& intrinsic = (*intrinsics)[i];
const std::string* serialized_aggregator = nullptr;
if (aggregator_state != nullptr) {
serialized_aggregator = &aggregator_state->aggregators(i);
}
TFF_ASSIGN_OR_RETURN(std::unique_ptr<TensorAggregator> aggregator,
CreateAggregator(intrinsic));
CreateAggregator(intrinsic, serialized_aggregator));
aggregators.push_back(std::move(aggregator));
}

Expand Down Expand Up @@ -201,14 +246,33 @@ absl::Status CheckpointAggregator::Report(

void CheckpointAggregator::Abort() { aggregation_finished_ = true; }

absl::StatusOr<std::string> CheckpointAggregator::Serialize() && {
absl::MutexLock lock(&aggregation_mu_);
if (aggregation_finished_) {
return absl::AbortedError("Aggregation has already been finished.");
}
CheckpointAggregatorState state;
google::protobuf::RepeatedPtrField<std::string>* aggregators_proto =
state.mutable_aggregators();
aggregators_proto->Reserve(aggregators_.size());
for (const auto& aggregator : aggregators_) {
aggregators_proto->Add(std::move(*aggregator).Serialize().value());
}
return state.SerializeAsString();
}

absl::StatusOr<std::unique_ptr<TensorAggregator>>
CheckpointAggregator::CreateAggregator(const Intrinsic& intrinsic) {
CheckpointAggregator::CreateAggregator(
const Intrinsic& intrinsic, const std::string* serialized_aggregator) {
// Resolve the intrinsic_uri to the registered TensorAggregatorFactory.
TFF_ASSIGN_OR_RETURN(const TensorAggregatorFactory* factory,
GetAggregatorFactory(intrinsic.uri));

// Use the factory to create the TensorAggregator instance.
return factory->Create(intrinsic);
if (serialized_aggregator == nullptr) {
return factory->Create(intrinsic);
}
return factory->Deserialize(intrinsic, *serialized_aggregator);
}

std::vector<std::unique_ptr<TensorAggregator>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <atomic>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "absl/base/attributes.h"
Expand All @@ -31,6 +32,7 @@
#include "absl/synchronization/mutex.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/intrinsic.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_aggregator.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.pb.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_builder.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_parser.h"
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/configuration.pb.h"
Expand Down Expand Up @@ -62,6 +64,20 @@ class CheckpointAggregator {
static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> Create(
const std::vector<Intrinsic>* intrinsics ABSL_ATTRIBUTE_LIFETIME_BOUND);

// Creates an instance of CheckpointAggregator based on the given
// configuration and serialized state.
static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> Deserialize(
const Configuration& configuration, std::string serialized_state);

// Creates an instance of CheckpointAggregator based on the given intrinsics
// and serialized state.
// The `intrinsics` are expected to be created using `ParseFromConfig` which
// validates the configuration. CheckpointAggregator does not take any
// ownership, and `intrinsics` must outlive it.
static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> Deserialize(
const std::vector<Intrinsic>* intrinsics ABSL_ATTRIBUTE_LIFETIME_BOUND,
std::string serialized_state);

// Accumulates a checkpoint via nested tensor aggregators. The tensors are
// provided by the CheckpointParser instance.
absl::Status Accumulate(CheckpointParser& checkpoint_parser);
Expand All @@ -75,6 +91,8 @@ class CheckpointAggregator {
// Signal that the aggregation must be aborted and the report can't be
// produced.
void Abort();
// Serialize the internal state of the checkpoint aggregator as a string.
absl::StatusOr<std::string> Serialize() &&;

private:
CheckpointAggregator(
Expand All @@ -85,9 +103,18 @@ class CheckpointAggregator {
std::vector<Intrinsic> intrinsics,
std::vector<std::unique_ptr<TensorAggregator>> aggregators);

// Creates an aggregation intrinsic based on the intrinsic configuration.
// Creates an aggregation intrinsic based on the intrinsic configuration and
// optional serialized state.
static absl::StatusOr<std::unique_ptr<TensorAggregator>> CreateAggregator(
const Intrinsic& intrinsic);
const Intrinsic& intrinsic, const std::string* serialized_aggregator);

static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> CreateInternal(
const Configuration& configuration,
const CheckpointAggregatorState* aggregator_state);

static absl::StatusOr<std::unique_ptr<CheckpointAggregator>> CreateInternal(
const std::vector<Intrinsic>* intrinsics,
const CheckpointAggregatorState* aggregator_state);

// Used by the implementation of Merge.
std::vector<std::unique_ptr<TensorAggregator>> TakeAggregators() &&;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
syntax = "proto3";

package tensorflow_federated.aggregation;

// Internal state representation of a CheckpointAggregator.
message CheckpointAggregatorState {
repeated bytes aggregators = 1;
}
Loading

0 comments on commit 9c5a51a

Please sign in to comment.