Skip to content

Commit

Permalink
✨ 实现修改数据支持自定义时间类型
Browse files Browse the repository at this point in the history
Signed-off-by: liutianqi <[email protected]>
  • Loading branch information
iTanken committed Dec 15, 2023
1 parent bc985d3 commit 1f9dc23
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 45 deletions.
49 changes: 7 additions & 42 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ package oracle

import (
"bytes"
"database/sql"
"reflect"
"time"

"github.com/godoes/gorm-oracle/clauses"
"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
gormSchema "gorm.io/gorm/schema"
"gorm.io/gorm/schema"
)

func Create(db *gorm.DB) {
Expand All @@ -24,26 +22,26 @@ func Create(db *gorm.DB) {
return
}

schema := stmt.Schema
if schema == nil {
stmtSchema := stmt.Schema
if stmtSchema == nil {
return
}

if !stmt.Unscoped {
for _, c := range schema.CreateClauses {
for _, c := range stmtSchema.CreateClauses {
stmt.AddClause(c)
}
}

if stmt.SQL.String() == "" {
if stmt.SQL.Len() == 0 {
var (
values = callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
)
// are all columns in value the primary fields in schema only?
if hasConflict && funk.Contains(
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }),
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }),
funk.Map(stmtSchema.PrimaryFields, func(field *schema.Field) string { return field.DBName }),
) {
stmt.AddClauseIfNotExists(clauses.Merge{
Using: []clause.Interface{
Expand All @@ -65,7 +63,7 @@ func Create(db *gorm.DB) {
Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}},
},
},
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression {
On: funk.Map(stmtSchema.PrimaryFields, func(field *schema.Field) clause.Expression {
return clause.Eq{
Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
Expand Down Expand Up @@ -122,36 +120,3 @@ func Create(db *gorm.DB) {
}
}
}

func convertCustomType(val interface{}) interface{} {
rv := reflect.ValueOf(val)
ri := rv.Interface()
typeName := reflect.TypeOf(ri).Name()
if reflect.TypeOf(val).Kind() == reflect.Ptr {
if rv.IsNil() {
typeName = rv.Type().Elem().Name()
} else {
for rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
ri = rv.Interface()
typeName = reflect.TypeOf(ri).Name()
}
}
if typeName == "DeletedAt" {
// gorm.DeletedAt
if rv.IsZero() {
val = sql.NullTime{}
} else {
val = ri.(gorm.DeletedAt).Time
}
} else if m := rv.MethodByName("Time"); m.IsValid() && m.Type().NumIn() == 0 {
// custom time type
for _, result := range m.Call([]reflect.Value{}) {
if reflect.TypeOf(result.Interface()).Name() == "Time" {
val = result.Interface().(time.Time)
}
}
}
return val
}
44 changes: 41 additions & 3 deletions oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"database/sql"
"fmt"
"log"
"reflect"
"regexp"
"strconv"
"strings"
"time"

"github.com/sijms/go-ora/v2"
"github.com/thoas/go-funk"
Expand Down Expand Up @@ -52,6 +54,39 @@ func BuildUrl(server string, port int, service, user, password string, options m
return go_ora.BuildUrl(server, port, service, user, password, options)
}

func convertCustomType(val interface{}) interface{} {
rv := reflect.ValueOf(val)
ri := rv.Interface()
typeName := reflect.TypeOf(ri).Name()
if reflect.TypeOf(val).Kind() == reflect.Ptr {
if rv.IsNil() {
typeName = rv.Type().Elem().Name()
} else {
for rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
ri = rv.Interface()
typeName = reflect.TypeOf(ri).Name()
}
}
if typeName == "DeletedAt" {
// gorm.DeletedAt
if rv.IsZero() {
val = sql.NullTime{}
} else {
val = ri.(gorm.DeletedAt).Time
}
} else if m := rv.MethodByName("Time"); m.IsValid() && m.Type().NumIn() == 0 {
// custom time type
for _, result := range m.Call([]reflect.Value{}) {
if reflect.TypeOf(result.Interface()).Name() == "Time" {
val = result.Interface().(time.Time)
}
}
}
return val
}

func (d Dialector) DummyTableName() string {
return "DUAL"
}
Expand All @@ -65,12 +100,12 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
d.DefaultStringSize = 1024

// register callbacks
//callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{WithReturning: true})
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
config := &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
})
}
callbacks.RegisterDefaultCallbacks(db, config)

d.DriverName = "oracle"

Expand All @@ -97,6 +132,9 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
return
}
if err = db.Callback().Update().Replace("gorm:update", Update(config)); err != nil {
return
}

for k, v := range d.ClauseBuilders() {
db.ClauseBuilders[k] = v
Expand Down
Loading

0 comments on commit 1f9dc23

Please sign in to comment.