diff --git a/connection.go b/connection.go index f3837f5..f3a9452 100644 --- a/connection.go +++ b/connection.go @@ -109,21 +109,40 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt go conn.ociBreakDone(ctx, done) defer func() { close(done) }() + if conn.stmtCacheSize == 0 { + if rv := C.OCIStmtPrepare2( + conn.svc, // service context handle + stmt, // pointer to the statement handle returned + conn.errHandle, // error handle + queryP, // statement text + C.ub4(len(query)), // statement text length + nil, // key to be used for searching the statement in the statement cache + C.ub4(0), // length of the key + C.ub4(C.OCI_NTV_SYNTAX), // syntax - OCI_NTV_SYNTAX: syntax depends upon the version of the server + C.ub4(C.OCI_DEFAULT), // mode + ); rv != C.OCI_SUCCESS { + return nil, conn.getError(rv) + } + + return &Stmt{conn: conn, stmt: *stmt, ctx: ctx, releaseMode: C.OCI_DEFAULT}, nil + } + if rv := C.OCIStmtPrepare2( conn.svc, // service context handle stmt, // pointer to the statement handle returned conn.errHandle, // error handle queryP, // statement text C.ub4(len(query)), // statement text length - nil, // key to be used for searching the statement in the statement cache - C.ub4(0), // length of the key + queryP, // key to be used for searching the statement in the statement cache + C.ub4(len(query)), // length of the key C.ub4(C.OCI_NTV_SYNTAX), // syntax - OCI_NTV_SYNTAX: syntax depends upon the version of the server C.ub4(C.OCI_DEFAULT), // mode - ); rv != C.OCI_SUCCESS { + ); rv != C.OCI_SUCCESS && rv != C.OCI_SUCCESS_WITH_INFO { + // Note that C.OCI_SUCCESS_WITH_INFO is returned the first time a statement it put into the cache return nil, conn.getError(rv) } - return &Stmt{conn: conn, stmt: *stmt, ctx: ctx}, nil + return &Stmt{conn: conn, stmt: *stmt, ctx: ctx, releaseMode: C.OCI_DEFAULT, cacheKey: query}, nil } // Begin starts a transaction diff --git a/globals.go b/globals.go index cda6474..90a710d 100644 --- a/globals.go +++ b/globals.go @@ -40,6 +40,7 @@ type ( transactionMode C.ub4 enableQMPlaceholders bool operationMode C.ub4 + stmtCacheSize C.ub4 } // DriverStruct is Oracle driver struct @@ -66,6 +67,7 @@ type ( prefetchMemory C.ub4 transactionMode C.ub4 operationMode C.ub4 + stmtCacheSize C.ub4 inTransaction bool enableQMPlaceholders bool closed bool @@ -80,10 +82,12 @@ type ( // Stmt is Oracle statement Stmt struct { - conn *Conn - stmt *C.OCIStmt - closed bool - ctx context.Context + conn *Conn + stmt *C.OCIStmt + closed bool + ctx context.Context + cacheKey string // if statement caching is enabled, this is the key for this statement into the cache + releaseMode C.ub4 } // Rows is Oracle rows diff --git a/oci8.go b/oci8.go index 3a29234..c0a461d 100644 --- a/oci8.go +++ b/oci8.go @@ -51,6 +51,7 @@ func ParseDSN(dsnString string) (dsn *DSN, err error) { dsn = &DSN{ prefetchRows: 0, prefetchMemory: 4096, + stmtCacheSize: 0, operationMode: C.OCI_DEFAULT, timeLocation: time.UTC, } @@ -119,7 +120,12 @@ func ParseDSN(dsnString string) (dsn *DSN, err error) { default: return nil, fmt.Errorf("Invalid as: %v", v[0]) } - + case "stmt_cache_size": + z, err := strconv.ParseUint(v[0], 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid stmt_cache_size: %v", v[0]) + } + dsn.stmtCacheSize = C.ub4(z) } } @@ -162,6 +168,7 @@ func (drv *DriverStruct) Open(dsnString string) (driver.Conn, error) { conn := Conn{ operationMode: dsn.operationMode, + stmtCacheSize: dsn.stmtCacheSize, logger: drv.Logger, } if conn.logger == nil { @@ -353,6 +360,14 @@ func (drv *DriverStruct) Open(dsnString string) (driver.Conn, error) { return nil, fmt.Errorf("authentication context attribute set error: %v", err) } + if dsn.stmtCacheSize > 0 { + stmtCacheSize := dsn.stmtCacheSize + err = conn.ociAttrSet(unsafe.Pointer(conn.svc), C.OCI_HTYPE_SVCCTX, unsafe.Pointer(&stmtCacheSize), 0, C.OCI_ATTR_STMTCACHESIZE) + if err != nil { + return nil, fmt.Errorf("stmt cache size attribute set error: %v", err) + } + } + } else { var svcCtxP *C.OCISvcCtx diff --git a/oci8_sql_go_113_test.go b/oci8_sql_go_113_test.go new file mode 100644 index 0000000..8602694 --- /dev/null +++ b/oci8_sql_go_113_test.go @@ -0,0 +1,84 @@ +// +build go1.13 + +package oci8 + +import ( + "context" + "testing" +) + +// TestStatementCaching tests to ensure statement caching is working +func TestStatementCaching(t *testing.T) { + if TestDisableDatabase { + t.SkipNow() + } + + t.Parallel() + + var err error + + db := testGetDB("?stmt_cache_size=10") + if db == nil { + t.Fatal("db is null") + } + + defer func() { + err = db.Close() + if err != nil { + t.Fatal("db close error:", err) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), TestContextTimeout) + conn, err := db.Conn(ctx) + cancel() + // we need to get access to the raw connection so we can access the different fields on the oci8.Stmt + var rawConn *Conn + // NOTE that conn.Raw() is only available with Go >= 1.13 + _ = conn.Raw(func(driverConn interface{}) error { + rawConn = driverConn.(*Conn) + return nil + }) + + ctx, cancel = context.WithTimeout(context.Background(), TestContextTimeout) + stmt, err := rawConn.PrepareContext(ctx, "select ?, ?, ? from dual") + cancel() + if err != nil { + t.Fatal("prepare error:", err) + } + + rawStmt := stmt.(*Stmt) + if rawStmt.cacheKey != "select ?, ?, ? from dual" { + err := stmt.Close() + if err != nil { + t.Fatal("stmt close error:", err) + } + t.Fatalf("cacheKey not equal: expected %s, but got %s", "select ?, ?, ? from dual", rawStmt.cacheKey) + } + + // closing the statement should put the statement into the cache + err = stmt.Close() + if err != nil { + t.Fatal("stmt close error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), TestContextTimeout) + stmt, err = rawConn.PrepareContext(ctx, "select ?, ?, ? from dual") + cancel() + if err != nil { + t.Fatal("prepare error:", err) + } + + rawStmt = stmt.(*Stmt) + if rawStmt.cacheKey != "select ?, ?, ? from dual" { + err := stmt.Close() + if err != nil { + t.Fatal("stmt close error:", err) + } + t.Fatalf("cacheKey not equal: expected %s, but got %s", "select ?, ?, ? from dual", rawStmt.cacheKey) + } + err = stmt.Close() + if err != nil { + t.Fatal("stmt close error:", err) + } +} diff --git a/oci8_sql_test.go b/oci8_sql_test.go index 5a972e6..b8462f9 100644 --- a/oci8_sql_test.go +++ b/oci8_sql_test.go @@ -66,7 +66,7 @@ func testExecQuery(t *testing.T, query string, args []interface{}) { } // testGetRows runs a statement and returns the rows as [][]interface{} -func testGetRows(t *testing.T, stmt *sql.Stmt, args []interface{}) ([][]interface{}, error) { +func testGetRows(t testing.TB, stmt *sql.Stmt, args []interface{}) ([][]interface{}, error) { // get rows ctx, cancel := context.WithTimeout(context.Background(), TestContextTimeout) defer cancel() @@ -468,7 +468,7 @@ end; _, err = stmt.ExecContext(ctx) cancel() expected := "ORA-01013" - if err == nil || len(err.Error()) < len(expected) || err.Error()[:len(expected)] != expected { + if err == nil || len(err.Error()) < len(expected) || !strings.Contains(err.Error(), expected) { t.Fatalf("stmt exec - expected: %v - received: %v", expected, err) } @@ -488,7 +488,7 @@ end; ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) _, err = stmt.QueryContext(ctx) cancel() - if err == nil || len(err.Error()) < len(expected) || err.Error()[:len(expected)] != expected { + if err == nil || len(err.Error()) < len(expected) || !strings.Contains(err.Error(), expected) { t.Fatalf("stmt query - expected: %v - received: %v", expected, err) } @@ -1475,6 +1475,104 @@ func benchmarkPrefetchSelect(b *testing.B, prefetchRows int64, prefetchMemory in } } +// TestSelectParallelWithStatementCaching checks parallel select from dual but with statement caching enabled +func TestSelectParallelWithStatementCaching(t *testing.T) { + if TestDisableDatabase { + t.SkipNow() + } + db := testGetDB("?stmt_cache_size=100") + if db == nil { + t.Fatal("db is null") + } + + var waitGroup sync.WaitGroup + waitGroup.Add(50) + + for i := 0; i < 50; i++ { + go func(num int) { + defer waitGroup.Done() + + selectNumFromDual(t, db, float64(num)) + }(i) + } + + waitGroup.Wait() +} + +// selectNumFromDual will execute a "select :1 from dual" where the parameter is the num param of this function +func selectNumFromDual(t testing.TB, db *sql.DB, num float64) { + ctx, cancel := context.WithTimeout(context.Background(), TestContextTimeout) + stmt, err := db.PrepareContext(ctx, "select :1 from dual") + cancel() + if err != nil { + t.Fatal("prepare error:", err) + } + defer func() { + if stmt != nil { + err := stmt.Close() + if err != nil { + t.Fatal("stmt close error:", err) + } + } + }() + + var result [][]interface{} + result, err = testGetRows(t, stmt, []interface{}{num}) + if err != nil { + t.Fatal("get rows error:", err) + } + if result == nil { + t.Fatal("result is nil") + } + if len(result) != 1 { + t.Fatal("len result not equal to 1") + } + if len(result[0]) != 1 { + t.Fatal("len result[0] not equal to 1") + } + data, ok := result[0][0].(float64) + if !ok { + t.Fatal("result not float64") + } + if data != num { + t.Fatal("result not equal to:", num) + } +} + +func BenchmarkSelectNoCaching(b *testing.B) { + if TestDisableDatabase || TestDisableDestructive { + b.SkipNow() + } + for i := 0; i < b.N; i++ { + selectNumFromDual(b, TestDB, float64(i)) + } +} + +func BenchmarkSelectWithCaching(b *testing.B) { + b.StopTimer() + + if TestDisableDatabase || TestDisableDestructive { + b.SkipNow() + } + + db := testGetDB("?stmt_cache_size=100") + if db == nil { + b.Fatal("db is null") + } + + defer func() { + err := db.Close() + if err != nil { + b.Fatal("db close error:", err) + } + }() + + b.StartTimer() + for i := 0; i < b.N; i++ { + selectNumFromDual(b, db, float64(i)) + } +} + func BenchmarkPrefetchR0M32768(b *testing.B) { b.StopTimer() diff --git a/oci8_test.go b/oci8_test.go index 6b61b07..bf571de 100644 --- a/oci8_test.go +++ b/oci8_test.go @@ -185,16 +185,18 @@ func TestParseDSN(t *testing.T) { const prefetchRows = 0 const prefetchMemory = 4096 + const stmtCacheSize = 0 var dsnTests = []struct { dsnString string expectedDSN *DSN }{ - {"oracle://xxmc:xxmc@107.20.30.169:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: timeLocations[5]}}, - {"xxmc/xxmc@107.20.30.169:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: timeLocations[5]}}, - {"sys/syspwd@107.20.30.169:1521/ORCL?loc=America%2FPhoenix&as=sysdba", &DSN{Username: "sys", Password: "syspwd", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: timeLocations[5], operationMode: 0x00000002}}, // with operationMode: 0x00000002 = C.OCI_SYDBA - {"xxmc/xxmc@107.20.30.169:1521/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: time.UTC}}, - {"xxmc/xxmc@107.20.30.169/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, timeLocation: time.UTC}}, + {"oracle://xxmc:xxmc@107.20.30.169:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: timeLocations[5]}}, + {"xxmc/xxmc@107.20.30.169:1521/ORCL?loc=America%2FPhoenix", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: timeLocations[5]}}, + {"sys/syspwd@107.20.30.169:1521/ORCL?loc=America%2FPhoenix&as=sysdba", &DSN{Username: "sys", Password: "syspwd", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: timeLocations[5], operationMode: 0x00000002}}, // with operationMode: 0x00000002 = C.OCI_SYDBA + {"xxmc/xxmc@107.20.30.169:1521/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169:1521/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: time.UTC}}, + {"xxmc/xxmc@107.20.30.169/ORCL", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: stmtCacheSize, timeLocation: time.UTC}}, + {"xxmc/xxmc@107.20.30.169/ORCL?stmt_cache_size=50", &DSN{Username: "xxmc", Password: "xxmc", Connect: "107.20.30.169/ORCL", prefetchRows: prefetchRows, prefetchMemory: prefetchMemory, stmtCacheSize: 50, timeLocation: time.UTC}}, } for _, tt := range dsnTests { diff --git a/rows.go b/rows.go index 01b9178..2fb8a9c 100644 --- a/rows.go +++ b/rows.go @@ -202,7 +202,7 @@ func (rows *Rows) Next(dest []driver.Value) error { // SQLT_RSET - ref cursor case C.SQLT_RSET: stmtP := (**C.OCIStmt)(rows.defines[i].pbuf) - subStmt := &Stmt{conn: rows.stmt.conn, stmt: *stmtP, ctx: rows.stmt.ctx} + subStmt := &Stmt{conn: rows.stmt.conn, stmt: *stmtP, ctx: rows.stmt.ctx, releaseMode: C.ub4(C.OCI_DEFAULT)} if rows.defines[i].subDefines == nil { var err error rows.defines[i].subDefines, err = subStmt.makeDefines() diff --git a/statement.go b/statement.go index 51a0e21..33000ea 100644 --- a/statement.go +++ b/statement.go @@ -22,13 +22,28 @@ func (stmt *Stmt) Close() error { } stmt.closed = true - result := C.OCIStmtRelease( - stmt.stmt, // statement handle - stmt.conn.errHandle, // error handle - nil, // key to be associated with the statement in the cache - C.ub4(0), // length of the key - C.ub4(C.OCI_DEFAULT), // mode - ) + var result C.sword + if stmt.cacheKey == "" { + result = C.OCIStmtRelease( + stmt.stmt, // statement handle + stmt.conn.errHandle, // error handle + nil, // key to be associated with the statement in the cache + C.ub4(0), // length of the key + stmt.releaseMode, // mode + ) + } else { + cacheKeyP := cString(stmt.cacheKey) + defer C.free(unsafe.Pointer(cacheKeyP)) + + result = C.OCIStmtRelease( + stmt.stmt, // statement handle + stmt.conn.errHandle, // error handle + cacheKeyP, // key to be associated with the statement in the cache + C.ub4(len(stmt.cacheKey)), // length of the key + stmt.releaseMode, // mode + ) + } + stmt.stmt = nil return stmt.conn.getError(result) @@ -1011,5 +1026,10 @@ func (stmt *Stmt) ociStmtExecute(iters C.ub4, mode C.ub4) error { mode, // The mode: https://docs.oracle.com/cd/E11882_01/appdev.112/e10646/oci17msc001.htm#LNOCI17163 ) + if stmt.cacheKey != "" && result != C.OCI_SUCCESS && result != C.OCI_SUCCESS_WITH_INFO { + // drop statement from cache for all errors when caching is enabled + stmt.releaseMode = C.OCI_STRLS_CACHE_DELETE + } + return stmt.conn.getError(result) } diff --git a/test.sh b/test.sh index a49f601..9f36f5a 100644 --- a/test.sh +++ b/test.sh @@ -53,10 +53,13 @@ sqlplus -L -S "sys/oracle@${DOCKER_IP}:1521 as sysdba" <