Skip to content

Commit

Permalink
feat(xtrace): 记录 trace 函数优化
Browse files Browse the repository at this point in the history
  • Loading branch information
Ccheers committed Apr 1, 2024
1 parent d34f89a commit 89c3ab3
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 0 deletions.
47 changes: 47 additions & 0 deletions xtrace/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,61 @@ import (

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/codes"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)

// Call
// 基于otel规范实现一个 call 函数,用于注入任意函数实现链路追踪
// 使用这个函数的 trace 一定会被记录, 请自行判断是否调用, 或则使用 CallWithSampler
func Call[T any](ctx context.Context, spanName string, f func(ctx context.Context) (T, error)) (T, error) {
// 基于otel规范实现一个 call 函数,用于注入任意函数实现链路追踪
tr := otel.Tracer("xpkg.trace")
psc := trace.SpanContextFromContext(ctx)
if !psc.IsValid() {
tid, sid := idGenerator.NewIDs(ctx)
ctx = trace.ContextWithSpanContext(ctx, trace.NewSpanContext(trace.SpanContextConfig{
TraceID: tid,
SpanID: sid,
TraceFlags: trace.FlagsSampled,
}))
}
ctx, span := tr.Start(ctx, spanName,
trace.WithSpanKind(trace.SpanKindInternal),
)
defer span.End()
reply, err := f(ctx)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
} else {
span.SetStatus(codes.Ok, "OK")
}
return reply, err
}

// CallWithSampler
// 基于 otel 规范实现一个 call 函数,用于注入任意函数实现链路追踪
// sdktrace.Sampler 用于判断是否采样
func CallWithSampler[T any](ctx context.Context, spanName string, sampler sdktrace.Sampler, f func(ctx context.Context) (T, error)) (T, error) {
tr := otel.Tracer("xpkg.trace")
psc := trace.SpanContextFromContext(ctx)
if !psc.IsValid() {
tid, sid := idGenerator.NewIDs(ctx)
result := sampler.ShouldSample(sdktrace.SamplingParameters{
ParentContext: nil,
TraceID: tid,
Name: spanName,
Kind: trace.SpanKindInternal,
})
if result.Decision == sdktrace.RecordAndSample {
ctx = trace.ContextWithSpanContext(ctx, trace.NewSpanContext(trace.SpanContextConfig{
TraceID: tid,
SpanID: sid,
TraceFlags: trace.FlagsSampled,
}))
}
}
ctx, span := tr.Start(ctx, spanName,
trace.WithSpanKind(trace.SpanKindInternal),
)
Expand Down
54 changes: 54 additions & 0 deletions xtrace/call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)

func TestCall(t *testing.T) {
Expand All @@ -26,9 +27,62 @@ func TestCall(t *testing.T) {
// Call the function with tracing
_, err = Call(context.TODO(), "测试一下", func(ctx context.Context) (struct{}, error) {
fmt.Println("Hello, World!")
psc := trace.SpanContextFromContext(ctx)
if psc.TraceFlags() != trace.FlagsSampled {
t.Errorf("trace flags should be sampled")
}
return struct{}{}, nil
})
if err != nil {
panic(err)
}
}

func TestCallWithSampler(t *testing.T) {
// Setup exporter
exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
if err != nil {
panic(err)
}

// Setup provider with the exporter
// Set the provider as the global tracer provider
otel.SetTracerProvider(sdktrace.NewTracerProvider(
sdktrace.WithSyncer(exporter),
sdktrace.WithSampler(sdktrace.AlwaysSample()),
))

// Call the function with tracing
_, err = Call(context.TODO(), "测试一下[1]", func(ctx context.Context) (struct{}, error) {
fmt.Println("Hello, World!")
psc := trace.SpanContextFromContext(ctx)
if psc.TraceFlags() != trace.FlagsSampled {
t.Errorf("trace flags should be sampled")
}
return struct{}{}, nil
})
if err != nil {
panic(err)
}

// Setup provider with the exporter
// Set the provider as the global tracer provider
otel.SetTracerProvider(sdktrace.NewTracerProvider(
sdktrace.WithSyncer(exporter),
sdktrace.WithSampler(sdktrace.NeverSample()),
))

// Call the function with tracing
_, err = Call(context.TODO(), "测试一下[2]", func(ctx context.Context) (struct{}, error) {
fmt.Println("Hello, World!")
psc := trace.SpanContextFromContext(ctx)
if psc.TraceFlags() == trace.FlagsSampled {
t.Errorf("trace flags should not be sampled")
}
return struct{}{}, nil
})
if err != nil {
panic(err)
}

}
50 changes: 50 additions & 0 deletions xtrace/idgenerator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package xtrace

import (
"context"
crand "crypto/rand"
"encoding/binary"
"math/rand"
"sync"

sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)

type randomIDGenerator struct {
sync.Mutex
randSource *rand.Rand
}

var _ sdktrace.IDGenerator = &randomIDGenerator{}

var idGenerator = defaultIDGenerator()

// NewSpanID returns a non-zero span ID from a randomly-chosen sequence.
func (gen *randomIDGenerator) NewSpanID(ctx context.Context, traceID trace.TraceID) trace.SpanID {
gen.Lock()
defer gen.Unlock()
sid := trace.SpanID{}
_, _ = gen.randSource.Read(sid[:])
return sid
}

// NewIDs returns a non-zero trace ID and a non-zero span ID from a
// randomly-chosen sequence.
func (gen *randomIDGenerator) NewIDs(ctx context.Context) (trace.TraceID, trace.SpanID) {
gen.Lock()
defer gen.Unlock()
tid := trace.TraceID{}
_, _ = gen.randSource.Read(tid[:])
sid := trace.SpanID{}
_, _ = gen.randSource.Read(sid[:])
return tid, sid
}

func defaultIDGenerator() sdktrace.IDGenerator {
gen := &randomIDGenerator{}
var rngSeed int64
_ = binary.Read(crand.Reader, binary.LittleEndian, &rngSeed)
gen.randSource = rand.New(rand.NewSource(rngSeed))
return gen
}

0 comments on commit 89c3ab3

Please sign in to comment.