diff --git a/fcp/client/http/http_federated_protocol.cc b/fcp/client/http/http_federated_protocol.cc index 56cc6d3..6f6ae9b 100644 --- a/fcp/client/http/http_federated_protocol.cc +++ b/fcp/client/http/http_federated_protocol.cc @@ -1135,9 +1135,15 @@ HttpFederatedProtocol::HandleMultipleTaskAssignmentsInnerResponse( std::move((*payloads)->confidential_data_access_policy); // Store the serialized data access policy in the PerTaskInfo, since // we need to calculate a hash over it at upload time. - task_info_map_[task_assignment.aggregation_session_id] - .confidential_data_access_policy = - task_assignment.confidential_agg_info->data_access_policy; + if (flags_->create_task_identifier()) { + task_info_map_[task_assignment.task_identifier] + .confidential_data_access_policy = + task_assignment.confidential_agg_info->data_access_policy; + } else { + task_info_map_[task_assignment.aggregation_session_id] + .confidential_data_access_policy = + task_assignment.confidential_agg_info->data_access_policy; + } } result.task_assignments[task_assignment.task_name] = std::move(task_assignment); diff --git a/fcp/client/http/http_federated_protocol_test.cc b/fcp/client/http/http_federated_protocol_test.cc index 1b61f5b..657c7c1 100644 --- a/fcp/client/http/http_federated_protocol_test.cc +++ b/fcp/client/http/http_federated_protocol_test.cc @@ -739,7 +739,10 @@ class HttpFederatedProtocolTest : public ::testing::Test { } absl::StatusOr - RunSuccessfulMultipleTaskAssignments(bool eligibility_eval_enabled = true) { + RunSuccessfulMultipleTaskAssignments( + bool eligibility_eval_enabled = true, + bool enable_confidential_aggregation = false, + std::optional confidential_data_access_policy = std::nullopt) { if (eligibility_eval_enabled) { std::string report_eet_request_uri = "https://initial.uri/v1/populations/TEST%2FPOPULATION/" @@ -755,6 +758,10 @@ class HttpFederatedProtocolTest : public ::testing::Test { request.mutable_client_version()->set_version_code(kClientVersion); request.mutable_resource_capabilities()->add_supported_compression_formats( ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP); + if (enable_confidential_aggregation) { + request.mutable_resource_capabilities() + ->set_supports_confidential_aggregation(true); + } for (const auto& task_name : task_names) { request.add_task_names(task_name); } @@ -771,7 +778,10 @@ class HttpFederatedProtocolTest : public ::testing::Test { plan_1, checkpoint_1, kFederatedSelectUriTemplate, kMultiTaskClientSessionId_1, kMultiTaskAggregationSessionId_1, kMultiTaskId_1, kAggregationTargetUri, - kMinimumClientsInServerVisibleAggregate); + enable_confidential_aggregation + ? 0 + : kMinimumClientsInServerVisibleAggregate, + confidential_data_access_policy); Resource plan_2; std::string plan_uri = "https://fake.uri/plan"; plan_2.set_uri(plan_uri); @@ -782,7 +792,10 @@ class HttpFederatedProtocolTest : public ::testing::Test { plan_2, checkpoint_2, kFederatedSelectUriTemplate, kMultiTaskClientSessionId_2, kMultiTaskAggregationSessionId_2, kMultiTaskId_2, kAggregationTargetUri, - kMinimumClientsInServerVisibleAggregate); + enable_confidential_aggregation + ? 0 + : kMinimumClientsInServerVisibleAggregate, + confidential_data_access_policy); std::string expected_plan_2 = "expected_plan_2"; std::string expected_checkpoint_2 = "expected_checkpoint_2"; @@ -2388,6 +2401,72 @@ TEST_F(HttpFederatedProtocolTest, _); } +TEST_F(HttpFederatedProtocolTest, + TestMultipleTaskAssignmentsWithConfidentialAggregation) { + EXPECT_CALL(mock_flags_, enable_confidential_aggregation) + .WillRepeatedly(Return(true)); + EXPECT_CALL(mock_flags_, create_task_identifier) + .WillRepeatedly(Return(false)); + + ASSERT_OK(RunSuccessfulEligibilityEvalCheckin( + /*eligibility_eval_enabled=*/true, + /*enable_confidential_aggregation=*/true)); + std::string serialized_access_policy = "the access policy"; + Resource access_policy_resource; + access_policy_resource.mutable_inline_resource()->set_data( + serialized_access_policy); + auto result = RunSuccessfulMultipleTaskAssignments( + /*eligibility_eval_enabled*/ true, + /*enable_confidential_aggregation=*/true, + /*confidential_data_access_policy=*/access_policy_resource); + ASSERT_OK(result); + EXPECT_THAT(result->task_assignments, testing::SizeIs(2)); + absl::Cord expected_access_policy(serialized_access_policy); + auto task_assignment_1 = result->task_assignments[kMultiTaskId_1]; + ASSERT_OK(task_assignment_1); + EXPECT_EQ(task_assignment_1->confidential_agg_info.value().data_access_policy, + expected_access_policy); + + auto task_assignment_2 = result->task_assignments[kMultiTaskId_2]; + ASSERT_OK(task_assignment_2); + EXPECT_EQ(task_assignment_2->confidential_agg_info.value().data_access_policy, + expected_access_policy); +} + +TEST_F( + HttpFederatedProtocolTest, + TestMultipleTaskAssignmentsWithConfidentialAggregationAndTaskIdentifier) { + EXPECT_CALL(mock_flags_, enable_confidential_aggregation) + .WillRepeatedly(Return(true)); + EXPECT_CALL(mock_flags_, create_task_identifier).WillRepeatedly(Return(true)); + + ASSERT_OK(RunSuccessfulEligibilityEvalCheckin( + /*eligibility_eval_enabled=*/true, + /*enable_confidential_aggregation=*/true)); + std::string serialized_access_policy = "the access policy"; + Resource access_policy_resource; + access_policy_resource.mutable_inline_resource()->set_data( + serialized_access_policy); + auto result = RunSuccessfulMultipleTaskAssignments( + /*eligibility_eval_enabled*/ true, + /*enable_confidential_aggregation=*/true, + /*confidential_data_access_policy=*/access_policy_resource); + ASSERT_OK(result); + EXPECT_THAT(result->task_assignments, testing::SizeIs(2)); + absl::Cord expected_access_policy(serialized_access_policy); + auto task_assignment_1 = result->task_assignments[kMultiTaskId_1]; + ASSERT_OK(task_assignment_1); + EXPECT_EQ(task_assignment_1->confidential_agg_info.value().data_access_policy, + expected_access_policy); + EXPECT_EQ(task_assignment_1->task_identifier, "task_0"); + + auto task_assignment_2 = result->task_assignments[kMultiTaskId_2]; + ASSERT_OK(task_assignment_2); + EXPECT_EQ(task_assignment_2->confidential_agg_info.value().data_access_policy, + expected_access_policy); + EXPECT_EQ(task_assignment_2->task_identifier, "task_1"); +} + // Ensures that polling the Operation returned by a StartTaskAssignmentRequest // works as expected. This serves mostly as a high-level check. Further // polling-specific behavior is tested in more detail in