Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: prepare partial evaluation at boot for better performances #394

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions core/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/url"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/rond-authz/rond/logging"
"github.com/rond-authz/rond/types"
)
Expand Down Expand Up @@ -93,11 +94,13 @@ type RegoInputOptions struct {
EnableResourcePermissionsMapOptimization bool
}

type EvalInput *ast.Term

func CreateRegoQueryInput(
logger logging.Logger,
input Input,
options RegoInputOptions,
) ([]byte, error) {
) (EvalInput, error) {
opaInputCreationTime := time.Now()

input.buildOptimizedResourcePermissionsMap(logger, options.EnableResourcePermissionsMapOptimization)
Expand All @@ -109,5 +112,10 @@ func CreateRegoQueryInput(
logger.
WithField("inputCreationTimeMicroseconds", time.Since(opaInputCreationTime).Microseconds()).
Trace("input creation time")
return inputBytes, nil

astInput, err := ast.ParseTerm(string(inputBytes))
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrFailedInputEncode, err)
}
return EvalInput(astInput), nil
}
2 changes: 1 addition & 1 deletion core/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestCreateRegoInput(t *testing.T) {
t.Run("returns correctly", func(t *testing.T) {
actual, err := CreateRegoQueryInput(log, Input{}, RegoInputOptions{})
require.NoError(t, err)
require.Equal(t, "{\"request\":{\"method\":\"\",\"path\":\"\"},\"response\":{},\"user\":{}}", string(actual))
require.Equal(t, "{\"request\":{\"method\":\"\",\"path\":\"\"},\"response\":{},\"user\":{}}", string(actual.Location.Text))
})

t.Run("buildOptimizedResourcePermissionsMap", func(t *testing.T) {
Expand Down
101 changes: 29 additions & 72 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"

"github.com/rond-authz/rond/custom_builtins"
Expand All @@ -29,7 +28,6 @@ import (
"github.com/rond-authz/rond/metrics"
"github.com/rond-authz/rond/types"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"go.mongodb.org/mongo-driver/bson/primitive"
)
Expand Down Expand Up @@ -60,17 +58,12 @@ type PermissionOptions struct {
IgnoreTrailingSlash bool `json:"ignoreTrailingSlash,omitempty"`
}

type Evaluator interface {
Eval(ctx context.Context) (rego.ResultSet, error)
Partial(ctx context.Context) (*rego.PartialQueries, error)
}

var Unknowns = []string{"data.resources"}

type OPAEvaluator struct {
PolicyEvaluator Evaluator
PolicyName string
PolicyName string

evaluator PartialEvaluator
context context.Context
mongoClient custom_builtins.IMongoClient
generateQuery bool
Expand All @@ -83,77 +76,34 @@ type OPAEvaluatorOptions struct {
Logger logging.Logger
}

func newQueryOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
func (opaEval *OPAEvaluator) partiallyEvaluate(logger logging.Logger, input EvalInput, options *PolicyEvaluationOptions) (primitive.M, error) {
if options == nil {
options = &OPAEvaluatorOptions{}
}
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
options = &PolicyEvaluationOptions{}
}
opaEvaluationTimeStart := time.Now()

sanitizedPolicy := strings.Replace(policy, ".", "_", -1)
queryString := fmt.Sprintf("data.policies.%s", sanitizedPolicy)
query := rego.New(
rego.Query(queryString),
rego.Module(opaModuleConfig.Name, opaModuleConfig.Content),
rego.ParsedInput(inputTerm.Value),
rego.Unknowns(Unknowns),
rego.Capabilities(ast.CapabilitiesForThisVersion()),
rego.EnablePrintStatements(options.EnablePrintStatements),
rego.PrintHook(NewPrintHook(os.Stdout, policy)),
custom_builtins.GetHeaderFunction,
custom_builtins.MongoFindOne,
custom_builtins.MongoFindMany,
)

return &OPAEvaluator{
PolicyEvaluator: query,
PolicyName: policy,

context: ctx,
mongoClient: options.MongoClient,
generateQuery: true,
logger: options.Logger,
}, nil
}

func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger logging.Logger, policy string, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
logger.WithFields(map[string]any{
"policyName": policy,
}).Info("Policy to be evaluated")

opaEvaluatorInstanceTime := time.Now()
evaluator, err := newQueryOPAEvaluator(ctx, policy, config, input, options)
if err != nil {
logger.WithField("error", err).Error(ErrEvaluatorCreationFailed)
return nil, err
if opaEval.evaluator.preparedPartialQuery == nil {
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, "preparedPartialQuery is nil")
}
logger.
WithField("evaluatorCreationTimeMicroseconds", time.Since(opaEvaluatorInstanceTime).Microseconds()).
Trace("evaluator creation time")
return evaluator, nil
}

func (evaluator *OPAEvaluator) partiallyEvaluate(logger logging.Logger, options *PolicyEvaluationOptions) (primitive.M, error) {
if options == nil {
options = &PolicyEvaluationOptions{}
}
opaEvaluationTimeStart := time.Now()
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.getContext())
partialResults, err := opaEval.evaluator.preparedPartialQuery.Partial(
fredmaggiowski marked this conversation as resolved.
Show resolved Hide resolved
opaEval.getContext(),
rego.EvalParsedInput(input.Value),
rego.EvalPrintHook(NewPrintHook(os.Stdout, opaEval.PolicyName)),
)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)

options.metrics().PolicyEvaluationDurationMilliseconds.With(metrics.Labels{
"policy_name": evaluator.PolicyName,
"policy_name": opaEval.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

fields := map[string]any{
"evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(),
"policyName": evaluator.PolicyName,
"policyName": opaEval.PolicyName,
"partialEval": true,
"allowed": true,
}
Expand All @@ -175,27 +125,34 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger logging.Logger, options
return q, nil
}

func (evaluator *OPAEvaluator) Evaluate(logger logging.Logger, options *PolicyEvaluationOptions) (interface{}, error) {
func (opaEval *OPAEvaluator) Evaluate(logger logging.Logger, input EvalInput, options *PolicyEvaluationOptions) (interface{}, error) {
if options == nil {
options = &PolicyEvaluationOptions{}
}
if opaEval.evaluator.preparedEvalQuery == nil {
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, "preparedEvalQuery is nil")
}

opaEvaluationTimeStart := time.Now()

results, err := evaluator.PolicyEvaluator.Eval(evaluator.getContext())
results, err := opaEval.evaluator.preparedEvalQuery.Eval(
opaEval.getContext(),
rego.EvalParsedInput(input.Value),
rego.EvalPrintHook(NewPrintHook(os.Stdout, opaEval.PolicyName)),
)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
options.metrics().PolicyEvaluationDurationMilliseconds.With(metrics.Labels{
"policy_name": evaluator.PolicyName,
"policy_name": opaEval.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

allowed, responseBodyOverwriter := processResults(results)
fields := map[string]any{
"evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(),
"policyName": evaluator.PolicyName,
"policyName": opaEval.PolicyName,
"partialEval": false,
"allowed": allowed,
"resultsLength": len(results),
Expand All @@ -205,7 +162,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger logging.Logger, options *PolicyEv
logger.WithFields(fields).Debug("policy evaluation completed")

logger.WithFields(map[string]any{
"policyName": evaluator.PolicyName,
"policyName": opaEval.PolicyName,
"allowed": allowed,
}).Info("policy result")

Expand Down Expand Up @@ -241,12 +198,12 @@ func (evaluator *PolicyEvaluationOptions) metrics() *metrics.Metrics {
return metrics.NoOpMetrics()
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger logging.Logger, options *PolicyEvaluationOptions) (interface{}, primitive.M, error) {
func (evaluator *OPAEvaluator) PolicyEvaluation(logger logging.Logger, input EvalInput, options *PolicyEvaluationOptions) (interface{}, primitive.M, error) {
if evaluator.generateQuery {
query, err := evaluator.partiallyEvaluate(logger, options)
query, err := evaluator.partiallyEvaluate(logger, input, options)
return nil, query, err
}
dataFromEvaluation, err := evaluator.Evaluate(logger, options)
dataFromEvaluation, err := evaluator.Evaluate(logger, input, options)
if err != nil {
return nil, nil, err
}
Expand Down
132 changes: 24 additions & 108 deletions core/opaevaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package core

import (
"context"
"encoding/json"
"net/http"
"testing"

"github.com/rond-authz/rond/custom_builtins"
Expand All @@ -28,22 +26,6 @@ import (
"github.com/stretchr/testify/require"
)

func TestNewOPAEvaluator(t *testing.T) {
input := map[string]interface{}{}
inputBytes, _ := json.Marshal(input)
t.Run("policy sanitization", func(t *testing.T) {
evaluator, _ := newQueryOPAEvaluator(context.Background(), "very.composed.policy", &OPAModuleConfig{Content: "package policies very_composed_policy {true}"}, inputBytes, nil)

result, err := evaluator.PolicyEvaluator.Eval(context.TODO())
require.Nil(t, err, "unexpected error")
require.True(t, result.Allowed(), "Unexpected failing policy")

parialResult, err := evaluator.PolicyEvaluator.Partial(context.TODO())
require.Nil(t, err, "unexpected error")
require.Equal(t, 1, len(parialResult.Queries), "Unexpected failing policy")
})
}

func TestOPAEvaluator(t *testing.T) {
t.Run("get context", func(t *testing.T) {
t.Run("no context", func(t *testing.T) {
Expand Down Expand Up @@ -100,6 +82,30 @@ func TestOPAEvaluator(t *testing.T) {
require.Equal(t, log, actualLog)
})
})

t.Run("PolicyEvaluation", func(t *testing.T) {
t.Run("with empty evaluator - generate query", func(t *testing.T) {
opaEval := OPAEvaluator{
generateQuery: true,
}
logger := logging.NewNoOpLogger()
result, query, err := opaEval.PolicyEvaluation(logger, nil, nil)

require.EqualError(t, err, "partial policy evaluation failed: preparedPartialQuery is nil")
require.Nil(t, result)
require.Empty(t, query)
})

t.Run("with empty evaluator - eval query", func(t *testing.T) {
opaEval := OPAEvaluator{}
logger := logging.NewNoOpLogger()
result, query, err := opaEval.PolicyEvaluation(logger, nil, nil)

require.EqualError(t, err, "policy evaluation failed: preparedEvalQuery is nil")
require.Nil(t, result)
require.Empty(t, query)
})
})
}

func TestBuildRolesMap(t *testing.T) {
Expand All @@ -120,93 +126,3 @@ func TestBuildRolesMap(t *testing.T) {
}
require.Equal(t, expected, result)
}

func TestCreateQueryEvaluator(t *testing.T) {
policy := `package policies
allow {
true
}
column_policy{
false
}
`
permission := RondConfig{
RequestFlow: RequestFlow{
PolicyName: "allow",
},
ResponseFlow: ResponseFlow{
PolicyName: "column_policy",
},
}

opaModuleConfig := &OPAModuleConfig{Name: "mypolicy.rego", Content: policy}

logger := logging.NewNoOpLogger()

input := Input{Request: InputRequest{}, Response: InputResponse{}}
inputBytes, _ := json.Marshal(input)

t.Run("create evaluator with allowPolicy", func(t *testing.T) {
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, permission.RequestFlow.PolicyName, inputBytes, nil)
require.True(t, evaluator != nil)
require.NoError(t, err, "Unexpected status code.")
})

t.Run("create evaluator with policy for column filtering", func(t *testing.T) {
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, permission.ResponseFlow.PolicyName, inputBytes, nil)
require.True(t, evaluator != nil)
require.NoError(t, err, "Unexpected status code.")
})
}

func TestGetHeaderFunction(t *testing.T) {
headerKeyMocked := "exampleKey"
headerValueMocked := "value"

opaModule := &OPAModuleConfig{
Name: "example.rego",
Content: `package policies
todo { get_header("ExAmPlEkEy", input.headers) == "value" }`,
}
queryString := "todo"

t.Run("if header key exists", func(t *testing.T) {
headers := http.Header{}
headers.Add(headerKeyMocked, headerValueMocked)
input := map[string]interface{}{
"headers": headers,
}
inputBytes, _ := json.Marshal(input)

opaEvaluator, err := newQueryOPAEvaluator(context.Background(), queryString, opaModule, inputBytes, nil)
require.NoError(t, err, "Unexpected error during creation of opaEvaluator")

results, err := opaEvaluator.PolicyEvaluator.Eval(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")
require.True(t, results.Allowed(), "The input is not allowed by rego")

partialResults, err := opaEvaluator.PolicyEvaluator.Partial(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")

require.Len(t, partialResults.Queries, 1, "Rego policy allows illegal input")
})

t.Run("if header key not exists", func(t *testing.T) {
input := map[string]interface{}{
"headers": http.Header{},
}
inputBytes, _ := json.Marshal(input)

opaEvaluator, err := newQueryOPAEvaluator(context.Background(), queryString, opaModule, inputBytes, nil)
require.NoError(t, err, "Unexpected error during creation of opaEvaluator")

results, err := opaEvaluator.PolicyEvaluator.Eval(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")
require.True(t, !results.Allowed(), "Rego policy allows illegal input")

partialResults, err := opaEvaluator.PolicyEvaluator.Partial(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")

require.Len(t, partialResults.Queries, 0, "Rego policy allows illegal input")
})
}
Loading