From 3ec54859f7288e1bb8ca07f28e5235fab537eb00 Mon Sep 17 00:00:00 2001 From: Bowen Xiao Date: Wed, 30 Oct 2024 18:23:30 -0700 Subject: [PATCH] add test for database.go --- tools/cli/app_test.go | 7 +- tools/cli/database_test.go | 235 +++++++++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) diff --git a/tools/cli/app_test.go b/tools/cli/app_test.go index ab1e31855e1..83aa4b979ca 100644 --- a/tools/cli/app_test.go +++ b/tools/cli/app_test.go @@ -22,6 +22,7 @@ package cli import ( "bytes" + "fmt" "io" "os" "strings" @@ -66,6 +67,7 @@ var _ ClientFactory = (*clientFactoryMock)(nil) type clientFactoryMock struct { serverFrontendClient frontend.Client serverAdminClient admin.Client + config *config.Config } func (m *clientFactoryMock) ServerFrontendClient(c *cli.Context) (frontend.Client, error) { @@ -89,7 +91,10 @@ func (m *clientFactoryMock) ElasticSearchClient(c *cli.Context) (*elastic.Client } func (m *clientFactoryMock) ServerConfig(c *cli.Context) (*config.Config, error) { - panic("not implemented") + if m.config != nil { + return m.config, nil + } + return nil, fmt.Errorf("config not set") } var commands = []string{ diff --git a/tools/cli/database_test.go b/tools/cli/database_test.go index b52324b5de6..10896e8ff17 100644 --- a/tools/cli/database_test.go +++ b/tools/cli/database_test.go @@ -23,7 +23,17 @@ package cli import ( + "flag" "fmt" + "github.com/stretchr/testify/require" + "github.com/uber/cadence/client/admin" + "github.com/uber/cadence/client/frontend" + "github.com/uber/cadence/common/config" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra" + "github.com/uber/cadence/common/persistence/sql" + "github.com/uber/cadence/common/persistence/sql/sqlplugin" + "github.com/uber/cadence/common/reconciliation/invariant" + commonFlag "github.com/uber/cadence/tools/common/flag" "testing" "github.com/golang/mock/gomock" @@ -159,3 +169,228 @@ func TestDefaultManagerFactory(t *testing.T) { }) } } + +func newClientFactoryMock() *clientFactoryMock { + return &clientFactoryMock{ + serverFrontendClient: frontend.NewMockClient(gomock.NewController(nil)), + serverAdminClient: admin.NewMockClient(gomock.NewController(nil)), + config: &config.Config{ + Persistence: config.Persistence{ + DefaultStore: "default", + DataStores: map[string]config.DataStore{ + "default": {NoSQL: &config.NoSQL{PluginName: cassandra.PluginName}}, + }, + }, + ClusterGroupMetadata: &config.ClusterGroupMetadata{ + CurrentClusterName: "current-cluster", + }, + }, + } +} + +func TestInitPersistenceFactory(t *testing.T) { + ctrl := gomock.NewController(t) + + // Mock the ManagerFactory and ClientFactory + mockClientFactory := NewMockClientFactory(ctrl) + mockPersistenceFactory := client.NewMockFactory(ctrl) + + // Set up the context and app + set := flag.NewFlagSet("test", 0) + app := NewCliApp(mockClientFactory) + c := cli.NewContext(app, set, nil) + + // Mock ServerConfig to return an error + mockClientFactory.EXPECT().ServerConfig(gomock.Any()).Return(nil, fmt.Errorf("config error")).Times(1) + + // Initialize the ManagerFactory with the mock ClientFactory + managerFactory := defaultManagerFactory{ + persistenceFactory: mockPersistenceFactory, + } + + // Call initPersistenceFactory and validate results + factory, err := managerFactory.initPersistenceFactory(c) + + // Assert that no error occurred and a default config was used + assert.NoError(t, err) + assert.NotNil(t, factory) +} + +func TestInitializeInvariantManager(t *testing.T) { + // Create an instance of defaultManagerFactory + factory := &defaultManagerFactory{} + + // Define some fake invariants for testing + invariants := []invariant.Invariant{} + + // Call initializeInvariantManager + manager, err := factory.initializeInvariantManager(invariants) + + // Check that no error is returned + require.NoError(t, err, "Expected no error from initializeInvariantManager") + + // Check that the returned Manager is not nil + require.NotNil(t, manager, "Expected non-nil invariant.Manager") +} + +func TestOverrideDataStore(t *testing.T) { + tests := []struct { + name string + setupContext func(app *cli.App) *cli.Context + inputDataStore config.DataStore + expectedError string + expectedSQL *config.SQL + }{ + { + name: "OverrideDBType_Cassandra", + setupContext: func(app *cli.App) *cli.Context { + set := flag.NewFlagSet("test", 0) + set.String(FlagDBType, cassandra.PluginName, "DB type flag") + require.NoError(t, set.Set(FlagDBType, cassandra.PluginName)) // Set DBType to Cassandra + return cli.NewContext(app, set, nil) + }, + inputDataStore: config.DataStore{}, // Empty DataStore to trigger createDataStore + expectedError: "", + expectedSQL: nil, // No SQL expected for Cassandra + }, + { + name: "OverrideSQLDataStore", + setupContext: func(app *cli.App) *cli.Context { + // Create a new mock SQL plugin using gomock + ctrl := gomock.NewController(t) + mockSQLPlugin := sqlplugin.NewMockPlugin(ctrl) + + // Register the mock SQL plugin for "mysql" + sql.RegisterPlugin("mysql", mockSQLPlugin) + + set := flag.NewFlagSet("test", 0) + set.String(FlagDBType, "mysql", "DB type flag") // Set SQL database type + set.String(FlagDBAddress, "127.0.0.1", "DB address flag") + set.String(FlagDBPort, "3306", "DB port flag") + set.String(FlagUsername, "testuser", "DB username flag") + set.String(FlagPassword, "testpass", "DB password flag") + connAttr := &commonFlag.StringMap{} + require.NoError(t, connAttr.Set("attr1=value1")) + require.NoError(t, connAttr.Set("attr2=value2")) + set.Var(connAttr, FlagConnectionAttributes, "Connection attributes flag") + require.NoError(t, set.Set(FlagDBType, "mysql")) + require.NoError(t, set.Set(FlagDBAddress, "127.0.0.1")) + require.NoError(t, set.Set(FlagDBPort, "3306")) + require.NoError(t, set.Set(FlagUsername, "testuser")) + require.NoError(t, set.Set(FlagPassword, "testpass")) + + return cli.NewContext(app, set, nil) + }, + expectedError: "", + expectedSQL: &config.SQL{ + PluginName: "mysql", + ConnectAddr: "127.0.0.1:3306", + User: "testuser", + Password: "testpass", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up app and context + app := cli.NewApp() + c := tt.setupContext(app) + + // Call overrideDataStore with initial DataStore and capture result + result, err := overrideDataStore(c, tt.inputDataStore) + + if tt.expectedError != "" { + assert.ErrorContains(t, err, tt.expectedError) + } else { + assert.NoError(t, err) + // Validate SQL DataStore settings if expected + if tt.expectedSQL != nil && result.SQL != nil { + assert.Equal(t, tt.expectedSQL.PluginName, result.SQL.PluginName) + assert.Equal(t, tt.expectedSQL.ConnectAddr, result.SQL.ConnectAddr) + assert.Equal(t, tt.expectedSQL.User, result.SQL.User) + assert.Equal(t, tt.expectedSQL.Password, result.SQL.Password) + } + } + }) + } +} + +func TestOverrideTLS(t *testing.T) { + tests := []struct { + name string + setupContext func(app *cli.App) *cli.Context + expectedTLS config.TLS + }{ + { + name: "AllTLSFlagsSet", + setupContext: func(app *cli.App) *cli.Context { + set := flag.NewFlagSet("test", 0) + set.Bool(FlagEnableTLS, true, "Enable TLS flag") + set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path") + set.String(FlagTLSKeyPath, "/path/to/key", "TLS Key Path") + set.String(FlagTLSCaPath, "/path/to/ca", "TLS CA Path") + set.Bool(FlagTLSEnableHostVerification, true, "Enable Host Verification") + + require.NoError(t, set.Set(FlagEnableTLS, "true")) + require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert")) + require.NoError(t, set.Set(FlagTLSKeyPath, "/path/to/key")) + require.NoError(t, set.Set(FlagTLSCaPath, "/path/to/ca")) + require.NoError(t, set.Set(FlagTLSEnableHostVerification, "true")) + + return cli.NewContext(app, set, nil) + }, + expectedTLS: config.TLS{ + Enabled: true, + CertFile: "/path/to/cert", + KeyFile: "/path/to/key", + CaFile: "/path/to/ca", + EnableHostVerification: true, + }, + }, + { + name: "PartialTLSFlagsSet", + setupContext: func(app *cli.App) *cli.Context { + set := flag.NewFlagSet("test", 0) + set.Bool(FlagEnableTLS, true, "Enable TLS flag") + set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path") + + require.NoError(t, set.Set(FlagEnableTLS, "true")) + require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert")) + + return cli.NewContext(app, set, nil) + }, + expectedTLS: config.TLS{ + Enabled: true, + CertFile: "/path/to/cert", + }, + }, + { + name: "NoTLSFlagsSet", + setupContext: func(app *cli.App) *cli.Context { + set := flag.NewFlagSet("test", 0) + return cli.NewContext(app, set, nil) + }, + expectedTLS: config.TLS{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up app and context + app := cli.NewApp() + c := tt.setupContext(app) + + // Initialize an empty TLS config and apply overrideTLS + tlsConfig := &config.TLS{} + overrideTLS(c, tlsConfig) + + // Validate TLS config settings + assert.Equal(t, tt.expectedTLS.Enabled, tlsConfig.Enabled) + assert.Equal(t, tt.expectedTLS.CertFile, tlsConfig.CertFile) + assert.Equal(t, tt.expectedTLS.KeyFile, tlsConfig.KeyFile) + assert.Equal(t, tt.expectedTLS.CaFile, tlsConfig.CaFile) + assert.Equal(t, tt.expectedTLS.EnableHostVerification, tlsConfig.EnableHostVerification) + }) + } +}