Skip to content

Commit

Permalink
Added Statement caching (#403)
Browse files Browse the repository at this point in the history
- added parameter `stmt_cache_size` to DSN string
  • Loading branch information
stampy88 authored Oct 20, 2020
1 parent 3a51645 commit b5e671b
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 26 deletions.
27 changes: 23 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,40 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
go conn.ociBreakDone(ctx, done)
defer func() { close(done) }()

if conn.stmtCacheSize == 0 {
if rv := C.OCIStmtPrepare2(
conn.svc, // service context handle
stmt, // pointer to the statement handle returned
conn.errHandle, // error handle
queryP, // statement text
C.ub4(len(query)), // statement text length
nil, // key to be used for searching the statement in the statement cache
C.ub4(0), // length of the key
C.ub4(C.OCI_NTV_SYNTAX), // syntax - OCI_NTV_SYNTAX: syntax depends upon the version of the server
C.ub4(C.OCI_DEFAULT), // mode
); rv != C.OCI_SUCCESS {
return nil, conn.getError(rv)
}

return &Stmt{conn: conn, stmt: *stmt, ctx: ctx, releaseMode: C.OCI_DEFAULT}, nil
}

if rv := C.OCIStmtPrepare2(
conn.svc, // service context handle
stmt, // pointer to the statement handle returned
conn.errHandle, // error handle
queryP, // statement text
C.ub4(len(query)), // statement text length
nil, // key to be used for searching the statement in the statement cache
C.ub4(0), // length of the key
queryP, // key to be used for searching the statement in the statement cache
C.ub4(len(query)), // length of the key
C.ub4(C.OCI_NTV_SYNTAX), // syntax - OCI_NTV_SYNTAX: syntax depends upon the version of the server
C.ub4(C.OCI_DEFAULT), // mode
); rv != C.OCI_SUCCESS {
); rv != C.OCI_SUCCESS && rv != C.OCI_SUCCESS_WITH_INFO {
// Note that C.OCI_SUCCESS_WITH_INFO is returned the first time a statement it put into the cache
return nil, conn.getError(rv)
}

return &Stmt{conn: conn, stmt: *stmt, ctx: ctx}, nil
return &Stmt{conn: conn, stmt: *stmt, ctx: ctx, releaseMode: C.OCI_DEFAULT, cacheKey: query}, nil
}

// Begin starts a transaction
Expand Down
12 changes: 8 additions & 4 deletions globals.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type (
transactionMode C.ub4
enableQMPlaceholders bool
operationMode C.ub4
stmtCacheSize C.ub4
}

