From 30bb58b7226c201e74346a612c818c6e52f6b8c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Charles-Edouard=20Br=C3=A9t=C3=A9ch=C3=A9?= Date: Sun, 22 Sep 2024 23:36:20 +0200 Subject: [PATCH] refactor: scalar projection (#510) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Charles-Edouard Brétéché --- pkg/apis/policy/v1alpha1/any.go | 40 +++++++++------- pkg/apis/policy/v1alpha1/any_test.go | 26 +++++----- .../v1alpha1/{engine.go => compiler.go} | 0 pkg/commands/jp/query/command.go | 2 +- pkg/commands/scan/options.go | 2 +- pkg/core/assertion/assertion.go | 2 +- pkg/core/assertion/assertion_test.go | 12 ++--- pkg/core/compilers/compilers.go | 43 +---------------- pkg/core/projection/projection.go | 48 +++++++++++++++++-- pkg/core/projection/projection_test.go | 16 +++---- pkg/json-engine/engine.go | 27 +++++++++-- pkg/matching/compiler.go | 6 +++ pkg/server/playground/handler.go | 2 +- pkg/server/scan/handler.go | 2 +- 14 files changed, 129 insertions(+), 99 deletions(-) rename pkg/apis/policy/v1alpha1/{engine.go => compiler.go} (100%) diff --git a/pkg/apis/policy/v1alpha1/any.go b/pkg/apis/policy/v1alpha1/any.go index 0fe20218..bea3c96d 100644 --- a/pkg/apis/policy/v1alpha1/any.go +++ b/pkg/apis/policy/v1alpha1/any.go @@ -1,6 +1,8 @@ package v1alpha1 import ( + "github.com/kyverno/kyverno-json/pkg/core/projection" + hashutils "github.com/kyverno/kyverno-json/pkg/utils/hash" "k8s.io/apimachinery/pkg/util/json" ) @@ -10,27 +12,18 @@ import ( // +kubebuilder:validation:Type:="" type Any struct { _value any + _hash string } func NewAny(value any) Any { - return Any{value} -} - -func (t *Any) Value() any { - return t._value -} - -func (in *Any) DeepCopyInto(out *Any) { - out._value = deepCopy(in._value) + return Any{ + _value: value, + _hash: hashutils.Hash(value), + } } -func (in *Any) DeepCopy() *Any { - if in == nil { - return nil - } - out := new(Any) - in.DeepCopyInto(out) - return out +func (t *Any) Compile(compiler func(string, any, string) (projection.ScalarHandler, error), defaultCompiler string) (projection.ScalarHandler, error) { + return compiler(t._hash, t._value, defaultCompiler) } func (a *Any) MarshalJSON() ([]byte, error) { @@ -44,5 +37,20 @@ func (a *Any) UnmarshalJSON(data []byte) error { return err } a._value = v + a._hash = hashutils.Hash(a._value) return nil } + +func (in *Any) DeepCopyInto(out *Any) { + out._value = deepCopy(in._value) + out._hash = in._hash +} + +func (in *Any) DeepCopy() *Any { + if in == nil { + return nil + } + out := new(Any) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/apis/policy/v1alpha1/any_test.go b/pkg/apis/policy/v1alpha1/any_test.go index 668984d0..efc148fd 100644 --- a/pkg/apis/policy/v1alpha1/any_test.go +++ b/pkg/apis/policy/v1alpha1/any_test.go @@ -9,40 +9,36 @@ import ( func TestAny_DeepCopyInto(t *testing.T) { tests := []struct { name string - in *Any - out *Any + in Any }{{ name: "nil", - in: &Any{nil}, - out: &Any{nil}, + in: NewAny(nil), }, { name: "int", - in: &Any{42}, - out: &Any{nil}, + in: NewAny(42), }, { name: "string", - in: &Any{"foo"}, - out: &Any{nil}, + in: NewAny("foo"), }, { name: "slice", - in: &Any{[]any{42, "string"}}, - out: &Any{nil}, + in: NewAny([]any{42, "string"}), }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.in.DeepCopyInto(tt.out) - assert.Equal(t, tt.in, tt.out) + var out Any + tt.in.DeepCopyInto(&out) + assert.Equal(t, tt.in, out) }) } { inner := map[string]any{ "foo": 42, } - in := Any{map[string]any{"inner": inner}} + in := NewAny(map[string]any{"inner": inner}) out := in.DeepCopy() - inPtr := in.Value().(map[string]any)["inner"].(map[string]any) + inPtr := in._value.(map[string]any)["inner"].(map[string]any) inPtr["foo"] = 55 - outPtr := out.Value().(map[string]any)["inner"].(map[string]any) + outPtr := out._value.(map[string]any)["inner"].(map[string]any) assert.NotEqual(t, inPtr, outPtr) } } diff --git a/pkg/apis/policy/v1alpha1/engine.go b/pkg/apis/policy/v1alpha1/compiler.go similarity index 100% rename from pkg/apis/policy/v1alpha1/engine.go rename to pkg/apis/policy/v1alpha1/compiler.go diff --git a/pkg/commands/jp/query/command.go b/pkg/commands/jp/query/command.go index 73aa2ccc..49309e78 100644 --- a/pkg/commands/jp/query/command.go +++ b/pkg/commands/jp/query/command.go @@ -155,7 +155,7 @@ func loadInput(cmd *cobra.Command, file string) (any, error) { } func evaluate(input any, query string) (any, error) { - result, err := compilers.Execute(query, input, nil, compilers.DefaultCompiler.Jp) + result, err := compilers.Execute(query, input, nil, compilers.DefaultCompilers.Jp) if err != nil { if syntaxError, ok := err.(parsing.SyntaxError); ok { return nil, fmt.Errorf("%s\n%s", syntaxError, syntaxError.HighlightLocation()) diff --git a/pkg/commands/scan/options.go b/pkg/commands/scan/options.go index 15049be4..46993edc 100644 --- a/pkg/commands/scan/options.go +++ b/pkg/commands/scan/options.go @@ -77,7 +77,7 @@ func (c *options) run(cmd *cobra.Command, _ []string) error { } out.println("Pre processing ...") for _, preprocessor := range c.preprocessors { - result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompiler.Jp) + result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompilers.Jp) if err != nil { return err } diff --git a/pkg/core/assertion/assertion.go b/pkg/core/assertion/assertion.go index cc26e62c..f1714e3d 100644 --- a/pkg/core/assertion/assertion.go +++ b/pkg/core/assertion/assertion.go @@ -91,7 +91,7 @@ func parseMap(assertion any, compiler compilers.Compilers, defaultCompiler strin } entry := assertions[key] entry.node = assertion - entry.Projection = projection.Parse(key, compiler, defaultCompiler) + entry.Projection = projection.ParseMapKey(key, compiler, defaultCompiler) assertions[key] = entry } return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { diff --git a/pkg/core/assertion/assertion_test.go b/pkg/core/assertion/assertion_test.go index 94b9f404..eb885d3d 100644 --- a/pkg/core/assertion/assertion_test.go +++ b/pkg/core/assertion/assertion_test.go @@ -6,7 +6,7 @@ import ( "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/kyverno/kyverno-json/pkg/core/compilers" "github.com/kyverno/kyverno-json/pkg/core/expression" - tassert "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/util/validation/field" ) @@ -49,16 +49,16 @@ func TestAssert(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - compiler := compilers.DefaultCompiler + compiler := compilers.DefaultCompilers parsed, err := Parse(tt.assertion, compiler, expression.CompilerJP) - tassert.NoError(t, err) + assert.NoError(t, err) got, err := parsed.Assert(nil, tt.value, tt.bindings) if tt.wantErr { - tassert.Error(t, err) + assert.Error(t, err) } else { - tassert.NoError(t, err) + assert.NoError(t, err) } - tassert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, got) }) } } diff --git a/pkg/core/compilers/compilers.go b/pkg/core/compilers/compilers.go index d1e61f83..e795abe3 100644 --- a/pkg/core/compilers/compilers.go +++ b/pkg/core/compilers/compilers.go @@ -1,16 +1,12 @@ package compilers import ( - "sync" - - "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/kyverno/kyverno-json/pkg/core/compilers/cel" "github.com/kyverno/kyverno-json/pkg/core/compilers/jp" "github.com/kyverno/kyverno-json/pkg/core/expression" - "k8s.io/apimachinery/pkg/util/validation/field" ) -var DefaultCompiler = Compilers{ +var DefaultCompilers = Compilers{ Jp: jp.NewCompiler(), Cel: cel.NewCompiler(), } @@ -32,40 +28,3 @@ func (c Compilers) Compiler(compiler string) Compiler { return c.Jp } } - -func (c Compilers) NewBinding(path *field.Path, value any, bindings binding.Bindings, template any, compiler string) binding.Binding { - return binding.NewDelegate( - sync.OnceValues( - func() (any, error) { - switch typed := template.(type) { - case string: - expr := expression.Parse(compiler, typed) - if expr.Foreach { - return nil, field.Invalid(path.Child("variable"), typed, "foreach is not supported in context") - } - if expr.Binding != "" { - return nil, field.Invalid(path.Child("variable"), typed, "binding is not supported in context") - } - switch expr.Compiler { - case expression.CompilerJP: - projected, err := Execute(expr.Statement, value, bindings, c.Jp) - if err != nil { - return nil, field.InternalError(path.Child("variable"), err) - } - return projected, nil - case expression.CompilerCEL: - projected, err := Execute(expr.Statement, value, bindings, c.Cel) - if err != nil { - return nil, field.InternalError(path.Child("variable"), err) - } - return projected, nil - default: - return expr.Statement, nil - } - default: - return typed, nil - } - }, - ), - ) -} diff --git a/pkg/core/projection/projection.go b/pkg/core/projection/projection.go index d4b23e75..44dd1a23 100644 --- a/pkg/core/projection/projection.go +++ b/pkg/core/projection/projection.go @@ -11,7 +11,10 @@ import ( reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" ) -type Handler = func(value any, bindings binding.Bindings) (any, bool, error) +type ( + ScalarHandler = func(value any, bindings binding.Bindings) (any, error) + MapKeyHandler = func(value any, bindings binding.Bindings) (any, bool, error) +) type Info struct { Foreach bool @@ -21,10 +24,10 @@ type Info struct { type Projection struct { Info - Handler + Handler MapKeyHandler } -func Parse(in any, compiler compilers.Compilers, defaultCompiler string) (projection Projection) { +func ParseMapKey(in any, compiler compilers.Compilers, defaultCompiler string) (projection Projection) { switch typed := in.(type) { case string: // 1. if we have a string, parse the expression @@ -47,7 +50,7 @@ func Parse(in any, compiler compilers.Compilers, defaultCompiler string) (projec if err != nil { return nil, false, err } - return projected, true, err + return projected, true, nil } } else { projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) { @@ -82,3 +85,40 @@ func Parse(in any, compiler compilers.Compilers, defaultCompiler string) (projec } return } + +func ParseScalar(in any, compiler compilers.Compilers, defaultCompiler string) (ScalarHandler, error) { + switch typed := in.(type) { + case string: + expr := expression.Parse(defaultCompiler, typed) + if expr.Foreach { + return nil, errors.New("foreach is not supported in scalar projections") + } + if expr.Binding != "" { + return nil, errors.New("binding is not supported in scalar projections") + } + if compiler := compiler.Compiler(expr.Compiler); compiler != nil { + compile := sync.OnceValues(func() (compilers.Program, error) { + return compiler.Compile(expr.Statement) + }) + return func(value any, bindings binding.Bindings) (any, error) { + program, err := compile() + if err != nil { + return nil, err + } + projected, err := program(value, bindings) + if err != nil { + return nil, err + } + return projected, nil + }, nil + } else { + return func(value any, bindings binding.Bindings) (any, error) { + return expr.Statement, nil + }, nil + } + default: + return func(value any, bindings binding.Bindings) (any, error) { + return typed, nil + }, nil + } +} diff --git a/pkg/core/projection/projection_test.go b/pkg/core/projection/projection_test.go index 401746fb..53c9946f 100644 --- a/pkg/core/projection/projection_test.go +++ b/pkg/core/projection/projection_test.go @@ -6,10 +6,10 @@ import ( "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/kyverno/kyverno-json/pkg/core/compilers" "github.com/kyverno/kyverno-json/pkg/core/expression" - tassert "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -func TestProjection(t *testing.T) { +func TestParseMap(t *testing.T) { tests := []struct { name string key any @@ -89,16 +89,16 @@ func TestProjection(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - compiler := compilers.DefaultCompiler - proj := Parse(tt.key, compiler, expression.CompilerJP) + compiler := compilers.DefaultCompilers + proj := ParseMapKey(tt.key, compiler, expression.CompilerJP) got, found, err := proj.Handler(tt.value, tt.bindings) if tt.wantErr { - tassert.Error(t, err) + assert.Error(t, err) } else { - tassert.NoError(t, err) + assert.NoError(t, err) } - tassert.Equal(t, tt.wantFound, found) - tassert.Equal(t, tt.want, got) + assert.Equal(t, tt.wantFound, found) + assert.Equal(t, tt.want, got) }) } } diff --git a/pkg/json-engine/engine.go b/pkg/json-engine/engine.go index 97656218..5fc42e53 100644 --- a/pkg/json-engine/engine.go +++ b/pkg/json-engine/engine.go @@ -3,8 +3,10 @@ package jsonengine import ( "context" "fmt" + "sync" "time" + "github.com/jmespath-community/go-jmespath/pkg/binding" jpbinding "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" "github.com/kyverno/kyverno-json/pkg/core/compilers" @@ -68,7 +70,7 @@ func New() engine.Engine[Request, Response] { resource any bindings jpbinding.Bindings } - compiler := matching.NewCompiler(compilers.DefaultCompiler, 256) + compiler := matching.NewCompiler(compilers.DefaultCompilers, 256) ruleEngine := builder. Function(func(ctx context.Context, r ruleRequest) []RuleResponse { bindings := r.bindings.Register("$rule", jpbinding.NewBinding(r.rule)) @@ -82,12 +84,31 @@ func New() engine.Engine[Request, Response] { // TODO: this doesn't seem to be the right path var path *field.Path path = path.Child("context") - for i, entry := range r.rule.Context { + for _, entry := range r.rule.Context { defaultCompiler := defaultCompiler if entry.Compiler != nil { defaultCompiler = string(*entry.Compiler) } - bindings = bindings.Register("$"+entry.Name, compiler.NewBinding(path.Index(i), r.resource, bindings, entry.Variable.Value(), defaultCompiler)) + bindings = func(variable v1alpha1.Any, bindings jpbinding.Bindings) jpbinding.Bindings { + return bindings.Register( + "$"+entry.Name, + binding.NewDelegate( + sync.OnceValues( + func() (any, error) { + handler, err := variable.Compile(compiler.CompileProjection, defaultCompiler) + if err != nil { + return nil, field.InternalError(path.Child("variable"), err) + } + projected, err := handler(r.resource, bindings) + if err != nil { + return nil, field.InternalError(path.Child("variable"), err) + } + return projected, nil + }, + ), + ), + ) + }(entry.Variable, bindings) } identifier := "" if r.rule.Identifier != "" { diff --git a/pkg/matching/compiler.go b/pkg/matching/compiler.go index 9f919f64..0ac3f8a7 100644 --- a/pkg/matching/compiler.go +++ b/pkg/matching/compiler.go @@ -7,6 +7,7 @@ import ( "github.com/elastic/go-freelru" "github.com/kyverno/kyverno-json/pkg/core/assertion" "github.com/kyverno/kyverno-json/pkg/core/compilers" + "github.com/kyverno/kyverno-json/pkg/core/projection" ) type _compilers = compilers.Compilers @@ -44,3 +45,8 @@ func (c Compiler) CompileAssertion(hash string, value any, defaultCompiler strin } return entry() } + +func (c Compiler) CompileProjection(hash string, value any, defaultCompiler string) (projection.ScalarHandler, error) { + // TODO: cache + return projection.ParseScalar(value, c._compilers, defaultCompiler) +} diff --git a/pkg/server/playground/handler.go b/pkg/server/playground/handler.go index c4939ae7..d136f131 100644 --- a/pkg/server/playground/handler.go +++ b/pkg/server/playground/handler.go @@ -34,7 +34,7 @@ func newHandler() (gin.HandlerFunc, error) { } // apply pre processors for _, preprocessor := range in.Preprocessors { - result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompiler.Jp) + result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompilers.Jp) if err != nil { return nil, fmt.Errorf("failed to execute prepocessor (%s) - %w", preprocessor, err) } diff --git a/pkg/server/scan/handler.go b/pkg/server/scan/handler.go index 194e1928..2f19b7a7 100644 --- a/pkg/server/scan/handler.go +++ b/pkg/server/scan/handler.go @@ -26,7 +26,7 @@ func newHandler(policyProvider PolicyProvider) (gin.HandlerFunc, error) { payload := in.Payload // apply pre processors for _, preprocessor := range in.Preprocessors { - result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompiler.Jp) + result, err := compilers.Execute(preprocessor, payload, nil, compilers.DefaultCompilers.Jp) if err != nil { return nil, fmt.Errorf("failed to execute prepocessor (%s) - %w", preprocessor, err) }