Skip to content

Commit

Permalink
refactor: scalar projection (#510)
Browse files Browse the repository at this point in the history
Signed-off-by: Charles-Edouard Brétéché <[email protected]>
  • Loading branch information
eddycharly authored Sep 22, 2024
1 parent 6397e6a commit 30bb58b
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 99 deletions.
40 changes: 24 additions & 16 deletions pkg/apis/policy/v1alpha1/any.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package v1alpha1

import (
"github.com/kyverno/kyverno-json/pkg/core/projection"
hashutils "github.com/kyverno/kyverno-json/pkg/utils/hash"
"k8s.io/apimachinery/pkg/util/json"
)

Expand All @@ -10,27 +12,18 @@ import (
// +kubebuilder:validation:Type:=""
type Any struct {
_value any
_hash string
}

func NewAny(value any) Any {
return Any{value}
}

func (t *Any) Value() any {
return t._value
}

func (in *Any) DeepCopyInto(out *Any) {
out._value = deepCopy(in._value)
return Any{
_value: value,
_hash: hashutils.Hash(value),
}
}

func (in *Any) DeepCopy() *Any {
if in == nil {
return nil
}
out := new(Any)
in.DeepCopyInto(out)
return out
func (t *Any) Compile(compiler func(string, any, string) (projection.ScalarHandler, error), defaultCompiler string) (projection.ScalarHandler, error) {
return compiler(t._hash, t._value, defaultCompiler)
}

func (a *Any) MarshalJSON() ([]byte, error) {
Expand All @@ -44,5 +37,20 @@ func (a *Any) UnmarshalJSON(data []byte) error {
return err
}
a._value = v
a._hash = hashutils.Hash(a._value)
return nil
}

func (in *Any) DeepCopyInto(out *Any) {
out._value = deepCopy(in._value)
out._hash = in._hash
}

func (in *Any) DeepCopy() *Any {
if in == nil {
return nil
}
out := new(Any)
in.DeepCopyInto(out)
return out
}
26 changes: 11 additions & 15 deletions pkg/apis/policy/v1alpha1/any_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,36 @@ import (
func TestAny_DeepCopyInto(t *testing.T) {
tests := []struct {
name string
in *Any
out *Any
in Any
}{{
name: "nil",
in: &Any{nil},
out: &Any{nil},
in: NewAny(nil),
}, {
name: "int",
in: &Any{42},
out: &Any{nil},
in: NewAny(42),
}, {
name: "string",
in: &Any{"foo"},
out: &Any{nil},
in: NewAny("foo"),
}, {
name: "slice",
in: &Any{[]any{42, "string"}},
out: &Any{nil},
in: NewAny([]any{42, "string"}),
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.in.DeepCopyInto(tt.out)
assert.Equal(t, tt.in, tt.out)
var out Any
tt.in.DeepCopyInto(&out)
assert.Equal(t, tt.in, out)
})
}
{
inner := map[string]any{
"foo": 42,
}
in := Any{map[string]any{"inner": inner}}
in := NewAny(map[string]any{"inner": inner})
out := in.DeepCopy()
inPtr := in.Value().(map[string]any)["inner"].(map[string]any)
inPtr := in._value.(map[string]any)["inner"].(map[string]any)
inPtr["foo"] = 55
outPtr := out.Value().(map[string]any)["inner"].(map[string]any)
outPtr := out._value.(map[string]any)["inner"].(map[string]any)
assert.NotEqual(t, inPtr, outPtr)
}
}
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pkg/commands/jp/query/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func loadInput(cmd *cobra.Command, file string) (any, error) {
}

func evaluate(input any, query string) (any, error) {
result, err := compilers.Execute(query, input, nil, compilers.DefaultCompiler.Jp)
result, err := compilers.Execute(query, input, nil, compilers.DefaultCompilers.Jp)
if err != nil {
if syntaxError, ok := err.(parsing.SyntaxError); ok {
return nil, fmt.Errorf("%s\n%s", syntaxError, syntaxError.HighlightLocation())
Expand Down
2 changes: 1 addition & 1 deletion pkg/commands/scan/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (c *options) run(cmd *cobra.Command, _ []string) error {
}
out.println("Pre processing ...")
for _, preprocessor := range c.preprocessors {
result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompiler.Jp)
result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompilers.Jp)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/assertion/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func parseMap(assertion any, compiler compilers.Compilers, defaultCompiler strin
}
entry := assertions[key]
entry.node = assertion
entry.Projection = projection.Parse(key, compiler, defaultCompiler)
entry.Projection = projection.ParseMapKey(key, compiler, defaultCompiler)
assertions[key] = entry
}
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
Expand Down
12 changes: 6 additions & 6 deletions pkg/core/assertion/assertion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/expression"
tassert "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/util/validation/field"
)

Expand Down Expand Up @@ -49,16 +49,16 @@ func TestAssert(t *testing.T) {
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompiler
compiler := compilers.DefaultCompilers
parsed, err := Parse(tt.assertion, compiler, expression.CompilerJP)
tassert.NoError(t, err)
assert.NoError(t, err)
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
tassert.Error(t, err)
assert.Error(t, err)
} else {
tassert.NoError(t, err)
assert.NoError(t, err)
}
tassert.Equal(t, tt.want, got)
assert.Equal(t, tt.want, got)
})
}
}
43 changes: 1 addition & 42 deletions pkg/core/compilers/compilers.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package compilers

