diff --git a/create.go b/create.go index b315433..7c4dcb1 100644 --- a/create.go +++ b/create.go @@ -10,21 +10,13 @@ import ( ) func Create(db *gorm.DB) { - if db.Error != nil { + if db.Error != nil || db.Statement == nil { return } stmt := db.Statement - if stmt == nil { - return - } - stmtSchema := stmt.Schema - if stmtSchema == nil { - return - } - - if !stmt.Unscoped { + if stmtSchema != nil && !stmt.Unscoped { for _, c := range stmtSchema.CreateClauses { stmt.AddClause(c) } @@ -37,7 +29,7 @@ func Create(db *gorm.DB) { ) if hasConflict { - if len(stmtSchema.PrimaryFields) > 0 { + if stmtSchema != nil && len(stmtSchema.PrimaryFields) > 0 { columnsMap := map[string]bool{} for _, column := range createValues.Columns { columnsMap[column.Name] = true @@ -53,37 +45,14 @@ func Create(db *gorm.DB) { } } - hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0 if hasConflict { MergeCreate(db, onConflict, createValues) } else { - stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}}) + stmt.AddClauseIfNotExists(clause.Insert{}) stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}}) - if hasDefaultValues { - columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue)) - for idx, field := range stmtSchema.FieldsWithDefaultDBValue { - columns[idx] = clause.Column{Name: field.DBName} - } - stmt.AddClauseIfNotExists(clause.Returning{Columns: columns}) - } - stmt.Build("INSERT", "VALUES", "RETURNING") - - if hasDefaultValues { - _, _ = stmt.WriteString(" INTO ") - for idx, field := range stmtSchema.FieldsWithDefaultDBValue { - if idx > 0 { - _ = stmt.WriteByte(',') - } - - outVar := go_ora.Out{Dest: reflect.New(field.FieldType).Interface()} - if field.Size > 0 { - outVar.Size = field.Size - } - stmt.AddVar(stmt, outVar) - } - _, _ = stmt.WriteString(" /*-go_ora.Out{}-*/") - } + stmt.Build("INSERT", "VALUES") + _ = outputInserted(db) } if !db.DryRun && db.Error == nil { @@ -107,9 +76,10 @@ func Create(db *gorm.DB) { result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...) if db.AddError(err) == nil { - db.RowsAffected, _ = result.RowsAffected() + rowsAffected, _ := result.RowsAffected() + db.RowsAffected += rowsAffected - if hasDefaultValues { + if stmtSchema != nil && len(stmtSchema.FieldsWithDefaultDBValue) > 0 { getDefaultValues(db, idx) } } @@ -119,14 +89,39 @@ func Create(db *gorm.DB) { } } -func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) { - var dummyTable string - switch d := ptrDereference(db.Dialector).(type) { - case Dialector: - dummyTable = d.DummyTableName() - default: - dummyTable = "DUAL" +func outputInserted(db *gorm.DB) (lenDefaultValue int) { + stmtSchema := db.Statement.Schema + if stmtSchema == nil { + return } + lenDefaultValue = len(stmtSchema.FieldsWithDefaultDBValue) + if lenDefaultValue > 0 { + columns := make([]clause.Column, lenDefaultValue) + for idx, field := range stmtSchema.FieldsWithDefaultDBValue { + columns[idx] = clause.Column{Name: field.DBName} + } + db.Statement.AddClauseIfNotExists(clause.Returning{Columns: columns}) + } + db.Statement.Build("RETURNING") + + _, _ = db.Statement.WriteString(" INTO ") + for idx, field := range stmtSchema.FieldsWithDefaultDBValue { + if idx > 0 { + _ = db.Statement.WriteByte(',') + } + + outVar := go_ora.Out{Dest: reflect.New(field.FieldType).Interface()} + if field.Size > 0 { + outVar.Size = field.Size + } + db.Statement.AddVar(db.Statement, outVar) + } + _, _ = db.Statement.WriteString(" /*-go_ora.Out{}-*/") + return +} + +func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) { + dummyTable := getDummyTable(db) _, _ = db.Statement.WriteString("MERGE INTO ") db.Statement.WriteQuoted(db.Statement.Table) @@ -220,7 +215,20 @@ func convertValue(val interface{}) interface{} { return val } +func getDummyTable(db *gorm.DB) (dummyTable string) { + switch d := ptrDereference(db.Dialector).(type) { + case Dialector: + dummyTable = d.DummyTableName() + default: + dummyTable = "DUAL" + } + return +} + func getDefaultValues(db *gorm.DB, idx int) { + if db.Statement.Schema == nil || len(db.Statement.Schema.FieldsWithDefaultDBValue) == 0 { + return + } insertTo := db.Statement.ReflectValue switch insertTo.Kind() { case reflect.Slice, reflect.Array: @@ -238,23 +246,24 @@ func getDefaultValues(db *gorm.DB, idx int) { case reflect.Slice, reflect.Array: for i := insertTo.Len() - 1; i >= 0; i-- { rv := insertTo.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) - if isZero { - _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest)) + switch reflect.Indirect(rv).Kind() { + case reflect.Struct: + setStructFieldValue(db, rv, v) + default: } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo) - if isZero { - _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest)) - } + setStructFieldValue(db, insertTo, v) default: } default: } } } + +func setStructFieldValue(db *gorm.DB, insertTo reflect.Value, out go_ora.Out) { + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo); !isZero { + return + } + _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, out.Dest)) +}