Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement strict decoding for JetStream API requests #5858

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions server/jetstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type JetStreamConfig struct {
Domain string `json:"domain,omitempty"`
CompressOK bool `json:"compress_ok,omitempty"`
UniqueTag string `json:"unique_tag,omitempty"`
Strict bool `json:"strict,omitempty"`
}

// Statistics about JetStream for this server.
Expand Down Expand Up @@ -461,6 +462,7 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error {
s.Noticef("")
}
s.Noticef("---------------- JETSTREAM ----------------")
s.Noticef(" Strict: %t", cfg.Strict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove these from banner

Copy link
Contributor Author

@caspervonb caspervonb Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd vote to keep it, it's a sanity notice that you got the configuration you expected?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's only print it if cfg.Strict == true.

s.Noticef(" Max Memory: %s", friendlyBytes(cfg.MaxMemory))
s.Noticef(" Max Storage: %s", friendlyBytes(cfg.MaxStore))
s.Noticef(" Store Directory: \"%s\"", cfg.StoreDir)
Expand Down Expand Up @@ -553,6 +555,7 @@ func (s *Server) restartJetStream() error {
MaxMemory: opts.JetStreamMaxMemory,
MaxStore: opts.JetStreamMaxStore,
Domain: opts.JetStreamDomain,
Strict: opts.JetStreamStrict,
}
s.Noticef("Restarting JetStream")
err := s.EnableJetStream(&cfg)
Expand Down Expand Up @@ -2525,6 +2528,9 @@ func (s *Server) dynJetStreamConfig(storeDir string, maxStore, maxMem int64) *Je

opts := s.getOpts()

// Strict mode.
jsc.Strict = opts.JetStreamStrict

// Sync options.
jsc.SyncInterval = opts.SyncInterval
jsc.SyncAlways = opts.SyncAlways
Expand Down
74 changes: 52 additions & 22 deletions server/jetstream_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"os"
"path/filepath"
Expand Down Expand Up @@ -1066,6 +1067,35 @@ func (s *Server) getRequestInfo(c *client, raw []byte) (pci *ClientInfo, acc *Ac
return &ci, acc, hdr, msg, nil
}

func (s *Server) unmarshalRequest(c *client, acc *Account, subject string, msg []byte, v interface{}) error {
decoder := json.NewDecoder(bytes.NewReader(msg))
decoder.DisallowUnknownFields()

for {
err := decoder.Decode(v)

if err == io.EOF {
return nil
}

if err != nil {
var syntaxErr *json.SyntaxError
if errors.As(err, &syntaxErr) {
err = fmt.Errorf("%w at offset %d", err, syntaxErr.Offset)
}

c.RateLimitWarnf("Invalid JetStream request '%s > %s': %s", acc, subject, err)

var config = s.JetStreamConfig()
if config.Strict {
return err
caspervonb marked this conversation as resolved.
Show resolved Hide resolved
}

return json.Unmarshal(msg, v)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not convinced by this block, in that if strict mode is disabled, we can in effect unmarshal twice. (Unmarshalling JSON is actually quite expensive on CPU so need to be careful on hot paths.)

Why this vs just having one decoder and setting DisallowUnknownFields() based on the Strict setting?

Copy link
Contributor Author

@caspervonb caspervonb Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first pass did just that, but was discussed offline with @ripienaar that we want to log marshaling failures even if strict is set to false to have a softer error path but keep the current behavior for existing deployments where it would at-least surface in the logs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A misbehaving client is going to a) spam log lines and b) unmarshal twice for each request. I don't think that's wise, needs further discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Derek suggested that for 2.11 while we are in soft mode we should still report the problem so client authors can discover and remediate issues so he suggested this approach. It's nice a nice assist before we enable this by default to reject requests but I agree with you it will spam logs until things are fixed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think client authors could enable this as an opt-in too, in order to avoid spamming there is rateLimitFormatWarnf which would log less

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we discussed maybe a header on specific requests - but we have no api behaviours on headers so I am a bit reluctant.

Other option is something at the request level like we have action on consumer or pedantic now - big ripple effect on clients having to support it and ultimately this does not help users. The point is to identify bad clients that perhaps are not compatible or have bugs. So making it opt-in only for clients that support it is a no win

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rateLimitFormatWarnf is a good idea though for sure

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that any opt-in for the client misses the point that Derek had - it's about detecting problems with non-tier-1 and older clients, and those will not have such opt-in implemented.

Copy link
Contributor Author

@caspervonb caspervonb Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the log lines to use rate limited warnings.

}
}
}

