Skip to content

Commit

Permalink
Use correct key to retrieve PerTaskInfo from task_info_map_ when assi…
Browse files Browse the repository at this point in the history
…gning access policy.

PiperOrigin-RevId: 679804935
  • Loading branch information
chunxiangzheng authored and copybara-github committed Sep 28, 2024
1 parent e2e41a0 commit a35e5de
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 6 deletions.
12 changes: 9 additions & 3 deletions fcp/client/http/http_federated_protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,15 @@ HttpFederatedProtocol::HandleMultipleTaskAssignmentsInnerResponse(
std::move((*payloads)->confidential_data_access_policy);
// Store the serialized data access policy in the PerTaskInfo, since
// we need to calculate a hash over it at upload time.
task_info_map_[task_assignment.aggregation_session_id]
.confidential_data_access_policy =
task_assignment.confidential_agg_info->data_access_policy;
if (flags_->create_task_identifier()) {
task_info_map_[task_assignment.task_identifier]
.confidential_data_access_policy =
task_assignment.confidential_agg_info->data_access_policy;
} else {
task_info_map_[task_assignment.aggregation_session_id]
.confidential_data_access_policy =
task_assignment.confidential_agg_info->data_access_policy;
}
}
result.task_assignments[task_assignment.task_name] =
std::move(task_assignment);
Expand Down
85 changes: 82 additions & 3 deletions fcp/client/http/http_federated_protocol_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,10 @@ class HttpFederatedProtocolTest : public ::testing::Test {
}

absl::StatusOr<FederatedProtocol::MultipleTaskAssignments>
RunSuccessfulMultipleTaskAssignments(bool eligibility_eval_enabled = true) {
RunSuccessfulMultipleTaskAssignments(
bool eligibility_eval_enabled = true,
bool enable_confidential_aggregation = false,
std::optional<Resource> confidential_data_access_policy = std::nullopt) {
if (eligibility_eval_enabled) {
std::string report_eet_request_uri =
"https://initial.uri/v1/populations/TEST%2FPOPULATION/"
Expand All @@ -755,6 +758,10 @@ class HttpFederatedProtocolTest : public ::testing::Test {
request.mutable_client_version()->set_version_code(kClientVersion);
request.mutable_resource_capabilities()->add_supported_compression_formats(
ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
if (enable_confidential_aggregation) {
request.mutable_resource_capabilities()
->set_supports_confidential_aggregation(true);
}
for (const auto& task_name : task_names) {
request.add_task_names(task_name);
}
Expand All @@ -771,7 +778,10 @@ class HttpFederatedProtocolTest : public ::testing::Test {
plan_1, checkpoint_1, kFederatedSelectUriTemplate,
kMultiTaskClientSessionId_1, kMultiTaskAggregationSessionId_1,
kMultiTaskId_1, kAggregationTargetUri,
kMinimumClientsInServerVisibleAggregate);
enable_confidential_aggregation
? 0
: kMinimumClientsInServerVisibleAggregate,
confidential_data_access_policy);
Resource plan_2;
std::string plan_uri = "https://fake.uri/plan";
plan_2.set_uri(plan_uri);
Expand All @@ -782,7 +792,10 @@ class HttpFederatedProtocolTest : public ::testing::Test {
plan_2, checkpoint_2, kFederatedSelectUriTemplate,
kMultiTaskClientSessionId_2, kMultiTaskAggregationSessionId_2,
kMultiTaskId_2, kAggregationTargetUri,
kMinimumClientsInServerVisibleAggregate);
enable_confidential_aggregation
? 0
: kMinimumClientsInServerVisibleAggregate,
confidential_data_access_policy);
std::string expected_plan_2 = "expected_plan_2";
std::string expected_checkpoint_2 = "expected_checkpoint_2";

Expand Down Expand Up @@ -2388,6 +2401,72 @@ TEST_F(HttpFederatedProtocolTest,
_);
}

TEST_F(HttpFederatedProtocolTest,
TestMultipleTaskAssignmentsWithConfidentialAggregation) {
EXPECT_CALL(mock_flags_, enable_confidential_aggregation)
.WillRepeatedly(Return(true));
EXPECT_CALL(mock_flags_, create_task_identifier)
.WillRepeatedly(Return(false));

ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
/*eligibility_eval_enabled=*/true,
/*enable_confidential_aggregation=*/true));
std::string serialized_access_policy = "the access policy";
Resource access_policy_resource;
access_policy_resource.mutable_inline_resource()->set_data(
serialized_access_policy);
auto result = RunSuccessfulMultipleTaskAssignments(
/*eligibility_eval_enabled*/ true,
/*enable_confidential_aggregation=*/true,
/*confidential_data_access_policy=*/access_policy_resource);
ASSERT_OK(result);
EXPECT_THAT(result->task_assignments, testing::SizeIs(2));
absl::Cord expected_access_policy(serialized_access_policy);
auto task_assignment_1 = result->task_assignments[kMultiTaskId_1];
ASSERT_OK(task_assignment_1);
EXPECT_EQ(task_assignment_1->confidential_agg_info.value().data_access_policy,
expected_access_policy);

auto task_assignment_2 = result->task_assignments[kMultiTaskId_2];
ASSERT_OK(task_assignment_2);
EXPECT_EQ(task_assignment_2->confidential_agg_info.value().data_access_policy,
expected_access_policy);
}

TEST_F(
HttpFederatedProtocolTest,
TestMultipleTaskAssignmentsWithConfidentialAggregationAndTaskIdentifier) {
EXPECT_CALL(mock_flags_, enable_confidential_aggregation)
.WillRepeatedly(Return(true));
EXPECT_CALL(mock_flags_, create_task_identifier).WillRepeatedly(Return(true));

ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
/*eligibility_eval_enabled=*/true,
/*enable_confidential_aggregation=*/true));
std::string serialized_access_policy = "the access policy";
Resource access_policy_resource;
access_policy_resource.mutable_inline_resource()->set_data(
serialized_access_policy);
auto result = RunSuccessfulMultipleTaskAssignments(
/*eligibility_eval_enabled*/ true,
/*enable_confidential_aggregation=*/true,
/*confidential_data_access_policy=*/access_policy_resource);
ASSERT_OK(result);
EXPECT_THAT(result->task_assignments, testing::SizeIs(2));
absl::Cord expected_access_policy(serialized_access_policy);
auto task_assignment_1 = result->task_assignments[kMultiTaskId_1];
ASSERT_OK(task_assignment_1);
EXPECT_EQ(task_assignment_1->confidential_agg_info.value().data_access_policy,
expected_access_policy);
EXPECT_EQ(task_assignment_1->task_identifier, "task_0");

auto task_assignment_2 = result->task_assignments[kMultiTaskId_2];
ASSERT_OK(task_assignment_2);
EXPECT_EQ(task_assignment_2->confidential_agg_info.value().data_access_policy,
expected_access_policy);
EXPECT_EQ(task_assignment_2->task_identifier, "task_1");
}

// Ensures that polling the Operation returned by a StartTaskAssignmentRequest
// works as expected. This serves mostly as a high-level check. Further
// polling-specific behavior is tested in more detail in
Expand Down

0 comments on commit a35e5de

Please sign in to comment.