Skip to content

Commit

Permalink
add unit tests for MiddlewareLoggingPostMux function (to handle Corre…
Browse files Browse the repository at this point in the history
…lationData)
  • Loading branch information
Aldo Fuster Turpin committed May 2, 2024
1 parent b912bfb commit 5d35d92
Show file tree
Hide file tree
Showing 3 changed files with 368 additions and 24 deletions.
67 changes: 43 additions & 24 deletions frontend/middleware_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,52 +76,71 @@ func MiddlewareLogging(w http.ResponseWriter, r *http.Request, next http.Handler
}

func MiddlewareLoggingPostMux(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
var pathValue string

ctx := r.Context()

correlationData := arm.NewCorrelationData(r)
ctx = ContextWithCorrelationData(ctx, correlationData)

setHeaders(w, r, correlationData)

attrs := getLogAttrs(correlationData, r)

logger, err := LoggerFromContext(ctx)
if err != nil {
DefaultLogger().Error(err.Error())
arm.WriteInternalServerError(w)
return
}

handler := logger.Handler()
loggerWithAttrs := slog.New(handler.WithAttrs(attrs))
ctx = ContextWithLogger(ctx, loggerWithAttrs)

reqWithContext := r.WithContext(ctx)

next(w, reqWithContext)
}

// setHeaders writes the appropriate headers in the response writer
// based on the request and the correlation data.
func setHeaders(w http.ResponseWriter, r *http.Request, correlationData *arm.CorrelationData) {
w.Header().Set(arm.HeaderNameRequestID, correlationData.RequestID.String())

if strings.EqualFold(r.Header.Get(arm.HeaderNameReturnClientRequestID), "true") {
returnClientRequestId := r.Header.Get(arm.HeaderNameReturnClientRequestID)
if strings.EqualFold(returnClientRequestId, "true") {
w.Header().Set(arm.HeaderNameClientRequestID, correlationData.ClientRequestID)
}
}

// getLogAttrs returns the appropiate Logging Attributes based on correlationData and a request.
func getLogAttrs(correlationData *arm.CorrelationData, r *http.Request) []slog.Attr {
attrs := []slog.Attr{
slog.String("request_id", correlationData.RequestID.String()),
slog.String("client_request_id", correlationData.ClientRequestID),
slog.String("correlation_request_id", correlationData.CorrelationRequestID),
}

if pathValue = r.PathValue(PathSegmentSubscriptionID); pathValue != "" {
attrs = append(attrs, slog.String("subscription_id", pathValue))
subscriptionID := r.PathValue(PathSegmentSubscriptionID)
if subscriptionID != "" {
attrs = append(attrs, slog.String("subscription_id", subscriptionID))
}

if pathValue = r.PathValue(PathSegmentResourceGroupName); pathValue != "" {
attrs = append(attrs, slog.String("resource_group", pathValue))
resourceGroup := r.PathValue(PathSegmentResourceGroupName)
if resourceGroup != "" {
attrs = append(attrs, slog.String("resource_group", resourceGroup))
}

if pathValue = r.PathValue(PathSegmentResourceName); pathValue != "" {
attrs = append(attrs, slog.String("resource_name", pathValue))
resource_id := fmt.Sprintf("/subscriptions/%s/resourcegroups/%s/providers/%s/%s",
r.PathValue(PathSegmentSubscriptionID),
r.PathValue(PathSegmentResourceGroupName),
api.ResourceType,
pathValue)
attrs = append(attrs, slog.String("resource_id", resource_id))
resourceName := r.PathValue(PathSegmentResourceName)
if resourceName != "" {
attrs = append(attrs, slog.String("resource_name", resourceName))
}

logger, err := LoggerFromContext(ctx)
if err != nil {
DefaultLogger().Error(err.Error())
arm.WriteInternalServerError(w)
return
wholePath := subscriptionID != "" && resourceGroup != "" && resourceName != ""
if wholePath {
format := "/subscriptions/%s/resourcegroups/%s/providers/%s/%s"
resource_id := fmt.Sprintf(format, subscriptionID, resourceGroup, api.ResourceType, resourceName)
attrs = append(attrs, slog.String("resource_id", resource_id))
}

handler := logger.Handler()
ctx = ContextWithLogger(ctx, slog.New(handler.WithAttrs(attrs)))

next(w, r.WithContext(ctx))
return attrs
}
273 changes: 273 additions & 0 deletions frontend/middleware_logging_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
package main

import (
"fmt"
"log/slog"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"

"github.com/Azure/ARO-HCP/internal/api"
"github.com/Azure/ARO-HCP/internal/api/arm"
"github.com/google/uuid"
)

const (
client_request_id = "random_client_request_id"
correlation_request_id string = "random_correlation_request_id"
)

func TestMiddlewareLoggingPostMux(t *testing.T) {
type testCase struct {
name string
header http.Header
}

tt := testCase{
name: "is able to process and forward the values from request's header to context",
header: http.Header{
arm.HeaderNameClientRequestID: []string{client_request_id},
arm.HeaderNameCorrelationRequestID: []string{correlation_request_id},
arm.HeaderNameRequestID: []string{uuid.NewString()},
},
}

t.Run(tt.name, func(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "", nil)
if err != nil {
t.Fatal(err)
}

request.Header = tt.header

// we assume the request carries a logger, we set it explicitly to not fail
ctx := ContextWithLogger(request.Context(), DefaultLogger())
request = request.WithContext(ctx)

next := func(w http.ResponseWriter, r *http.Request) {
request = r // capture modified request
w.WriteHeader(http.StatusOK)
}

writer := httptest.NewRecorder()
MiddlewareLoggingPostMux(writer, request, next)

result, err := CorrelationDataFromContext(request.Context())
if err != nil {
t.Fatal(err)
}

if result.ClientRequestID != client_request_id {
t.Fatalf("ClientRequestID from header was not propperly propagated to requestcontext, expected %v, but got %v",
client_request_id,
result.ClientRequestID)
}
})

}

// ReqPathModifier is an alias to a function that receives a request
// and it should modify its Path value as needed, for testing purposes.
type ReqPathModifier func(req *http.Request)

// noModifyReqfunc is a function that receives a request and does not modify it.
func noModifyReqfunc(req *http.Request) {
// empty on purpose
}

func Test_getLogAttrs(t *testing.T) {
var expectedRequestID = uuid.New()

fakeSubscriptionId := "the_subscription_id"
fakeResourceGroupName := "the_resource_group_name"
fakeResourceName := "the_resource_name"

sampleCorrelationData := &arm.CorrelationData{
RequestID: expectedRequestID,
ClientRequestID: client_request_id,
CorrelationRequestID: correlation_request_id,
RequestTime: time.Now(),
}

commonAttrs := []slog.Attr{
slog.String("request_id", expectedRequestID.String()),
slog.String("client_request_id", client_request_id),
slog.String("correlation_request_id", correlation_request_id),
}

type testCase struct {
name string
correlationData *arm.CorrelationData
req *http.Request
want []slog.Attr
setReqPathValue ReqPathModifier
}

tests := []testCase{
{
name: "handles the common logging attributes",
correlationData: sampleCorrelationData,
req: &http.Request{},
want: commonAttrs,
setReqPathValue: noModifyReqfunc,
},
{
name: "handles the common attributes and the attributes for the subscription_id segment path",
correlationData: sampleCorrelationData,
req: &http.Request{},
want: append(commonAttrs, slog.String("subscription_id", fakeSubscriptionId)),
setReqPathValue: func(req *http.Request) {
req.SetPathValue(PathSegmentSubscriptionID, fakeSubscriptionId)
},
},
{
name: "handles the common attributes and the attributes for the resourcegroupname path",
correlationData: sampleCorrelationData,
req: &http.Request{},
want: append(commonAttrs, slog.String("resource_group", fakeResourceGroupName)),
setReqPathValue: func(req *http.Request) {
req.SetPathValue(PathSegmentResourceGroupName, fakeResourceGroupName)
},
},
{
name: "handles the common attributes and the attributes for the resourcegroupname path",
correlationData: sampleCorrelationData,
req: &http.Request{},
want: append(commonAttrs, slog.String("resource_group", fakeResourceGroupName)),
setReqPathValue: func(req *http.Request) {
req.SetPathValue(PathSegmentResourceGroupName, fakeResourceGroupName)
},
},
{
name: "handles the common attributes and the attributes for the resourcename path, and produces the correct resourceID attribute",
correlationData: sampleCorrelationData,
req: &http.Request{},
want: append(
commonAttrs,
slog.String("subscription_id", fakeSubscriptionId),
slog.String("resource_group", fakeResourceGroupName),
slog.String("resource_name", fakeResourceName),
slog.String(
"resource_id",
fmt.Sprintf(
"/subscriptions/%s/resourcegroups/%s/providers/%s/%s",
fakeSubscriptionId,
fakeResourceGroupName,
api.ResourceType,
fakeResourceName)),
),
setReqPathValue: func(req *http.Request) {
// assuming the PathSegmentResourceName is present in the Path
req.SetPathValue(PathSegmentResourceName, fakeResourceName)

// assuming the PathSegmentSubscriptionID is present in the Path
req.SetPathValue(PathSegmentSubscriptionID, fakeSubscriptionId)

// assuming the PathSegmentResourceGroupName is present in the Path
req.SetPathValue(PathSegmentResourceGroupName, fakeResourceGroupName)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setReqPathValue(tt.req)
got := getLogAttrs(tt.correlationData, tt.req)
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("want %v, but got %v", tt.want, got)
}
})
}
}

func Test_setHeaders(t *testing.T) {
var expectedRequestId = uuid.New()
const expectedClientRequestId = "the_client_request_id"

type testCase struct {
name string
w http.ResponseWriter
r *http.Request
correlationData *arm.CorrelationData
expectedHeaders http.Header
}

tests := []testCase{
{
name: "should set the requestId header to the value of correlation data",
w: &httptest.ResponseRecorder{},
r: &http.Request{},
correlationData: &arm.CorrelationData{RequestID: expectedRequestId},
expectedHeaders: http.Header{
arm.HeaderNameRequestID: []string{expectedRequestId.String()},
},
},
{
name: "should set the clientRequestId header to the value of correlation data when the 'should return client request id' header is true",
w: &httptest.ResponseRecorder{},
r: &http.Request{
Header: http.Header{
arm.HeaderNameReturnClientRequestID: []string{"true"},
},
},
correlationData: &arm.CorrelationData{
RequestID: expectedRequestId,
ClientRequestID: expectedClientRequestId,
},
expectedHeaders: http.Header{
arm.HeaderNameRequestID: []string{expectedRequestId.String()},
arm.HeaderNameClientRequestID: []string{expectedClientRequestId},
},
},
{
name: "should not set the clientRequestId header to the value of correlation data when the 'should return client request id' header is false",
w: &httptest.ResponseRecorder{},
r: &http.Request{
Header: http.Header{
arm.HeaderNameReturnClientRequestID: []string{"false"},
},
},
correlationData: &arm.CorrelationData{
RequestID: expectedRequestId,
ClientRequestID: expectedClientRequestId,
},
expectedHeaders: http.Header{
arm.HeaderNameRequestID: []string{expectedRequestId.String()},
},
},
{
name: "should not set the clientRequestId header to the value from correlation data when header is empty",
w: &httptest.ResponseRecorder{},
r: &http.Request{},
correlationData: &arm.CorrelationData{
RequestID: expectedRequestId,
ClientRequestID: expectedClientRequestId,
},
expectedHeaders: http.Header{
arm.HeaderNameRequestID: []string{expectedRequestId.String()},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setHeaders(tt.w, tt.r, tt.correlationData)
assertAllHeadersAreWritten(t, tt.expectedHeaders, tt.w)
})
}
}

// assertAllHeadersAreWritten asserts that all the headers h are written in w
func assertAllHeadersAreWritten(t *testing.T, h http.Header, w http.ResponseWriter) {
for expectedKey, expectedValues := range h {
valueInHeader := w.Header().Get(expectedKey)
if valueInHeader == "" {
t.Fatalf("header with key %v is not present in response writer\n", expectedKey)
}

if valueInHeader != expectedValues[0] {
t.Fatalf("header with key %v and value %v is different than expected value %v in response writer\n", expectedKey, valueInHeader, expectedValues[0])
}
}
}
Loading

0 comments on commit 5d35d92

Please sign in to comment.