Skip to content

Commit

Permalink
Merge pull request #1128 from openmeterio/add-regression-tests
Browse files Browse the repository at this point in the history
feat: fix issues when grant expires the same time when a reset happens
  • Loading branch information
turip authored Jul 1, 2024
2 parents f2044dc + 5363ecd commit 84fb0f0
Show file tree
Hide file tree
Showing 20 changed files with 472 additions and 68 deletions.
17 changes: 12 additions & 5 deletions internal/credit/balance_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/openmeterio/openmeter/internal/streaming"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/framework/entutils"
"github.com/openmeterio/openmeter/pkg/models"
"github.com/openmeterio/openmeter/pkg/recurrence"
Expand Down Expand Up @@ -218,7 +219,7 @@ func (m *balanceConnector) GetBalanceHistoryOfOwner(ctx context.Context, owner N

func (m *balanceConnector) ResetUsageForOwner(ctx context.Context, owner NamespacedGrantOwner, params ResetUsageForOwnerParams) (*GrantBalanceSnapshot, error) {
// Cannot reset for the future
if params.At.After(time.Now()) {
if params.At.After(clock.Now()) {
return nil, &models.GenericUserError{Message: fmt.Sprintf("cannot reset at %s in the future", params.At)}
}

Expand All @@ -230,7 +231,7 @@ func (m *balanceConnector) ResetUsageForOwner(ctx context.Context, owner Namespa
at := params.At.Truncate(ownerMeter.WindowSize.Duration())

// check if reset is possible (after last reset)
periodStart, err := m.ownerConnector.GetUsagePeriodStartAt(ctx, owner, time.Now())
periodStart, err := m.ownerConnector.GetUsagePeriodStartAt(ctx, owner, clock.Now())
if err != nil {
if _, ok := err.(*OwnerNotFoundError); ok {
return nil, err
Expand Down Expand Up @@ -286,15 +287,21 @@ func (m *balanceConnector) ResetUsageForOwner(ctx context.Context, owner Namespa
grantMap[grant.ID] = grant
}

// We have to roll over the grants and save the starting balance for the next period
// at the reset time.
startingBalance := endingBalance.Copy()
// We have to roll over the grants and save the starting balance for the next period at the reset time.
// Engine treates the output balance as a period end (exclusive), but we need to treat it as a period start (inclusive).
startingBalance := GrantBalanceMap{}
for grantID, grantBalance := range endingBalance {
grant, ok := grantMap[grantID]
// inconsistency check, shouldn't happen
if !ok {
return nil, fmt.Errorf("attempting to roll over unknown grant %s", grantID)
}

// grants might become inactive at the reset time, in which case they're irrelevant for the next period
if !grant.ActiveAt(at) {
continue
}

startingBalance.Set(grantID, grant.RolloverBalance(grantBalance))
}

Expand Down
3 changes: 2 additions & 1 deletion internal/credit/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ type Grant struct {

// Expiration The expiration configuration.
Expiration ExpirationPeriod `json:"expiration"`
// ExpiresAt contains the exact expiration date calculated from effectiveAt and Expiration for rendering
// ExpiresAt contains the exact expiration date calculated from effectiveAt and Expiration for rendering.
// ExpiresAt is exclusive, meaning that the grant is no longer active after this time, but it is still active at the time.
ExpiresAt time.Time `json:"expiresAt"`

Metadata map[string]string `json:"metadata,omitempty"`
Expand Down
5 changes: 3 additions & 2 deletions internal/credit/grant_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"time"

"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/framework/entutils"
"github.com/openmeterio/openmeter/pkg/models"
"github.com/openmeterio/openmeter/pkg/recurrence"
Expand Down Expand Up @@ -109,7 +110,7 @@ func (m *grantConnector) CreateGrant(ctx context.Context, owner NamespacedGrantO
if input.Recurrence != nil {
input.Recurrence.Anchor = input.Recurrence.Anchor.Truncate(granularity)
}
periodStart, err := m.ownerConnector.GetUsagePeriodStartAt(ctx, owner, time.Now())
periodStart, err := m.ownerConnector.GetUsagePeriodStartAt(ctx, owner, clock.Now())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -176,7 +177,7 @@ func (m *grantConnector) VoidGrant(ctx context.Context, grantID models.Namespace
if err != nil {
return nil, err
}
now := time.Now()
now := clock.Now()
err = m.grantRepo.WithTx(ctx, tx).VoidGrant(ctx, grantID, now)
if err != nil {
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion internal/credit/postgresadapter/balance_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/openmeterio/openmeter/internal/credit"
"github.com/openmeterio/openmeter/internal/credit/postgresadapter/ent/db"
db_balancesnapshot "github.com/openmeterio/openmeter/internal/credit/postgresadapter/ent/db/balancesnapshot"
"github.com/openmeterio/openmeter/pkg/clock"
)

// naive implementation of the BalanceSnapshotConnector
Expand All @@ -25,7 +26,7 @@ func NewPostgresBalanceSnapshotRepo(db *db.Client) credit.BalanceSnapshotConnect
func (b *balanceSnapshotAdapter) InvalidateAfter(ctx context.Context, owner credit.NamespacedGrantOwner, at time.Time) error {
return b.db.BalanceSnapshot.Update().
Where(db_balancesnapshot.OwnerID(owner.ID), db_balancesnapshot.Namespace(owner.Namespace), db_balancesnapshot.AtGT(at)).
SetDeletedAt(time.Now()).
SetDeletedAt(clock.Now()).
Exec(ctx)
}

Expand Down
3 changes: 2 additions & 1 deletion internal/entitlement/boolean/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/openmeterio/openmeter/internal/entitlement"
"github.com/openmeterio/openmeter/internal/productcatalog"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/recurrence"
)

Expand Down Expand Up @@ -43,7 +44,7 @@ func (c *connector) BeforeCreate(model entitlement.CreateEntitlementInputs, feat
if model.UsagePeriod != nil {
usagePeriod = model.UsagePeriod

calculatedPeriod, err := usagePeriod.GetCurrentPeriodAt(time.Now())
calculatedPeriod, err := usagePeriod.GetCurrentPeriodAt(clock.Now())
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion internal/entitlement/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/openmeterio/openmeter/internal/meter"
"github.com/openmeterio/openmeter/internal/productcatalog"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/framework/entutils"
"github.com/openmeterio/openmeter/pkg/models"
)
Expand Down Expand Up @@ -81,7 +82,7 @@ func (c *entitlementConnector) CreateEntitlement(ctx context.Context, input Crea
if err != nil || feature == nil {
return nil, &productcatalog.FeatureNotFoundError{ID: *idOrFeatureKey}
}
if feature.ArchivedAt != nil && feature.ArchivedAt.Before(time.Now()) {
if feature.ArchivedAt != nil && feature.ArchivedAt.Before(clock.Now()) {
return nil, &models.GenericUserError{Message: "Feature is archived"}
}

Expand Down
9 changes: 5 additions & 4 deletions internal/entitlement/httpdriver/entitlement.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
meteredentitlement "github.com/openmeterio/openmeter/internal/entitlement/metered"
staticentitlement "github.com/openmeterio/openmeter/internal/entitlement/static"
"github.com/openmeterio/openmeter/internal/namespace/namespacedriver"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/defaultx"
"github.com/openmeterio/openmeter/pkg/framework/commonhttp"
Expand Down Expand Up @@ -83,7 +84,7 @@ func (h *entitlementHandler) CreateEntitlement() CreateEntitlementHandler {
IsSoftLimit: v.IsSoftLimit,
IssueAfterReset: v.IssueAfterReset,
UsagePeriod: &entitlement.UsagePeriod{
Anchor: defaultx.WithDefault(v.UsagePeriod.Anchor, time.Now()), // TODO: shouldn't we truncate this?
Anchor: defaultx.WithDefault(v.UsagePeriod.Anchor, clock.Now()), // TODO: shouldn't we truncate this?
Interval: recurrence.RecurrenceInterval(v.UsagePeriod.Interval),
},
}
Expand All @@ -101,7 +102,7 @@ func (h *entitlementHandler) CreateEntitlement() CreateEntitlementHandler {
}
if v.UsagePeriod != nil {
request.UsagePeriod = &entitlement.UsagePeriod{
Anchor: defaultx.WithDefault(v.UsagePeriod.Anchor, time.Now()), // TODO: shouldn't we truncate this?
Anchor: defaultx.WithDefault(v.UsagePeriod.Anchor, clock.Now()), // TODO: shouldn't we truncate this?
Interval: recurrence.RecurrenceInterval(v.UsagePeriod.Interval),
}
}
Expand All @@ -118,7 +119,7 @@ func (h *entitlementHandler) CreateEntitlement() CreateEntitlementHandler {
}
if v.UsagePeriod != nil {
request.UsagePeriod = &entitlement.UsagePeriod{
Anchor: defaultx.WithDefault(v.UsagePeriod.Anchor, time.Now()), // TODO: shouldn't we truncate this?
Anchor: defaultx.WithDefault(v.UsagePeriod.Anchor, clock.Now()), // TODO: shouldn't we truncate this?
Interval: recurrence.RecurrenceInterval(v.UsagePeriod.Interval),
}
}
Expand Down Expand Up @@ -173,7 +174,7 @@ func (h *entitlementHandler) GetEntitlementValue() GetEntitlementValueHandler {
SubjectKey: params.SubjectKey,
EntitlementIdOrFeatureKey: params.EntitlementIdOrFeatureKey,
Namespace: ns,
At: defaultx.WithDefault(params.Params.Time, time.Now()),
At: defaultx.WithDefault(params.Params.Time, clock.Now()),
}, nil
},
func(ctx context.Context, request GetEntitlementValueHandlerRequest) (api.EntitlementValue, error) {
Expand Down
3 changes: 2 additions & 1 deletion internal/entitlement/httpdriver/metered.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/openmeterio/openmeter/internal/entitlement"
meteredentitlement "github.com/openmeterio/openmeter/internal/entitlement/metered"
"github.com/openmeterio/openmeter/internal/namespace/namespacedriver"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/defaultx"
"github.com/openmeterio/openmeter/pkg/framework/commonhttp"
Expand Down Expand Up @@ -216,7 +217,7 @@ func (h *meteredEntitlementHandler) ResetEntitlementUsage() ResetEntitlementUsag
EntitlementID: params.EntitlementID,
Namespace: ns,
SubjectID: params.SubjectKey,
At: defaultx.WithDefault(body.EffectiveAt, time.Now()),
At: defaultx.WithDefault(body.EffectiveAt, clock.Now()),
RetainAnchor: defaultx.WithDefault(body.RetainAnchor, false),
}, nil
},
Expand Down
3 changes: 2 additions & 1 deletion internal/entitlement/metered/balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/openmeterio/openmeter/internal/credit"
"github.com/openmeterio/openmeter/internal/entitlement"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/models"
"github.com/openmeterio/openmeter/pkg/slicesx"
Expand Down Expand Up @@ -115,7 +116,7 @@ func (e *connector) GetEntitlementBalanceHistory(ctx context.Context, entitlemen
}

if params.To == nil {
params.To = convert.ToPointer(time.Now())
params.To = convert.ToPointer(clock.Now())
}

// query period cannot be before start of measuring usage
Expand Down
82 changes: 45 additions & 37 deletions internal/entitlement/metered/balance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,43 +993,6 @@ func TestResetEntitlementUsage(t *testing.T) {
assert.Equal(t, resetTime.Format(time.RFC3339), ent.LastReset.Format(time.RFC3339))
},
},
{
name: "Should return proper last reset time after reset",
run: func(t *testing.T, connector meteredentitlement.Connector, deps *testDependencies) {
ctx := context.Background()
startTime := testutils.GetRFC3339Time(t, "2024-03-01T00:00:00Z")

// create featute in db
feature, err := deps.featureDB.CreateFeature(ctx, exampleFeature)
assert.NoError(t, err)

// create entitlement in db
inp := getEntitlement(t, feature)
inp.MeasureUsageFrom = &startTime
ent, err := deps.entitlementDB.CreateEntitlement(ctx, inp)
assert.NoError(t, err)

ent, err = deps.entitlementDB.GetEntitlement(ctx, models.NamespacedID{Namespace: namespace, ID: ent.ID})
assert.NoError(t, err)
assert.Equal(t, startTime.Format(time.RFC3339), ent.LastReset.Format(time.RFC3339))

deps.streaming.AddSimpleEvent(meterSlug, 600, startTime.Add(time.Minute))

// resetTime before snapshot
resetTime := startTime.Add(time.Hour * 5)
_, err = connector.ResetEntitlementUsage(ctx,
models.NamespacedID{Namespace: namespace, ID: ent.ID},
meteredentitlement.ResetEntitlementUsageParams{
At: resetTime,
})
assert.NoError(t, err)

// validate that lastReset time is properly set
ent, err = deps.entitlementDB.GetEntitlement(ctx, models.NamespacedID{Namespace: namespace, ID: ent.ID})
assert.NoError(t, err)
assert.Equal(t, resetTime.Format(time.RFC3339), ent.LastReset.Format(time.RFC3339))
},
},
{
name: "Should calculate balance for grants taking effect after last saved snapshot",
run: func(t *testing.T, connector meteredentitlement.Connector, deps *testDependencies) {
Expand Down Expand Up @@ -1212,6 +1175,51 @@ func TestResetEntitlementUsage(t *testing.T) {
assert.Equal(t, g2.Amount, balanceAfterReset.Balance) // 1000 - 0 = 1000
},
},
{
name: "Should properly handle grants expiring the same time as reset",
run: func(t *testing.T, connector meteredentitlement.Connector, deps *testDependencies) {
ctx := context.Background()
startTime := testutils.GetRFC3339Time(t, "2024-03-01T00:00:00Z")
resetTime := startTime.AddDate(0, 0, 3)

// create featute in db
feature, err := deps.featureDB.CreateFeature(ctx, exampleFeature)
assert.NoError(t, err)

// add 0 usage so meter is found in mock
deps.streaming.AddSimpleEvent(meterSlug, 0, startTime)

// create entitlement in db
inp := getEntitlement(t, feature)
inp.MeasureUsageFrom = &startTime
ent, err := deps.entitlementDB.CreateEntitlement(ctx, inp)
assert.NoError(t, err)

// issue grants
_, err = deps.grantDB.CreateGrant(ctx, credit.GrantRepoCreateGrantInput{
OwnerID: credit.GrantOwner(ent.ID),
Namespace: namespace,
Amount: 1000,
Priority: 1,
EffectiveAt: startTime.Add(time.Hour * 2),
ExpiresAt: resetTime,
ResetMaxRollover: 1000, // full amount can be rolled over
})
assert.NoError(t, err)

// do a reset
balanceAfterReset, err := connector.ResetEntitlementUsage(ctx,
models.NamespacedID{Namespace: namespace, ID: ent.ID},
meteredentitlement.ResetEntitlementUsageParams{
At: resetTime,
})

// assert balance after reset is 0 for grant
assert.NoError(t, err)
assert.Equal(t, 0.0, balanceAfterReset.UsageInPeriod) // 0 usage right after reset
assert.Equal(t, 0.0, balanceAfterReset.Balance) // Grant expires at reset time so we should see no balance
},
},
{
name: "Should reseting without anchor update keeps the next reset time intact",
run: func(t *testing.T, connector meteredentitlement.Connector, deps *testDependencies) {
Expand Down
3 changes: 2 additions & 1 deletion internal/entitlement/metered/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/openmeterio/openmeter/internal/entitlement"
"github.com/openmeterio/openmeter/internal/productcatalog"
"github.com/openmeterio/openmeter/internal/streaming"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/defaultx"
"github.com/openmeterio/openmeter/pkg/models"
Expand Down Expand Up @@ -117,7 +118,7 @@ func (c *connector) BeforeCreate(model entitlement.CreateEntitlementInputs, feat
return nil, &entitlement.InvalidFeatureError{FeatureID: feature.ID, Message: "Feature has no meter"}
}

model.MeasureUsageFrom = convert.ToPointer(defaultx.WithDefault(model.MeasureUsageFrom, time.Now().Truncate(c.granularity)))
model.MeasureUsageFrom = convert.ToPointer(defaultx.WithDefault(model.MeasureUsageFrom, clock.Now().Truncate(c.granularity)))
model.IsSoftLimit = convert.ToPointer(defaultx.WithDefault(model.IsSoftLimit, false))
model.IssueAfterReset = convert.ToPointer(defaultx.WithDefault(model.IssueAfterReset, 0.0))

Expand Down
3 changes: 2 additions & 1 deletion internal/entitlement/metered/entitlement_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/openmeterio/openmeter/internal/credit"
"github.com/openmeterio/openmeter/internal/entitlement"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/models"
)
Expand Down Expand Up @@ -86,7 +87,7 @@ type EntitlementGrant struct {
func GrantFromCreditGrant(grant credit.Grant) (*EntitlementGrant, error) {
g := &EntitlementGrant{}
if grant.Recurrence != nil {
next, err := grant.Recurrence.NextAfter(time.Now())
next, err := grant.Recurrence.NextAfter(clock.Now())
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions internal/entitlement/postgresadapter/entitlement.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/openmeterio/openmeter/internal/entitlement/postgresadapter/ent/db"
db_entitlement "github.com/openmeterio/openmeter/internal/entitlement/postgresadapter/ent/db/entitlement"
"github.com/openmeterio/openmeter/internal/entitlement/postgresadapter/ent/db/usagereset"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/models"
"github.com/openmeterio/openmeter/pkg/recurrence"
Expand All @@ -32,7 +33,7 @@ func (a *entitlementDBAdapter) GetEntitlement(ctx context.Context, entitlementID
Where(
db_entitlement.ID(entitlementID.ID),
db_entitlement.Namespace(entitlementID.Namespace),
db_entitlement.Or(db_entitlement.DeletedAtGT(time.Now()), db_entitlement.DeletedAtIsNil()),
db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()),
).
First(ctx)

Expand All @@ -49,7 +50,7 @@ func (a *entitlementDBAdapter) GetEntitlement(ctx context.Context, entitlementID
func (a *entitlementDBAdapter) GetEntitlementOfSubject(ctx context.Context, namespace string, subjectKey string, idOrFeatureKey string) (*entitlement.Entitlement, error) {
res, err := withLatestUsageReset(a.db.Entitlement.Query()).
Where(
db_entitlement.Or(db_entitlement.DeletedAtGT(time.Now()), db_entitlement.DeletedAtIsNil()),
db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()),
db_entitlement.SubjectKey(string(subjectKey)),
db_entitlement.Namespace(namespace),
db_entitlement.Or(db_entitlement.ID(idOrFeatureKey), db_entitlement.FeatureKey(idOrFeatureKey)),
Expand Down Expand Up @@ -115,7 +116,7 @@ func (a *entitlementDBAdapter) CreateEntitlement(ctx context.Context, entitlemen
func (a *entitlementDBAdapter) DeleteEntitlement(ctx context.Context, entitlementID models.NamespacedID) error {
affectedCount, err := a.db.Entitlement.Update().
Where(db_entitlement.ID(entitlementID.ID), db_entitlement.Namespace(entitlementID.Namespace)).
SetDeletedAt(time.Now()).
SetDeletedAt(clock.Now()).
Save(ctx)
if err != nil {
return err
Expand All @@ -129,7 +130,7 @@ func (a *entitlementDBAdapter) DeleteEntitlement(ctx context.Context, entitlemen
func (a *entitlementDBAdapter) GetEntitlementsOfSubject(ctx context.Context, namespace string, subjectKey models.SubjectKey) ([]entitlement.Entitlement, error) {
res, err := withLatestUsageReset(a.db.Entitlement.Query()).
Where(
db_entitlement.Or(db_entitlement.DeletedAtGT(time.Now()), db_entitlement.DeletedAtIsNil()),
db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()),
db_entitlement.SubjectKey(string(subjectKey)),
db_entitlement.Namespace(namespace),
).
Expand Down Expand Up @@ -161,7 +162,7 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti
}

if !params.IncludeDeleted {
query = query.Where(db_entitlement.Or(db_entitlement.DeletedAtGT(time.Now()), db_entitlement.DeletedAtIsNil()))
query = query.Where(db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()))
}

if params.Limit > 0 {
Expand Down
Loading

0 comments on commit 84fb0f0

Please sign in to comment.