func (a *Account) trackAPI() {
a.mu.RLock()
jsa := a.js
Expand Down Expand Up @@ -1195,7 +1225,7 @@ func (s *Server) jsTemplateCreateRequest(sub *subscription, c *client, _ *Accoun
}

var cfg StreamTemplateConfig
if err := json.Unmarshal(msg, &cfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1252,7 +1282,7 @@ func (s *Server) jsTemplateNamesRequest(sub *subscription, c *client, _ *Account
var offset int
if !isEmptyRequest(msg) {
var req JSApiStreamTemplatesRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1435,7 +1465,7 @@ func (s *Server) jsStreamCreateRequest(sub *subscription, c *client, _ *Account,
}

var cfg StreamConfigRequest
if err := json.Unmarshal(msg, &cfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1546,7 +1576,7 @@ func (s *Server) jsStreamUpdateRequest(sub *subscription, c *client, _ *Account,
return
}
var ncfg StreamConfigRequest
if err := json.Unmarshal(msg, &ncfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &ncfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1645,7 +1675,7 @@ func (s *Server) jsStreamNamesRequest(sub *subscription, c *client, _ *Account,

if !isEmptyRequest(msg) {
var req JSApiStreamNamesRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1775,7 +1805,7 @@ func (s *Server) jsStreamListRequest(sub *subscription, c *client, _ *Account, s

if !isEmptyRequest(msg) {
var req JSApiStreamListRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1945,7 +1975,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s
var offset int
if !isEmptyRequest(msg) {
var req JSApiStreamInfoRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2302,7 +2332,7 @@ func (s *Server) jsStreamRemovePeerRequest(sub *subscription, c *client, _ *Acco
}

var req JSApiStreamRemovePeerRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2382,7 +2412,7 @@ func (s *Server) jsLeaderServerRemoveRequest(sub *subscription, c *client, _ *Ac
}

var req JSApiMetaServerRemoveRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2485,7 +2515,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _
var resp = JSApiStreamUpdateResponse{ApiResponse: ApiResponse{Type: JSApiStreamUpdateResponseType}}

var req JSApiMetaServerStreamMoveRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand All @@ -2512,7 +2542,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _
if ok {
sa, ok := streams[streamName]
if ok {
cfg = *sa.Config.clone()
cfg = *sa.Config
streamFound = true
currPeers = sa.Group.Peers
currCluster = sa.Group.Cluster
Expand Down Expand Up @@ -2654,7 +2684,7 @@ func (s *Server) jsLeaderServerStreamCancelMoveRequest(sub *subscription, c *cli
if ok {
sa, ok := streams[streamName]
if ok {
cfg = *sa.Config.clone()
cfg = *sa.Config
streamFound = true
currPeers = sa.Group.Peers
}
Expand Down Expand Up @@ -2829,7 +2859,7 @@ func (s *Server) jsLeaderStepDownRequest(sub *subscription, c *client, _ *Accoun

if !isEmptyRequest(msg) {
var req JSApiLeaderStepdownRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3035,7 +3065,7 @@ func (s *Server) jsMsgDeleteRequest(sub *subscription, c *client, _ *Account, su
return
}
var req JSApiMsgDeleteRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3154,7 +3184,7 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje
return
}
var req JSApiMsgGetRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3297,7 +3327,7 @@ func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account,
var purgeRequest *JSApiStreamPurgeRequest
if !isEmptyRequest(msg) {
var req JSApiStreamPurgeRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3387,7 +3417,7 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
}

var req JSApiStreamRestoreRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3690,7 +3720,7 @@ func (s *Server) jsStreamSnapshotRequest(sub *subscription, c *client, _ *Accoun
}

var req JSApiStreamSnapshotRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, smsg, s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3888,7 +3918,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
var resp = JSApiConsumerCreateResponse{ApiResponse: ApiResponse{Type: JSApiConsumerCreateResponseType}}

var req CreateConsumerRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4130,7 +4160,7 @@ func (s *Server) jsConsumerNamesRequest(sub *subscription, c *client, _ *Account
var offset int
if !isEmptyRequest(msg) {
var req JSApiConsumersRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4252,7 +4282,7 @@ func (s *Server) jsConsumerListRequest(sub *subscription, c *client, _ *Account,
var offset int
if !isEmptyRequest(msg) {
var req JSApiConsumersRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4563,7 +4593,7 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
var resp = JSApiConsumerPauseResponse{ApiResponse: ApiResponse{Type: JSApiConsumerPauseResponseType}}

if !isEmptyRequest(msg) {
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down
115 changes: 115 additions & 0 deletions server/jetstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24150,6 +24150,121 @@ func TestJetStreamStreamCreatePedanticMode(t *testing.T) {
}
}

func TestJetStreamStrictMode(t *testing.T) {
cfgFmt := []byte(fmt.Sprintf(`
jetstream: {
strict: true
enabled: true
max_file_store: 100MB
store_dir: %s
limits: {duplicate_window: "1m", max_request_batch: 250}
}
`, t.TempDir()))
conf := createConfFile(t, cfgFmt)
s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Error connecting to NATS: %v", err)
}
defer nc.Close()

tests := []struct {
name string
subject string
payload []byte
expectedErr string
}{
{
name: "Stream Create",
subject: "$JS.API.STREAM.CREATE.TEST_STREAM",
payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Stream Update",
subject: "$JS.API.STREAM.UPDATE.TEST_STREAM",
payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Stream Delete",
subject: "$JS.API.STREAM.DELETE.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Stream Info",
subject: "$JS.API.STREAM.INFO.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer Create",
subject: "$JS.API.CONSUMER.CREATE.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"durable_name":"TEST_CONSUMER","ack_policy":"explicit","extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer Delete",
subject: "$JS.API.CONSUMER.DELETE.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Consumer Info",
subject: "$JS.API.CONSUMER.INFO.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Stream List",
subject: "$JS.API.STREAM.LIST",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer List",
subject: "$JS.API.CONSUMER.LIST.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := nc.Request(tt.subject, tt.payload, time.Second*10)
if err != nil {
t.Fatalf("Request failed: %v", err)
}

var apiResp map[string]interface{}
caspervonb marked this conversation as resolved.
Show resolved Hide resolved
if err := json.Unmarshal(resp.Data, &apiResp); err != nil {
t.Fatalf("Error unmarshalling response: %v", err)
}

if apiResp["error"] == nil {
t.Fatalf("Expected error containing %q, but got no error", tt.expectedErr)
}

errorObj, ok := apiResp["error"].(map[string]interface{})
if !ok {
t.Fatalf("Expected error to be an object, got %T", apiResp["error"])
}

errorDescription, ok := errorObj["description"].(string)
if !ok {
t.Fatalf("Expected error description to be a string, got %T", errorObj["description"])
}

if !strings.Contains(errorDescription, tt.expectedErr) {
t.Errorf("Expected error containing %q, but got %q", tt.expectedErr, errorDescription)
}
})
}
}

func addConsumerWithError(t *testing.T, nc *nats.Conn, cfg *CreateConsumerRequest) (*ConsumerInfo, *ApiError) {
t.Helper()
req, err := json.Marshal(cfg)
Expand Down
Loading