Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(axelarnet)!: pick up the new methods for setting general messages in the msg_server and proposal handler #2009

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions x/axelarnet/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func NewProposalHandler(k keeper.Keeper, nexusK types.Nexus, accountK types.Acco
payloadHash := crypto.Keccak256(contractCall.Payload)

msgID, txID, nonce := nexusK.GenerateMessageID(ctx)
msg := nexus.NewGeneralMessage(msgID, sender, recipient, payloadHash, nexus.Approved, txID, nonce, nil)
msg := nexus.NewGeneralMessage_(msgID, sender, recipient, payloadHash, txID, nonce, nil)

events.Emit(ctx, &types.ContractCallSubmitted{
MessageID: msg.ID,
Expand All @@ -123,7 +123,7 @@ func NewProposalHandler(k keeper.Keeper, nexusK types.Nexus, accountK types.Acco
Payload: contractCall.Payload,
})

if err := nexusK.SetNewMessage(ctx, msg); err != nil {
if err := nexusK.SetNewMessage_(ctx, msg); err != nil {
return sdkerrors.Wrap(err, "failed to add general message")
}

Expand Down
22 changes: 3 additions & 19 deletions x/axelarnet/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (s msgServer) CallContract(c context.Context, req *types.CallContractReques
payloadHash := crypto.Keccak256(req.Payload)

msgID, txID, nonce := s.nexus.GenerateMessageID(ctx)
msg := nexus.NewGeneralMessage(msgID, sender, recipient, payloadHash, nexus.Approved, txID, nonce, nil)
msg := nexus.NewGeneralMessage_(msgID, sender, recipient, payloadHash, txID, nonce, nil)

events.Emit(ctx, &types.ContractCallSubmitted{
MessageID: msg.ID,
Expand Down Expand Up @@ -111,7 +111,7 @@ func (s msgServer) CallContract(c context.Context, req *types.CallContractReques
events.Emit(ctx, &feePaidEvent)
}

if err := s.nexus.SetNewMessage(ctx, msg); err != nil {
if err := s.nexus.SetNewMessage_(ctx, msg); err != nil {
return nil, sdkerrors.Wrap(err, "failed to add general message")
}

Expand Down Expand Up @@ -496,26 +496,10 @@ func (s msgServer) RouteMessage(c context.Context, req *types.RouteMessageReques
return nil, fmt.Errorf("message %s not found", req.ID)
}

if !s.nexus.IsChainActivated(ctx, msg.Sender.Chain) {
return nil, fmt.Errorf("chain %s is not activated", msg.GetSourceChain())
}

if !s.nexus.IsChainActivated(ctx, msg.Recipient.Chain) {
return nil, fmt.Errorf("chain %s is not activated", msg.GetDestinationChain())
}

if msg.Type() == nexus.TypeGeneralMessageWithToken {
funcs.MustTrue(s.nexus.IsAssetRegistered(ctx, msg.Recipient.Chain, msg.Asset.GetDenom()))
}

if !msg.Match(req.Payload) {
return nil, fmt.Errorf("payload hash does not match")
}

if !(msg.Is(nexus.Approved) || msg.Is(nexus.Failed)) {
return nil, fmt.Errorf("general message %s already executed", req.ID)
}

// send ibc message if destination is cosmos
if msg.Recipient.Chain.IsFrom(exported.ModuleName) {
bz, err := types.TranslateMessage(msg, req.Payload)
Expand All @@ -534,7 +518,7 @@ func (s msgServer) RouteMessage(c context.Context, req *types.RouteMessageReques
}
}

err := s.nexus.SetMessageProcessing(ctx, msg.ID)
err := s.nexus.SetMessageProcessing_(ctx, msg.ID)
if err != nil {
return nil, err
}
Expand Down
53 changes: 8 additions & 45 deletions x/axelarnet/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ func TestRouteMessage(t *testing.T) {
GetChainByNativeAssetFunc: func(sdk.Context, string) (nexus.Chain, bool) {
return chain, true
},
SetMessageProcessingFunc: func(sdk.Context, string) error {
SetMessageProcessing_Func: func(sdk.Context, string) error {
return nil
},
}
Expand All @@ -1035,18 +1035,6 @@ func TestRouteMessage(t *testing.T) {
server = keeper.NewMsgServerImpl(k, nexusK, bankK, accountK, ibcK)
})

isChainActivated := func(isActivated bool) func() {
return func() {
nexusK.IsChainActivatedFunc = func(_ sdk.Context, chain nexus.Chain) bool { return isActivated }
}
}

isAssetRegistered := func(isRegistered bool) func() {
return func() {
nexusK.IsAssetRegisteredFunc = func(sdk.Context, nexus.Chain, string) bool { return isRegistered }
}
}

isMessageFound := func(isFound bool, status nexus.GeneralMessage_Status) func() {
return func() {
nexusK.GetMessageFunc = func(ctx sdk.Context, messageID string) (nexus.GeneralMessage, bool) {
Expand All @@ -1063,12 +1051,10 @@ func TestRouteMessage(t *testing.T) {
whenMessageIsFromEVM := When("message is from evm", func() {
isMessageFound(true, nexus.Approved)()
msg.Sender.Chain.Module = evmtypes.ModuleName
isChainActivated(true)()
})
whenMessageIsFromCosmos := When("message is from cosmos", func() {
isMessageFound(true, nexus.Approved)()
msg.Sender.Chain.Module = exported.ModuleName
isChainActivated(true)()
})
whenMessageIsToEVM := When("message is to evm", func() {
msg.Recipient.Chain.Module = evmtypes.ModuleName
Expand Down Expand Up @@ -1100,15 +1086,8 @@ func TestRouteMessage(t *testing.T) {
When2(requestIsMade).
Then("should fail", routeFailsWithError("not found")),

When("general message is found", isMessageFound(true, nexus.Approved)).
When("source chain is an EVM chain", func() { msg.Sender.Chain.Module = evmtypes.ModuleName }).
When("chain is not activated", isChainActivated(false)).
When2(requestIsMade).
Then("should fail", routeFailsWithError("not activated")),

whenMessageIsFromEVM.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload does not match", func() {
req = types.NewRouteMessage(
rand.AccAddr(),
Expand All @@ -1121,14 +1100,6 @@ func TestRouteMessage(t *testing.T) {

whenMessageIsFromEVM.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("general message already executed", isMessageFound(true, nexus.Executed)).
When2(requestIsMade).
Then("should fail", routeFailsWithError("already executed")),

whenMessageIsFromEVM.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload with version is invalid", func() {
payload = rand.Bytes(4)
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
Expand All @@ -1138,7 +1109,6 @@ func TestRouteMessage(t *testing.T) {

whenMessageIsFromEVM.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload is invalid", func() {
payload = axelartestutils.PackPayloadWithVersion(types.CosmWasmV1, rand.BytesBetween(100, 500))
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
Expand All @@ -1148,16 +1118,15 @@ func TestRouteMessage(t *testing.T) {

whenMessageIsFromCosmos.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload is invalid", func() {
payload = rand.BytesBetween(100, 500)
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
}).
When2(requestIsMade).
Then("should fail", routeFailsWithError("invalid payload")),

whenMessageIsFromEVM.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload is valid", func() {
payload = randWasmPayload()
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
Expand All @@ -1174,59 +1143,54 @@ func TestRouteMessage(t *testing.T) {

whenMessageIsFromEVM.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload is valid", func() {
payload = randPayload()
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
}).
When2(requestIsMade).
Then("should success", func(t *testing.T) {
_, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req)
fmt.Println(err)
assert.NoError(t, err)
}),

whenMessageIsFromCosmos.
When2(whenMessageIsToEVM).
When("asset is registered", isAssetRegistered(true)).
When("payload is valid", func() {
payload = rand.BytesBetween(100, 500)
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
}).
When2(requestIsMade).
Then("should success", func(t *testing.T) {
_, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req)
fmt.Println(err)
assert.NoError(t, err)
}),

whenMessageIsFromCosmos.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload is valid", func() {
payload = randPayload()
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
}).
When2(requestIsMade).
Then("should success", func(t *testing.T) {
_, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req)
fmt.Println(err)
assert.NoError(t, err)
}),

whenMessageIsFromCosmos.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload is valid", func() {
payload = randWasmPayload()
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
}).
When2(requestIsMade).
Then("should success", func(t *testing.T) {
_, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req)
fmt.Println(err)
assert.NoError(t, err)
}),

whenMessageIsFromEVM.
When2(whenMessageIsToCosmos).
When("asset is registered", isAssetRegistered(true)).
When("payload is valid", func() {
payload = randWasmPayload()
msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes()
Expand All @@ -1245,7 +1209,6 @@ func TestRouteMessage(t *testing.T) {
}).
Then("should success", func(t *testing.T) {
_, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req)
fmt.Println(err)
assert.NoError(t, err)
}),
).Run(t)
Expand Down Expand Up @@ -1304,7 +1267,7 @@ func TestHandleCallContract(t *testing.T) {
})

whenSetNewMessageSucceeds := When("set new message succeeds", func() {
nexusK.SetNewMessageFunc = func(_ sdk.Context, m nexus.GeneralMessage) error {
nexusK.SetNewMessage_Func = func(_ sdk.Context, m nexus.GeneralMessage) error {
msg = m
return m.ValidateBasic()
}
Expand Down Expand Up @@ -1415,7 +1378,7 @@ func TestHandleCallContract(t *testing.T) {
When2(whenChainIsActivated).
When2(whenAddressIsValid).
When("set new message fails", func() {
nexusK.SetNewMessageFunc = func(_ sdk.Context, m nexus.GeneralMessage) error {
nexusK.SetNewMessage_Func = func(_ sdk.Context, m nexus.GeneralMessage) error {
return fmt.Errorf("failed to set message")
}
}).
Expand Down
2 changes: 2 additions & 0 deletions x/axelarnet/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ type Nexus interface {
RateLimitTransfer(ctx sdk.Context, chain nexus.ChainName, asset sdk.Coin, direction nexus.TransferDirection) error
GetMessage(ctx sdk.Context, id string) (m nexus.GeneralMessage, found bool)
SetNewMessage(ctx sdk.Context, m nexus.GeneralMessage) error
SetNewMessage_(ctx sdk.Context, m nexus.GeneralMessage) error
SetMessageProcessing(ctx sdk.Context, id string) error
SetMessageProcessing_(ctx sdk.Context, id string) error
SetMessageExecuted(ctx sdk.Context, id string) error
SetMessageFailed(ctx sdk.Context, id string) error
GenerateMessageID(ctx sdk.Context) (string, []byte, uint64)
Expand Down
Loading
Loading