import (
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/compilers/cel"
"github.com/kyverno/kyverno-json/pkg/core/compilers/jp"
"github.com/kyverno/kyverno-json/pkg/core/expression"
"k8s.io/apimachinery/pkg/util/validation/field"
)

var DefaultCompiler = Compilers{
var DefaultCompilers = Compilers{
Jp: jp.NewCompiler(),
Cel: cel.NewCompiler(),
}
Expand All @@ -32,40 +28,3 @@ func (c Compilers) Compiler(compiler string) Compiler {
return c.Jp
}
}

func (c Compilers) NewBinding(path *field.Path, value any, bindings binding.Bindings, template any, compiler string) binding.Binding {
return binding.NewDelegate(
sync.OnceValues(
func() (any, error) {
switch typed := template.(type) {
case string:
expr := expression.Parse(compiler, typed)
if expr.Foreach {
return nil, field.Invalid(path.Child("variable"), typed, "foreach is not supported in context")
}
if expr.Binding != "" {
return nil, field.Invalid(path.Child("variable"), typed, "binding is not supported in context")
}
switch expr.Compiler {
case expression.CompilerJP:
projected, err := Execute(expr.Statement, value, bindings, c.Jp)
if err != nil {
return nil, field.InternalError(path.Child("variable"), err)
}
return projected, nil
case expression.CompilerCEL:
projected, err := Execute(expr.Statement, value, bindings, c.Cel)
if err != nil {
return nil, field.InternalError(path.Child("variable"), err)
}
return projected, nil
default:
return expr.Statement, nil
}
default:
return typed, nil
}
},
),
)
}
48 changes: 44 additions & 4 deletions pkg/core/projection/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import (
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
)

type Handler = func(value any, bindings binding.Bindings) (any, bool, error)
type (
ScalarHandler = func(value any, bindings binding.Bindings) (any, error)
MapKeyHandler = func(value any, bindings binding.Bindings) (any, bool, error)
)

