Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce visibility struct tags #89

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ func (c *HcpOpenShiftClusterResource) ValidateStatic(current api.VersionedHCPOpe
"Content validation failed on multiple fields")
cloudError.Details = make([]arm.CloudErrorBody, 0)

// FIXME Validate visibility tags by comparing the new cluster (c) to current.
errorDetails = api.ValidateVisibility(c, current, clusterStructTagMap, updating)
if errorDetails != nil {
cloudError.Details = append(cloudError.Details, errorDetails...)
}

c.Normalize(&normalized)

Expand Down
16 changes: 15 additions & 1 deletion internal/api/v20240610preview/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func (v version) String() string {
return "2024-06-10-preview"
}

var validate = api.NewValidator()
var (
validate = api.NewValidator()
clusterStructTagMap = api.NewStructTagMap[api.HCPOpenShiftCluster]()
)

func EnumValidateTag[S ~string](values ...S) string {
s := make([]string, len(values))
Expand All @@ -28,6 +31,17 @@ func EnumValidateTag[S ~string](values ...S) string {
}

func init() {
// NOTE: If future versions of the API expand field visibility, such as
// a field with @visibility("read","create") becoming updatable,
// then earlier versions of the API will need to override their
// StructTagMap to maintain the original visibility flags. This
// is where such overrides should happen, along with a comment
// about what changed and when. For example:
//
// // This field became updatable in version YYYY-MM-DD.
// clusterStructTagMap["Properties.Spec.FieldName"] = reflect.StructTag("visibility:\"read create\"")
//

api.Register(version{})

// Register enum type validations
Expand Down
292 changes: 292 additions & 0 deletions internal/api/visibility.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
package api

// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.

import (
"fmt"
"reflect"
"strings"

"github.com/Azure/ARO-HCP/internal/api/arm"
)

// Property visibility meanings:
// https://azure.github.io/typespec-azure/docs/howtos/ARM/resource-type#property-visibility-and-other-constraints
//
// Field mutability guidelines:
// https://github.com/microsoft/api-guidelines/blob/vNext/azure/Guidelines.md#resource-schema--field-mutability

const VisibilityStructTagKey = "visibility"

// VisibilityFlags holds a visibility struct tag value as bit flags.
type VisibilityFlags uint8

const (
VisibilityRead VisibilityFlags = 1 << iota
VisibilityCreate
VisibilityUpdate

// option flags
VisibilityCaseInsensitive

VisibilityDefault = VisibilityRead | VisibilityCreate | VisibilityUpdate
)

func (f VisibilityFlags) ReadOnly() bool {
return f&(VisibilityRead|VisibilityCreate|VisibilityUpdate) == VisibilityRead
}

func (f VisibilityFlags) CanUpdate() bool {
return f&VisibilityUpdate != 0
}

func (f VisibilityFlags) CaseInsensitive() bool {
return f&VisibilityCaseInsensitive != 0
}

func (f VisibilityFlags) String() string {
s := []string{}
if f&VisibilityRead != 0 {
s = append(s, "read")
}
if f&VisibilityCreate != 0 {
s = append(s, "create")
}
if f&VisibilityUpdate != 0 {
s = append(s, "update")
}
if f&VisibilityCaseInsensitive != 0 {
s = append(s, "nocase")
}
return strings.Join(s, " ")
}

func GetVisibilityFlags(tag reflect.StructTag) (VisibilityFlags, bool) {
var flags VisibilityFlags

tagValue, ok := tag.Lookup(VisibilityStructTagKey)
if ok {
for _, v := range strings.Fields(tagValue) {
switch strings.ToLower(v) {
case "read":
flags |= VisibilityRead
case "create":
flags |= VisibilityCreate
case "update":
flags |= VisibilityUpdate
case "nocase":
flags |= VisibilityCaseInsensitive
default:
panic(fmt.Sprintf("Unknown visibility tag value '%s'", v))
}
}
}

return flags, ok
}

func join(ns, name string) string {
res := ns
if res != "" {
res += "."
}
res += name
return res
}

type StructTagMap map[string]reflect.StructTag

func buildStructTagMap(structTagMap StructTagMap, t reflect.Type, path string) {
switch t.Kind() {
case reflect.Pointer, reflect.Slice:
buildStructTagMap(structTagMap, t.Elem(), path)

case reflect.Struct:
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
subpath := join(path, field.Name)

if len(field.Tag) > 0 {
structTagMap[subpath] = field.Tag
}

buildStructTagMap(structTagMap, field.Type, subpath)
}
}
}

// NewStructTagMap returns a mapping of dot-separated struct field names
// to struct tags for the given type. Each versioned API should create
// its own visibiilty map for tracked resource types.
//
// Note: This assumes field names for internal and versioned structs are
// identical where visibility is explicitly specified. If some divergence
// emerges, one workaround could be to pass a field name override map.
func NewStructTagMap[T any]() StructTagMap {
structTagMap := StructTagMap{}
buildStructTagMap(structTagMap, reflect.TypeFor[T](), "")
return structTagMap
}

type validateVisibility struct {
structTagMap StructTagMap
updating bool
errs []arm.CloudErrorBody
}

// ValidateVisibility compares the new value (newVal) to the current value
// (curVal) and returns any violations of visibility restrictions as defined
// by structTagMap.
func ValidateVisibility(newVal, curVal interface{}, structTagMap StructTagMap, updating bool) []arm.CloudErrorBody {
vv := validateVisibility{
structTagMap: structTagMap,
updating: updating,
}
vv.recurse(reflect.ValueOf(newVal), reflect.ValueOf(curVal), "", "", "", VisibilityDefault)
return vv.errs
}

// mapKey is a lookup key for the StructTagMap. It DOES NOT include subscripts
// for arrays, maps or slices since all elements are the same type.
//
// namespace is the struct field path up to but not including the field being
// evaluated, analogous to path.Dir. It DOES include subscripts for arrays,
// maps and slices since its purpose is for error reporting.
//
// fieldname is the current field being evaluated, analgous to path.Base. It
// also includes subscripts for arrays, maps and slices when evaluating their
// immediate elements.
func (vv *validateVisibility) recurse(newVal, curVal reflect.Value, mapKey, namespace, fieldname string, implicitVisibility VisibilityFlags) {
flags, ok := GetVisibilityFlags(vv.structTagMap[mapKey])
if !ok {
flags = implicitVisibility
}

if newVal.Type() != curVal.Type() {
panic(fmt.Sprintf("%s: value types differ (%s vs %s)", join(namespace, fieldname), newVal.Type().Name(), curVal.Type().Name()))
}

// Generated API structs are all pointer fields. A nil pointer in
// the incoming request (newVal) means the value is absent, which
// is always acceptable for visibility validation.
switch newVal.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
if newVal.IsNil() {
return
}
}

switch newVal.Kind() {
case reflect.Bool:
if newVal.Bool() != curVal.Bool() {
vv.checkFlags(flags, namespace, fieldname)
}

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if newVal.Int() != curVal.Int() {
vv.checkFlags(flags, namespace, fieldname)
}

case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if newVal.Uint() != curVal.Uint() {
vv.checkFlags(flags, namespace, fieldname)
}

case reflect.Float32, reflect.Float64:
if newVal.Float() != curVal.Float() {
vv.checkFlags(flags, namespace, fieldname)
}

case reflect.Complex64, reflect.Complex128:
if newVal.Complex() != curVal.Complex() {
vv.checkFlags(flags, namespace, fieldname)
}

case reflect.String:
if flags.CaseInsensitive() {
if !strings.EqualFold(newVal.String(), curVal.String()) {
vv.checkFlags(flags, namespace, fieldname)
}
} else {
if newVal.String() != curVal.String() {
vv.checkFlags(flags, namespace, fieldname)
}
}

case reflect.Slice:
// We already know that newVal is not nil.
if curVal.IsNil() {
vv.checkFlags(flags, namespace, fieldname)
return
}

fallthrough

case reflect.Array:
if newVal.Len() != curVal.Len() {
vv.checkFlags(flags, namespace, fieldname)
} else {
for i := 0; i < min(newVal.Len(), curVal.Len()); i++ {
subscript := fmt.Sprintf("[%d]", i)
vv.recurse(newVal.Index(i), curVal.Index(i), mapKey, namespace, fieldname+subscript, flags)
}
}

case reflect.Interface, reflect.Pointer:
// We already know that newVal is not nil.
if curVal.IsNil() {
vv.checkFlags(flags, namespace, fieldname)
} else {
vv.recurse(newVal.Elem(), curVal.Elem(), mapKey, namespace, fieldname, flags)
}

case reflect.Map:
// We already know that newVal is not nil.
if curVal.IsNil() || newVal.Len() != curVal.Len() {
vv.checkFlags(flags, namespace, fieldname)
} else {
iter := newVal.MapRange()
for iter.Next() {
k := iter.Key()

subscript := fmt.Sprintf("[%q]", k.Interface())
if curVal.MapIndex(k).IsValid() {
vv.recurse(newVal.MapIndex(k), curVal.MapIndex(k), mapKey, namespace, fieldname+subscript, flags)
} else {
vv.checkFlags(flags, namespace, fieldname+subscript)
}
}
}

case reflect.Struct:
for i := 0; i < newVal.NumField(); i++ {
structField := newVal.Type().Field(i)
mapKeyNext := join(mapKey, structField.Name)
namespaceNext := join(namespace, fieldname)
fieldnameNext := GetJSONTagName(vv.structTagMap[mapKeyNext])
if fieldnameNext == "" {
fieldnameNext = structField.Name
}
vv.recurse(newVal.Field(i), curVal.Field(i), mapKeyNext, namespaceNext, fieldnameNext, flags)
}
}
}

func (vv *validateVisibility) checkFlags(flags VisibilityFlags, namespace, fieldname string) {
if flags.ReadOnly() {
vv.errs = append(vv.errs,
arm.CloudErrorBody{
Code: arm.CloudErrorCodeInvalidRequestContent,
Message: fmt.Sprintf("Field '%s' is read-only", fieldname),
Target: join(namespace, fieldname),
})
} else if vv.updating && !flags.CanUpdate() {
vv.errs = append(vv.errs,
arm.CloudErrorBody{
Code: arm.CloudErrorCodeInvalidRequestContent,
Message: fmt.Sprintf("Field '%s' cannot be updated", fieldname),
Target: join(namespace, fieldname),
})
}
}
Loading
Loading