Skip to content

Commit

Permalink
✨ 添加批量设置和移除会话参数工具函数
Browse files Browse the repository at this point in the history
- AddSessionParams setting database connection session parameters
- DelSessionParams remove session parameters

Signed-off-by: liutianqi <[email protected]>
  • Loading branch information
iTanken committed Jan 3, 2024
1 parent 1c7e78f commit 63b93a3
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
38 changes: 38 additions & 0 deletions oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,44 @@ func GetStringExpr(value string, quotes ...bool) clause.Expr {
return gorm.Expr(value)
}

// AddSessionParams setting database connection session parameters
func AddSessionParams(db *sql.DB, params map[string]string) (keys []string, err error) {
if db == nil {
return
}
if _, ok := db.Driver().(*go_ora.OracleDriver); !ok {
return
}

for key, value := range params {
if key == "" || value == "" {
continue
}
if err = go_ora.AddSessionParam(db, key, value); err != nil {
return
}
keys = append(keys, key)
}
return
}

// DelSessionParams remove session parameters
func DelSessionParams(db *sql.DB, keys []string) {
if db == nil {
return
}
if _, ok := db.Driver().(*go_ora.OracleDriver); !ok {
return
}

for _, key := range keys {
if key == "" {
continue
}
go_ora.DelSessionParam(db, key)
}
}

func convertCustomType(val interface{}) interface{} {
rv := reflect.ValueOf(val)
ri := rv.Interface()
Expand Down
50 changes: 50 additions & 0 deletions oracle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oracle

import (
"database/sql"
"log"
"os"
"reflect"
Expand Down Expand Up @@ -111,3 +112,52 @@ func openTestConnection(ignoreCase, namingCase bool) (db *gorm.DB, err error) {
}
return
}

func TestAddSessionParams(t *testing.T) {
db, err := openTestConnection(true, false)
if err != nil {
t.Fatal(err)
}
var sqlDB *sql.DB
if sqlDB, err = db.DB(); err != nil {
t.Fatal(err)
}
type args struct {
params map[string]string
}
tests := []struct {
name string
args args
}{
{name: "TimeParams", args: args{params: map[string]string{
"TIME_ZONE": "+08:00", // alter session set TIME_ZONE = '+08:00';
"NLS_DATE_FORMAT": "YYYY-MM-DD", // alter session set NLS_DATE_FORMAT = 'YYYY-MM-DD';
"NLS_TIME_FORMAT": "HH24:MI:SSXFF", // alter session set NLS_TIME_FORMAT = 'HH24:MI:SS.FF3';
"NLS_TIMESTAMP_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF", // alter session set NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3';
"NLS_TIME_TZ_FORMAT": "HH24:MI:SS.FF TZR", // alter session set NLS_TIME_TZ_FORMAT = 'HH24:MI:SS.FF3 TZR';
"NLS_TIMESTAMP_TZ_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF TZR", // alter session set NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZR';
}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
//queryTime := `SELECT SYSDATE FROM DUAL`
queryTime := `SELECT CAST(SYSDATE AS VARCHAR(30)) AS D FROM DUAL`
var timeStr string
if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil {
t.Fatal(err)
}
t.Logf("SYSDATE 1: %s", timeStr)

var keys []string
if keys, err = AddSessionParams(sqlDB, tt.args.params); err != nil {
t.Fatalf("AddSessionParams() error = %v", err)
}
if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil {
t.Fatal(err)
}
defer DelSessionParams(sqlDB, keys)
t.Logf("SYSDATE 2: %s", timeStr)
t.Logf("keys: %#v", keys)
})
}
}

0 comments on commit 63b93a3

Please sign in to comment.