From 62d9d907c95b1e8e58093f067ee849c79b1d8bd7 Mon Sep 17 00:00:00 2001 From: liutianqi Date: Wed, 3 Jan 2024 18:42:12 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20=E4=BF=AE=E5=A4=8D=E5=88=9B?= =?UTF-8?q?=E5=BB=BA=E6=95=B0=E6=8D=AE=E5=86=B2=E7=AA=81=E6=97=B6=E7=9A=84?= =?UTF-8?q?=20MERGE=20=E8=AF=AD=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: liutianqi --- clauses/merge.go | 49 -------------- clauses/returning_into.go | 10 --- clauses/when_matched.go | 39 ----------- clauses/when_not_matched.go | 32 --------- create.go | 125 ++++++++++++++++++++++++------------ go.mod | 1 + 6 files changed, 84 insertions(+), 172 deletions(-) delete mode 100644 clauses/merge.go delete mode 100644 clauses/returning_into.go delete mode 100644 clauses/when_matched.go delete mode 100644 clauses/when_not_matched.go diff --git a/clauses/merge.go b/clauses/merge.go deleted file mode 100644 index 62812a1..0000000 --- a/clauses/merge.go +++ /dev/null @@ -1,49 +0,0 @@ -package clauses - -import ( - "gorm.io/gorm/clause" -) - -type Merge struct { - Table clause.Table - Using []clause.Interface - On []clause.Expression -} - -func (merge Merge) Name() string { - return "MERGE" -} - -func MergeDefaultExcludeName() string { - return "exclude" -} - -// Build from clause -func (merge Merge) Build(builder clause.Builder) { - clause.Insert{}.Build(builder) - _, _ = builder.WriteString(" USING (") - for idx, v := range merge.Using { - if idx > 0 { - _ = builder.WriteByte(' ') - } - _, _ = builder.WriteString(v.Name()) - _ = builder.WriteByte(' ') - v.Build(builder) - } - _, _ = builder.WriteString(") ") - _, _ = builder.WriteString(MergeDefaultExcludeName()) - _, _ = builder.WriteString(" ON (") - for idx, on := range merge.On { - if idx > 0 { - _, _ = builder.WriteString(", ") - } - on.Build(builder) - } - _, _ = builder.WriteString(")") -} - -// MergeClause merge values clauses -func (merge Merge) MergeClause(clause *clause.Clause) { - clause.Name = merge.Name() - clause.Expression = merge -} diff --git a/clauses/returning_into.go b/clauses/returning_into.go deleted file mode 100644 index 0f2857b..0000000 --- a/clauses/returning_into.go +++ /dev/null @@ -1,10 +0,0 @@ -package clauses - -import ( - "gorm.io/gorm/clause" -) - -type ReturningInto struct { - Variables []clause.Column - Into []*clause.Values -} diff --git a/clauses/when_matched.go b/clauses/when_matched.go deleted file mode 100644 index 14c522d..0000000 --- a/clauses/when_matched.go +++ /dev/null @@ -1,39 +0,0 @@ -package clauses - -import ( - "gorm.io/gorm/clause" -) - -type WhenMatched struct { - clause.Set - Where, Delete clause.Where -} - -func (w WhenMatched) Name() string { - return "WHEN MATCHED" -} - -func (w WhenMatched) Build(builder clause.Builder) { - if len(w.Set) > 0 { - _, _ = builder.WriteString(" THEN") - _, _ = builder.WriteString(" UPDATE ") - _, _ = builder.WriteString(w.Name()) - _ = builder.WriteByte(' ') - w.Build(builder) - - buildWhere := func(where clause.Where) { - _, _ = builder.WriteString(where.Name()) - _ = builder.WriteByte(' ') - where.Build(builder) - } - - if len(w.Where.Exprs) > 0 { - buildWhere(w.Where) - } - - if len(w.Delete.Exprs) > 0 { - _, _ = builder.WriteString(" DELETE ") - buildWhere(w.Delete) - } - } -} diff --git a/clauses/when_not_matched.go b/clauses/when_not_matched.go deleted file mode 100644 index 631a922..0000000 --- a/clauses/when_not_matched.go +++ /dev/null @@ -1,32 +0,0 @@ -package clauses - -import ( - "gorm.io/gorm/clause" -) - -type WhenNotMatched struct { - clause.Values - Where clause.Where -} - -func (w WhenNotMatched) Name() string { - return "WHEN NOT MATCHED" -} - -func (w WhenNotMatched) Build(builder clause.Builder) { - if len(w.Columns) > 0 { - if len(w.Values.Values) != 1 { - panic("cannot insert more than one rows due to Oracle SQL language restriction") - } - - _, _ = builder.WriteString(" THEN") - _, _ = builder.WriteString(" INSERT ") - w.Build(builder) - - if len(w.Where.Exprs) > 0 { - _, _ = builder.WriteString(w.Where.Name()) - _ = builder.WriteByte(' ') - w.Where.Build(builder) - } - } -} diff --git a/create.go b/create.go index 128a4c2..8f044c0 100644 --- a/create.go +++ b/create.go @@ -1,9 +1,6 @@ package oracle import ( - "bytes" - - "github.com/godoes/gorm-oracle/clauses" "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" @@ -54,45 +51,7 @@ func Create(db *gorm.DB) { } } if hasConflict { - stmt.AddClauseIfNotExists(clauses.Merge{ - Using: []clause.Interface{ - clause.Select{ - Columns: func() (columns []clause.Column) { - // HACK: I can not come up with a better alternative for now - // I want to add a value to the list of variable and then capture the bind variable position as well - columns = values.Columns - for i, column := range columns { - buf := bytes.NewBufferString("") - stmt.Vars = append(stmt.Vars, values.Values[0][i]) - stmt.BindVarTo(buf, stmt, nil) - - column.Alias = column.Name - // then the captured bind var will be the name - column.Name = buf.String() - columns[i] = column - } - return - }(), - }, - clause.From{ - Tables: []clause.Table{{Name: db.Dialector.(*Dialector).DummyTableName()}}, - }, - }, - On: func() (onExpr []clause.Expression) { - onExpr = make([]clause.Expression, len(stmtSchema.PrimaryFields)) - for i, field := range stmtSchema.PrimaryFields { - onExpr[i] = clause.Eq{ - Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName}, - Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName}, - } - } - return - }(), - }) - stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates}) - stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values}) - - stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED") + MergeCreate(db, onConflict, values) } else { stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}}) stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}}) @@ -133,3 +92,85 @@ func Create(db *gorm.DB) { } } } + +func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) { + var dummyTable string + switch d := ptrDereference(db.Dialector).(type) { + case Dialector: + dummyTable = d.DummyTableName() + default: + dummyTable = "DUAL" + } + + _, _ = db.Statement.WriteString("MERGE INTO ") + db.Statement.WriteQuoted(db.Statement.Table) + _, _ = db.Statement.WriteString(" USING (") + + for idx, value := range values.Values { + if idx > 0 { + _, _ = db.Statement.WriteString(" UNION ALL ") + } + + _, _ = db.Statement.WriteString("SELECT ") + for i, v := range value { + if i > 0 { + _ = db.Statement.WriteByte(',') + } + column := values.Columns[i] + db.Statement.AddVar(db.Statement, v) + _, _ = db.Statement.WriteString(" AS ") + db.Statement.WriteQuoted(column.Name) + } + _, _ = db.Statement.WriteString(" FROM ") + _, _ = db.Statement.WriteString(dummyTable) + } + + _, _ = db.Statement.WriteString(`) `) + db.Statement.WriteQuoted("excluded") + _, _ = db.Statement.WriteString(" ON (") + + var where clause.Where + for _, field := range db.Statement.Schema.PrimaryFields { + where.Exprs = append(where.Exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, + Value: clause.Column{Table: "excluded", Name: field.DBName}, + }) + } + where.Build(db.Statement) + _ = db.Statement.WriteByte(')') + + if len(onConflict.DoUpdates) > 0 { + _, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") + onConflict.DoUpdates.Build(db.Statement) + } + + _, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") + + written := false + for _, column := range values.Columns { + if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name { + if written { + _ = db.Statement.WriteByte(',') + } + written = true + db.Statement.WriteQuoted(column.Name) + } + } + + _, _ = db.Statement.WriteString(") VALUES (") + + written = false + for _, column := range values.Columns { + if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name { + if written { + _ = db.Statement.WriteByte(',') + } + written = true + db.Statement.WriteQuoted(clause.Column{ + Table: "excluded", + Name: column.Name, + }) + } + } + _, _ = db.Statement.WriteString(")") +} diff --git a/go.mod b/go.mod index 81a529c..980817b 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( ) retract ( + v1.5.12 v1.5.1 v1.5.0 )