Skip to content

Commit

Permalink
🩹 修改优化自动迁移相关代码,实现 GetTables 方法
Browse files Browse the repository at this point in the history
Signed-off-by: liutianqi <[email protected]>
  • Loading branch information
iTanken committed Jan 18, 2024
1 parent 7c62962 commit 953a73d
Showing 1 changed file with 46 additions and 21 deletions.
67 changes: 46 additions & 21 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,38 @@ func (m Migrator) HasTable(value interface{}) bool {
var count int64

_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
ownerTable := strings.Split(stmt.Schema.Table, ".")
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownerTable[0], ownerTable[1]).Row().Scan(&count)
if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownerName, tableName).Row().Scan(&count)
} else {
return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", stmt.Table).Row().Scan(&count)
return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", tableName).Row().Scan(&count)
}
})

return count > 0
}

func (m Migrator) getSchemaTable(stmt *gorm.Statement) (ownerName, tableName string) {
if stmt == nil {
return
}
if stmt.Schema == nil {
tableName = stmt.Table
} else {
tableName = stmt.Schema.Table
if strings.Contains(tableName, ".") {
ownerTable := strings.Split(tableName, ".")
ownerName, tableName = ownerTable[0], ownerTable[1]
}
}
return
}

// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Schema.Table).Where("ROWNUM = 1").Rows()
_, tableName := m.getSchemaTable(stmt)
rows, err := m.DB.Session(&gorm.Session{}).Table(tableName).Where("ROWNUM = 1").Rows()
if err != nil {
return err
}
Expand Down Expand Up @@ -234,6 +250,15 @@ func (m Migrator) RenameTable(oldName, newName interface{}) (err error) {
).Error
}

func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw(`SELECT TABLE_NAME FROM USER_TABLES
WHERE TABLESPACE_NAME IS NOT NULL AND TABLESPACE_NAME <> 'SYSAUX'
AND TABLE_NAME NOT LIKE 'AQ$%' AND TABLE_NAME NOT LIKE 'MVIEW$%' AND TABLE_NAME NOT LIKE 'ROLLING$%'
AND TABLE_NAME NOT IN ('HELP', 'SQLPLUS_PRODUCT_PROFILE', 'LOGSTDBY$PARAMETERS', 'LOGMNRGGC_GTCS', 'LOGMNRGGC_GTLO', 'LOGMNR_PARAMETER$', 'LOGMNR_SESSION$', 'SCHEDULER_JOB_ARGS_TBL', 'SCHEDULER_PROGRAM_ARGS_TBL')
`).Scan(&tableList).Error
return
}

// AddColumn create "name" column for value
func (m Migrator) AddColumn(value interface{}, name string) (err error) {
if err = m.Migrator.AddColumn(value, name); err != nil {
Expand Down Expand Up @@ -263,9 +288,10 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {

return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
_, tableName := m.getSchemaTable(stmt)
return m.DB.Exec(
"ALTER TABLE ? MODIFY ? ?",
clause.Table{Name: stmt.Schema.Table},
clause.Table{Name: tableName},
clause.Column{Name: field.DBName},
m.AlterDataTypeOf(stmt, field),
).Error
Expand All @@ -277,11 +303,10 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
ownerTable := strings.Split(stmt.Schema.Table, ".")
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerTable[0], ownerTable[1], field).Row().Scan(&count)
if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerName, tableName, field).Row().Scan(&count)
} else {
return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field).Row().Scan(&count)
return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", tableName, field).Row().Scan(&count)
}

}) == nil && count > 0
Expand All @@ -295,16 +320,15 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy

return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var description string
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
ownerTable := strings.Split(stmt.Schema.Table, ".")
if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
_ = m.DB.Raw(
"SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE OWNER = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?",
ownerTable[0], ownerTable[1], field.DBName,
ownerName, tableName, field.DBName,
).Row().Scan(&description)
} else {
_ = m.DB.Raw(
"SELECT COMMENTS FROM USER_COL_COMMENTS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?",
stmt.Table, field.DBName,
tableName, field.DBName,
).Row().Scan(&description)
}
if comment := field.Comment; comment != "" && comment != description {
Expand All @@ -320,11 +344,10 @@ func (m Migrator) AlterDataTypeOf(stmt *gorm.Statement, field *schema.Field) (ex
expr.SQL = m.DataTypeOf(field)

var nullable = ""
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
ownerTable := strings.Split(stmt.Schema.Table, ".")
_ = m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerTable[0], ownerTable[1], field.DBName).Row().Scan(&nullable)
if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
_ = m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerName, tableName, field.DBName).Row().Scan(&nullable)
} else {
_ = m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field.DBName).Row().Scan(&nullable)
_ = m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", tableName, field.DBName).Row().Scan(&nullable)
}

if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
Expand Down Expand Up @@ -354,18 +377,19 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
//goland:noinspection SqlNoDataSourceInspection
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
_, tableName := m.getSchemaTable(stmt)
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if chk.Name == name {
return m.DB.Exec(
"ALTER TABLE ? DROP CHECK ?",
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name},
clause.Table{Name: tableName}, clause.Column{Name: name},
).Error
}
}

return m.DB.Exec(
"ALTER TABLE ? DROP CONSTRAINT ?",
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name},
clause.Table{Name: tableName}, clause.Column{Name: name},
).Error
})
}
Expand All @@ -384,8 +408,9 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
_, tableName := m.getSchemaTable(stmt)

return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: stmt.Schema.Table}).Error
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: tableName}).Error
})
}

Expand Down

0 comments on commit 953a73d

Please sign in to comment.