Skip to content

Commit

Permalink
refactor: add context functions for better encapsulation
Browse files Browse the repository at this point in the history
  • Loading branch information
Aldo Fuster Turpin committed Apr 24, 2024
1 parent d5615c4 commit e3a5202
Show file tree
Hide file tree
Showing 16 changed files with 370 additions and 75 deletions.
11 changes: 0 additions & 11 deletions frontend/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,10 @@ package main
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.

type contextKey int

const (
// APIVersionKey is the request parameter name for the API version.
APIVersionKey = "api-version"

// Keys for request-scoped data in http.Request contexts
ContextKeyOriginalPath contextKey = iota
ContextKeyBody
ContextKeyLogger
ContextKeyVersion
ContextKeyCorrelationData
ContextKeySystemData
ContextKeySubscriptionState

// Wildcard path segment names for request multiplexing, must be lowercase as we lowercase the request URL pattern when registering handlers
PageSegmentLocation = "location"
PathSegmentSubscriptionID = "subscriptionid"
Expand Down
97 changes: 97 additions & 0 deletions frontend/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package main

import (
"context"
"fmt"
"log/slog"

"github.com/Azure/ARO-HCP/internal/api"
"github.com/Azure/ARO-HCP/internal/api/arm"
)

type ContextError struct {
got any
}

func (c *ContextError) Error() string {
return fmt.Sprintf(
"error retrieving value from context, value obtained was '%v' and type obtained was '%T'",
c.got,
c.got)
}

type contextKey int

const (
// Keys for request-scoped data in http.Request contexts
contextKeyOriginalPath contextKey = iota
contextKeyBody
contextKeyLogger
contextKeyVersion
contextKeyCorrelationData
contextKeySystemData
contextKeySubscriptionState
)

func ContextWithOriginalPath(ctx context.Context, originalPath string) context.Context {
return context.WithValue(ctx, contextKeyOriginalPath, originalPath)
}

func OriginalPathFromContext(ctx context.Context) (string, bool) {
originalPath, ok := ctx.Value(contextKeyOriginalPath).(string)
return originalPath, ok
}

func ContextWithBody(ctx context.Context, body []byte) context.Context {
return context.WithValue(ctx, contextKeyBody, body)
}

func BodyFromContext(ctx context.Context) ([]byte, bool) {
body, ok := ctx.Value(contextKeyBody).([]byte)
return body, ok
}

func ContextWithLogger(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, contextKeyLogger, logger)
}

func LoggerFromContext(ctx context.Context) (*slog.Logger, bool) {
logger, ok := ctx.Value(contextKeyLogger).(*slog.Logger)
return logger, ok
}

func ContextWithVersion(ctx context.Context, version api.Version) context.Context {
return context.WithValue(ctx, contextKeyVersion, version)
}

func VersionFromContext(ctx context.Context) (api.Version, bool) {
version, ok := ctx.Value(contextKeyVersion).(api.Version)
return version, ok
}

func ContextWithCorrelationData(ctx context.Context, correlationData *arm.CorrelationData) context.Context {
return context.WithValue(ctx, contextKeyCorrelationData, correlationData)
}

func CorrelationDataFromContext(ctx context.Context) (*arm.CorrelationData, bool) {
correlationData, ok := ctx.Value(contextKeyCorrelationData).(*arm.CorrelationData)
return correlationData, ok
}

func ContextWithSystemData(ctx context.Context, systemData *arm.SystemData) context.Context {
return context.WithValue(ctx, contextKeySystemData, systemData)
}

func SystemDataFromContext(ctx context.Context) (*arm.SystemData, bool) {
systemData, ok := ctx.Value(contextKeySystemData).(*arm.SystemData)
return systemData, ok
}

func ContextWithSubscriptionState(ctx context.Context, subscriptionState arm.RegistrationState) context.Context {
return context.WithValue(ctx, contextKeySubscriptionState, subscriptionState)
}

