Skip to content

Commit

Permalink
Basic support for insert with on conflict (#21)
Browse files Browse the repository at this point in the history
* Support insert with on conflict

* fix function name

* update adapter implementation

* support mysql on conflict ignore

* fix tests

* bump rel
  • Loading branch information
Fs02 authored Mar 12, 2022
1 parent e0a1447 commit aab59ad
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 60 deletions.
4 changes: 2 additions & 2 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ type QueryBuilder interface {
}

type InsertBuilder interface {
Build(table string, primaryField string, mutates map[string]rel.Mutate) (string, []interface{})
Build(table string, primaryField string, mutates map[string]rel.Mutate, onConflict rel.OnConflict) (string, []interface{})
}

type InsertAllBuilder interface {
Build(table string, primaryField string, fields []string, bulkMutates []map[string]rel.Mutate) (string, []interface{})
Build(table string, primaryField string, fields []string, bulkMutates []map[string]rel.Mutate, onConflict rel.OnConflict) (string, []interface{})
}

type UpdateBuilder interface {
Expand Down
34 changes: 17 additions & 17 deletions builder/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ type Filter struct{}
func (f Filter) Write(buffer *Buffer, table string, filter rel.FilterQuery, queryWriter QueryWriter) {
switch filter.Type {
case rel.FilterAndOp:
f.BuildLogical(buffer, table, "AND", filter.Inner, queryWriter)
f.WriteLogical(buffer, table, "AND", filter.Inner, queryWriter)
case rel.FilterOrOp:
f.BuildLogical(buffer, table, "OR", filter.Inner, queryWriter)
f.WriteLogical(buffer, table, "OR", filter.Inner, queryWriter)
case rel.FilterNotOp:
buffer.WriteString("NOT ")
f.BuildLogical(buffer, table, "AND", filter.Inner, queryWriter)
f.WriteLogical(buffer, table, "AND", filter.Inner, queryWriter)
case rel.FilterEqOp,
rel.FilterNeOp,
rel.FilterLtOp,
rel.FilterLteOp,
rel.FilterGtOp,
rel.FilterGteOp:
f.BuildComparison(buffer, table, filter, queryWriter)
f.WriteComparison(buffer, table, filter, queryWriter)
case rel.FilterNilOp:
buffer.WriteField(table, filter.Field)
buffer.WriteString(" IS NULL")
Expand All @@ -32,7 +32,7 @@ func (f Filter) Write(buffer *Buffer, table string, filter rel.FilterQuery, quer
buffer.WriteString(" IS NOT NULL")
case rel.FilterInOp,
rel.FilterNinOp:
f.BuildInclusion(buffer, table, filter, queryWriter)
f.WriteInclusion(buffer, table, filter, queryWriter)
case rel.FilterLikeOp:
buffer.WriteField(table, filter.Field)
buffer.WriteString(" LIKE ")
Expand All @@ -49,8 +49,8 @@ func (f Filter) Write(buffer *Buffer, table string, filter rel.FilterQuery, quer
}
}

// BuildLogical SQL to buffer.
func (f Filter) BuildLogical(buffer *Buffer, table, op string, inner []rel.FilterQuery, queryWriter QueryWriter) {
// WriteLogical SQL to buffer.
func (f Filter) WriteLogical(buffer *Buffer, table, op string, inner []rel.FilterQuery, queryWriter QueryWriter) {
var (
length = len(inner)
)
Expand All @@ -74,8 +74,8 @@ func (f Filter) BuildLogical(buffer *Buffer, table, op string, inner []rel.Filte
}
}

// BuildComparison SQL to buffer.
func (f Filter) BuildComparison(buffer *Buffer, table string, filter rel.FilterQuery, queryWriter QueryWriter) {
// WriteComparison SQL to buffer.
func (f Filter) WriteComparison(buffer *Buffer, table string, filter rel.FilterQuery, queryWriter QueryWriter) {
buffer.WriteField(table, filter.Field)

switch filter.Type {
Expand All @@ -96,18 +96,18 @@ func (f Filter) BuildComparison(buffer *Buffer, table string, filter rel.FilterQ
switch v := filter.Value.(type) {
case rel.SubQuery:
// For warped sub-queries
f.buildSubQuery(buffer, v, queryWriter)
f.WriteSubQuery(buffer, v, queryWriter)
case rel.Query:
// For sub-queries without warp
f.buildSubQuery(buffer, rel.SubQuery{Query: v}, queryWriter)
f.WriteSubQuery(buffer, rel.SubQuery{Query: v}, queryWriter)
default:
// For simple values
buffer.WriteValue(filter.Value)
}
}

// BuildInclusion SQL to buffer.
func (f Filter) BuildInclusion(buffer *Buffer, table string, filter rel.FilterQuery, queryWriter QueryWriter) {
// WriteInclusion SQL to buffer.
func (f Filter) WriteInclusion(buffer *Buffer, table string, filter rel.FilterQuery, queryWriter QueryWriter) {
var (
values = filter.Value.([]interface{})
)
Expand All @@ -127,14 +127,14 @@ func (f Filter) BuildInclusion(buffer *Buffer, table string, filter rel.FilterQu
buffer.WriteString(" NOT IN ")
}

f.buildInclusionValues(buffer, values, queryWriter)
f.WriteInclusionValues(buffer, values, queryWriter)
}
}

func (f Filter) buildInclusionValues(buffer *Buffer, values []interface{}, queryWriter QueryWriter) {
func (f Filter) WriteInclusionValues(buffer *Buffer, values []interface{}, queryWriter QueryWriter) {
if len(values) == 1 {
if value, ok := values[0].(rel.Query); ok {
f.buildSubQuery(buffer, rel.SubQuery{Query: value}, queryWriter)
f.WriteSubQuery(buffer, rel.SubQuery{Query: value}, queryWriter)
return
}
}
Expand All @@ -149,7 +149,7 @@ func (f Filter) buildInclusionValues(buffer *Buffer, values []interface{}, query
buffer.WriteByte(')')
}

func (f Filter) buildSubQuery(buffer *Buffer, sub rel.SubQuery, queryWriter QueryWriter) {
func (f Filter) WriteSubQuery(buffer *Buffer, sub rel.SubQuery, queryWriter QueryWriter) {
buffer.WriteString(sub.Prefix)
buffer.WriteByte('(')
queryWriter.Write(buffer, sub.Query)
Expand Down
2 changes: 1 addition & 1 deletion builder/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (i Index) WriteCreateIndex(buffer *Buffer, index rel.Index) {
buffer.WriteString(")")
if !index.Filter.None() {
if !i.SupportFilter {
log.Print("[WARN] Adapter does not support filtered/partial indexes")
log.Print("[REL] Adapter does not support filtered/partial indexes")
return
}
buffer.WriteString(" WHERE ")
Expand Down
27 changes: 21 additions & 6 deletions builder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,34 @@ type Insert struct {
BufferFactory BufferFactory
ReturningPrimaryValue bool
InsertDefaultValues bool
OnConflict OnConflict
}

// Build sql query and its arguments.
func (i Insert) Build(table string, primaryField string, mutates map[string]rel.Mutate) (string, []interface{}) {
func (i Insert) Build(table string, primaryField string, mutates map[string]rel.Mutate, onConflict rel.OnConflict) (string, []interface{}) {
var (
buffer = i.BufferFactory.Create()
count = len(mutates)
)

i.WriteInsertInto(&buffer, table)
i.WriteValues(&buffer, mutates)
i.OnConflict.WriteMutates(&buffer, mutates, onConflict)
i.WriteReturning(&buffer, primaryField)

buffer.WriteString(";")

return buffer.String(), buffer.Arguments()
}

func (i Insert) WriteInsertInto(buffer *Buffer, table string) {
buffer.WriteString("INSERT INTO ")
buffer.WriteEscape(table)
}

func (i Insert) WriteValues(buffer *Buffer, mutates map[string]rel.Mutate) {
var (
count = len(mutates)
)

if count == 0 && i.InsertDefaultValues {
buffer.WriteString(" DEFAULT VALUES")
Expand Down Expand Up @@ -55,13 +72,11 @@ func (i Insert) Build(table string, primaryField string, mutates map[string]rel.
buffer.AddArguments(arguments...)
buffer.WriteByte(')')
}
}

func (i Insert) WriteReturning(buffer *Buffer, primaryField string) {
if i.ReturningPrimaryValue && primaryField != "" {
buffer.WriteString(" RETURNING ")
buffer.WriteEscape(primaryField)
}

buffer.WriteString(";")

return buffer.String(), buffer.Arguments()
}
32 changes: 23 additions & 9 deletions builder/insert_all.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,35 @@ import (
type InsertAll struct {
BufferFactory BufferFactory
ReturningPrimaryValue bool
OnConflict OnConflict
}

// Build SQL string and its arguments.
func (ia InsertAll) Build(table string, primaryField string, fields []string, bulkMutates []map[string]rel.Mutate) (string, []interface{}) {
func (ia InsertAll) Build(table string, primaryField string, fields []string, bulkMutates []map[string]rel.Mutate, onConflict rel.OnConflict) (string, []interface{}) {
var (
buffer = ia.BufferFactory.Create()
fieldsCount = len(fields)
mutatesCount = len(bulkMutates)
buffer = ia.BufferFactory.Create()
)

ia.WriteInsertInto(&buffer, table)
ia.WriteValues(&buffer, fields, bulkMutates)
ia.OnConflict.Write(&buffer, fields, onConflict)
ia.WriteReturning(&buffer, primaryField)
buffer.WriteString(";")

return buffer.String(), buffer.Arguments()
}

func (ia InsertAll) WriteInsertInto(buffer *Buffer, table string) {
buffer.WriteString("INSERT INTO ")
buffer.WriteEscape(table)
}

func (ia InsertAll) WriteValues(buffer *Buffer, fields []string, bulkMutates []map[string]rel.Mutate) {
var (
fieldsCount = len(fields)
mutatesCount = len(bulkMutates)
)

buffer.WriteString(" (")

for i := range fields {
Expand Down Expand Up @@ -53,14 +70,11 @@ func (ia InsertAll) Build(table string, primaryField string, fields []string, bu
buffer.WriteByte(')')
}
}
}

func (ia InsertAll) WriteReturning(buffer *Buffer, primaryField string) {
if ia.ReturningPrimaryValue && primaryField != "" {
buffer.WriteString(" RETURNING ")
buffer.WriteEscape(primaryField)
}

buffer.WriteString(";")

return buffer.String(), buffer.Arguments()

}
87 changes: 82 additions & 5 deletions builder/insert_all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func BenchmarkInsertAll_Build(b *testing.B) {
)

for n := 0; n < b.N; n++ {
insertAllBuilder.Build("users", "id", []string{"name"}, bulkMutates)
insertAllBuilder.Build("users", "id", []string{"name"}, bulkMutates, rel.OnConflict{})
}
}

Expand All @@ -50,12 +50,12 @@ func TestInsertAll_Build(t *testing.T) {
}
)

statement, args := insertAllBuilder.Build("users", "id", []string{"name"}, bulkMutates)
statement, args := insertAllBuilder.Build("users", "id", []string{"name"}, bulkMutates, rel.OnConflict{})
assert.Equal(t, "INSERT INTO `users` (`name`) VALUES (?),(DEFAULT),(?);", statement)
assert.Equal(t, []interface{}{"foo", "boo"}, args)

// with age
statement, args = insertAllBuilder.Build("users", "id", []string{"name", "age"}, bulkMutates)
statement, args = insertAllBuilder.Build("users", "id", []string{"name", "age"}, bulkMutates, rel.OnConflict{})
assert.Equal(t, "INSERT INTO `users` (`name`,`age`) VALUES (?,DEFAULT),(DEFAULT,?),(?,?);", statement)
assert.Equal(t, []interface{}{"foo", 10, "boo", 20}, args)
}
Expand All @@ -80,12 +80,89 @@ func TestInsertAll_Build_ordinal(t *testing.T) {
}
)

statement, args := insertAllBuilder.Build("users", "id", []string{"name"}, bulkMutates)
statement, args := insertAllBuilder.Build("users", "id", []string{"name"}, bulkMutates, rel.OnConflict{})
assert.Equal(t, "INSERT INTO \"users\" (\"name\") VALUES ($1),(DEFAULT),($2) RETURNING \"id\";", statement)
assert.Equal(t, []interface{}{"foo", "boo"}, args)

// with age
statement, args = insertAllBuilder.Build("users", "id", []string{"name", "age"}, bulkMutates)
statement, args = insertAllBuilder.Build("users", "id", []string{"name", "age"}, bulkMutates, rel.OnConflict{})
assert.Equal(t, "INSERT INTO \"users\" (\"name\",\"age\") VALUES ($1,DEFAULT),(DEFAULT,$2),($3,$4) RETURNING \"id\";", statement)
assert.Equal(t, []interface{}{"foo", 10, "boo", 20}, args)
}

func TestInsertAll_Build_onConflictIgnore(t *testing.T) {
var (
insertAllBuilder = InsertAll{
BufferFactory: BufferFactory{ArgumentPlaceholder: "?", Quoter: Quote{IDPrefix: "`", IDSuffix: "`", IDSuffixEscapeChar: "`", ValueQuote: "'", ValueQuoteEscapeChar: "'"}},
OnConflict: OnConflict{
Statement: "ON CONFLICT",
IgnoreStatement: "IGNORE",
SupportKey: true,
},
}
bulkMutates = []map[string]rel.Mutate{
{
"id": rel.Set("id", 1),
},
{
"id": rel.Set("id", 2),
},
}
onConflict = rel.OnConflict{Keys: []string{"id"}, Ignore: true}
qs, args = insertAllBuilder.Build("users", "id", []string{"id"}, bulkMutates, onConflict)
)

assert.Equal(t, "INSERT INTO `users` (`id`) VALUES (?),(?) ON CONFLICT(`id`) IGNORE;", qs)
assert.Equal(t, []interface{}{1, 2}, args)
}

func TestInsertAll_Build_onConflictReplace(t *testing.T) {
var (
insertAllBuilder = InsertAll{
BufferFactory: BufferFactory{ArgumentPlaceholder: "?", Quoter: Quote{IDPrefix: "`", IDSuffix: "`", IDSuffixEscapeChar: "`", ValueQuote: "'", ValueQuoteEscapeChar: "'"}},
OnConflict: OnConflict{
Statement: "ON CONFLICT",
UpdateStatement: "DO UPDATE SET",
TableQualifier: "EXCLUDED",
SupportKey: true,
},
}
bulkMutates = []map[string]rel.Mutate{
{
"id": rel.Set("id", 1),
},
{
"id": rel.Set("id", 2),
},
}
onConflict = rel.OnConflict{Keys: []string{"id", "username"}, Replace: true}
qs, args = insertAllBuilder.Build("users", "id", []string{"id"}, bulkMutates, onConflict)
)

assert.Equal(t, "INSERT INTO `users` (`id`) VALUES (?),(?) ON CONFLICT(`id`,`username`) DO UPDATE SET `id`=`EXCLUDED`.`id`;", qs)
assert.Equal(t, []interface{}{1, 2}, args)
}

func TestInsertAll_Build_onConflictFragment(t *testing.T) {
var (
insertAllBuilder = InsertAll{
BufferFactory: BufferFactory{ArgumentPlaceholder: "?", Quoter: Quote{IDPrefix: "`", IDSuffix: "`", IDSuffixEscapeChar: "`", ValueQuote: "'", ValueQuoteEscapeChar: "'"}},
OnConflict: OnConflict{
Statement: "ON CONFLICT",
},
}
bulkMutates = []map[string]rel.Mutate{
{
"id": rel.Set("id", 1),
},
{
"id": rel.Set("id", 2),
},
}
onConflict = rel.OnConflict{Fragment: "SET `name`=?", FragmentArgs: []interface{}{"foo"}}
qs, args = insertAllBuilder.Build("users", "id", []string{"id"}, bulkMutates, onConflict)
)

assert.Equal(t, "INSERT INTO `users` (`id`) VALUES (?),(?) ON CONFLICT SET `name`=?;", qs)
assert.Equal(t, []interface{}{1, 2, "foo"}, args)
}
Loading

0 comments on commit aab59ad

Please sign in to comment.