type Info struct {
Foreach bool
Expand All @@ -21,10 +24,10 @@ type Info struct {

type Projection struct {
Info
Handler
Handler MapKeyHandler
}

func Parse(in any, compiler compilers.Compilers, defaultCompiler string) (projection Projection) {
func ParseMapKey(in any, compiler compilers.Compilers, defaultCompiler string) (projection Projection) {
switch typed := in.(type) {
case string:
// 1. if we have a string, parse the expression
Expand All @@ -47,7 +50,7 @@ func Parse(in any, compiler compilers.Compilers, defaultCompiler string) (projec
if err != nil {
return nil, false, err
}
return projected, true, err
return projected, true, nil
}
} else {
projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) {
Expand Down Expand Up @@ -82,3 +85,40 @@ func Parse(in any, compiler compilers.Compilers, defaultCompiler string) (projec
}
return
}

func ParseScalar(in any, compiler compilers.Compilers, defaultCompiler string) (ScalarHandler, error) {
switch typed := in.(type) {
case string:
expr := expression.Parse(defaultCompiler, typed)
if expr.Foreach {
return nil, errors.New("foreach is not supported in scalar projections")
}
if expr.Binding != "" {
return nil, errors.New("binding is not supported in scalar projections")
}
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
compile := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
return func(value any, bindings binding.Bindings) (any, error) {
program, err := compile()
if err != nil {
return nil, err
}
projected, err := program(value, bindings)
if err != nil {
return nil, err
}
return projected, nil
}, nil
} else {
return func(value any, bindings binding.Bindings) (any, error) {
return expr.Statement, nil
}, nil
}
default:
return func(value any, bindings binding.Bindings) (any, error) {
return typed, nil
}, nil
}
}
16 changes: 8 additions & 8 deletions pkg/core/projection/projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/expression"
tassert "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
)

func TestProjection(t *testing.T) {
func TestParseMap(t *testing.T) {
tests := []struct {
name string
key any
Expand Down Expand Up @@ -89,16 +89,16 @@ func TestProjection(t *testing.T) {
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompiler
proj := Parse(tt.key, compiler, expression.CompilerJP)
compiler := compilers.DefaultCompilers
proj := ParseMapKey(tt.key, compiler, expression.CompilerJP)
got, found, err := proj.Handler(tt.value, tt.bindings)
if tt.wantErr {
tassert.Error(t, err)
assert.Error(t, err)
} else {
tassert.NoError(t, err)
assert.NoError(t, err)
}
tassert.Equal(t, tt.wantFound, found)
tassert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantFound, found)
assert.Equal(t, tt.want, got)
})
}
}
27 changes: 24 additions & 3 deletions pkg/json-engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package jsonengine
import (
"context"
"fmt"
"sync"
"time"

"github.com/jmespath-community/go-jmespath/pkg/binding"
jpbinding "github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
Expand Down Expand Up @@ -68,7 +70,7 @@ func New() engine.Engine[Request, Response] {
resource any
bindings jpbinding.Bindings
}
compiler := matching.NewCompiler(compilers.DefaultCompiler, 256)
compiler := matching.NewCompiler(compilers.DefaultCompilers, 256)
ruleEngine := builder.
Function(func(ctx context.Context, r ruleRequest) []RuleResponse {
bindings := r.bindings.Register("$rule", jpbinding.NewBinding(r.rule))
Expand All @@ -82,12 +84,31 @@ func New() engine.Engine[Request, Response] {
// TODO: this doesn't seem to be the right path
var path *field.Path
path = path.Child("context")
for i, entry := range r.rule.Context {
for _, entry := range r.rule.Context {
defaultCompiler := defaultCompiler
if entry.Compiler != nil {
defaultCompiler = string(*entry.Compiler)
}
bindings = bindings.Register("$"+entry.Name, compiler.NewBinding(path.Index(i), r.resource, bindings, entry.Variable.Value(), defaultCompiler))
bindings = func(variable v1alpha1.Any, bindings jpbinding.Bindings) jpbinding.Bindings {
return bindings.Register(
"$"+entry.Name,
binding.NewDelegate(
sync.OnceValues(
func() (any, error) {
handler, err := variable.Compile(compiler.CompileProjection, defaultCompiler)
if err != nil {
return nil, field.InternalError(path.Child("variable"), err)
}
projected, err := handler(r.resource, bindings)
if err != nil {
return nil, field.InternalError(path.Child("variable"), err)
}
return projected, nil
},
),
),
)
}(entry.Variable, bindings)
}
identifier := ""
if r.rule.Identifier != "" {
Expand Down
Loading

0 comments on commit 30bb58b

Please sign in to comment.