Skip to content

Commit

Permalink
Change the output tensor name and output tensors field in PlanResult …
Browse files Browse the repository at this point in the history
…to a map of QuantizedTensors, and move the logic of creating QuantizedTensors into plan engines.

PiperOrigin-RevId: 682117166
  • Loading branch information
chunxiangzheng authored and copybara-github committed Oct 4, 2024
1 parent 0e7d219 commit 20770c2
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 81 deletions.
2 changes: 0 additions & 2 deletions fcp/client/eligibility_decider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ TEST_F(EligibilityDeciderTest, TfCustomPolicyPreparesNeetContextIterator) {

engine::PlanResult plan_result(engine::PlanOutcome::kSuccess,
absl::OkStatus());
plan_result.output_tensors = {};

TaskEligibilityInfo tf_custom_policy_output;
tf_custom_policy_output.set_version(1);
Expand Down Expand Up @@ -590,7 +589,6 @@ TEST_F(EligibilityDeciderTest, TfCustomPolicyParseOutputsLogsNonfatal) {

engine::PlanResult plan_result(engine::PlanOutcome::kSuccess,
absl::OkStatus());
plan_result.output_tensors = {};

auto execution_error = absl::InternalError("cripes!");
EXPECT_CALL(mock_eet_plan_runner_, RunPlan(_))
Expand Down
7 changes: 6 additions & 1 deletion fcp/client/engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,16 @@ cc_library(
deps = [
":engine_cc_proto",
"//fcp/base",
"//fcp/client:federated_protocol",
"//fcp/client:interfaces",
"//fcp/protos:federated_api_cc_proto",
"//fcp/protos:plan_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
Expand All @@ -339,11 +340,15 @@ cc_library(
srcs = ["tensorflow_utils.cc"],
hdrs = ["tensorflow_utils.h"],
deps = [
"//fcp/client:federated_protocol",
"//fcp/protos:federated_api_cc_proto",
"//fcp/protos:plan_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core/platform:tstring",
],
)
Expand Down
9 changes: 4 additions & 5 deletions fcp/client/engine/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "fcp/client/engine/engine.pb.h"
#include "fcp/client/federated_protocol.h"
#include "fcp/client/stats.h"
#include "fcp/protos/federated_api.pb.h"
#include "fcp/protos/plan.pb.h"
#include "tensorflow/core/framework/tensor.h"

namespace fcp {
namespace client {
Expand All @@ -52,10 +53,8 @@ struct PlanResult {

// The outcome of the plan execution.
PlanOutcome outcome;
// Only set if `outcome` is `kSuccess`, otherwise this is empty.
std::vector<tensorflow::Tensor> output_tensors;
// Only set if `outcome` is `kSuccess`, otherwise this is empty.
std::vector<std::string> output_names;
// The secagg tensors from the plan execution.
absl::flat_hash_map<std::string, QuantizedTensor> secagg_tensor_map;
// Only set if 'outcome' is 'kSuccess' and the federated compute wire format
// is enabled, otherwise this is empty.
absl::Cord federated_compute_checkpoint;
Expand Down
9 changes: 7 additions & 2 deletions fcp/client/engine/simple_plan_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,13 @@ PlanResult SimplePlanEngine::RunPlan(
plan_result.task_eligibility_info =
ParseEligibilityEvalPlanOutput(tf_result.value());
} else {
plan_result.output_names = output_names;
plan_result.output_tensors = std::move(tf_result).value();
auto secagg_tensor_map =
CreateQuantizedTensorMap(output_names, *tf_result, tensorflow_spec);
if (!secagg_tensor_map.ok()) {
return PlanResult(PlanOutcome::kTensorflowError,
secagg_tensor_map.status());
}
plan_result.secagg_tensor_map = std::move(*secagg_tensor_map);
}
plan_result.example_stats = {
.example_count = total_example_count,
Expand Down
86 changes: 86 additions & 0 deletions fcp/client/engine/tensorflow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,105 @@
*/
#include "fcp/client/engine/tensorflow_utils.h"

#include <cstdint>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "fcp/client/federated_protocol.h"
#include "fcp/protos/federated_api.pb.h"
#include "fcp/protos/plan.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/struct.pb.h"

namespace fcp::client::engine {

using ::google::internal::federated::plan::TensorflowSpec;
using ::google::internal::federatedml::v2::TaskEligibilityInfo;

template <typename T>
void AddValuesToQuantized(QuantizedTensor* quantized,
const tensorflow::Tensor& tensor) {
auto flat_tensor = tensor.flat<T>();
quantized->values.reserve(quantized->values.size() + flat_tensor.size());
for (int i = 0; i < flat_tensor.size(); i++) {
quantized->values.push_back(flat_tensor(i));
}
}

// Converts a tensorflow::Tensor to a QuantizedTensor. The tensor shape is not
// filled in this method.
absl::StatusOr<QuantizedTensor> TfTensorToQuantizedTensor(
const tensorflow::Tensor& tensor) {
QuantizedTensor quantized;
switch (tensor.dtype()) {
case tensorflow::DT_INT8:
AddValuesToQuantized<int8_t>(&quantized, tensor);
quantized.bitwidth = 7;
break;
case tensorflow::DT_UINT8:
AddValuesToQuantized<uint8_t>(&quantized, tensor);
quantized.bitwidth = 8;
break;
case tensorflow::DT_INT16:
AddValuesToQuantized<int16_t>(&quantized, tensor);
quantized.bitwidth = 15;
break;
case tensorflow::DT_UINT16:
AddValuesToQuantized<uint16_t>(&quantized, tensor);
quantized.bitwidth = 16;
break;
case tensorflow::DT_INT32:
AddValuesToQuantized<int32_t>(&quantized, tensor);
quantized.bitwidth = 31;
break;
case tensorflow::DT_INT64:
AddValuesToQuantized<int64_t>(&quantized, tensor);
quantized.bitwidth = 62;
break;
default:
return absl::InvalidArgumentError(absl::StrCat(
"Tensor of type", tensorflow::DataType_Name(tensor.dtype()),
"could not be converted to quantized value"));
}
return quantized;
}

absl::StatusOr<absl::flat_hash_map<std::string, QuantizedTensor>>
CreateQuantizedTensorMap(const std::vector<std::string>& tensor_names,
const std::vector<tensorflow::Tensor>& tensors,
const TensorflowSpec& tensorflow_spec) {
absl::flat_hash_map<std::string, QuantizedTensor> quantized_tensor_map;
for (int i = 0; i < tensor_names.size(); i++) {
absl::StatusOr<QuantizedTensor> quantized =
TfTensorToQuantizedTensor(tensors[i]);
if (!quantized.ok()) {
return quantized.status();
}
quantized_tensor_map[tensor_names[i]] = std::move(*quantized);
}
// Add dimensions to QuantizedTensors.
for (const tensorflow::TensorSpecProto& tensor_spec :
tensorflow_spec.output_tensor_specs()) {
if (!quantized_tensor_map.contains(tensor_spec.name())) {
return absl::InvalidArgumentError(absl::StrCat(
"Tensor spec not found for tensor name: ", tensor_spec.name()));
}
auto& quantized = quantized_tensor_map[tensor_spec.name()];
for (const tensorflow::TensorShapeProto_Dim& dim :
tensor_spec.shape().dim()) {
quantized.dimensions.push_back(dim.size());
}
}
return quantized_tensor_map;
}

absl::StatusOr<TaskEligibilityInfo> ParseEligibilityEvalPlanOutput(
const std::vector<tensorflow::Tensor>& output_tensors) {
auto output_size = output_tensors.size();
Expand Down
13 changes: 13 additions & 0 deletions fcp/client/engine/tensorflow_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
#ifndef FCP_CLIENT_ENGINE_TENSORFLOW_UTILS_H_
#define FCP_CLIENT_ENGINE_TENSORFLOW_UTILS_H_

#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "fcp/client/federated_protocol.h"
#include "fcp/protos/federated_api.pb.h"
#include "fcp/protos/plan.pb.h"
#include "tensorflow/core/framework/tensor.h"

namespace fcp::client::engine {
Expand All @@ -33,6 +37,15 @@ absl::StatusOr<google::internal::federatedml::v2::TaskEligibilityInfo>
ParseEligibilityEvalPlanOutput(
const std::vector<tensorflow::Tensor>& output_tensors);

// Converts a vector of tensor name and a vector of tensorflow::Tensor to a
// map of tensor name to QuantizedTensor. The shape of the tensor is taken from
// the TensorflowSpec.
absl::StatusOr<absl::flat_hash_map<std::string, QuantizedTensor>>
CreateQuantizedTensorMap(
const std::vector<std::string>& tensor_names,
const std::vector<tensorflow::Tensor>& tensors,
const google::internal::federated::plan::TensorflowSpec& tensorflow_spec);

} // namespace fcp::client::engine

#endif // FCP_CLIENT_ENGINE_TENSORFLOW_UTILS_H_
16 changes: 12 additions & 4 deletions fcp/client/engine/tflite_plan_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,23 @@ namespace {
PlanResult CreatePlanResultFromOutput(
absl::StatusOr<OutputTensors> output, std::atomic<int>* total_example_count,
std::atomic<int64_t>* total_example_size_bytes,
absl::Status example_iterator_status, bool is_eligibility_eval_plan) {
absl::Status example_iterator_status, bool is_eligibility_eval_plan,
const TensorflowSpec& tensorflow_spec) {
switch (output.status().code()) {
case absl::StatusCode::kOk: {
PlanResult plan_result(PlanOutcome::kSuccess, absl::OkStatus());
if (is_eligibility_eval_plan) {
plan_result.task_eligibility_info =
ParseEligibilityEvalPlanOutput(output->output_tensors);
} else {
plan_result.output_names = std::move(output->output_tensor_names);
plan_result.output_tensors = std::move(output->output_tensors);
auto secagg_tensor_map =
CreateQuantizedTensorMap(output->output_tensor_names,
output->output_tensors, tensorflow_spec);
if (!secagg_tensor_map.ok()) {
return PlanResult(PlanOutcome::kTensorflowError,
secagg_tensor_map.status());
}
plan_result.secagg_tensor_map = std::move(*secagg_tensor_map);
}
plan_result.example_stats = {
.example_count = *total_example_count,
Expand Down Expand Up @@ -150,7 +157,8 @@ PlanResult TfLitePlanEngine::RunPlan(
output_names, CreateOptions(flags_), flags_.num_threads_for_tflite());
PlanResult plan_result = CreatePlanResultFromOutput(
std::move(output), &total_example_count, &total_example_size_bytes,
example_iterator_status.GetStatus(), is_eligibility_eval_plan);
example_iterator_status.GetStatus(), is_eligibility_eval_plan,
tensorflow_spec);
return plan_result;
}

Expand Down
79 changes: 12 additions & 67 deletions fcp/client/fl_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,6 @@ using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;

namespace {

template <typename T>
void AddValuesToQuantized(QuantizedTensor* quantized,
const tensorflow::Tensor& tensor) {
auto flat_tensor = tensor.flat<T>();
quantized->values.reserve(quantized->values.size() + flat_tensor.size());
for (int i = 0; i < flat_tensor.size(); i++) {
quantized->values.push_back(flat_tensor(i));
}
}

struct PlanResultAndCheckpointFile {
explicit PlanResultAndCheckpointFile(engine::PlanResult plan_result)
: plan_result(std::move(plan_result)) {}
Expand All @@ -141,63 +131,16 @@ struct PlanResultAndCheckpointFile {
// `tensorflow_spec != nullptr`.
absl::StatusOr<ComputationResults> CreateComputationResults(
const TensorflowSpec* tensorflow_spec,
const PlanResultAndCheckpointFile& plan_result_and_checkpoint_file,
PlanResultAndCheckpointFile& plan_result_and_checkpoint_file,
const Flags* flags) {
const auto& [plan_result, checkpoint_filename] =
plan_result_and_checkpoint_file;
auto& [plan_result, checkpoint_filename] = plan_result_and_checkpoint_file;
if (plan_result.outcome != engine::PlanOutcome::kSuccess) {
return absl::InvalidArgumentError("Computation failed.");
}
ComputationResults computation_results;
if (tensorflow_spec != nullptr) {
for (int i = 0; i < plan_result.output_names.size(); i++) {
QuantizedTensor quantized;
const auto& output_tensor = plan_result.output_tensors[i];
switch (output_tensor.dtype()) {
case tensorflow::DT_INT8:
AddValuesToQuantized<int8_t>(&quantized, output_tensor);
quantized.bitwidth = 7;
break;
case tensorflow::DT_UINT8:
AddValuesToQuantized<uint8_t>(&quantized, output_tensor);
quantized.bitwidth = 8;
break;
case tensorflow::DT_INT16:
AddValuesToQuantized<int16_t>(&quantized, output_tensor);
quantized.bitwidth = 15;
break;
case tensorflow::DT_UINT16:
AddValuesToQuantized<uint16_t>(&quantized, output_tensor);
quantized.bitwidth = 16;
break;
case tensorflow::DT_INT32:
AddValuesToQuantized<int32_t>(&quantized, output_tensor);
quantized.bitwidth = 31;
break;
case tensorflow::DT_INT64:
AddValuesToQuantized<int64_t>(&quantized, output_tensor);
quantized.bitwidth = 62;
break;
default:
return absl::InvalidArgumentError(
absl::StrCat("Tensor of type",
tensorflow::DataType_Name(output_tensor.dtype()),
"could not be converted to quantized value"));
}
computation_results[plan_result.output_names[i]] = std::move(quantized);
}

// Add dimensions to QuantizedTensors.
for (const tensorflow::TensorSpecProto& tensor_spec :
tensorflow_spec->output_tensor_specs()) {
if (computation_results.find(tensor_spec.name()) !=
computation_results.end()) {
for (const tensorflow::TensorShapeProto_Dim& dim :
tensor_spec.shape().dim()) {
std::get<QuantizedTensor>(computation_results[tensor_spec.name()])
.dimensions.push_back(dim.size());
}
}
for (auto& [name, tensor] : plan_result.secagg_tensor_map) {
computation_results[name] = std::move(tensor);
}
}

Expand Down Expand Up @@ -2323,6 +2266,10 @@ FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
if (native_task_eligibility_info.ok()) {
plan_result =
engine::PlanResult(engine::PlanOutcome::kSuccess, absl::OkStatus());
if (native_task_eligibility_info->has_value()) {
result.set_task_eligibility_info(
native_task_eligibility_info->value().SerializeAsString());
}
} else {
plan_result = engine::PlanResult(engine::PlanOutcome::kTensorflowError,
native_task_eligibility_info.status());
Expand All @@ -2333,6 +2280,10 @@ FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
example_iterator_factories, should_abort, log_manager,
opstats_logger.get(), flags, client_plan, checkpoint_input_filename,
timing_config, run_plan_start_time, reference_time);
if (plan_result.task_eligibility_info.ok()) {
result.set_task_eligibility_info(
plan_result.task_eligibility_info->SerializeAsString());
}
}

} else {
Expand All @@ -2349,12 +2300,6 @@ FLRunnerTensorflowSpecResult RunPlanWithTensorflowSpecForTesting(
result.set_outcome(
engine::ConvertPlanOutcomeToPhaseOutcome(plan_result.outcome));
if (plan_result.outcome == engine::PlanOutcome::kSuccess) {
for (int i = 0; i < plan_result.output_names.size(); i++) {
tensorflow::TensorProto output_tensor_proto;
plan_result.output_tensors[i].AsProtoField(&output_tensor_proto);
(*result.mutable_output_tensors())[plan_result.output_names[i]] =
std::move(output_tensor_proto);
}
phase_logger.LogComputationCompleted(
plan_result.example_stats,
// Empty network stats, since no network protocol is actually used in
Expand Down

0 comments on commit 20770c2

Please sign in to comment.