Skip to content

Commit

Permalink
Add tests for UpdateWorkflowExecution (#5718)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll authored Mar 5, 2024
1 parent 3bf06cb commit 08d84d1
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 28 deletions.
49 changes: 21 additions & 28 deletions common/persistence/sql/sql_execution_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ const (

type sqlExecutionStore struct {
sqlStore
shardID int
txExecuteShardLockedFn func(context.Context, int, string, int64, func(sqlplugin.Tx) error) error
lockCurrentExecutionIfExistsFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error)
createOrUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, p.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error
applyWorkflowSnapshotTxAsNewFn func(context.Context, sqlplugin.Tx, int, *p.InternalWorkflowSnapshot, serialization.Parser) error
shardID int
txExecuteShardLockedFn func(context.Context, int, string, int64, func(sqlplugin.Tx) error) error
lockCurrentExecutionIfExistsFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error)
createOrUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, p.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error
assertNotCurrentExecutionFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID) error
assertRunIDAndUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID, serialization.UUID, string, int, int, int64, int64) error
applyWorkflowSnapshotTxAsNewFn func(context.Context, sqlplugin.Tx, int, *p.InternalWorkflowSnapshot, serialization.Parser) error
applyWorkflowMutationTxFn func(context.Context, sqlplugin.Tx, int, *p.InternalWorkflowMutation, serialization.Parser) error
}

