From 71207f326a479cbe24444070b02a3dcfeab261f5 Mon Sep 17 00:00:00 2001 From: Nathan Baulch Date: Mon, 15 Apr 2024 14:41:29 +1000 Subject: [PATCH] feat: expose server implementation on transport --- api/metadata/metadata_http.pb.go | 2 ++ cmd/protoc-gen-go-http/httpTemplate.tpl | 1 + .../testdata/helloworld/helloworld_http.pb.go | 1 + middleware/auth/jwt/jwt_test.go | 4 ++++ middleware/circuitbreaker/circuitbreaker_test.go | 4 ++++ middleware/logging/logging_test.go | 4 ++++ middleware/metadata/metadata_test.go | 1 + middleware/selector/selector_test.go | 4 ++++ middleware/tracing/tracing_test.go | 1 + transport/grpc/interceptor.go | 2 ++ transport/grpc/transport.go | 6 ++++++ transport/grpc/transport_test.go | 8 ++++++++ transport/http/transport.go | 15 +++++++++++++++ transport/http/transport_test.go | 10 ++++++++++ transport/transport.go | 2 ++ transport/transport_test.go | 5 +++++ 16 files changed, 70 insertions(+) diff --git a/api/metadata/metadata_http.pb.go b/api/metadata/metadata_http.pb.go index 91636b48d07..4a2eed3393f 100644 --- a/api/metadata/metadata_http.pb.go +++ b/api/metadata/metadata_http.pb.go @@ -34,6 +34,7 @@ func _Metadata_ListServices0_HTTP_Handler(srv MetadataHTTPServer) func(ctx http. if err := ctx.BindQuery(&in); err != nil { return err } + http.SetServer(ctx, srv) http.SetOperation(ctx, "/kratos.api.Metadata/ListServices") h := ctx.Middleware(func(ctx context.Context, req interface{}) (interface{}, error) { return srv.ListServices(ctx, req.(*ListServicesRequest)) @@ -56,6 +57,7 @@ func _Metadata_GetServiceDesc0_HTTP_Handler(srv MetadataHTTPServer) func(ctx htt if err := ctx.BindVars(&in); err != nil { return err } + http.SetServer(ctx, srv) http.SetOperation(ctx, "/kratos.api.Metadata/GetServiceDesc") h := ctx.Middleware(func(ctx context.Context, req interface{}) (interface{}, error) { return srv.GetServiceDesc(ctx, req.(*GetServiceDescRequest)) diff --git a/cmd/protoc-gen-go-http/httpTemplate.tpl b/cmd/protoc-gen-go-http/httpTemplate.tpl index ec8477f92a7..ad46d5dc2f3 100644 --- a/cmd/protoc-gen-go-http/httpTemplate.tpl +++ b/cmd/protoc-gen-go-http/httpTemplate.tpl @@ -38,6 +38,7 @@ func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) fu return err } {{- end}} + http.SetServer(ctx, srv) http.SetOperation(ctx,Operation{{$svrType}}{{.OriginalName}}) h := ctx.Middleware(func(ctx context.Context, req interface{}) (interface{}, error) { return srv.{{.Name}}(ctx, req.(*{{.Request}})) diff --git a/internal/testdata/helloworld/helloworld_http.pb.go b/internal/testdata/helloworld/helloworld_http.pb.go index b77f02ec260..60695d75e59 100644 --- a/internal/testdata/helloworld/helloworld_http.pb.go +++ b/internal/testdata/helloworld/helloworld_http.pb.go @@ -37,6 +37,7 @@ func _Greeter_SayHello0_HTTP_Handler(srv GreeterHTTPServer) func(ctx http.Contex if err := ctx.BindVars(&in); err != nil { return err } + http.SetServer(ctx, srv) http.SetOperation(ctx, OperationGreeterSayHello) h := ctx.Middleware(func(ctx context.Context, req interface{}) (interface{}, error) { return srv.SayHello(ctx, req.(*HelloRequest)) diff --git a/middleware/auth/jwt/jwt_test.go b/middleware/auth/jwt/jwt_test.go index 5f2a3ba5ec4..815d4258044 100644 --- a/middleware/auth/jwt/jwt_test.go +++ b/middleware/auth/jwt/jwt_test.go @@ -61,6 +61,10 @@ func (tr *Transport) Endpoint() string { return tr.endpoint } +func (tr *Transport) Server() interface{} { + return nil +} + func (tr *Transport) Operation() string { return tr.operation } diff --git a/middleware/circuitbreaker/circuitbreaker_test.go b/middleware/circuitbreaker/circuitbreaker_test.go index 3b27da1a3c8..0d4c7279758 100644 --- a/middleware/circuitbreaker/circuitbreaker_test.go +++ b/middleware/circuitbreaker/circuitbreaker_test.go @@ -28,6 +28,10 @@ func (tr *transportMock) Endpoint() string { return tr.endpoint } +func (tr *transportMock) Server() interface{} { + return nil +} + func (tr *transportMock) Operation() string { return tr.operation } diff --git a/middleware/logging/logging_test.go b/middleware/logging/logging_test.go index 1e9495188b6..d06679d5a66 100644 --- a/middleware/logging/logging_test.go +++ b/middleware/logging/logging_test.go @@ -27,6 +27,10 @@ func (tr *Transport) Endpoint() string { return tr.endpoint } +func (tr *Transport) Server() interface{} { + return nil +} + func (tr *Transport) Operation() string { return tr.operation } diff --git a/middleware/metadata/metadata_test.go b/middleware/metadata/metadata_test.go index 9e9436d75b2..a1eff7770f4 100644 --- a/middleware/metadata/metadata_test.go +++ b/middleware/metadata/metadata_test.go @@ -37,6 +37,7 @@ type testTransport struct{ header headerCarrier } func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP } func (tr *testTransport) Endpoint() string { return "" } +func (tr *testTransport) Server() interface{} { return nil } func (tr *testTransport) Operation() string { return "" } func (tr *testTransport) RequestHeader() transport.Header { return tr.header } func (tr *testTransport) ReplyHeader() transport.Header { return tr.header } diff --git a/middleware/selector/selector_test.go b/middleware/selector/selector_test.go index 0835ef84414..e91c74fa560 100644 --- a/middleware/selector/selector_test.go +++ b/middleware/selector/selector_test.go @@ -27,6 +27,10 @@ func (tr *Transport) Endpoint() string { return tr.endpoint } +func (tr *Transport) Server() interface{} { + return nil +} + func (tr *Transport) Operation() string { return tr.operation } diff --git a/middleware/tracing/tracing_test.go b/middleware/tracing/tracing_test.go index 72740bafa24..3c26a631921 100644 --- a/middleware/tracing/tracing_test.go +++ b/middleware/tracing/tracing_test.go @@ -58,6 +58,7 @@ type mockTransport struct { func (tr *mockTransport) Kind() transport.Kind { return tr.kind } func (tr *mockTransport) Endpoint() string { return tr.endpoint } +func (tr *mockTransport) Server() interface{} { return nil } func (tr *mockTransport) Operation() string { return tr.operation } func (tr *mockTransport) RequestHeader() transport.Header { return tr.header } func (tr *mockTransport) ReplyHeader() transport.Header { return tr.header } diff --git a/transport/grpc/interceptor.go b/transport/grpc/interceptor.go index 6cc331547c6..38392a516c1 100644 --- a/transport/grpc/interceptor.go +++ b/transport/grpc/interceptor.go @@ -20,6 +20,7 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { replyHeader := grpcmd.MD{} tr := &Transport{ operation: info.FullMethod, + server: info.Server, reqHeader: headerCarrier(md), replyHeader: headerCarrier(replyHeader), } @@ -71,6 +72,7 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { replyHeader := grpcmd.MD{} ctx = transport.NewServerContext(ctx, &Transport{ endpoint: s.endpoint.String(), + server: srv, operation: info.FullMethod, reqHeader: headerCarrier(md), replyHeader: headerCarrier(replyHeader), diff --git a/transport/grpc/transport.go b/transport/grpc/transport.go index 56e21a86365..1e9e83fbb75 100644 --- a/transport/grpc/transport.go +++ b/transport/grpc/transport.go @@ -12,6 +12,7 @@ var _ transport.Transporter = (*Transport)(nil) // Transport is a gRPC transport. type Transport struct { endpoint string + server interface{} operation string reqHeader headerCarrier replyHeader headerCarrier @@ -28,6 +29,11 @@ func (tr *Transport) Endpoint() string { return tr.endpoint } +// Server returns the transport server. +func (tr *Transport) Server() interface{} { + return tr.server +} + // Operation returns the transport operation. func (tr *Transport) Operation() string { return tr.operation diff --git a/transport/grpc/transport_test.go b/transport/grpc/transport_test.go index 270243f052e..d3a66097d17 100644 --- a/transport/grpc/transport_test.go +++ b/transport/grpc/transport_test.go @@ -23,6 +23,14 @@ func TestTransport_Endpoint(t *testing.T) { } } +func TestTransport_Server(t *testing.T) { + v := struct{}{} + o := &Transport{server: v} + if !reflect.DeepEqual(v, o.Server()) { + t.Errorf("expect %v, got %v", v, o.Server()) + } +} + func TestTransport_Operation(t *testing.T) { v := "hello" o := &Transport{operation: v} diff --git a/transport/http/transport.go b/transport/http/transport.go index 0400ea833a1..757fb75be2c 100644 --- a/transport/http/transport.go +++ b/transport/http/transport.go @@ -18,6 +18,7 @@ type Transporter interface { // Transport is an HTTP transport. type Transport struct { + server interface{} endpoint string operation string reqHeader headerCarrier @@ -37,6 +38,11 @@ func (tr *Transport) Endpoint() string { return tr.endpoint } +// Server returns the transport server. +func (tr *Transport) Server() interface{} { + return tr.server +} + // Operation returns the transport operation. func (tr *Transport) Operation() string { return tr.operation @@ -71,6 +77,15 @@ func SetOperation(ctx context.Context, op string) { } } +// SetServer sets the transport server. +func SetServer(ctx context.Context, srv interface{}) { + if tr, ok := transport.FromServerContext(ctx); ok { + if tr, ok := tr.(*Transport); ok { + tr.server = srv + } + } +} + // SetCookie adds a Set-Cookie header to the provided [ResponseWriter]'s headers. // The provided cookie must have a valid Name. Invalid cookies may be // silently dropped. diff --git a/transport/http/transport_test.go b/transport/http/transport_test.go index 01095c7ea26..6c245fd6152 100644 --- a/transport/http/transport_test.go +++ b/transport/http/transport_test.go @@ -84,6 +84,16 @@ func TestHeaderCarrier_Keys(t *testing.T) { } } +func TestSetServer(t *testing.T) { + tr := &Transport{} + ctx := transport.NewServerContext(context.Background(), tr) + srv := struct{}{} + SetServer(ctx, srv) + if !reflect.DeepEqual(tr.server, srv) { + t.Errorf("expect %v, got %v", srv, tr.server) + } +} + func TestSetOperation(t *testing.T) { tr := &Transport{} ctx := transport.NewServerContext(context.Background(), tr) diff --git a/transport/transport.go b/transport/transport.go index c1a5396f16f..3afe17ac8b7 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -42,6 +42,8 @@ type Transporter interface { // Server Transport: grpc://127.0.0.1:9000 // Client Transport: discovery:///provider-demo Endpoint() string + // Server return server implementation + Server() interface{} // Operation Service full method selector generated by protobuf // example: /helloworld.Greeter/SayHello Operation() string diff --git a/transport/transport_test.go b/transport/transport_test.go index cba7f2ae15f..0d23b1a39c5 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -22,6 +22,11 @@ func (tr *mockTransport) Endpoint() string { return tr.endpoint } +// Server returns the transport server. +func (tr *mockTransport) Server() interface{} { + return nil +} + // Operation returns the transport operation. func (tr *mockTransport) Operation() string { return tr.operation