Skip to content

Commit

Permalink
Use json.RawMessage to delay serialisation/deserialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
aidantwoods committed Nov 18, 2023
1 parent 8199925 commit dbf6f1f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 18 deletions.
42 changes: 24 additions & 18 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,30 @@ import (

// Token is a set of paseto claims, and a footer
type Token struct {
claims map[string]tokenValue
claims map[string]json.RawMessage
footer []byte
}

// NewToken returns a token with no claims and no footer.
func NewToken() Token {
return Token{make(map[string]tokenValue), nil}
return Token{make(map[string]json.RawMessage), nil}
}

func makeToken(claims map[string]json.RawMessage, footer []byte) (*Token, error) {
tokenValueClaims := make(map[string]json.RawMessage)

token := Token{tokenValueClaims, footer}

for key, value := range claims {
token.claims[key] = value
}

return &token, nil
}

// MakeToken allows specifying both claims and a footer.
func MakeToken(claims map[string]interface{}, footer []byte) (*Token, error) {
tokenValueClaims := make(map[string]tokenValue)
tokenValueClaims := make(map[string]json.RawMessage)

token := Token{tokenValueClaims, footer}

Expand All @@ -37,12 +49,12 @@ func MakeToken(claims map[string]interface{}, footer []byte) (*Token, error) {
// NewTokenFromClaimsJSON parses the JSON using encoding/json in claimsData
// and returns a token with those claims, and the specified footer.
func NewTokenFromClaimsJSON(claimsData []byte, footer []byte) (*Token, error) {
var claims map[string]interface{}
var claims map[string]json.RawMessage
if err := json.Unmarshal(claimsData, &claims); err != nil {
return nil, err
}

return MakeToken(claims, footer)
return makeToken(claims, footer)
}

// Set sets the key with the specified value. Note that this value needs to
Expand All @@ -51,7 +63,7 @@ func NewTokenFromClaimsJSON(claimsData []byte, footer []byte) (*Token, error) {
func (token *Token) Set(key string, value interface{}) error {
return t.Chain[any](
marshalTokenValue(value)).
AndThen(func(value tokenValue) t.Result[any] {
AndThen(func(value json.RawMessage) t.Result[any] {
token.claims[key] = value
return t.Ok[any](nil)
}).
Expand All @@ -67,7 +79,7 @@ func (t Token) Get(key string, output interface{}) (err error) {
return fmt.Errorf("value for key `%s' not present in claims", key)
}

if err := json.Unmarshal(v.rawValue, &output); err != nil {
if err := json.Unmarshal(v, &output); err != nil {
output = nil
return err
}
Expand Down Expand Up @@ -118,7 +130,7 @@ func (t Token) Claims() map[string]interface{} {

for key, value := range t.claims {
var claimValue interface{}
if err := json.Unmarshal(value.rawValue, &claimValue); err != nil {
if err := json.Unmarshal(value, &claimValue); err != nil {
// we only store claims that have gone through json.Marshal
// it is *very* unexpected if this is not reversable
panic(err)
Expand All @@ -134,7 +146,7 @@ func (t Token) Claims() map[string]interface{} {
func (token Token) ClaimsJSON() []byte {
// these were *just* unmarshalled (and a top level of string keys added)
// it is *very* unexpected if this is not reversable
data := t.NewResult(json.Marshal(token.Claims())).
data := t.NewResult(json.Marshal(token.claims)).
Expect("internal claims data should be well formed JSON")

return data
Expand Down Expand Up @@ -200,16 +212,10 @@ func (t Token) V4Encrypt(key V4SymmetricKey, implicit []byte) string {
return v4LocalEncrypt(t.packet(), key, implicit, nil).encoded()
}

type tokenValue struct {
// we store the encoded value, and let json.Unmarshal take care of
// conversion
rawValue []byte
}

func newTokenValue(bytes []byte) tokenValue {
return tokenValue{bytes}
func newTokenValue(bytes []byte) json.RawMessage {
return json.RawMessage(bytes)
}

func marshalTokenValue(value interface{}) t.Result[tokenValue] {
func marshalTokenValue(value interface{}) t.Result[json.RawMessage] {
return t.Map(t.NewResult(json.Marshal(value)), newTokenValue)
}
33 changes: 33 additions & 0 deletions token_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package paseto_test

import (
"math"
"testing"
"time"

"aidanwoods.dev/go-paseto"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -285,3 +287,34 @@ func TestReadmePublicExample(t *testing.T) {
string(token.Footer()),
)
}

func TestBigUint64Claim(t *testing.T) {
token := paseto.NewToken()
var claims struct {
ID uint64 `json:"id"`
}
claims.ID = math.MaxUint64 // 18446744073709551615
require.NoError(t, token.Set("data", claims), "must be able to write a uint64 value to claims")
token.SetExpiration(time.Now().Add(time.Minute))
token.SetNotBefore(time.Now())
token.SetIssuedAt(time.Now())

secretKey := paseto.NewV4AsymmetricSecretKey()

signed := token.V4Sign(secretKey, nil)

parser := paseto.NewParser()
publicKey := secretKey.Public()

parsedToken, err := parser.ParseV4Public(publicKey, signed, nil)
require.NoError(t, err)
inputJson := token.ClaimsJSON()
parsedJson := parsedToken.ClaimsJSON()
require.JSONEq(t, string(inputJson), string(parsedJson))
var result struct {
ID uint64 `json:"id"`
}
// Returns an error because the value of ID is 18446744073709552000
require.NoError(t, parsedToken.Get("data", &result), "must decode data claims")
require.Equal(t, claims.ID, result.ID, "ID should be equal")
}

0 comments on commit dbf6f1f

Please sign in to comment.