var _ p.ExecutionStore = (*sqlExecutionStore)(nil)
Expand All @@ -68,10 +71,13 @@ func NewSQLExecutionStore(
) (p.ExecutionStore, error) {

store := &sqlExecutionStore{
shardID: shardID,
lockCurrentExecutionIfExistsFn: lockCurrentExecutionIfExists,
createOrUpdateCurrentExecutionFn: createOrUpdateCurrentExecution,
applyWorkflowSnapshotTxAsNewFn: applyWorkflowSnapshotTxAsNew,
shardID: shardID,
lockCurrentExecutionIfExistsFn: lockCurrentExecutionIfExists,
createOrUpdateCurrentExecutionFn: createOrUpdateCurrentExecution,
assertNotCurrentExecutionFn: assertNotCurrentExecution,
assertRunIDAndUpdateCurrentExecutionFn: assertRunIDAndUpdateCurrentExecution,
applyWorkflowSnapshotTxAsNewFn: applyWorkflowSnapshotTxAsNew,
applyWorkflowMutationTxFn: applyWorkflowMutationTx,
sqlStore: sqlStore{
db: db,
logger: logger,
Expand Down Expand Up @@ -384,7 +390,7 @@ func (m *sqlExecutionStore) UpdateWorkflowExecution(
request *p.InternalUpdateWorkflowExecutionRequest,
) error {
dbShardID := sqlplugin.GetDBShardIDFromHistoryShardID(m.shardID, m.db.GetTotalNumDBShards())
return m.txExecuteShardLocked(ctx, dbShardID, "UpdateWorkflowExecution", request.RangeID, func(tx sqlplugin.Tx) error {
return m.txExecuteShardLockedFn(ctx, dbShardID, "UpdateWorkflowExecution", request.RangeID, func(tx sqlplugin.Tx) error {
return m.updateWorkflowExecutionTx(ctx, tx, request)
})
}
Expand Down Expand Up @@ -416,7 +422,7 @@ func (m *sqlExecutionStore) updateWorkflowExecutionTx(
case p.UpdateWorkflowModeIgnoreCurrent:
// no-op
case p.UpdateWorkflowModeBypassCurrent:
if err := assertNotCurrentExecution(
if err := m.assertNotCurrentExecutionFn(
ctx,
tx,
shardID,
Expand All @@ -440,7 +446,7 @@ func (m *sqlExecutionStore) updateWorkflowExecutionTx(
}
}

if err := assertRunIDAndUpdateCurrentExecution(
if err := m.assertRunIDAndUpdateCurrentExecutionFn(
ctx,
tx,
shardID,
Expand All @@ -459,7 +465,7 @@ func (m *sqlExecutionStore) updateWorkflowExecutionTx(
startVersion := updateWorkflow.StartVersion
lastWriteVersion := updateWorkflow.LastWriteVersion
// this is only to update the current record
if err := assertRunIDAndUpdateCurrentExecution(
if err := m.assertRunIDAndUpdateCurrentExecutionFn(
ctx,
tx,
shardID,
Expand All @@ -482,24 +488,11 @@ func (m *sqlExecutionStore) updateWorkflowExecutionTx(
}
}

if m.useAsyncTransaction() { // async transaction is enabled
// TODO: it's possible to merge some operations in the following 2 functions in a batch, should refactor the code later
if err := applyWorkflowMutationAsyncTx(ctx, tx, shardID, &updateWorkflow, m.parser); err != nil {
return err
}
if newWorkflow != nil {
if err := m.applyWorkflowSnapshotAsyncTxAsNew(ctx, tx, shardID, newWorkflow, m.parser); err != nil {
return err
}
}
return nil
}

if err := applyWorkflowMutationTx(ctx, tx, shardID, &updateWorkflow, m.parser); err != nil {
if err := m.applyWorkflowMutationTxFn(ctx, tx, shardID, &updateWorkflow, m.parser); err != nil {
return err
}
if newWorkflow != nil {
if err := applyWorkflowSnapshotTxAsNew(ctx, tx, shardID, newWorkflow, m.parser); err != nil {
if err := m.applyWorkflowSnapshotTxAsNewFn(ctx, tx, shardID, newWorkflow, m.parser); err != nil {
return err
}
}
Expand Down
239 changes: 239 additions & 0 deletions common/persistence/sql/sql_execution_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2254,3 +2254,242 @@ func TestCreateWorkflowExecution(t *testing.T) {
})
}
}

func TestUpdateWorkflowExecution(t *testing.T) {
testCases := []struct {
name string
req *persistence.InternalUpdateWorkflowExecutionRequest
assertNotCurrentExecutionFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID) error
assertRunIDAndUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID, serialization.UUID, string, int, int, int64, int64) error
applyWorkflowSnapshotTxAsNewFn func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error
applyWorkflowMutationTxFn func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowMutation, serialization.Parser) error
wantErr bool
assertErr func(t *testing.T, err error)
}{
{
name: "Success - mode ignore current",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeIgnoreCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
applyWorkflowMutationTxFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowMutation, serialization.Parser) error {
return nil
},
wantErr: false,
},
{
name: "Success - mode bypass current",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeBypassCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCompleted,
},
},
},
assertNotCurrentExecutionFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID) error {
return nil
},
applyWorkflowMutationTxFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowMutation, serialization.Parser) error {
return nil
},
wantErr: false,
},
{
name: "Success - mode update current, new workflow",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeUpdateCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCompleted,
},
},
NewWorkflowSnapshot: &persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCreated,
},
},
},
assertRunIDAndUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID, serialization.UUID, string, int, int, int64, int64) error {
return nil
},
applyWorkflowMutationTxFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowMutation, serialization.Parser) error {
return nil
},
applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error {
return nil
},
wantErr: false,
},
{
name: "Success - mode update current, no new workflow",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeUpdateCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateRunning,
},
},
},
assertRunIDAndUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID, serialization.UUID, string, int, int, int64, int64) error {
return nil
},
applyWorkflowMutationTxFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowMutation, serialization.Parser) error {
return nil
},
wantErr: false,
},
{
name: "Error - mode state validation failed",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeUpdateCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateZombie,
},
},
},
wantErr: true,
},
{
name: "Error - assertNotCurrentExecution failed",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeBypassCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCompleted,
},
},
},
assertNotCurrentExecutionFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID) error {
return errors.New("some random error")
},
wantErr: true,
},
{
name: "Error - domain ID mismatch",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeUpdateCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
DomainID: "a8ead65c-9d0d-43a2-a6ad-dd17c99509af",
State: persistence.WorkflowStateCompleted,
},
},
NewWorkflowSnapshot: &persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
DomainID: "c3fab112-5175-4044-a096-a32e7badd4a8",
},
},
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.Equal(t, &types.InternalServiceError{
Message: "UpdateWorkflowExecution: cannot continue as new to another domain",
}, err)
},
},
{
name: "Error - assertRunIDAndUpdateCurrentExecution failed",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeUpdateCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCompleted,
},
},
},
assertRunIDAndUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID, serialization.UUID, string, int, int, int64, int64) error {
return errors.New("some random error")
},
wantErr: true,
},
{
name: "Error - applyWorkflowMutationTxFn failed",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeUpdateCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCompleted,
},
},
},
assertRunIDAndUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID, serialization.UUID, string, int, int, int64, int64) error {
return nil
},
applyWorkflowMutationTxFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowMutation, serialization.Parser) error {
return errors.New("some random error")
},
wantErr: true,
},
{
name: "Error - applyWorkflowSnapshotTxAsNew failed",
req: &persistence.InternalUpdateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.UpdateWorkflowModeUpdateCurrent,
UpdateWorkflowMutation: persistence.InternalWorkflowMutation{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCompleted,
},
},
NewWorkflowSnapshot: &persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCreated,
},
},
},
assertRunIDAndUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string, serialization.UUID, serialization.UUID, string, int, int, int64, int64) error {
return nil
},
applyWorkflowMutationTxFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowMutation, serialization.Parser) error {
return nil
},
applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error {
return errors.New("some random error")
},
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockDB := sqlplugin.NewMockDB(ctrl)
mockDB.EXPECT().GetTotalNumDBShards().Return(1)
s := &sqlExecutionStore{
shardID: 0,
sqlStore: sqlStore{
db: mockDB,
logger: testlogger.New(t),
},
txExecuteShardLockedFn: func(_ context.Context, _ int, _ string, _ int64, fn func(sqlplugin.Tx) error) error {
return fn(nil)
},
assertNotCurrentExecutionFn: tc.assertNotCurrentExecutionFn,
assertRunIDAndUpdateCurrentExecutionFn: tc.assertRunIDAndUpdateCurrentExecutionFn,
applyWorkflowMutationTxFn: tc.applyWorkflowMutationTxFn,
applyWorkflowSnapshotTxAsNewFn: tc.applyWorkflowSnapshotTxAsNewFn,
}

err := s.UpdateWorkflowExecution(context.Background(), tc.req)
if tc.wantErr {
assert.Error(t, err, "Expected an error for test case")
if tc.assertErr != nil {
tc.assertErr(t, err)
}
} else {
assert.NoError(t, err, "Did not expect an error for test case")
}
})
}
}

0 comments on commit 08d84d1

Please sign in to comment.