Skip to content

Commit

Permalink
fix: producer with options
Browse files Browse the repository at this point in the history
Signed-off-by: Eray Ates <[email protected]>
  • Loading branch information
rytsh committed Jan 3, 2024
1 parent f7d119c commit 0a79414
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 38 deletions.
15 changes: 15 additions & 0 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,18 @@ func (codecJSON[T]) Decode(raw []byte, _ *kgo.Record) (T, error) {

return data, nil
}

type codecByte[T any] struct{}

func (codecByte[T]) Encode(data T) ([]byte, error) {
v, ok := any(data).([]byte)
if !ok {
return nil, fmt.Errorf("invalid data type: %T", data)
}

return v, nil
}

func (codecByte[T]) Decode(raw []byte, _ *kgo.Record) (T, error) {

Check failure on line 76 in codec.go

View workflow job for this annotation

GitHub Actions / sonarcloud

Decode returns interface (T) (ireturn)
return any(raw).(T), nil

Check failure on line 77 in codec.go

View workflow job for this annotation

GitHub Actions / sonarcloud

type assertion must be checked (forcetypeassert)
}
18 changes: 17 additions & 1 deletion consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ func (o *optionConsumer) apply(opts ...OptionConsumer) error {
return nil
}

// WithCallbackBatch to set wkafka consumer's callback function.
// - Default is json.Unmarshal, use WithDecode option to add custom decode function.
// - If [][]byte then default decode function will be skipped.
func WithCallbackBatch[T any](fn func(ctx context.Context, msg []T) error) OptionConsumer {
return func(o *optionConsumer) error {
o.Consumer = &consumerBatch[T]{
Expand All @@ -143,11 +146,24 @@ func WithCallbackBatch[T any](fn func(ctx context.Context, msg []T) error) Optio
}
}

// WithCallback to set wkafka consumer's callback function.
// - Default is json.Unmarshal, use WithDecode option to add custom decode function.
// - If []byte then default decode function will be skipped.
func WithCallback[T any](fn func(ctx context.Context, msg T) error) OptionConsumer {
return func(o *optionConsumer) error {
var decode func(raw []byte, r *kgo.Record) (T, error)

var msg T
switch any(msg).(type) {
case []byte:
decode = codecByte[T]{}.Decode
default:
decode = codecJSON[T]{}.Decode
}

o.Consumer = &consumerSingle[T]{
Process: fn,
Decode: codecJSON[T]{}.Decode,
Decode: decode,
}

return nil
Expand Down
4 changes: 1 addition & 3 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ func Test_GroupConsuming(t *testing.T) {

slog.Info("topic created", slog.String("topic", topic.Name))

byteProducer, err := wkafka.NewProducer(client, wkafka.ProducerConfig[Data]{
Topic: topic.Name,
})
byteProducer, err := wkafka.NewProducer[Data](client, topic.Name)
if err != nil {
t.Fatalf("NewProducer() error = %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ var (
ErrNotImplemented = fmt.Errorf("not implemented")
ErrClientClosed = fmt.Errorf("client closed")
ErrNilData = fmt.Errorf("nil data")
// ErrSkip is use to skip message in the PreCheck hook.
// ErrSkip is use to skip message in the PreCheck hook or Decode function.
ErrSkip = fmt.Errorf("skip message")
// ErrInvalidCompression for producer setting check.
ErrInvalidCompression = fmt.Errorf("invalid compression")
Expand Down
2 changes: 1 addition & 1 deletion example/produce/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func run(ctx context.Context, _ *sync.WaitGroup) error {
},
}

producer, err := wkafka.NewProducer(client, wkafka.ProducerConfig[*Data]{})
producer, err := wkafka.NewProducer[*Data](client, "test")
if err != nil {
return err
}
Expand Down
93 changes: 67 additions & 26 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package wkafka

import (
"context"
"errors"
"fmt"

"github.com/twmb/franz-go/pkg/kgo"
Expand All @@ -16,11 +17,7 @@ type Producer[T any] interface {
Produce(ctx context.Context, data ...T) error
}

type ProducerHook interface {
ProduceHook(r *Record)
}

type ProducerConfig[T any] struct {
type producerConfig[T any] struct {
// Topic is the default topic to produce to.
Topic string
// Headers is the default headers to produce with it.
Expand All @@ -29,20 +26,67 @@ type ProducerConfig[T any] struct {
// - If data is []byte, Encode will be ignored.
// - This works after Hook and record.Value is nil.
Encode func(T) ([]byte, error)
// Hook is use to modify record before produce.
Hook func(T, *Record) error
}

type OptionProducer[T any] func(*producerConfig[T]) error

func (c *producerConfig[T]) apply(opts ...OptionProducer[T]) error {
for _, opt := range opts {
if err := opt(c); err != nil {
return err
}
}

return nil
}

// WithEncoder to set encoder function.
func WithEncoder[T any](fn func(T) ([]byte, error)) OptionProducer[T] {
return func(o *producerConfig[T]) error {
o.Encode = fn

return nil
}
}

// WithHeaders to append headers.
func WithHeaders[T any](headers ...Header) OptionProducer[T] {
return func(o *producerConfig[T]) error {
o.Headers = append(o.Headers, headers...)

return nil
}
}

// WithHook to set hook function.
// - Hook will be called before Encoder.
// - If Hook return ErrSkip, record will be skip.
// - If Hook not set any value to record, Encoder will be called.
func WithHook[T any](fn func(T, *Record) error) OptionProducer[T] {
return func(o *producerConfig[T]) error {
o.Hook = fn

return nil
}
}

func NewProducer[T any](client *Client, cfg ProducerConfig[T]) (Producer[T], error) {
// NewProducer to create a new procuder with type.
// - If data is []byte, Encoder will be ignored.
func NewProducer[T any](client *Client, topic string, opts ...OptionProducer[T]) (Producer[T], error) {
var encode func(data T) ([]byte, error)

var value T
switch any(value).(type) {
case []byte:
encode = nil
encode = codecByte[T]{}.Encode
default:
encode = codecJSON[T]{}.Encode
}

setCfg := ProducerConfig[T]{
setCfg := &producerConfig[T]{
Topic: topic,
Headers: []Header{
{
Key: "server",
Expand All @@ -52,39 +96,35 @@ func NewProducer[T any](client *Client, cfg ProducerConfig[T]) (Producer[T], err
Encode: encode,
}

if cfg.Topic != "" {
setCfg.Topic = cfg.Topic
}

if cfg.Headers != nil {
setCfg.Headers = append(setCfg.Headers, cfg.Headers...)
}

if cfg.Encode != nil {
setCfg.Encode = cfg.Encode
if err := setCfg.apply(opts...); err != nil {
return nil, fmt.Errorf("apply options: %w", err)
}

return &produce[T]{
ProducerConfig: setCfg,
producerConfig: *setCfg,
produceRaw: client.ProduceRaw,
}, nil
}

type produce[T any] struct {
ProducerConfig[T]
producerConfig[T]
produceRaw func(ctx context.Context, records []*Record) error
}

func (p *produce[T]) Produce(ctx context.Context, data ...T) error {
records := make([]*Record, len(data))
records := make([]*Record, 0, len(data))

for i, d := range data {
for _, d := range data {
record, err := p.prepare(d)
if err != nil {
if errors.Is(err, ErrSkip) {
continue
}

return fmt.Errorf("prepare record: %w", err)
}

records[i] = record
records = append(records, record)
}

return p.produceRaw(ctx, records)
Expand All @@ -96,9 +136,10 @@ func (p *produce[T]) prepare(data T) (*Record, error) {
Topic: p.Topic,
}

// check data has Hook interface
if data, ok := any(data).(ProducerHook); ok {
data.ProduceHook(record)
if p.Hook != nil {
if err := p.Hook(data, record); err != nil {
return nil, err
}
}

if record.Value != nil {
Expand Down
8 changes: 4 additions & 4 deletions producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (d *testData) ProduceHook(r *Record) {

func Test_produce_Produce(t *testing.T) {
type fields[T any] struct {
Config ProducerConfig[T]
Config producerConfig[T]
ProduceRaw func(t *testing.T) func(ctx context.Context, records []*kgo.Record) error
}
type args struct {
Expand All @@ -44,7 +44,7 @@ func Test_produce_Produce(t *testing.T) {
{
name: "test",
fields: fields[*testData]{
Config: ProducerConfig[*testData]{
Config: producerConfig[*testData]{
Topic: "test",
Headers: []Header{
{
Expand Down Expand Up @@ -97,7 +97,7 @@ func Test_produce_Produce(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &produce[*testData]{
ProducerConfig: tt.fields.Config,
producerConfig: tt.fields.Config,
produceRaw: tt.fields.ProduceRaw(t),
}

Expand All @@ -110,7 +110,7 @@ func Test_produce_Produce(t *testing.T) {

func BenchmarkProduce(b *testing.B) {
p := &produce[*testData]{
ProducerConfig: ProducerConfig[*testData]{
producerConfig: producerConfig[*testData]{
Topic: "test",
Headers: []Header{
{
Expand Down
9 changes: 7 additions & 2 deletions sasl.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import (
"github.com/twmb/franz-go/pkg/sasl/scram"
)

var (
ScramSha256 = "SCRAM-SHA-256"
ScramSha512 = "SCRAM-SHA-512"
)

type SaslConfigs []SalsConfig

func (c SaslConfigs) Generate() ([]sasl.Mechanism, error) {
Expand Down Expand Up @@ -101,9 +106,9 @@ func (s SaslSCRAM) Generate() (sasl.Mechanism, error) {
}

switch s.Algorithm {
case "SCRAM-SHA-256":
case ScramSha256:
return auth.AsSha256Mechanism(), nil
case "SCRAM-SHA-512":
case ScramSha512:
return auth.AsSha512Mechanism(), nil
default:
return nil, fmt.Errorf("invalid algorithm %q", s.Algorithm)
Expand Down

0 comments on commit 0a79414

Please sign in to comment.