Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for database.go #6453

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tools/cli/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package cli

import (
"bytes"
"fmt"
"io"
"os"
"strings"
Expand Down Expand Up @@ -66,6 +67,7 @@ var _ ClientFactory = (*clientFactoryMock)(nil)
type clientFactoryMock struct {
serverFrontendClient frontend.Client
serverAdminClient admin.Client
config *config.Config
}

func (m *clientFactoryMock) ServerFrontendClient(c *cli.Context) (frontend.Client, error) {
Expand All @@ -89,7 +91,10 @@ func (m *clientFactoryMock) ElasticSearchClient(c *cli.Context) (*elastic.Client
}

func (m *clientFactoryMock) ServerConfig(c *cli.Context) (*config.Config, error) {
panic("not implemented")
if m.config != nil {
return m.config, nil
}
return nil, fmt.Errorf("config not set")
}

var commands = []string{
Expand Down
235 changes: 235 additions & 0 deletions tools/cli/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@
package cli

import (
"flag"
"fmt"
"github.com/stretchr/testify/require"
"github.com/uber/cadence/client/admin"
"github.com/uber/cadence/client/frontend"
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra"
"github.com/uber/cadence/common/persistence/sql"
"github.com/uber/cadence/common/persistence/sql/sqlplugin"
"github.com/uber/cadence/common/reconciliation/invariant"
commonFlag "github.com/uber/cadence/tools/common/flag"
"testing"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -159,3 +169,228 @@ func TestDefaultManagerFactory(t *testing.T) {
})
}
}

func newClientFactoryMock() *clientFactoryMock {
bowenxia marked this conversation as resolved.
Show resolved Hide resolved
return &clientFactoryMock{
serverFrontendClient: frontend.NewMockClient(gomock.NewController(nil)),
serverAdminClient: admin.NewMockClient(gomock.NewController(nil)),
config: &config.Config{
Persistence: config.Persistence{
DefaultStore: "default",
DataStores: map[string]config.DataStore{
"default": {NoSQL: &config.NoSQL{PluginName: cassandra.PluginName}},
},
},
ClusterGroupMetadata: &config.ClusterGroupMetadata{
CurrentClusterName: "current-cluster",
},
},
}
}

func TestInitPersistenceFactory(t *testing.T) {
ctrl := gomock.NewController(t)

// Mock the ManagerFactory and ClientFactory
mockClientFactory := NewMockClientFactory(ctrl)
mockPersistenceFactory := client.NewMockFactory(ctrl)

// Set up the context and app
set := flag.NewFlagSet("test", 0)
app := NewCliApp(mockClientFactory)
c := cli.NewContext(app, set, nil)

// Mock ServerConfig to return an error
mockClientFactory.EXPECT().ServerConfig(gomock.Any()).Return(nil, fmt.Errorf("config error")).Times(1)

// Initialize the ManagerFactory with the mock ClientFactory
managerFactory := defaultManagerFactory{
persistenceFactory: mockPersistenceFactory,
}

// Call initPersistenceFactory and validate results
factory, err := managerFactory.initPersistenceFactory(c)

// Assert that no error occurred and a default config was used
assert.NoError(t, err)
assert.NotNil(t, factory)
}

func TestInitializeInvariantManager(t *testing.T) {
// Create an instance of defaultManagerFactory
factory := &defaultManagerFactory{}

// Define some fake invariants for testing
invariants := []invariant.Invariant{}

// Call initializeInvariantManager
manager, err := factory.initializeInvariantManager(invariants)

// Check that no error is returned
require.NoError(t, err, "Expected no error from initializeInvariantManager")

// Check that the returned Manager is not nil
require.NotNil(t, manager, "Expected non-nil invariant.Manager")
}

func TestOverrideDataStore(t *testing.T) {
tests := []struct {
name string
setupContext func(app *cli.App) *cli.Context
inputDataStore config.DataStore
expectedError string
expectedSQL *config.SQL
}{
{
name: "OverrideDBType_Cassandra",
setupContext: func(app *cli.App) *cli.Context {
set := flag.NewFlagSet("test", 0)
set.String(FlagDBType, cassandra.PluginName, "DB type flag")
require.NoError(t, set.Set(FlagDBType, cassandra.PluginName)) // Set DBType to Cassandra
return cli.NewContext(app, set, nil)
},
inputDataStore: config.DataStore{}, // Empty DataStore to trigger createDataStore
expectedError: "",
expectedSQL: nil, // No SQL expected for Cassandra
},
{
name: "OverrideSQLDataStore",
setupContext: func(app *cli.App) *cli.Context {
// Create a new mock SQL plugin using gomock
ctrl := gomock.NewController(t)
mockSQLPlugin := sqlplugin.NewMockPlugin(ctrl)

// Register the mock SQL plugin for "mysql"
sql.RegisterPlugin("mysql", mockSQLPlugin)

set := flag.NewFlagSet("test", 0)
set.String(FlagDBType, "mysql", "DB type flag") // Set SQL database type
set.String(FlagDBAddress, "127.0.0.1", "DB address flag")
set.String(FlagDBPort, "3306", "DB port flag")
set.String(FlagUsername, "testuser", "DB username flag")
set.String(FlagPassword, "testpass", "DB password flag")
connAttr := &commonFlag.StringMap{}
require.NoError(t, connAttr.Set("attr1=value1"))
require.NoError(t, connAttr.Set("attr2=value2"))
set.Var(connAttr, FlagConnectionAttributes, "Connection attributes flag")
require.NoError(t, set.Set(FlagDBType, "mysql"))
require.NoError(t, set.Set(FlagDBAddress, "127.0.0.1"))
require.NoError(t, set.Set(FlagDBPort, "3306"))
require.NoError(t, set.Set(FlagUsername, "testuser"))
require.NoError(t, set.Set(FlagPassword, "testpass"))

return cli.NewContext(app, set, nil)
},
expectedError: "",
expectedSQL: &config.SQL{
PluginName: "mysql",
ConnectAddr: "127.0.0.1:3306",
User: "testuser",
Password: "testpass",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up app and context
app := cli.NewApp()
c := tt.setupContext(app)

// Call overrideDataStore with initial DataStore and capture result
result, err := overrideDataStore(c, tt.inputDataStore)

if tt.expectedError != "" {
assert.ErrorContains(t, err, tt.expectedError)
} else {
assert.NoError(t, err)
// Validate SQL DataStore settings if expected
if tt.expectedSQL != nil && result.SQL != nil {
assert.Equal(t, tt.expectedSQL.PluginName, result.SQL.PluginName)
assert.Equal(t, tt.expectedSQL.ConnectAddr, result.SQL.ConnectAddr)
assert.Equal(t, tt.expectedSQL.User, result.SQL.User)
assert.Equal(t, tt.expectedSQL.Password, result.SQL.Password)
}
}
})
}
}

func TestOverrideTLS(t *testing.T) {
tests := []struct {
name string
setupContext func(app *cli.App) *cli.Context
expectedTLS config.TLS
}{
{
name: "AllTLSFlagsSet",
setupContext: func(app *cli.App) *cli.Context {
set := flag.NewFlagSet("test", 0)
set.Bool(FlagEnableTLS, true, "Enable TLS flag")
set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path")
set.String(FlagTLSKeyPath, "/path/to/key", "TLS Key Path")
set.String(FlagTLSCaPath, "/path/to/ca", "TLS CA Path")
set.Bool(FlagTLSEnableHostVerification, true, "Enable Host Verification")

require.NoError(t, set.Set(FlagEnableTLS, "true"))
require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert"))
require.NoError(t, set.Set(FlagTLSKeyPath, "/path/to/key"))
require.NoError(t, set.Set(FlagTLSCaPath, "/path/to/ca"))
require.NoError(t, set.Set(FlagTLSEnableHostVerification, "true"))

return cli.NewContext(app, set, nil)
},
expectedTLS: config.TLS{
Enabled: true,
CertFile: "/path/to/cert",
KeyFile: "/path/to/key",
CaFile: "/path/to/ca",
EnableHostVerification: true,
},
},
{
name: "PartialTLSFlagsSet",
setupContext: func(app *cli.App) *cli.Context {
set := flag.NewFlagSet("test", 0)
set.Bool(FlagEnableTLS, true, "Enable TLS flag")
set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path")

require.NoError(t, set.Set(FlagEnableTLS, "true"))
require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert"))

return cli.NewContext(app, set, nil)
},
expectedTLS: config.TLS{
Enabled: true,
CertFile: "/path/to/cert",
},
},
{
name: "NoTLSFlagsSet",
setupContext: func(app *cli.App) *cli.Context {
set := flag.NewFlagSet("test", 0)
return cli.NewContext(app, set, nil)
},
expectedTLS: config.TLS{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up app and context
app := cli.NewApp()
c := tt.setupContext(app)

// Initialize an empty TLS config and apply overrideTLS
tlsConfig := &config.TLS{}
overrideTLS(c, tlsConfig)

// Validate TLS config settings
assert.Equal(t, tt.expectedTLS.Enabled, tlsConfig.Enabled)
assert.Equal(t, tt.expectedTLS.CertFile, tlsConfig.CertFile)
assert.Equal(t, tt.expectedTLS.KeyFile, tlsConfig.KeyFile)
assert.Equal(t, tt.expectedTLS.CaFile, tlsConfig.CaFile)
assert.Equal(t, tt.expectedTLS.EnableHostVerification, tlsConfig.EnableHostVerification)
})
}
}
Loading