Skip to content

Commit

Permalink
🐛 修复创建数据冲突时的 MERGE 语法
Browse files Browse the repository at this point in the history
Signed-off-by: liutianqi <[email protected]>
  • Loading branch information
iTanken committed Jan 4, 2024
1 parent f51f3ff commit 62d9d90
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 172 deletions.
49 changes: 0 additions & 49 deletions clauses/merge.go

This file was deleted.

10 changes: 0 additions & 10 deletions clauses/returning_into.go

This file was deleted.

39 changes: 0 additions & 39 deletions clauses/when_matched.go

This file was deleted.

32 changes: 0 additions & 32 deletions clauses/when_not_matched.go

This file was deleted.

125 changes: 83 additions & 42 deletions create.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package oracle

import (
"bytes"

"github.com/godoes/gorm-oracle/clauses"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
Expand Down Expand Up @@ -54,45 +51,7 @@ func Create(db *gorm.DB) {
}
}
if hasConflict {
stmt.AddClauseIfNotExists(clauses.Merge{
Using: []clause.Interface{
clause.Select{
Columns: func() (columns []clause.Column) {
// HACK: I can not come up with a better alternative for now
// I want to add a value to the list of variable and then capture the bind variable position as well
columns = values.Columns
for i, column := range columns {
buf := bytes.NewBufferString("")
stmt.Vars = append(stmt.Vars, values.Values[0][i])
stmt.BindVarTo(buf, stmt, nil)

column.Alias = column.Name
// then the captured bind var will be the name
column.Name = buf.String()
columns[i] = column
}
return
}(),
},
clause.From{
Tables: []clause.Table{{Name: db.Dialector.(*Dialector).DummyTableName()}},
},
},
On: func() (onExpr []clause.Expression) {
onExpr = make([]clause.Expression, len(stmtSchema.PrimaryFields))
for i, field := range stmtSchema.PrimaryFields {
onExpr[i] = clause.Eq{
Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
}
}
return
}(),
})
stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})

stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED")
MergeCreate(db, onConflict, values)
} else {
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}})
Expand Down Expand Up @@ -133,3 +92,85 @@ func Create(db *gorm.DB) {
}
}
}

func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
var dummyTable string
switch d := ptrDereference(db.Dialector).(type) {
case Dialector:
dummyTable = d.DummyTableName()
default:
dummyTable = "DUAL"
}

_, _ = db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
_, _ = db.Statement.WriteString(" USING (")

for idx, value := range values.Values {
if idx > 0 {
_, _ = db.Statement.WriteString(" UNION ALL ")
}

_, _ = db.Statement.WriteString("SELECT ")
for i, v := range value {
if i > 0 {
_ = db.Statement.WriteByte(',')
}
column := values.Columns[i]
db.Statement.AddVar(db.Statement, v)
_, _ = db.Statement.WriteString(" AS ")
db.Statement.WriteQuoted(column.Name)
}
_, _ = db.Statement.WriteString(" FROM ")
_, _ = db.Statement.WriteString(dummyTable)
}

_, _ = db.Statement.WriteString(`) `)
db.Statement.WriteQuoted("excluded")
_, _ = db.Statement.WriteString(" ON (")

var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "excluded", Name: field.DBName},
})
}
where.Build(db.Statement)
_ = db.Statement.WriteByte(')')

if len(onConflict.DoUpdates) > 0 {
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}

_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")

written := false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(column.Name)
}
}

_, _ = db.Statement.WriteString(") VALUES (")

written = false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(clause.Column{
Table: "excluded",
Name: column.Name,
})
}
}
_, _ = db.Statement.WriteString(")")
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
)

retract (
v1.5.12
v1.5.1
v1.5.0
)

0 comments on commit 62d9d90

Please sign in to comment.