From be3f2691eb90373cb5463fe9d864c12fff990b6d Mon Sep 17 00:00:00 2001 From: ilya-korotya Date: Tue, 17 Sep 2024 14:56:12 +0200 Subject: [PATCH] add parse connection strung function; fix linter --- auth.go | 62 +++++++++++++++++++++++++++++++++++++++++--- auth_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/auth.go b/auth.go index da97909..24e7460 100644 --- a/auth.go +++ b/auth.go @@ -9,6 +9,9 @@ import ( "log" "math/big" "net/http" + "net/url" + "strconv" + "strings" "time" "github.com/ethereum/go-ethereum/common" @@ -20,6 +23,7 @@ import ( "github.com/iden3/go-iden3-auth/v2/proofs" "github.com/iden3/go-iden3-auth/v2/pubsignals" "github.com/iden3/go-iden3-auth/v2/state" + core "github.com/iden3/go-iden3-core/v2" "github.com/iden3/go-iden3-core/v2/w3c" "github.com/iden3/go-jwz/v2" schemaloaders "github.com/iden3/go-schema-processor/v2/loaders" @@ -153,6 +157,58 @@ func newOpts() verifierOpts { } } +// ParseConnectionString parses the connection string and returns a map of state resolvers. +// The connection string format is as follows: +// +// chainID1=rpcURL1|contractAddress1;chainID2=rpcURL2|contractAddress2;... +// +// Each chainID is an integer representing the blockchain ID. +// Each rpcURL is the URL of the RPC endpoint for the blockchain. +// Each contractAddress is the address of the contract on the blockchain. +// The function returns an error if the connection string is in an invalid format or if the contract address is invalid. +// If the length of the result is 0, it means no connection info was found and an error is returned. +func ParseConnectionString(str string) (map[string]pubsignals.StateResolver, error) { + str = strings.TrimSpace(str) + result := make(map[string]pubsignals.StateResolver) + connectionInfo := strings.Split(str, ";") + for _, chain := range connectionInfo { + parts := strings.Split(chain, "=") + if len(parts) != 2 { + return nil, errors.Errorf("invalid format: '%s'", chain) + } + chainIDstr := parts[0] + rpcToContractAddress := strings.Split(parts[1], "|") + if len(rpcToContractAddress) != 2 { + return nil, errors.Errorf("invalid format: '%s'", parts[1]) + } + + _, err := url.ParseRequestURI(rpcToContractAddress[0]) + if err != nil { + return nil, errors.Errorf("invalid rpc url: '%s'", rpcToContractAddress[0]) + } + if common.HexToAddress(rpcToContractAddress[1]).Hex() == (common.Address{}).Hex() { + return nil, errors.Errorf("invalid contract address: '%s'", rpcToContractAddress[1]) + } + + chainID, err := strconv.Atoi(chainIDstr) + if err != nil { + return nil, errors.Errorf("invalid chain id: '%s'", chainIDstr) + } + //nolint:gosec // integer overflow is not possible here + blockchain, network, err := core.NetworkByChainID(core.ChainID(chainID)) + if err != nil { + return nil, errors.Errorf("invalid chain id: '%s'", chainIDstr) + } + + c := state.NewETHResolver(rpcToContractAddress[0], rpcToContractAddress[1]) + result[fmt.Sprintf("%s:%s", blockchain, network)] = c + } + if len(result) == 0 { + return nil, errors.New("no connection info found") + } + return result, nil +} + // NewVerifier returns setup instance of auth library func NewVerifier( keyLoader loaders.VerificationKeyLoader, @@ -239,7 +295,7 @@ func (v *Verifier) SetupAuthV2ZKPPacker() error { // SetupJWSPacker sets the JWS packer for the VerifierBuilder. func (v *Verifier) SetupJWSPacker(didResolver packers.DIDResolverHandlerFunc) error { - signerFnStub := packers.SignerResolverHandlerFunc(func(kid string) (crypto.Signer, error) { + signerFnStub := packers.SignerResolverHandlerFunc(func(_ string) (crypto.Signer, error) { return nil, nil }) jwsPacker := packers.NewJWSPacker(didResolver, signerFnStub) @@ -519,9 +575,9 @@ func verifyGroupIDMathch(linkID *big.Int, groupID int, requestID uint32, groupID // VerifyJWZ performs verification of jwz token func (v *Verifier) VerifyJWZ( - ctx context.Context, + _ context.Context, token string, - opts ...pubsignals.VerifyOpt, + _ ...pubsignals.VerifyOpt, ) (t *jwz.Token, err error) { _, _, err = v.packageManager.Unpack([]byte(token)) diff --git a/auth_test.go b/auth_test.go index 7699c14..bda75df 100644 --- a/auth_test.go +++ b/auth_test.go @@ -1225,3 +1225,75 @@ func TestFullVerifyLinkedProofsVerification(t *testing.T) { require.NotNil(t, returnMsg) schemaLoader.assert(t) } + +func TestParseConnectionString(t *testing.T) { + tests := []struct { + name string + input string + want map[string]pubsignals.StateResolver + wantErr bool + errMsg string + }{ + { + name: "Valid input", + input: "1=http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678", + want: map[string]pubsignals.StateResolver{"eth:main": state.NewETHResolver("http://localhost:8545", "0x1234567890abcdef1234567890abcdef12345678")}, + wantErr: false, + }, + { + name: "Invalid format - missing =", + input: "1http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678", + wantErr: true, + errMsg: "invalid format: '1http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678'", + }, + { + name: "Invalid format - missing |", + input: "1=http://localhost:85450x1234567890abcdef1234567890abcdef12345678", + wantErr: true, + errMsg: "invalid format: 'http://localhost:85450x1234567890abcdef1234567890abcdef12345678'", + }, + { + name: "Invalid RPC URL", + input: "1=invalid_url|0x1234567890abcdef1234567890abcdef12345678", + wantErr: true, + errMsg: "invalid rpc url: 'invalid_url'", + }, + { + name: "Invalid contract address", + input: "1=http://localhost:8545|invalid_address", + wantErr: true, + errMsg: "invalid contract address: 'invalid_address'", + }, + { + name: "Invalid chain ID", + input: "invalid_chain_id=http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678", + wantErr: true, + errMsg: "invalid chain id: 'invalid_chain_id'", + }, + { + name: "Unknown chain ID", + input: "9999=http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678", + wantErr: true, + errMsg: "invalid chain id: '9999'", + }, + { + name: "Empty input", + input: "", + wantErr: true, + errMsg: "invalid format: ''", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseConnectionString(tt.input) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +}