diff --git a/cc/client/grpc_client_cli.cc b/cc/client/grpc_client_cli.cc index aceec04a77b..7aad3174394 100644 --- a/cc/client/grpc_client_cli.cc +++ b/cc/client/grpc_client_cli.cc @@ -61,8 +61,8 @@ int main(int argc, char* argv[]) { // Create Oak Client. LOG(INFO) << "creating Oak Client"; - std::unique_ptr transport = std::make_unique( - GrpcStreamingTransport(std::move(channel_reader_writer))); + std::unique_ptr transport = + std::make_unique(std::move(channel_reader_writer)); InsecureAttestationVerifier verifier = InsecureAttestationVerifier(); absl::StatusOr> oak_client = OakClient::Create(std::move(transport), verifier); diff --git a/cc/transport/BUILD b/cc/transport/BUILD index 75101aa1a17..50489b6d405 100644 --- a/cc/transport/BUILD +++ b/cc/transport/BUILD @@ -31,6 +31,7 @@ cc_library( cc_library( name = "grpc_streaming_transport", + srcs = ["grpc_streaming_transport.cc"], hdrs = ["grpc_streaming_transport.h"], deps = [ ":transport", @@ -38,6 +39,8 @@ cc_library( "//oak_remote_attestation/proto/v1:service_streaming_cc_grpc", "//oak_remote_attestation/proto/v1:service_streaming_cc_proto", "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], diff --git a/cc/transport/grpc_streaming_transport.cc b/cc/transport/grpc_streaming_transport.cc new file mode 100644 index 00000000000..5df90ec2b2e --- /dev/null +++ b/cc/transport/grpc_streaming_transport.cc @@ -0,0 +1,118 @@ +/* + * Copyright 2023 The Project Oak Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cc/transport/grpc_streaming_transport.h" + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/grpcpp.h" + +namespace oak::transport { + +namespace { +using ::oak::session::v1::AttestationBundle; +using ::oak::session::v1::GetPublicKeyRequest; +using ::oak::session::v1::InvokeRequest; +using ::oak::session::v1::RequestWrapper; +using ::oak::session::v1::ResponseWrapper; +} // namespace + +absl::StatusOr GrpcStreamingTransport::GetEvidence() { + // Create request. + RequestWrapper request; + GetPublicKeyRequest get_public_key_request; + *request.mutable_get_public_key_request() = get_public_key_request; + + // Send request. + auto response = Send(request); + if (!response.ok()) { + return response.status(); + } + + // Process response. + switch (response->response_case()) { + case ResponseWrapper::kGetPublicKeyResponseFieldNumber: + return response->get_public_key_response().attestation_bundle(); + case ResponseWrapper::kInvokeResponseFieldNumber: + return absl::InternalError("received InvokeResponse instead of GetPublicKeyResponse"); + case ResponseWrapper::RESPONSE_NOT_SET: + default: + return absl::InternalError("received unsupported response: " + response->DebugString()); + } +} + +absl::StatusOr GrpcStreamingTransport::Invoke(absl::string_view request_bytes) { + // Create request. + RequestWrapper request; + InvokeRequest* invoke_request = request.mutable_invoke_request(); + invoke_request->set_encrypted_body(request_bytes); + + // Send request. + auto response = Send(request); + if (!response.ok()) { + return response.status(); + } + + // Process response. + switch (response->response_case()) { + case ResponseWrapper::kGetPublicKeyResponseFieldNumber: + return absl::InternalError("received GetPublicKeyResponse instead of InvokeResponse"); + case ResponseWrapper::kInvokeResponseFieldNumber: + return response->invoke_response().encrypted_body(); + case ResponseWrapper::RESPONSE_NOT_SET: + default: + return absl::InternalError("received unsupported response: " + response->DebugString()); + } +} + +GrpcStreamingTransport::~GrpcStreamingTransport() { + absl::Status status = Close(); + if (!status.ok()) { + LOG(WARNING) << "couldn't stop gRPC stream: " << status.message(); + } +} + +absl::StatusOr GrpcStreamingTransport::Send(const RequestWrapper& request) { + // Send a request. + if (!channel_reader_writer_->Write(request)) { + return absl::InternalError("couldn't send request"); + } + + // Receive a response. + ResponseWrapper response; + if (!channel_reader_writer_->Read(&response)) { + return absl::InternalError("couldn't receive response"); + } + return response; +} + +absl::Status GrpcStreamingTransport::Close() { + if (!channel_reader_writer_->WritesDone()) { + return absl::InternalError("couldn't close writing stream"); + } + ::grpc::Status status = channel_reader_writer_->Finish(); + if (!status.ok()) { + return absl::InternalError("couldn't close reading stream: " + status.error_message()); + } + return absl::OkStatus(); +} + +} // namespace oak::transport diff --git a/cc/transport/grpc_streaming_transport.h b/cc/transport/grpc_streaming_transport.h index 31d7665496e..f1095aecf68 100644 --- a/cc/transport/grpc_streaming_transport.h +++ b/cc/transport/grpc_streaming_transport.h @@ -19,13 +19,10 @@ #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "cc/transport/transport.h" -#include "grpcpp/channel.h" -#include "grpcpp/client_context.h" -#include "grpcpp/create_channel.h" -#include "grpcpp/grpcpp.h" #include "oak_remote_attestation/proto/v1/messages.pb.h" #include "oak_remote_attestation/proto/v1/service_streaming.grpc.pb.h" #include "oak_remote_attestation/proto/v1/service_streaming.pb.h" @@ -40,53 +37,10 @@ class GrpcStreamingTransport : public TransportWrapper { channel_reader_writer) : channel_reader_writer_(std::move(channel_reader_writer)) {} - absl::StatusOr<::oak::session::v1::AttestationBundle> GetEvidence() override { - // Create request. - ::oak::session::v1::RequestWrapper request; - ::oak::session::v1::GetPublicKeyRequest get_public_key_request; - *request.mutable_get_public_key_request() = get_public_key_request; + absl::StatusOr<::oak::session::v1::AttestationBundle> GetEvidence() override; + absl::StatusOr Invoke(absl::string_view request_bytes) override; - // Send request. - auto response = Send(request); - if (!response.ok()) { - return response.status(); - } - - // Process response. - switch (response->response_case()) { - case ::oak::session::v1::ResponseWrapper::kGetPublicKeyResponseFieldNumber: - return response->get_public_key_response().attestation_bundle(); - case ::oak::session::v1::ResponseWrapper::kInvokeResponseFieldNumber: - return absl::InternalError("received InvokeResponse instead of GetPublicKeyResponse"); - case ::oak::session::v1::ResponseWrapper::RESPONSE_NOT_SET: - default: - return absl::InternalError("received unsupported response: " + response->DebugString()); - } - } - - absl::StatusOr Invoke(absl::string_view request_bytes) override { - // Create request. - ::oak::session::v1::RequestWrapper request; - ::oak::session::v1::InvokeRequest* invoke_request = request.mutable_invoke_request(); - invoke_request->set_encrypted_body(request_bytes); - - // Send request. - auto response = Send(request); - if (!response.ok()) { - return response.status(); - } - - // Process response. - switch (response->response_case()) { - case ::oak::session::v1::ResponseWrapper::kGetPublicKeyResponseFieldNumber: - return absl::InternalError("received GetPublicKeyResponse instead of InvokeResponse"); - case ::oak::session::v1::ResponseWrapper::kInvokeResponseFieldNumber: - return response->invoke_response().encrypted_body(); - case ::oak::session::v1::ResponseWrapper::RESPONSE_NOT_SET: - default: - return absl::InternalError("received unsupported response: " + response->DebugString()); - } - } + ~GrpcStreamingTransport() override; private: std::unique_ptr<::grpc::ClientReaderWriter<::oak::session::v1::RequestWrapper, @@ -94,21 +48,8 @@ class GrpcStreamingTransport : public TransportWrapper { channel_reader_writer_; absl::StatusOr<::oak::session::v1::ResponseWrapper> Send( - const ::oak::session::v1::RequestWrapper& request) { - // Send a request. - channel_reader_writer_->Write(request); - channel_reader_writer_->WritesDone(); - - // Receive a response. - ::oak::session::v1::ResponseWrapper response; - channel_reader_writer_->Read(&response); - ::grpc::Status status = channel_reader_writer_->Finish(); - if (status.ok()) { - return response; - } else { - return absl::InternalError("couldn't send request: " + status.error_message()); - } - } + const ::oak::session::v1::RequestWrapper& request); + absl::Status Close(); }; } // namespace oak::transport diff --git a/cc/transport/transport.h b/cc/transport/transport.h index 8a58e145ab8..49fce94a770 100644 --- a/cc/transport/transport.h +++ b/cc/transport/transport.h @@ -44,7 +44,10 @@ class Transport { }; // Wrapper for `EvidenceProvider` and `Transport` abstract classes. -class TransportWrapper : public EvidenceProvider, public Transport {}; +class TransportWrapper : public EvidenceProvider, public Transport { + public: + virtual ~TransportWrapper() = default; +}; } // namespace oak::transport diff --git a/oak_functions_launcher/tests/integration_test.rs b/oak_functions_launcher/tests/integration_test.rs index 784231abe2f..30355cbcfc9 100644 --- a/oak_functions_launcher/tests/integration_test.rs +++ b/oak_functions_launcher/tests/integration_test.rs @@ -126,6 +126,7 @@ async fn test_launcher_weather_lookup_virtual() { r#"{"temperature_degrees_celsius":29}"# ); + // TODO(#4177): Check response in the integration test. // Run Java client via Bazel. let status = tokio::process::Command::new("bazel") .arg("run") @@ -140,6 +141,23 @@ async fn test_launcher_weather_lookup_virtual() { .expect("failed to wait for bazel"); eprintln!("bazel status: {:?}", status); assert!(status.success()); + + // TODO(#4177): Check response in the integration test. + // Run C++ client via Bazel. + let status = tokio::process::Command::new("bazel") + .arg("run") + .arg("//cc/client:cli") + .arg("--") + .arg(format!("--address=localhost:{port}")) + .arg("--request={\"lat\":0,\"lng\":0}") + .current_dir(workspace_path(&[])) + .spawn() + .expect("failed to spawn bazel") + .wait() + .await + .expect("failed to wait for bazel"); + eprintln!("bazel status: {:?}", status); + assert!(status.success()); } #[tokio::test]