func SubscriptionStateFromContext(ctx context.Context) (arm.RegistrationState, bool) {
subscriptionState, ok := ctx.Value(contextKeySubscriptionState).(arm.RegistrationState)
return subscriptionState, ok
}
194 changes: 175 additions & 19 deletions frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func NewFrontend(logger *slog.Logger, listener net.Listener, emitter metrics.Emi
server: http.Server{
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
BaseContext: func(net.Listener) context.Context {
return context.WithValue(context.Background(), ContextKeyLogger, logger)
return ContextWithLogger(context.Background(), logger)
},
},
cache: *NewCache(),
Expand Down Expand Up @@ -174,8 +174,25 @@ func (f *Frontend) HealthzReady(writer http.ResponseWriter, request *http.Reques

func (f *Frontend) ArmResourceListBySubscription(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourceListBySubscription", versionedInterface))

Expand All @@ -184,8 +201,25 @@ func (f *Frontend) ArmResourceListBySubscription(writer http.ResponseWriter, req

func (f *Frontend) ArmResourceListByLocation(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourceListByLocation", versionedInterface))

Expand All @@ -194,8 +228,25 @@ func (f *Frontend) ArmResourceListByLocation(writer http.ResponseWriter, request

func (f *Frontend) ArmResourceListByResourceGroup(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourceListByResourceGroup", versionedInterface))

Expand All @@ -204,8 +255,26 @@ func (f *Frontend) ArmResourceListByResourceGroup(writer http.ResponseWriter, re

func (f *Frontend) ArmResourceRead(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourceRead", versionedInterface))

// URL path is already lowercased by middleware.
Expand All @@ -231,16 +300,42 @@ func (f *Frontend) ArmResourceRead(writer http.ResponseWriter, request *http.Req

func (f *Frontend) ArmResourceCreateOrUpdate(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourceCreateOrUpdate", versionedInterface))

// URL path is already lowercased by middleware.
resourceID := request.URL.Path
_, updating := f.cache.GetCluster(resourceID)
newCluster := api.NewDefaultHCPOpenShiftCluster()
body := ctx.Value(ContextKeyBody).([]byte)
body, ok := BodyFromContext(ctx)
if !ok {
err := &ContextError{
got: body,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

err := versionedInterface.UnmarshalHCPOpenShiftCluster(body, newCluster, request.Method, updating)
if err != nil {
f.logger.Error(err.Error())
Expand All @@ -266,8 +361,25 @@ func (f *Frontend) ArmResourceCreateOrUpdate(writer http.ResponseWriter, request

func (f *Frontend) ArmResourcePatch(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourcePatch", versionedInterface))

Expand All @@ -276,8 +388,26 @@ func (f *Frontend) ArmResourcePatch(writer http.ResponseWriter, request *http.Re

func (f *Frontend) ArmResourceDelete(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourceDelete", versionedInterface))

// URL path is already lowercased by middleware.
Expand All @@ -294,8 +424,25 @@ func (f *Frontend) ArmResourceDelete(writer http.ResponseWriter, request *http.R

func (f *Frontend) ArmResourceAction(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
logger := ctx.Value(ContextKeyLogger).(*slog.Logger)
versionedInterface := ctx.Value(ContextKeyVersion).(api.Version)
logger, ok := LoggerFromContext(ctx)
if !ok {
err := &ContextError{
got: logger,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

versionedInterface, ok := VersionFromContext(ctx)
if !ok {
err := &ContextError{
got: versionedInterface,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

logger.Info(fmt.Sprintf("%s: ArmResourceAction", versionedInterface))

Expand All @@ -305,7 +452,16 @@ func (f *Frontend) ArmResourceAction(writer http.ResponseWriter, request *http.R
func (f *Frontend) ArmSubscriptionAction(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()

body := ctx.Value(ContextKeyBody).([]byte)
body, ok := BodyFromContext(ctx)
if !ok {
err := &ContextError{
got: body,
}
f.logger.Error(err.Error())
arm.WriteInternalServerError(writer)
return
}

var subscription arm.Subscription
err := json.Unmarshal(body, &subscription)
if err != nil {
Expand Down
Loading

0 comments on commit e3a5202

Please sign in to comment.