Skip to content

Commit

Permalink
🚧 修改优化创建数据相关代码
Browse files Browse the repository at this point in the history
  • Loading branch information
iTanken committed Nov 15, 2024
1 parent fc9a6ba commit b3a6c10
Showing 1 changed file with 66 additions and 57 deletions.
123 changes: 66 additions & 57 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,13 @@ import (
)

func Create(db *gorm.DB) {
if db.Error != nil {
if db.Error != nil || db.Statement == nil {
return
}

stmt := db.Statement
if stmt == nil {
return
}

stmtSchema := stmt.Schema
if stmtSchema == nil {
return
}

if !stmt.Unscoped {
if stmtSchema != nil && !stmt.Unscoped {
for _, c := range stmtSchema.CreateClauses {
stmt.AddClause(c)
}
Expand All @@ -37,7 +29,7 @@ func Create(db *gorm.DB) {
)

if hasConflict {
if len(stmtSchema.PrimaryFields) > 0 {
if stmtSchema != nil && len(stmtSchema.PrimaryFields) > 0 {
columnsMap := map[string]bool{}
for _, column := range createValues.Columns {
columnsMap[column.Name] = true
Expand All @@ -53,37 +45,14 @@ func Create(db *gorm.DB) {
}
}

hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0
if hasConflict {
MergeCreate(db, onConflict, createValues)
} else {
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
stmt.AddClauseIfNotExists(clause.Insert{})
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})

if hasDefaultValues {
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue))
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns})
}
stmt.Build("INSERT", "VALUES", "RETURNING")

if hasDefaultValues {
_, _ = stmt.WriteString(" INTO ")
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
if idx > 0 {
_ = stmt.WriteByte(',')
}

outVar := go_ora.Out{Dest: reflect.New(field.FieldType).Interface()}
if field.Size > 0 {
outVar.Size = field.Size
}
stmt.AddVar(stmt, outVar)
}
_, _ = stmt.WriteString(" /*-go_ora.Out{}-*/")
}
stmt.Build("INSERT", "VALUES")
_ = outputInserted(db)
}

if !db.DryRun && db.Error == nil {
Expand All @@ -107,9 +76,10 @@ func Create(db *gorm.DB) {

result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
rowsAffected, _ := result.RowsAffected()
db.RowsAffected += rowsAffected

if hasDefaultValues {
if stmtSchema != nil && len(stmtSchema.FieldsWithDefaultDBValue) > 0 {
getDefaultValues(db, idx)
}
}
Expand All @@ -119,14 +89,39 @@ func Create(db *gorm.DB) {
}
}

func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
var dummyTable string
switch d := ptrDereference(db.Dialector).(type) {
case Dialector:
dummyTable = d.DummyTableName()
default:
dummyTable = "DUAL"
func outputInserted(db *gorm.DB) (lenDefaultValue int) {
stmtSchema := db.Statement.Schema
if stmtSchema == nil {
return
}
lenDefaultValue = len(stmtSchema.FieldsWithDefaultDBValue)
if lenDefaultValue > 0 {
columns := make([]clause.Column, lenDefaultValue)
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
columns[idx] = clause.Column{Name: field.DBName}
}
db.Statement.AddClauseIfNotExists(clause.Returning{Columns: columns})
}
db.Statement.Build("RETURNING")

_, _ = db.Statement.WriteString(" INTO ")
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
if idx > 0 {
_ = db.Statement.WriteByte(',')
}

outVar := go_ora.Out{Dest: reflect.New(field.FieldType).Interface()}
if field.Size > 0 {
outVar.Size = field.Size
}
db.Statement.AddVar(db.Statement, outVar)
}
_, _ = db.Statement.WriteString(" /*-go_ora.Out{}-*/")
return
}

func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
dummyTable := getDummyTable(db)

_, _ = db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
Expand Down Expand Up @@ -220,7 +215,20 @@ func convertValue(val interface{}) interface{} {
return val
}

func getDummyTable(db *gorm.DB) (dummyTable string) {
switch d := ptrDereference(db.Dialector).(type) {
case Dialector:
dummyTable = d.DummyTableName()
default:
dummyTable = "DUAL"
}
return
}

func getDefaultValues(db *gorm.DB, idx int) {
if db.Statement.Schema == nil || len(db.Statement.Schema.FieldsWithDefaultDBValue) == 0 {
return
}
insertTo := db.Statement.ReflectValue
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
Expand All @@ -238,23 +246,24 @@ func getDefaultValues(db *gorm.DB, idx int) {
case reflect.Slice, reflect.Array:
for i := insertTo.Len() - 1; i >= 0; i-- {
rv := insertTo.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}

_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest))
switch reflect.Indirect(rv).Kind() {
case reflect.Struct:
setStructFieldValue(db, rv, v)
default:
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
}
setStructFieldValue(db, insertTo, v)
default:
}
default:
}
}
}

func setStructFieldValue(db *gorm.DB, insertTo reflect.Value, out go_ora.Out) {
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo); !isZero {
return
}
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, out.Dest))
}

0 comments on commit b3a6c10

Please sign in to comment.