From dbf6f1f666f75468a0fd4918ba08a630eeec00a4 Mon Sep 17 00:00:00 2001 From: Aidan Woods Date: Sat, 18 Nov 2023 16:44:53 +0000 Subject: [PATCH] Use json.RawMessage to delay serialisation/deserialisation --- token.go | 42 ++++++++++++++++++++++++------------------ token_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 18 deletions(-) diff --git a/token.go b/token.go index 335ff56..f135724 100644 --- a/token.go +++ b/token.go @@ -10,18 +10,30 @@ import ( // Token is a set of paseto claims, and a footer type Token struct { - claims map[string]tokenValue + claims map[string]json.RawMessage footer []byte } // NewToken returns a token with no claims and no footer. func NewToken() Token { - return Token{make(map[string]tokenValue), nil} + return Token{make(map[string]json.RawMessage), nil} +} + +func makeToken(claims map[string]json.RawMessage, footer []byte) (*Token, error) { + tokenValueClaims := make(map[string]json.RawMessage) + + token := Token{tokenValueClaims, footer} + + for key, value := range claims { + token.claims[key] = value + } + + return &token, nil } // MakeToken allows specifying both claims and a footer. func MakeToken(claims map[string]interface{}, footer []byte) (*Token, error) { - tokenValueClaims := make(map[string]tokenValue) + tokenValueClaims := make(map[string]json.RawMessage) token := Token{tokenValueClaims, footer} @@ -37,12 +49,12 @@ func MakeToken(claims map[string]interface{}, footer []byte) (*Token, error) { // NewTokenFromClaimsJSON parses the JSON using encoding/json in claimsData // and returns a token with those claims, and the specified footer. func NewTokenFromClaimsJSON(claimsData []byte, footer []byte) (*Token, error) { - var claims map[string]interface{} + var claims map[string]json.RawMessage if err := json.Unmarshal(claimsData, &claims); err != nil { return nil, err } - return MakeToken(claims, footer) + return makeToken(claims, footer) } // Set sets the key with the specified value. Note that this value needs to @@ -51,7 +63,7 @@ func NewTokenFromClaimsJSON(claimsData []byte, footer []byte) (*Token, error) { func (token *Token) Set(key string, value interface{}) error { return t.Chain[any]( marshalTokenValue(value)). - AndThen(func(value tokenValue) t.Result[any] { + AndThen(func(value json.RawMessage) t.Result[any] { token.claims[key] = value return t.Ok[any](nil) }). @@ -67,7 +79,7 @@ func (t Token) Get(key string, output interface{}) (err error) { return fmt.Errorf("value for key `%s' not present in claims", key) } - if err := json.Unmarshal(v.rawValue, &output); err != nil { + if err := json.Unmarshal(v, &output); err != nil { output = nil return err } @@ -118,7 +130,7 @@ func (t Token) Claims() map[string]interface{} { for key, value := range t.claims { var claimValue interface{} - if err := json.Unmarshal(value.rawValue, &claimValue); err != nil { + if err := json.Unmarshal(value, &claimValue); err != nil { // we only store claims that have gone through json.Marshal // it is *very* unexpected if this is not reversable panic(err) @@ -134,7 +146,7 @@ func (t Token) Claims() map[string]interface{} { func (token Token) ClaimsJSON() []byte { // these were *just* unmarshalled (and a top level of string keys added) // it is *very* unexpected if this is not reversable - data := t.NewResult(json.Marshal(token.Claims())). + data := t.NewResult(json.Marshal(token.claims)). Expect("internal claims data should be well formed JSON") return data @@ -200,16 +212,10 @@ func (t Token) V4Encrypt(key V4SymmetricKey, implicit []byte) string { return v4LocalEncrypt(t.packet(), key, implicit, nil).encoded() } -type tokenValue struct { - // we store the encoded value, and let json.Unmarshal take care of - // conversion - rawValue []byte -} - -func newTokenValue(bytes []byte) tokenValue { - return tokenValue{bytes} +func newTokenValue(bytes []byte) json.RawMessage { + return json.RawMessage(bytes) } -func marshalTokenValue(value interface{}) t.Result[tokenValue] { +func marshalTokenValue(value interface{}) t.Result[json.RawMessage] { return t.Map(t.NewResult(json.Marshal(value)), newTokenValue) } diff --git a/token_test.go b/token_test.go index 1daf14e..0912bfa 100644 --- a/token_test.go +++ b/token_test.go @@ -1,7 +1,9 @@ package paseto_test import ( + "math" "testing" + "time" "aidanwoods.dev/go-paseto" "github.com/stretchr/testify/require" @@ -285,3 +287,34 @@ func TestReadmePublicExample(t *testing.T) { string(token.Footer()), ) } + +func TestBigUint64Claim(t *testing.T) { + token := paseto.NewToken() + var claims struct { + ID uint64 `json:"id"` + } + claims.ID = math.MaxUint64 // 18446744073709551615 + require.NoError(t, token.Set("data", claims), "must be able to write a uint64 value to claims") + token.SetExpiration(time.Now().Add(time.Minute)) + token.SetNotBefore(time.Now()) + token.SetIssuedAt(time.Now()) + + secretKey := paseto.NewV4AsymmetricSecretKey() + + signed := token.V4Sign(secretKey, nil) + + parser := paseto.NewParser() + publicKey := secretKey.Public() + + parsedToken, err := parser.ParseV4Public(publicKey, signed, nil) + require.NoError(t, err) + inputJson := token.ClaimsJSON() + parsedJson := parsedToken.ClaimsJSON() + require.JSONEq(t, string(inputJson), string(parsedJson)) + var result struct { + ID uint64 `json:"id"` + } + // Returns an error because the value of ID is 18446744073709552000 + require.NoError(t, parsedToken.Get("data", &result), "must decode data claims") + require.Equal(t, claims.ID, result.ID, "ID should be equal") +}