From 2887db97616ec0880cf6944d69ed5b75b58a61e1 Mon Sep 17 00:00:00 2001 From: bowen xiao Date: Tue, 29 Oct 2024 15:47:01 -0700 Subject: [PATCH] Refactor persistence manager initialization (#6441) * Refactor persistence manager initialization * add persistencyManagerFactory mock * add initPersistenceFactory into persistenceManagerFactory as well * add persistence factory into defaultPersistenceManagerFactory to avoid global variables --- tools/cli/admin_commands.go | 12 +-- tools/cli/admin_db_clean_command.go | 4 +- tools/cli/admin_db_scan_command.go | 6 +- tools/cli/admin_timers.go | 3 +- tools/cli/app.go | 4 +- tools/cli/database.go | 48 ++++++---- tools/cli/database_mock.go | 135 ++++++++++++++++++++++++++++ tools/cli/domain_utils.go | 2 +- 8 files changed, 182 insertions(+), 32 deletions(-) create mode 100644 tools/cli/database_mock.go diff --git a/tools/cli/admin_commands.go b/tools/cli/admin_commands.go index ff8a15ea808..f551684369f 100644 --- a/tools/cli/admin_commands.go +++ b/tools/cli/admin_commands.go @@ -61,7 +61,7 @@ func AdminShowWorkflow(c *cli.Context) error { var history []*persistence.DataBlob if len(tid) != 0 { thriftrwEncoder := codec.NewThriftRWEncoder() - histV2, err := initializeHistoryManager(c) + histV2, err := getDeps(c).initializeHistoryManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -303,12 +303,12 @@ func AdminDeleteWorkflow(c *cli.Context) error { if err != nil { return commoncli.Problem("strconv.Atoi(shardID) err", err) } - histV2, err := initializeHistoryManager(c) + histV2, err := getDeps(c).initializeHistoryManager(c) defer histV2.Close() if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } - exeStore, err := initializeExecutionStore(c, shardIDInt) + exeStore, err := getDeps(c).initializeExecutionStore(c, shardIDInt) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -387,7 +387,7 @@ func AdminGetDomainIDOrName(c *cli.Context) error { return commoncli.Problem("Need either domainName or domainID", nil) } - domainManager, err := initializeDomainManager(c) + domainManager, err := getDeps(c).initializeDomainManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -487,7 +487,7 @@ func AdminDescribeShard(c *cli.Context) error { if err != nil { return commoncli.Problem("Error in creating context: ", err) } - shardManager, err := initializeShardManager(c) + shardManager, err := getDeps(c).initializeShardManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -516,7 +516,7 @@ func AdminSetShardRangeID(c *cli.Context) error { if err != nil { return commoncli.Problem("Error in creating context: ", err) } - shardManager, err := initializeShardManager(c) + shardManager, err := getDeps(c).initializeShardManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } diff --git a/tools/cli/admin_db_clean_command.go b/tools/cli/admin_db_clean_command.go index 2992ec8a187..56a6cdde4d5 100644 --- a/tools/cli/admin_db_clean_command.go +++ b/tools/cli/admin_db_clean_command.go @@ -131,12 +131,12 @@ func fixExecution( invariants []executions.InvariantFactory, execution *store.ScanOutputEntity, ) (invariant.ManagerFixResult, error) { - execManager, err := initializeExecutionStore(c, execution.Execution.(entity.Entity).GetShardID()) + execManager, err := getDeps(c).initializeExecutionStore(c, execution.Execution.(entity.Entity).GetShardID()) defer execManager.Close() if err != nil { return invariant.ManagerFixResult{}, fmt.Errorf("Error in fix execution: %w", err) } - historyV2Mgr, err := initializeHistoryManager(c) + historyV2Mgr, err := getDeps(c).initializeHistoryManager(c) defer historyV2Mgr.Close() if err != nil { return invariant.ManagerFixResult{}, fmt.Errorf("Error in fix execution: %w", err) diff --git a/tools/cli/admin_db_scan_command.go b/tools/cli/admin_db_scan_command.go index c041e9d1c84..5781cb067e0 100644 --- a/tools/cli/admin_db_scan_command.go +++ b/tools/cli/admin_db_scan_command.go @@ -139,12 +139,12 @@ func checkExecution( invariants []executions.InvariantFactory, fetcher executions.ExecutionFetcher, ) (interface{}, invariant.ManagerCheckResult, error) { - execManager, err := initializeExecutionStore(c, common.WorkflowIDToHistoryShard(req.WorkflowID, numberOfShards)) + execManager, err := getDeps(c).initializeExecutionStore(c, common.WorkflowIDToHistoryShard(req.WorkflowID, numberOfShards)) defer execManager.Close() if err != nil { return nil, invariant.ManagerCheckResult{}, fmt.Errorf("Error in execution check: %w", err) } - historyV2Mgr, err := initializeHistoryManager(c) + historyV2Mgr, err := getDeps(c).initializeHistoryManager(c) defer historyV2Mgr.Close() if err != nil { return nil, invariant.ManagerCheckResult{}, fmt.Errorf("Error in execution check: %w", err) @@ -200,7 +200,7 @@ func listExecutionsByShardID( outputFile *os.File, ) error { - client, err := initializeExecutionStore(c, shardID) + client, err := getDeps(c).initializeExecutionStore(c, shardID) defer client.Close() if err != nil { commoncli.Problem("Error in Admin DB unsupported WF scan: ", err) diff --git a/tools/cli/admin_timers.go b/tools/cli/admin_timers.go index d32d3bb827c..a09772ffaea 100644 --- a/tools/cli/admin_timers.go +++ b/tools/cli/admin_timers.go @@ -79,7 +79,8 @@ func NewDBLoadCloser(c *cli.Context) (LoadCloser, error) { if err != nil { return nil, fmt.Errorf("error in NewDBLoadCloser: failed to get shard ID: %w", err) } - executionManager, err := initializeExecutionStore(c, shardID) + + executionManager, err := getDeps(c).initializeExecutionStore(c, shardID) if err != nil { return nil, fmt.Errorf("error in NewDBLoadCloser: failed to initialize execution store: %w", err) } diff --git a/tools/cli/app.go b/tools/cli/app.go index 49e274ed880..7ba4cde1f23 100644 --- a/tools/cli/app.go +++ b/tools/cli/app.go @@ -63,7 +63,7 @@ func NewCliApp(cf ClientFactory, opts ...CLIAppOptions) *cli.App { app.Usage = "A command-line tool for cadence users" app.Version = version app.Metadata = map[string]any{ - depsKey: &deps{ClientFactory: cf, IOHandler: &defaultIOHandler{app: app}}, + depsKey: &deps{ClientFactory: cf, IOHandler: &defaultIOHandler{app: app}, PersistenceManagerFactory: &defaultPersistenceManagerFactory{}}, } app.Flags = []cli.Flag{ &cli.StringFlag{ @@ -255,6 +255,7 @@ func getDeps(ctx *cli.Context) cliDeps { type cliDeps interface { ClientFactory IOHandler + PersistenceManagerFactory } type IOHandler interface { @@ -305,4 +306,5 @@ var _ cliDeps = &deps{} type deps struct { ClientFactory IOHandler + PersistenceManagerFactory } diff --git a/tools/cli/database.go b/tools/cli/database.go index dd65e86b9db..4cb4c787960 100644 --- a/tools/cli/database.go +++ b/tools/cli/database.go @@ -18,6 +18,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination database_mock.go -self_package github.com/uber/cadence/tools/cli + package cli import ( @@ -150,20 +152,32 @@ func getDBFlags() []cli.Flag { } } -func initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { - factory, err := getPersistenceFactory(c) +type PersistenceManagerFactory interface { + initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) + initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) + initializeShardManager(c *cli.Context) (persistence.ShardManager, error) + initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) + initPersistenceFactory(c *cli.Context) (client.Factory, error) +} + +type defaultPersistenceManagerFactory struct { + persistenceFactory client.Factory +} + +func (f *defaultPersistenceManagerFactory) initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } - historyManager, err := factory.NewExecutionManager(shardID) + executionManager, err := factory.NewExecutionManager(shardID) if err != nil { return nil, fmt.Errorf("Failed to initialize history manager %w", err) } - return historyManager, nil + return executionManager, nil } -func initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { - factory, err := getPersistenceFactory(c) +func (f *defaultPersistenceManagerFactory) initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } @@ -174,8 +188,8 @@ func initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error return historyManager, nil } -func initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { - factory, err := getPersistenceFactory(c) +func (f *defaultPersistenceManagerFactory) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } @@ -186,8 +200,8 @@ func initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { return shardManager, nil } -func initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { - factory, err := getPersistenceFactory(c) +func (f *defaultPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } @@ -198,20 +212,18 @@ func initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) return domainManager, nil } -var persistenceFactory client.Factory - -func getPersistenceFactory(c *cli.Context) (client.Factory, error) { +func (f *defaultPersistenceManagerFactory) getPersistenceFactory(c *cli.Context) (client.Factory, error) { var err error - if persistenceFactory == nil { - persistenceFactory, err = initPersistenceFactory(c) + if f.persistenceFactory == nil { + f.persistenceFactory, err = getDeps(c).initPersistenceFactory(c) if err != nil { - return persistenceFactory, fmt.Errorf("%w", err) + return f.persistenceFactory, fmt.Errorf("%w", err) } } - return persistenceFactory, nil + return f.persistenceFactory, nil } -func initPersistenceFactory(c *cli.Context) (client.Factory, error) { +func (f *defaultPersistenceManagerFactory) initPersistenceFactory(c *cli.Context) (client.Factory, error) { cfg, err := getDeps(c).ServerConfig(c) if err != nil { diff --git a/tools/cli/database_mock.go b/tools/cli/database_mock.go new file mode 100644 index 00000000000..045175a8c05 --- /dev/null +++ b/tools/cli/database_mock.go @@ -0,0 +1,135 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: database.go + +// Package cli is a generated GoMock package. +package cli + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + cli "github.com/urfave/cli/v2" + + persistence "github.com/uber/cadence/common/persistence" + client "github.com/uber/cadence/common/persistence/client" +) + +// MockPersistenceManagerFactory is a mock of PersistenceManagerFactory interface. +type MockPersistenceManagerFactory struct { + ctrl *gomock.Controller + recorder *MockPersistenceManagerFactoryMockRecorder +} + +// MockPersistenceManagerFactoryMockRecorder is the mock recorder for MockPersistenceManagerFactory. +type MockPersistenceManagerFactoryMockRecorder struct { + mock *MockPersistenceManagerFactory +} + +// NewMockPersistenceManagerFactory creates a new mock instance. +func NewMockPersistenceManagerFactory(ctrl *gomock.Controller) *MockPersistenceManagerFactory { + mock := &MockPersistenceManagerFactory{ctrl: ctrl} + mock.recorder = &MockPersistenceManagerFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPersistenceManagerFactory) EXPECT() *MockPersistenceManagerFactoryMockRecorder { + return m.recorder +} + +// initPersistenceFactory mocks base method. +func (m *MockPersistenceManagerFactory) initPersistenceFactory(c *cli.Context) (client.Factory, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initPersistenceFactory", c) + ret0, _ := ret[0].(client.Factory) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initPersistenceFactory indicates an expected call of initPersistenceFactory. +func (mr *MockPersistenceManagerFactoryMockRecorder) initPersistenceFactory(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initPersistenceFactory", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initPersistenceFactory), c) +} + +// initializeDomainManager mocks base method. +func (m *MockPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeDomainManager", c) + ret0, _ := ret[0].(persistence.DomainManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeDomainManager indicates an expected call of initializeDomainManager. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeDomainManager(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeDomainManager", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeDomainManager), c) +} + +// initializeExecutionStore mocks base method. +func (m *MockPersistenceManagerFactory) initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeExecutionStore", c, shardID) + ret0, _ := ret[0].(persistence.ExecutionManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeExecutionStore indicates an expected call of initializeExecutionStore. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeExecutionStore(c, shardID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeExecutionStore", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeExecutionStore), c, shardID) +} + +// initializeHistoryManager mocks base method. +func (m *MockPersistenceManagerFactory) initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeHistoryManager", c) + ret0, _ := ret[0].(persistence.HistoryManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeHistoryManager indicates an expected call of initializeHistoryManager. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeHistoryManager(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeHistoryManager", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeHistoryManager), c) +} + +// initializeShardManager mocks base method. +func (m *MockPersistenceManagerFactory) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeShardManager", c) + ret0, _ := ret[0].(persistence.ShardManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeShardManager indicates an expected call of initializeShardManager. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeShardManager(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeShardManager", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeShardManager), c) +} diff --git a/tools/cli/domain_utils.go b/tools/cli/domain_utils.go index af358911817..01324cfb705 100644 --- a/tools/cli/domain_utils.go +++ b/tools/cli/domain_utils.go @@ -291,7 +291,7 @@ func initializeAdminDomainHandler(c *cli.Context) (domain.Handler, error) { return nil, fmt.Errorf("Error in init admin domain handler: %w", err) } clusterMetadata := initializeClusterMetadata(configuration, metricsClient, logger) - metadataMgr, err := initializeDomainManager(c) + metadataMgr, err := getDeps(c).initializeDomainManager(c) if err != nil { return nil, fmt.Errorf("Error in init admin domain handler: %w", err) }