-
Notifications
You must be signed in to change notification settings - Fork 72
/
stmt.go
120 lines (97 loc) · 2.8 KB
/
stmt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package otelsql
import (
"context"
"database/sql/driver"
"go.opentelemetry.io/otel/trace"
)
type otelStmt struct {
driver.Stmt
query string
instrum *dbInstrum
execCtx stmtExecCtxFunc
queryCtx stmtQueryCtxFunc
}
var _ driver.Stmt = (*otelStmt)(nil)
func newStmt(stmt driver.Stmt, query string, instrum *dbInstrum) *otelStmt {
s := &otelStmt{
Stmt: stmt,
query: query,
instrum: instrum,
}
s.execCtx = s.createExecCtxFunc(stmt)
s.queryCtx = s.createQueryCtxFunc(stmt)
return s
}
//------------------------------------------------------------------------------
var _ driver.StmtExecContext = (*otelStmt)(nil)
func (stmt *otelStmt) ExecContext(
ctx context.Context, args []driver.NamedValue,
) (driver.Result, error) {
return stmt.execCtx(ctx, args)
}
type stmtExecCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Result, error)
func (s *otelStmt) createExecCtxFunc(stmt driver.Stmt) stmtExecCtxFunc {
var fn stmtExecCtxFunc
if execer, ok := s.Stmt.(driver.StmtExecContext); ok {
fn = execer.ExecContext
} else {
fn = func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
vArgs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return stmt.Exec(vArgs)
}
}
return func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
var res driver.Result
err := s.instrum.withSpan(ctx, "stmt.Exec", s.query,
func(ctx context.Context, span trace.Span) error {
var err error
res, err = fn(ctx, args)
if err != nil {
return err
}
if span.IsRecording() {
rows, err := res.RowsAffected()
if err == nil {
span.SetAttributes(dbRowsAffected.Int64(rows))
}
}
return nil
})
return res, err
}
}
//------------------------------------------------------------------------------
var _ driver.StmtQueryContext = (*otelStmt)(nil)
func (stmt *otelStmt) QueryContext(
ctx context.Context, args []driver.NamedValue,
) (driver.Rows, error) {
return stmt.queryCtx(ctx, args)
}
type stmtQueryCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error)
func (s *otelStmt) createQueryCtxFunc(stmt driver.Stmt) stmtQueryCtxFunc {
var fn stmtQueryCtxFunc
if queryer, ok := s.Stmt.(driver.StmtQueryContext); ok {
fn = queryer.QueryContext
} else {
fn = func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
vArgs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return s.Query(vArgs)
}
}
return func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
var rows driver.Rows
err := s.instrum.withSpan(ctx, "stmt.Query", s.query,
func(ctx context.Context, span trace.Span) error {
var err error
rows, err = fn(ctx, args)
return err
})
return rows, err
}
}