Skip to content

Commit

Permalink
Merge branch 'main' into feat/migrate-envs-lib
Browse files Browse the repository at this point in the history
  • Loading branch information
fredmaggiowski authored Oct 24, 2024
2 parents 45ca8dc + 686c4e0 commit dd0cc4a
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 289 deletions.
12 changes: 10 additions & 2 deletions core/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/url"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/rond-authz/rond/logging"
"github.com/rond-authz/rond/types"
)
Expand Down Expand Up @@ -93,11 +94,13 @@ type RegoInputOptions struct {
EnableResourcePermissionsMapOptimization bool
}

type EvalInput *ast.Term

func CreateRegoQueryInput(
logger logging.Logger,
input Input,
options RegoInputOptions,
) ([]byte, error) {
) (EvalInput, error) {
opaInputCreationTime := time.Now()

input.buildOptimizedResourcePermissionsMap(logger, options.EnableResourcePermissionsMapOptimization)
Expand All @@ -109,5 +112,10 @@ func CreateRegoQueryInput(
logger.
WithField("inputCreationTimeMicroseconds", time.Since(opaInputCreationTime).Microseconds()).
Trace("input creation time")
return inputBytes, nil

astInput, err := ast.ParseTerm(string(inputBytes))
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrFailedInputEncode, err)
}
return EvalInput(astInput), nil
}
2 changes: 1 addition & 1 deletion core/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestCreateRegoInput(t *testing.T) {
t.Run("returns correctly", func(t *testing.T) {
actual, err := CreateRegoQueryInput(log, Input{}, RegoInputOptions{})
require.NoError(t, err)
require.Equal(t, "{\"request\":{\"method\":\"\",\"path\":\"\"},\"response\":{},\"user\":{}}", string(actual))
require.Equal(t, "{\"request\":{\"method\":\"\",\"path\":\"\"},\"response\":{},\"user\":{}}", string(actual.Location.Text))
})

t.Run("buildOptimizedResourcePermissionsMap", func(t *testing.T) {
Expand Down
101 changes: 29 additions & 72 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"

"github.com/rond-authz/rond/custom_builtins"
Expand All @@ -29,7 +28,6 @@ import (
"github.com/rond-authz/rond/metrics"
"github.com/rond-authz/rond/types"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"go.mongodb.org/mongo-driver/bson/primitive"
)
Expand Down Expand Up @@ -60,17 +58,12 @@ type PermissionOptions struct {
IgnoreTrailingSlash bool `json:"ignoreTrailingSlash,omitempty"`
}

type Evaluator interface {
Eval(ctx context.Context) (rego.ResultSet, error)
Partial(ctx context.Context) (*rego.PartialQueries, error)
}

var Unknowns = []string{"data.resources"}

type OPAEvaluator struct {
PolicyEvaluator Evaluator
PolicyName string
PolicyName string

evaluator PartialEvaluator
context context.Context
mongoClient custom_builtins.IMongoClient
generateQuery bool
Expand All @@ -83,77 +76,34 @@ type OPAEvaluatorOptions struct {
Logger logging.Logger
}

func newQueryOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
func (opaEval *OPAEvaluator) partiallyEvaluate(logger logging.Logger, input EvalInput, options *PolicyEvaluationOptions) (primitive.M, error) {
if options == nil {
options = &OPAEvaluatorOptions{}
}
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
options = &PolicyEvaluationOptions{}
}
opaEvaluationTimeStart := time.Now()

sanitizedPolicy := strings.Replace(policy, ".", "_", -1)
queryString := fmt.Sprintf("data.policies.%s", sanitizedPolicy)
query := rego.New(
rego.Query(queryString),
rego.Module(opaModuleConfig.Name, opaModuleConfig.Content),
rego.ParsedInput(inputTerm.Value),
rego.Unknowns(Unknowns),
rego.Capabilities(ast.CapabilitiesForThisVersion()),
rego.EnablePrintStatements(options.EnablePrintStatements),
rego.PrintHook(NewPrintHook(os.Stdout, policy)),
custom_builtins.GetHeaderFunction,
custom_builtins.MongoFindOne,
custom_builtins.MongoFindMany,
)

return &OPAEvaluator{
PolicyEvaluator: query,
PolicyName: policy,

context: ctx,
mongoClient: options.MongoClient,
generateQuery: true,
logger: options.Logger,
}, nil
}

func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger logging.Logger, policy string, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
logger.WithFields(map[string]any{
"policyName": policy,
}).Info("Policy to be evaluated")

opaEvaluatorInstanceTime := time.Now()
evaluator, err := newQueryOPAEvaluator(ctx, policy, config, input, options)
if err != nil {
logger.WithField("error", err).Error(ErrEvaluatorCreationFailed)
return nil, err
if opaEval.evaluator.preparedPartialQuery == nil {
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, "preparedPartialQuery is nil")
}
logger.
WithField("evaluatorCreationTimeMicroseconds", time.Since(opaEvaluatorInstanceTime).Microseconds()).
Trace("evaluator creation time")
return evaluator, nil
}

func (evaluator *OPAEvaluator) partiallyEvaluate(logger logging.Logger, options *PolicyEvaluationOptions) (primitive.M, error) {
if options == nil {
options = &PolicyEvaluationOptions{}
}
opaEvaluationTimeStart := time.Now()
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.getContext())
partialResults, err := opaEval.evaluator.preparedPartialQuery.Partial(
opaEval.getContext(),
rego.EvalParsedInput(input.Value),
rego.EvalPrintHook(NewPrintHook(os.Stdout, opaEval.PolicyName)),
)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)

options.metrics().PolicyEvaluationDurationMilliseconds.With(metrics.Labels{
"policy_name": evaluator.PolicyName,
"policy_name": opaEval.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

fields := map[string]any{
"evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(),
"policyName": evaluator.PolicyName,
"policyName": opaEval.PolicyName,
"partialEval": true,
"allowed": true,
}
Expand All @@ -175,27 +125,34 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger logging.Logger, options
return q, nil
}

func (evaluator *OPAEvaluator) Evaluate(logger logging.Logger, options *PolicyEvaluationOptions) (interface{}, error) {
func (opaEval *OPAEvaluator) Evaluate(logger logging.Logger, input EvalInput, options *PolicyEvaluationOptions) (interface{}, error) {
if options == nil {
options = &PolicyEvaluationOptions{}
}
if opaEval.evaluator.preparedEvalQuery == nil {
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, "preparedEvalQuery is nil")
}

opaEvaluationTimeStart := time.Now()

results, err := evaluator.PolicyEvaluator.Eval(evaluator.getContext())
results, err := opaEval.evaluator.preparedEvalQuery.Eval(
opaEval.getContext(),
rego.EvalParsedInput(input.Value),
rego.EvalPrintHook(NewPrintHook(os.Stdout, opaEval.PolicyName)),
)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
options.metrics().PolicyEvaluationDurationMilliseconds.With(metrics.Labels{
"policy_name": evaluator.PolicyName,
"policy_name": opaEval.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

allowed, responseBodyOverwriter := processResults(results)
fields := map[string]any{
"evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(),
"policyName": evaluator.PolicyName,
"policyName": opaEval.PolicyName,
"partialEval": false,
"allowed": allowed,
"resultsLength": len(results),
Expand All @@ -205,7 +162,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger logging.Logger, options *PolicyEv
logger.WithFields(fields).Debug("policy evaluation completed")

logger.WithFields(map[string]any{
"policyName": evaluator.PolicyName,
"policyName": opaEval.PolicyName,
"allowed": allowed,
}).Info("policy result")

Expand Down Expand Up @@ -241,12 +198,12 @@ func (evaluator *PolicyEvaluationOptions) metrics() *metrics.Metrics {
return metrics.NoOpMetrics()
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger logging.Logger, options *PolicyEvaluationOptions) (interface{}, primitive.M, error) {
func (evaluator *OPAEvaluator) PolicyEvaluation(logger logging.Logger, input EvalInput, options *PolicyEvaluationOptions) (interface{}, primitive.M, error) {
if evaluator.generateQuery {
query, err := evaluator.partiallyEvaluate(logger, options)
query, err := evaluator.partiallyEvaluate(logger, input, options)
return nil, query, err
}
dataFromEvaluation, err := evaluator.Evaluate(logger, options)
dataFromEvaluation, err := evaluator.Evaluate(logger, input, options)
if err != nil {
return nil, nil, err
}
Expand Down
132 changes: 24 additions & 108 deletions core/opaevaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package core

import (
"context"
"encoding/json"
"net/http"
"testing"

"github.com/rond-authz/rond/custom_builtins"
Expand All @@ -28,22 +26,6 @@ import (
"github.com/stretchr/testify/require"
)

func TestNewOPAEvaluator(t *testing.T) {
input := map[string]interface{}{}
inputBytes, _ := json.Marshal(input)
t.Run("policy sanitization", func(t *testing.T) {
evaluator, _ := newQueryOPAEvaluator(context.Background(), "very.composed.policy", &OPAModuleConfig{Content: "package policies very_composed_policy {true}"}, inputBytes, nil)

result, err := evaluator.PolicyEvaluator.Eval(context.TODO())
require.Nil(t, err, "unexpected error")
require.True(t, result.Allowed(), "Unexpected failing policy")

parialResult, err := evaluator.PolicyEvaluator.Partial(context.TODO())
require.Nil(t, err, "unexpected error")
require.Equal(t, 1, len(parialResult.Queries), "Unexpected failing policy")
})
}

func TestOPAEvaluator(t *testing.T) {
t.Run("get context", func(t *testing.T) {
t.Run("no context", func(t *testing.T) {
Expand Down Expand Up @@ -100,6 +82,30 @@ func TestOPAEvaluator(t *testing.T) {
require.Equal(t, log, actualLog)
})
})

t.Run("PolicyEvaluation", func(t *testing.T) {
t.Run("with empty evaluator - generate query", func(t *testing.T) {
opaEval := OPAEvaluator{
generateQuery: true,
}
logger := logging.NewNoOpLogger()
result, query, err := opaEval.PolicyEvaluation(logger, nil, nil)

require.EqualError(t, err, "partial policy evaluation failed: preparedPartialQuery is nil")
require.Nil(t, result)
require.Empty(t, query)
})

t.Run("with empty evaluator - eval query", func(t *testing.T) {
opaEval := OPAEvaluator{}
logger := logging.NewNoOpLogger()
result, query, err := opaEval.PolicyEvaluation(logger, nil, nil)

require.EqualError(t, err, "policy evaluation failed: preparedEvalQuery is nil")
require.Nil(t, result)
require.Empty(t, query)
})
})
}

func TestBuildRolesMap(t *testing.T) {
Expand All @@ -120,93 +126,3 @@ func TestBuildRolesMap(t *testing.T) {
}
require.Equal(t, expected, result)
}

func TestCreateQueryEvaluator(t *testing.T) {
policy := `package policies
allow {
true
}
column_policy{
false
}
`
permission := RondConfig{
RequestFlow: RequestFlow{
PolicyName: "allow",
},
ResponseFlow: ResponseFlow{
PolicyName: "column_policy",
},
}

opaModuleConfig := &OPAModuleConfig{Name: "mypolicy.rego", Content: policy}

logger := logging.NewNoOpLogger()

input := Input{Request: InputRequest{}, Response: InputResponse{}}
inputBytes, _ := json.Marshal(input)

t.Run("create evaluator with allowPolicy", func(t *testing.T) {
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, permission.RequestFlow.PolicyName, inputBytes, nil)
require.True(t, evaluator != nil)
require.NoError(t, err, "Unexpected status code.")
})

t.Run("create evaluator with policy for column filtering", func(t *testing.T) {
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, permission.ResponseFlow.PolicyName, inputBytes, nil)
require.True(t, evaluator != nil)
require.NoError(t, err, "Unexpected status code.")
})
}

func TestGetHeaderFunction(t *testing.T) {
headerKeyMocked := "exampleKey"
headerValueMocked := "value"

opaModule := &OPAModuleConfig{
Name: "example.rego",
Content: `package policies
todo { get_header("ExAmPlEkEy", input.headers) == "value" }`,
}
queryString := "todo"

t.Run("if header key exists", func(t *testing.T) {
headers := http.Header{}
headers.Add(headerKeyMocked, headerValueMocked)
input := map[string]interface{}{
"headers": headers,
}
inputBytes, _ := json.Marshal(input)

opaEvaluator, err := newQueryOPAEvaluator(context.Background(), queryString, opaModule, inputBytes, nil)
require.NoError(t, err, "Unexpected error during creation of opaEvaluator")

results, err := opaEvaluator.PolicyEvaluator.Eval(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")
require.True(t, results.Allowed(), "The input is not allowed by rego")

partialResults, err := opaEvaluator.PolicyEvaluator.Partial(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")

require.Len(t, partialResults.Queries, 1, "Rego policy allows illegal input")
})

t.Run("if header key not exists", func(t *testing.T) {
input := map[string]interface{}{
"headers": http.Header{},
}
inputBytes, _ := json.Marshal(input)

opaEvaluator, err := newQueryOPAEvaluator(context.Background(), queryString, opaModule, inputBytes, nil)
require.NoError(t, err, "Unexpected error during creation of opaEvaluator")

results, err := opaEvaluator.PolicyEvaluator.Eval(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")
require.True(t, !results.Allowed(), "Rego policy allows illegal input")

partialResults, err := opaEvaluator.PolicyEvaluator.Partial(context.TODO())
require.NoError(t, err, "Unexpected error during rego validation")

require.Len(t, partialResults.Queries, 0, "Rego policy allows illegal input")
})
}
Loading

0 comments on commit dd0cc4a

Please sign in to comment.