From 62a4250dea3b47e094f2b67bcd88cf84a4bd08b3 Mon Sep 17 00:00:00 2001 From: Bowen Xiao Date: Tue, 29 Oct 2024 14:43:18 -0700 Subject: [PATCH 1/4] Refactor persistence manager initialization --- tools/cli/admin_commands.go | 12 ++++++------ tools/cli/admin_db_clean_command.go | 4 ++-- tools/cli/admin_db_scan_command.go | 6 +++--- tools/cli/admin_timers.go | 3 ++- tools/cli/app.go | 4 +++- tools/cli/database.go | 21 +++++++++++++++------ tools/cli/domain_utils.go | 2 +- 7 files changed, 32 insertions(+), 20 deletions(-) diff --git a/tools/cli/admin_commands.go b/tools/cli/admin_commands.go index ff8a15ea808..f551684369f 100644 --- a/tools/cli/admin_commands.go +++ b/tools/cli/admin_commands.go @@ -61,7 +61,7 @@ func AdminShowWorkflow(c *cli.Context) error { var history []*persistence.DataBlob if len(tid) != 0 { thriftrwEncoder := codec.NewThriftRWEncoder() - histV2, err := initializeHistoryManager(c) + histV2, err := getDeps(c).initializeHistoryManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -303,12 +303,12 @@ func AdminDeleteWorkflow(c *cli.Context) error { if err != nil { return commoncli.Problem("strconv.Atoi(shardID) err", err) } - histV2, err := initializeHistoryManager(c) + histV2, err := getDeps(c).initializeHistoryManager(c) defer histV2.Close() if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } - exeStore, err := initializeExecutionStore(c, shardIDInt) + exeStore, err := getDeps(c).initializeExecutionStore(c, shardIDInt) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -387,7 +387,7 @@ func AdminGetDomainIDOrName(c *cli.Context) error { return commoncli.Problem("Need either domainName or domainID", nil) } - domainManager, err := initializeDomainManager(c) + domainManager, err := getDeps(c).initializeDomainManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -487,7 +487,7 @@ func AdminDescribeShard(c *cli.Context) error { if err != nil { return commoncli.Problem("Error in creating context: ", err) } - shardManager, err := initializeShardManager(c) + shardManager, err := getDeps(c).initializeShardManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } @@ -516,7 +516,7 @@ func AdminSetShardRangeID(c *cli.Context) error { if err != nil { return commoncli.Problem("Error in creating context: ", err) } - shardManager, err := initializeShardManager(c) + shardManager, err := getDeps(c).initializeShardManager(c) if err != nil { return commoncli.Problem("Error in Admin delete WF: ", err) } diff --git a/tools/cli/admin_db_clean_command.go b/tools/cli/admin_db_clean_command.go index 2992ec8a187..56a6cdde4d5 100644 --- a/tools/cli/admin_db_clean_command.go +++ b/tools/cli/admin_db_clean_command.go @@ -131,12 +131,12 @@ func fixExecution( invariants []executions.InvariantFactory, execution *store.ScanOutputEntity, ) (invariant.ManagerFixResult, error) { - execManager, err := initializeExecutionStore(c, execution.Execution.(entity.Entity).GetShardID()) + execManager, err := getDeps(c).initializeExecutionStore(c, execution.Execution.(entity.Entity).GetShardID()) defer execManager.Close() if err != nil { return invariant.ManagerFixResult{}, fmt.Errorf("Error in fix execution: %w", err) } - historyV2Mgr, err := initializeHistoryManager(c) + historyV2Mgr, err := getDeps(c).initializeHistoryManager(c) defer historyV2Mgr.Close() if err != nil { return invariant.ManagerFixResult{}, fmt.Errorf("Error in fix execution: %w", err) diff --git a/tools/cli/admin_db_scan_command.go b/tools/cli/admin_db_scan_command.go index c041e9d1c84..5781cb067e0 100644 --- a/tools/cli/admin_db_scan_command.go +++ b/tools/cli/admin_db_scan_command.go @@ -139,12 +139,12 @@ func checkExecution( invariants []executions.InvariantFactory, fetcher executions.ExecutionFetcher, ) (interface{}, invariant.ManagerCheckResult, error) { - execManager, err := initializeExecutionStore(c, common.WorkflowIDToHistoryShard(req.WorkflowID, numberOfShards)) + execManager, err := getDeps(c).initializeExecutionStore(c, common.WorkflowIDToHistoryShard(req.WorkflowID, numberOfShards)) defer execManager.Close() if err != nil { return nil, invariant.ManagerCheckResult{}, fmt.Errorf("Error in execution check: %w", err) } - historyV2Mgr, err := initializeHistoryManager(c) + historyV2Mgr, err := getDeps(c).initializeHistoryManager(c) defer historyV2Mgr.Close() if err != nil { return nil, invariant.ManagerCheckResult{}, fmt.Errorf("Error in execution check: %w", err) @@ -200,7 +200,7 @@ func listExecutionsByShardID( outputFile *os.File, ) error { - client, err := initializeExecutionStore(c, shardID) + client, err := getDeps(c).initializeExecutionStore(c, shardID) defer client.Close() if err != nil { commoncli.Problem("Error in Admin DB unsupported WF scan: ", err) diff --git a/tools/cli/admin_timers.go b/tools/cli/admin_timers.go index d32d3bb827c..a09772ffaea 100644 --- a/tools/cli/admin_timers.go +++ b/tools/cli/admin_timers.go @@ -79,7 +79,8 @@ func NewDBLoadCloser(c *cli.Context) (LoadCloser, error) { if err != nil { return nil, fmt.Errorf("error in NewDBLoadCloser: failed to get shard ID: %w", err) } - executionManager, err := initializeExecutionStore(c, shardID) + + executionManager, err := getDeps(c).initializeExecutionStore(c, shardID) if err != nil { return nil, fmt.Errorf("error in NewDBLoadCloser: failed to initialize execution store: %w", err) } diff --git a/tools/cli/app.go b/tools/cli/app.go index 49e274ed880..836bfd46a71 100644 --- a/tools/cli/app.go +++ b/tools/cli/app.go @@ -63,7 +63,7 @@ func NewCliApp(cf ClientFactory, opts ...CLIAppOptions) *cli.App { app.Usage = "A command-line tool for cadence users" app.Version = version app.Metadata = map[string]any{ - depsKey: &deps{ClientFactory: cf, IOHandler: &defaultIOHandler{app: app}}, + depsKey: &deps{ClientFactory: cf, IOHandler: &defaultIOHandler{app: app}, PersistenceManagerFactory: &DefaultPersistenceManagerFactory{}}, } app.Flags = []cli.Flag{ &cli.StringFlag{ @@ -255,6 +255,7 @@ func getDeps(ctx *cli.Context) cliDeps { type cliDeps interface { ClientFactory IOHandler + PersistenceManagerFactory } type IOHandler interface { @@ -305,4 +306,5 @@ var _ cliDeps = &deps{} type deps struct { ClientFactory IOHandler + PersistenceManagerFactory } diff --git a/tools/cli/database.go b/tools/cli/database.go index dd65e86b9db..0ae70f822e1 100644 --- a/tools/cli/database.go +++ b/tools/cli/database.go @@ -150,19 +150,28 @@ func getDBFlags() []cli.Flag { } } -func initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { +type PersistenceManagerFactory interface { + initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) + initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) + initializeShardManager(c *cli.Context) (persistence.ShardManager, error) + initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) +} + +type DefaultPersistenceManagerFactory struct{} + +func (f *DefaultPersistenceManagerFactory) initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } - historyManager, err := factory.NewExecutionManager(shardID) + executionManager, err := factory.NewExecutionManager(shardID) if err != nil { return nil, fmt.Errorf("Failed to initialize history manager %w", err) } - return historyManager, nil + return executionManager, nil } -func initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { +func (f *DefaultPersistenceManagerFactory) initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) @@ -174,7 +183,7 @@ func initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error return historyManager, nil } -func initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { +func (f *DefaultPersistenceManagerFactory) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) @@ -186,7 +195,7 @@ func initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { return shardManager, nil } -func initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { +func (f *DefaultPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) diff --git a/tools/cli/domain_utils.go b/tools/cli/domain_utils.go index af358911817..01324cfb705 100644 --- a/tools/cli/domain_utils.go +++ b/tools/cli/domain_utils.go @@ -291,7 +291,7 @@ func initializeAdminDomainHandler(c *cli.Context) (domain.Handler, error) { return nil, fmt.Errorf("Error in init admin domain handler: %w", err) } clusterMetadata := initializeClusterMetadata(configuration, metricsClient, logger) - metadataMgr, err := initializeDomainManager(c) + metadataMgr, err := getDeps(c).initializeDomainManager(c) if err != nil { return nil, fmt.Errorf("Error in init admin domain handler: %w", err) } From 648d2d472a73052eff71506c02132e2e4431f747 Mon Sep 17 00:00:00 2001 From: Bowen Xiao Date: Tue, 29 Oct 2024 14:54:25 -0700 Subject: [PATCH 2/4] add persistencyManagerFactory mock --- tools/cli/database.go | 2 + tools/cli/database_mock.go | 119 +++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 tools/cli/database_mock.go diff --git a/tools/cli/database.go b/tools/cli/database.go index 0ae70f822e1..dd1b6a3e45d 100644 --- a/tools/cli/database.go +++ b/tools/cli/database.go @@ -18,6 +18,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination database_mock.go -self_package github.com/uber/cadence/tools/cli + package cli import ( diff --git a/tools/cli/database_mock.go b/tools/cli/database_mock.go new file mode 100644 index 00000000000..80102d4acb5 --- /dev/null +++ b/tools/cli/database_mock.go @@ -0,0 +1,119 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: database.go + +// Package cli is a generated GoMock package. +package cli + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + cli "github.com/urfave/cli/v2" + + persistence "github.com/uber/cadence/common/persistence" +) + +// MockPersistenceManagerFactory is a mock of PersistenceManagerFactory interface. +type MockPersistenceManagerFactory struct { + ctrl *gomock.Controller + recorder *MockPersistenceManagerFactoryMockRecorder +} + +// MockPersistenceManagerFactoryMockRecorder is the mock recorder for MockPersistenceManagerFactory. +type MockPersistenceManagerFactoryMockRecorder struct { + mock *MockPersistenceManagerFactory +} + +// NewMockPersistenceManagerFactory creates a new mock instance. +func NewMockPersistenceManagerFactory(ctrl *gomock.Controller) *MockPersistenceManagerFactory { + mock := &MockPersistenceManagerFactory{ctrl: ctrl} + mock.recorder = &MockPersistenceManagerFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPersistenceManagerFactory) EXPECT() *MockPersistenceManagerFactoryMockRecorder { + return m.recorder +} + +// initializeDomainManager mocks base method. +func (m *MockPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeDomainManager", c) + ret0, _ := ret[0].(persistence.DomainManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeDomainManager indicates an expected call of initializeDomainManager. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeDomainManager(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeDomainManager", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeDomainManager), c) +} + +// initializeExecutionStore mocks base method. +func (m *MockPersistenceManagerFactory) initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeExecutionStore", c, shardID) + ret0, _ := ret[0].(persistence.ExecutionManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeExecutionStore indicates an expected call of initializeExecutionStore. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeExecutionStore(c, shardID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeExecutionStore", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeExecutionStore), c, shardID) +} + +// initializeHistoryManager mocks base method. +func (m *MockPersistenceManagerFactory) initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeHistoryManager", c) + ret0, _ := ret[0].(persistence.HistoryManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeHistoryManager indicates an expected call of initializeHistoryManager. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeHistoryManager(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeHistoryManager", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeHistoryManager), c) +} + +// initializeShardManager mocks base method. +func (m *MockPersistenceManagerFactory) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initializeShardManager", c) + ret0, _ := ret[0].(persistence.ShardManager) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initializeShardManager indicates an expected call of initializeShardManager. +func (mr *MockPersistenceManagerFactoryMockRecorder) initializeShardManager(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initializeShardManager", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initializeShardManager), c) +} From 994e2697be44fda53bce914c955daa9daecae3e5 Mon Sep 17 00:00:00 2001 From: Bowen Xiao Date: Tue, 29 Oct 2024 15:04:47 -0700 Subject: [PATCH 3/4] add initPersistenceFactory into persistenceManagerFactory as well --- tools/cli/app.go | 2 +- tools/cli/database.go | 15 ++++++++------- tools/cli/database_mock.go | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tools/cli/app.go b/tools/cli/app.go index 836bfd46a71..7ba4cde1f23 100644 --- a/tools/cli/app.go +++ b/tools/cli/app.go @@ -63,7 +63,7 @@ func NewCliApp(cf ClientFactory, opts ...CLIAppOptions) *cli.App { app.Usage = "A command-line tool for cadence users" app.Version = version app.Metadata = map[string]any{ - depsKey: &deps{ClientFactory: cf, IOHandler: &defaultIOHandler{app: app}, PersistenceManagerFactory: &DefaultPersistenceManagerFactory{}}, + depsKey: &deps{ClientFactory: cf, IOHandler: &defaultIOHandler{app: app}, PersistenceManagerFactory: &defaultPersistenceManagerFactory{}}, } app.Flags = []cli.Flag{ &cli.StringFlag{ diff --git a/tools/cli/database.go b/tools/cli/database.go index dd1b6a3e45d..b0887568ce3 100644 --- a/tools/cli/database.go +++ b/tools/cli/database.go @@ -157,11 +157,12 @@ type PersistenceManagerFactory interface { initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) + initPersistenceFactory(c *cli.Context) (client.Factory, error) } -type DefaultPersistenceManagerFactory struct{} +type defaultPersistenceManagerFactory struct{} -func (f *DefaultPersistenceManagerFactory) initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { +func (f *defaultPersistenceManagerFactory) initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) @@ -173,7 +174,7 @@ func (f *DefaultPersistenceManagerFactory) initializeExecutionStore(c *cli.Conte return executionManager, nil } -func (f *DefaultPersistenceManagerFactory) initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { +func (f *defaultPersistenceManagerFactory) initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) @@ -185,7 +186,7 @@ func (f *DefaultPersistenceManagerFactory) initializeHistoryManager(c *cli.Conte return historyManager, nil } -func (f *DefaultPersistenceManagerFactory) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { +func (f *defaultPersistenceManagerFactory) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) @@ -197,7 +198,7 @@ func (f *DefaultPersistenceManagerFactory) initializeShardManager(c *cli.Context return shardManager, nil } -func (f *DefaultPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { +func (f *defaultPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { factory, err := getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) @@ -214,7 +215,7 @@ var persistenceFactory client.Factory func getPersistenceFactory(c *cli.Context) (client.Factory, error) { var err error if persistenceFactory == nil { - persistenceFactory, err = initPersistenceFactory(c) + persistenceFactory, err = getDeps(c).initPersistenceFactory(c) if err != nil { return persistenceFactory, fmt.Errorf("%w", err) } @@ -222,7 +223,7 @@ func getPersistenceFactory(c *cli.Context) (client.Factory, error) { return persistenceFactory, nil } -func initPersistenceFactory(c *cli.Context) (client.Factory, error) { +func (f *defaultPersistenceManagerFactory) initPersistenceFactory(c *cli.Context) (client.Factory, error) { cfg, err := getDeps(c).ServerConfig(c) if err != nil { diff --git a/tools/cli/database_mock.go b/tools/cli/database_mock.go index 80102d4acb5..045175a8c05 100644 --- a/tools/cli/database_mock.go +++ b/tools/cli/database_mock.go @@ -33,6 +33,7 @@ import ( cli "github.com/urfave/cli/v2" persistence "github.com/uber/cadence/common/persistence" + client "github.com/uber/cadence/common/persistence/client" ) // MockPersistenceManagerFactory is a mock of PersistenceManagerFactory interface. @@ -58,6 +59,21 @@ func (m *MockPersistenceManagerFactory) EXPECT() *MockPersistenceManagerFactoryM return m.recorder } +// initPersistenceFactory mocks base method. +func (m *MockPersistenceManagerFactory) initPersistenceFactory(c *cli.Context) (client.Factory, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "initPersistenceFactory", c) + ret0, _ := ret[0].(client.Factory) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// initPersistenceFactory indicates an expected call of initPersistenceFactory. +func (mr *MockPersistenceManagerFactoryMockRecorder) initPersistenceFactory(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "initPersistenceFactory", reflect.TypeOf((*MockPersistenceManagerFactory)(nil).initPersistenceFactory), c) +} + // initializeDomainManager mocks base method. func (m *MockPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { m.ctrl.T.Helper() From e3b43c9577147c4bd525f4fa6550620878591d3c Mon Sep 17 00:00:00 2001 From: Bowen Xiao Date: Tue, 29 Oct 2024 15:12:30 -0700 Subject: [PATCH 4/4] add persistence factory into defaultPersistenceManagerFactory to avoid global variables --- tools/cli/database.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tools/cli/database.go b/tools/cli/database.go index b0887568ce3..4cb4c787960 100644 --- a/tools/cli/database.go +++ b/tools/cli/database.go @@ -160,10 +160,12 @@ type PersistenceManagerFactory interface { initPersistenceFactory(c *cli.Context) (client.Factory, error) } -type defaultPersistenceManagerFactory struct{} +type defaultPersistenceManagerFactory struct { + persistenceFactory client.Factory +} func (f *defaultPersistenceManagerFactory) initializeExecutionStore(c *cli.Context, shardID int) (persistence.ExecutionManager, error) { - factory, err := getPersistenceFactory(c) + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } @@ -175,7 +177,7 @@ func (f *defaultPersistenceManagerFactory) initializeExecutionStore(c *cli.Conte } func (f *defaultPersistenceManagerFactory) initializeHistoryManager(c *cli.Context) (persistence.HistoryManager, error) { - factory, err := getPersistenceFactory(c) + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } @@ -187,7 +189,7 @@ func (f *defaultPersistenceManagerFactory) initializeHistoryManager(c *cli.Conte } func (f *defaultPersistenceManagerFactory) initializeShardManager(c *cli.Context) (persistence.ShardManager, error) { - factory, err := getPersistenceFactory(c) + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } @@ -199,7 +201,7 @@ func (f *defaultPersistenceManagerFactory) initializeShardManager(c *cli.Context } func (f *defaultPersistenceManagerFactory) initializeDomainManager(c *cli.Context) (persistence.DomainManager, error) { - factory, err := getPersistenceFactory(c) + factory, err := f.getPersistenceFactory(c) if err != nil { return nil, fmt.Errorf("Failed to get persistence factory: %w", err) } @@ -210,17 +212,15 @@ func (f *defaultPersistenceManagerFactory) initializeDomainManager(c *cli.Contex return domainManager, nil } -var persistenceFactory client.Factory - -func getPersistenceFactory(c *cli.Context) (client.Factory, error) { +func (f *defaultPersistenceManagerFactory) getPersistenceFactory(c *cli.Context) (client.Factory, error) { var err error - if persistenceFactory == nil { - persistenceFactory, err = getDeps(c).initPersistenceFactory(c) + if f.persistenceFactory == nil { + f.persistenceFactory, err = getDeps(c).initPersistenceFactory(c) if err != nil { - return persistenceFactory, fmt.Errorf("%w", err) + return f.persistenceFactory, fmt.Errorf("%w", err) } } - return persistenceFactory, nil + return f.persistenceFactory, nil } func (f *defaultPersistenceManagerFactory) initPersistenceFactory(c *cli.Context) (client.Factory, error) {