Skip to content

Commit

Permalink
add parse connection strung function; fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-korotya committed Sep 17, 2024
1 parent 288059c commit be3f269
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 3 deletions.
62 changes: 59 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"log"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/ethereum/go-ethereum/common"
Expand All @@ -20,6 +23,7 @@ import (
"github.com/iden3/go-iden3-auth/v2/proofs"
"github.com/iden3/go-iden3-auth/v2/pubsignals"
"github.com/iden3/go-iden3-auth/v2/state"
core "github.com/iden3/go-iden3-core/v2"
"github.com/iden3/go-iden3-core/v2/w3c"
"github.com/iden3/go-jwz/v2"
schemaloaders "github.com/iden3/go-schema-processor/v2/loaders"
Expand Down Expand Up @@ -153,6 +157,58 @@ func newOpts() verifierOpts {
}
}

// ParseConnectionString parses the connection string and returns a map of state resolvers.
// The connection string format is as follows:
//
// chainID1=rpcURL1|contractAddress1;chainID2=rpcURL2|contractAddress2;...
//
// Each chainID is an integer representing the blockchain ID.
// Each rpcURL is the URL of the RPC endpoint for the blockchain.
// Each contractAddress is the address of the contract on the blockchain.
// The function returns an error if the connection string is in an invalid format or if the contract address is invalid.
// If the length of the result is 0, it means no connection info was found and an error is returned.
func ParseConnectionString(str string) (map[string]pubsignals.StateResolver, error) {
str = strings.TrimSpace(str)
result := make(map[string]pubsignals.StateResolver)
connectionInfo := strings.Split(str, ";")
for _, chain := range connectionInfo {
parts := strings.Split(chain, "=")
if len(parts) != 2 {
return nil, errors.Errorf("invalid format: '%s'", chain)
}
chainIDstr := parts[0]
rpcToContractAddress := strings.Split(parts[1], "|")
if len(rpcToContractAddress) != 2 {
return nil, errors.Errorf("invalid format: '%s'", parts[1])
}

_, err := url.ParseRequestURI(rpcToContractAddress[0])
if err != nil {
return nil, errors.Errorf("invalid rpc url: '%s'", rpcToContractAddress[0])
}
if common.HexToAddress(rpcToContractAddress[1]).Hex() == (common.Address{}).Hex() {
return nil, errors.Errorf("invalid contract address: '%s'", rpcToContractAddress[1])
}

chainID, err := strconv.Atoi(chainIDstr)
if err != nil {
return nil, errors.Errorf("invalid chain id: '%s'", chainIDstr)
}
//nolint:gosec // integer overflow is not possible here
blockchain, network, err := core.NetworkByChainID(core.ChainID(chainID))
if err != nil {
return nil, errors.Errorf("invalid chain id: '%s'", chainIDstr)
}

c := state.NewETHResolver(rpcToContractAddress[0], rpcToContractAddress[1])
result[fmt.Sprintf("%s:%s", blockchain, network)] = c
}
if len(result) == 0 {
return nil, errors.New("no connection info found")
}
return result, nil
}

// NewVerifier returns setup instance of auth library
func NewVerifier(
keyLoader loaders.VerificationKeyLoader,
Expand Down Expand Up @@ -239,7 +295,7 @@ func (v *Verifier) SetupAuthV2ZKPPacker() error {
// SetupJWSPacker sets the JWS packer for the VerifierBuilder.
func (v *Verifier) SetupJWSPacker(didResolver packers.DIDResolverHandlerFunc) error {

signerFnStub := packers.SignerResolverHandlerFunc(func(kid string) (crypto.Signer, error) {
signerFnStub := packers.SignerResolverHandlerFunc(func(_ string) (crypto.Signer, error) {
return nil, nil
})
jwsPacker := packers.NewJWSPacker(didResolver, signerFnStub)
Expand Down Expand Up @@ -519,9 +575,9 @@ func verifyGroupIDMathch(linkID *big.Int, groupID int, requestID uint32, groupID

// VerifyJWZ performs verification of jwz token
func (v *Verifier) VerifyJWZ(
ctx context.Context,
_ context.Context,
token string,
opts ...pubsignals.VerifyOpt,
_ ...pubsignals.VerifyOpt,
) (t *jwz.Token, err error) {

_, _, err = v.packageManager.Unpack([]byte(token))
Expand Down
72 changes: 72 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1225,3 +1225,75 @@ func TestFullVerifyLinkedProofsVerification(t *testing.T) {
require.NotNil(t, returnMsg)
schemaLoader.assert(t)
}

func TestParseConnectionString(t *testing.T) {
tests := []struct {
name string
input string
want map[string]pubsignals.StateResolver
wantErr bool
errMsg string
}{
{
name: "Valid input",
input: "1=http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678",
want: map[string]pubsignals.StateResolver{"eth:main": state.NewETHResolver("http://localhost:8545", "0x1234567890abcdef1234567890abcdef12345678")},
wantErr: false,
},
{
name: "Invalid format - missing =",
input: "1http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678",
wantErr: true,
errMsg: "invalid format: '1http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678'",
},
{
name: "Invalid format - missing |",
input: "1=http://localhost:85450x1234567890abcdef1234567890abcdef12345678",
wantErr: true,
errMsg: "invalid format: 'http://localhost:85450x1234567890abcdef1234567890abcdef12345678'",
},
{
name: "Invalid RPC URL",
input: "1=invalid_url|0x1234567890abcdef1234567890abcdef12345678",
wantErr: true,
errMsg: "invalid rpc url: 'invalid_url'",
},
{
name: "Invalid contract address",
input: "1=http://localhost:8545|invalid_address",
wantErr: true,
errMsg: "invalid contract address: 'invalid_address'",
},
{
name: "Invalid chain ID",
input: "invalid_chain_id=http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678",
wantErr: true,
errMsg: "invalid chain id: 'invalid_chain_id'",
},
{
name: "Unknown chain ID",
input: "9999=http://localhost:8545|0x1234567890abcdef1234567890abcdef12345678",
wantErr: true,
errMsg: "invalid chain id: '9999'",
},
{
name: "Empty input",
input: "",
wantErr: true,
errMsg: "invalid format: ''",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseConnectionString(tt.input)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

0 comments on commit be3f269

Please sign in to comment.