// DriverStruct is Oracle driver struct
Expand All @@ -66,6 +67,7 @@ type (
prefetchMemory C.ub4
transactionMode C.ub4
operationMode C.ub4
stmtCacheSize C.ub4
inTransaction bool
enableQMPlaceholders bool
closed bool
Expand All @@ -80,10 +82,12 @@ type (

// Stmt is Oracle statement
Stmt struct {
conn *Conn
stmt *C.OCIStmt
closed bool
ctx context.Context
conn *Conn
stmt *C.OCIStmt
closed bool
ctx context.Context
cacheKey string // if statement caching is enabled, this is the key for this statement into the cache
releaseMode C.ub4
}

// Rows is Oracle rows
Expand Down
17 changes: 16 additions & 1 deletion oci8.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func ParseDSN(dsnString string) (dsn *DSN, err error) {
dsn = &DSN{
prefetchRows: 0,
prefetchMemory: 4096,
stmtCacheSize: 0,
operationMode: C.OCI_DEFAULT,
timeLocation: time.UTC,
}
Expand Down Expand Up @@ -119,7 +120,12 @@ func ParseDSN(dsnString string) (dsn *DSN, err error) {
default:
return nil, fmt.Errorf("Invalid as: %v", v[0])
}

case "stmt_cache_size":
z, err := strconv.ParseUint(v[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid stmt_cache_size: %v", v[0])
}
dsn.stmtCacheSize = C.ub4(z)
}
}

Expand Down Expand Up @@ -162,6 +168,7 @@ func (drv *DriverStruct) Open(dsnString string) (driver.Conn, error) {

conn := Conn{
operationMode: dsn.operationMode,
stmtCacheSize: dsn.stmtCacheSize,
logger: drv.Logger,
}
if conn.logger == nil {
Expand Down Expand Up @@ -353,6 +360,14 @@ func (drv *DriverStruct) Open(dsnString string) (driver.Conn, error) {
return nil, fmt.Errorf("authentication context attribute set error: %v", err)
}

if dsn.stmtCacheSize > 0 {
stmtCacheSize := dsn.stmtCacheSize
err = conn.ociAttrSet(unsafe.Pointer(conn.svc), C.OCI_HTYPE_SVCCTX, unsafe.Pointer(&stmtCacheSize), 0, C.OCI_ATTR_STMTCACHESIZE)
if err != nil {
return nil, fmt.Errorf("stmt cache size attribute set error: %v", err)
}
}

} else {

var svcCtxP *C.OCISvcCtx
Expand Down
84 changes: 84 additions & 0 deletions oci8_sql_go_113_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// +build go1.13

package oci8

import (
"context"
"testing"
)

// TestStatementCaching tests to ensure statement caching is working
func TestStatementCaching(t *testing.T) {
if TestDisableDatabase {
t.SkipNow()
}

t.Parallel()

var err error

db := testGetDB("?stmt_cache_size=10")
if db == nil {
t.Fatal("db is null")
}

defer func() {
err = db.Close()
if err != nil {
t.Fatal("db close error:", err)
}
}()

ctx, cancel := context.WithTimeout(context.Background(), TestContextTimeout)
conn, err := db.Conn(ctx)
cancel()
// we need to get access to the raw connection so we can access the different fields on the oci8.Stmt
var rawConn *Conn
// NOTE that conn.Raw() is only available with Go >= 1.13
_ = conn.Raw(func(driverConn interface{}) error {
rawConn = driverConn.(*Conn)
return nil
})

ctx, cancel = context.WithTimeout(context.Background(), TestContextTimeout)
stmt, err := rawConn.PrepareContext(ctx, "select ?, ?, ? from dual")
cancel()
if err != nil {
t.Fatal("prepare error:", err)
}

rawStmt := stmt.(*Stmt)
if rawStmt.cacheKey != "select ?, ?, ? from dual" {
err := stmt.Close()
if err != nil {
t.Fatal("stmt close error:", err)
}
t.Fatalf("cacheKey not equal: expected %s, but got %s", "select ?, ?, ? from dual", rawStmt.cacheKey)
}

// closing the statement should put the statement into the cache
err = stmt.Close()
if err != nil {
t.Fatal("stmt close error:", err)
}

ctx, cancel = context.WithTimeout(context.Background(), TestContextTimeout)
stmt, err = rawConn.PrepareContext(ctx, "select ?, ?, ? from dual")
cancel()
if err != nil {
t.Fatal("prepare error:", err)
}

rawStmt = stmt.(*Stmt)
if rawStmt.cacheKey != "select ?, ?, ? from dual" {
err := stmt.Close()
if err != nil {
t.Fatal("stmt close error:", err)
}
t.Fatalf("cacheKey not equal: expected %s, but got %s", "select ?, ?, ? from dual", rawStmt.cacheKey)
}
err = stmt.Close()
if err != nil {
t.Fatal("stmt close error:", err)
}
}
104 changes: 101 additions & 3 deletions oci8_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func testExecQuery(t *testing.T, query string, args []interface{}) {
}

// testGetRows runs a statement and returns the rows as [][]interface{}
func testGetRows(t *testing.T, stmt *sql.Stmt, args []interface{}) ([][]interface{}, error) {
func testGetRows(t testing.TB, stmt *sql.Stmt, args []interface{}) ([][]interface{}, error) {
// get rows
ctx, cancel := context.WithTimeout(context.Background(), TestContextTimeout)
defer cancel()
Expand Down Expand Up @@ -468,7 +468,7 @@ end;
_, err = stmt.ExecContext(ctx)
cancel()
expected := "ORA-01013"
if err == nil || len(err.Error()) < len(expected) || err.Error()[:len(expected)] != expected {
if err == nil || len(err.Error()) < len(expected) || !strings.Contains(err.Error(), expected) {
t.Fatalf("stmt exec - expected: %v - received: %v", expected, err)
}

Expand All @@ -488,7 +488,7 @@ end;
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
_, err = stmt.QueryContext(ctx)
cancel()
if err == nil || len(err.Error()) < len(expected) || err.Error()[:len(expected)] != expected {
if err == nil || len(err.Error()) < len(expected) || !strings.Contains(err.Error(), expected) {
t.Fatalf("stmt query - expected: %v - received: %v", expected, err)
}

Expand Down Expand Up @@ -1475,6 +1475,104 @@ func benchmarkPrefetchSelect(b *testing.B, prefetchRows int64, prefetchMemory in
}
}

// TestSelectParallelWithStatementCaching checks parallel select from dual but with statement caching enabled
func TestSelectParallelWithStatementCaching(t *testing.T) {
if TestDisableDatabase {
t.SkipNow()
}
db := testGetDB("?stmt_cache_size=100")
if db == nil {
t.Fatal("db is null")
}

var waitGroup sync.WaitGroup
waitGroup.Add(50)

for i := 0; i < 50; i++ {
go func(num int) {
defer waitGroup.Done()

selectNumFromDual(t, db, float64(num))
}(i)
}

waitGroup.Wait()
}

// selectNumFromDual will execute a "select :1 from dual" where the parameter is the num param of this function
func selectNumFromDual(t testing.TB, db *sql.DB, num float64) {
ctx, cancel := context.WithTimeout(context.Background(), TestContextTimeout)
stmt, err := db.PrepareContext(ctx, "select :1 from dual")
cancel()
if err != nil {
t.Fatal("prepare error:", err)
}
defer func() {
if stmt != nil {
err := stmt.Close()
if err != nil {
t.Fatal("stmt close error:", err)
}
}
}()

var result [][]interface{}
result, err = testGetRows(t, stmt, []interface{}{num})
if err != nil {
t.Fatal("get rows error:", err)
}
if result == nil {
t.Fatal("result is nil")
}
if len(result) != 1 {
t.Fatal("len result not equal to 1")
}
if len(result[0]) != 1 {
t.Fatal("len result[0] not equal to 1")
}
data, ok := result[0][0].(float64)
if !ok {
t.Fatal("result not float64")
}
if data != num {
t.Fatal("result not equal to:", num)
}
}

func BenchmarkSelectNoCaching(b *testing.B) {
if TestDisableDatabase || TestDisableDestructive {
b.SkipNow()
}
for i := 0; i < b.N; i++ {
selectNumFromDual(b, TestDB, float64(i))
}
}

func BenchmarkSelectWithCaching(b *testing.B) {
b.StopTimer()

if TestDisableDatabase || TestDisableDestructive {
b.SkipNow()
}

db := testGetDB("?stmt_cache_size=100")
if db == nil {
b.Fatal("db is null")
}

defer func() {
err := db.Close()
if err != nil {
b.Fatal("db close error:", err)
}
}()

b.StartTimer()
for i := 0; i < b.N; i++ {
selectNumFromDual(b, db, float64(i))
}
}

func BenchmarkPrefetchR0M32768(b *testing.B) {
b.StopTimer()

Expand Down
12 changes: 7 additions & 5 deletions oci8_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,18 @@ func TestParseDSN(t *testing.T) {

const prefetchRows = 0
const prefetchMemory = 4096
const stmtCacheSize = 0

var dsnTests = []struct {
dsnString string
expectedDSN *DSN
}{
{"oracle://xxmc:[email protected]:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: timeLocations[5]}},
{"xxmc/[email protected]:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: timeLocations[5]}},
{"sys/[email protected]:1521/ORCL?loc=America%2FPhoenix&as=sysdba", &DSN{Username: "sys", Password: "syspwd", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: timeLocations[5], operationMode: 0x00000002}}, // with operationMode: 0x00000002 = C.OCI_SYDBA
{"xxmc/[email protected]:1521/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: time.UTC}},
{"xxmc/[email protected]/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: time.UTC}},
{"oracle://xxmc:[email protected]:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: timeLocations[5]}},
{"xxmc/[email protected]:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: timeLocations[5]}},
{"sys/[email protected]:1521/ORCL?loc=America%2FPhoenix&as=sysdba", &DSN{Username: "sys", Password: "syspwd", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: timeLocations[5], operationMode: 0x00000002}}, // with operationMode: 0x00000002 = C.OCI_SYDBA
{"xxmc/[email protected]:1521/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: time.UTC}},
{"xxmc/[email protected]/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: time.UTC}},
{"xxmc/[email protected]/ORCL?stmt_cache_size=50", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: 50, timeLocation: time.UTC}},
}

for _, tt := range dsnTests {
Expand Down
2 changes: 1 addition & 1 deletion rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (rows *Rows) Next(dest []driver.Value) error {
// SQLT_RSET - ref cursor
case C.SQLT_RSET:
stmtP := (**C.OCIStmt)(rows.defines[i].pbuf)
subStmt := &Stmt{conn: rows.stmt.conn, stmt: *stmtP, ctx: rows.stmt.ctx}
subStmt := &Stmt{conn: rows.stmt.conn, stmt: *stmtP, ctx: rows.stmt.ctx, releaseMode: C.ub4(C.OCI_DEFAULT)}
if rows.defines[i].subDefines == nil {
var err error
rows.defines[i].subDefines, err = subStmt.makeDefines()
Expand Down
Loading

0 comments on commit b5e671b

Please sign in to comment.