diff --git a/ably/ably_test.go b/ably/ably_test.go index 89fd0627..9b5cf2f2 100644 --- a/ably/ably_test.go +++ b/ably/ably_test.go @@ -272,6 +272,9 @@ func (rec *MessageRecorder) CheckIfReceived(action ably.ProtoAction, times int) } } } + if times == 0 && times == counter { + return true + } return false } } diff --git a/ably/proto_protocol_message.go b/ably/proto_protocol_message.go index 4b6542c2..2fc460b6 100644 --- a/ably/proto_protocol_message.go +++ b/ably/proto_protocol_message.go @@ -183,6 +183,9 @@ func (msg *protocolMessage) String() string { case actionMessage: return fmt.Sprintf("(action=%q, id=%q, messages=%v)", msg.Action, msg.ConnectionID, msg.Messages) + case actionAuth: + return fmt.Sprintf("(action=%q, id=%q, auth=%v)", msg.Action, + msg.ConnectionID, msg.Auth) default: return fmt.Sprintf("%#v", msg) } diff --git a/ably/realtime_conn_spec_integration_test.go b/ably/realtime_conn_spec_integration_test.go index c78846bc..cacf1d4a 100644 --- a/ably/realtime_conn_spec_integration_test.go +++ b/ably/realtime_conn_spec_integration_test.go @@ -1783,7 +1783,7 @@ func TestRealtimeConn_RTN15h3_Success(t *testing.T) { ablytest.Instantly.NoRecv(t, nil, stateChanges, t.Fatalf) } -func TestRealtimeConn_RTN15h_Integration_ClientInitiatedAuth(t *testing.T) { +func TestRealtimeConn_RTN22a_RTN15h2_Integration_ServerInitiatedAuth(t *testing.T) { t.Parallel() app, restClient := ablytest.NewREST() defer safeclose(t, app) @@ -1809,20 +1809,30 @@ func TestRealtimeConn_RTN15h_Integration_ClientInitiatedAuth(t *testing.T) { err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) assert.NoError(t, err) + changes := make(ably.ConnStateChanges, 2) + off := realtime.Connection.OnAll(changes.Receive) + defer off() + var state ably.ConnectionStateChange + for i := 0; i < 3; i++ { - err = ablytest.Wait(ablytest.ConnWaiter(realtime, nil, ably.ConnectionEventDisconnected), nil) - var errorInfo *ably.ErrorInfo - assert.Error(t, err) - assert.ErrorAs(t, err, &errorInfo) - assert.Equal(t, 401, errorInfo.StatusCode) - assert.Equal(t, 40142, int(errorInfo.Code)) - assert.ErrorContains(t, err, "token expired") - err = ablytest.Wait(ablytest.ConnWaiter(realtime, nil, ably.ConnectionEventConnected), nil) - assert.NoError(t, err) + ablytest.Soon.Recv(t, &state, changes, t.Fatalf) + assert.Equal(t, ably.ConnectionEventDisconnected, state.Event) + assert.Equal(t, ably.ConnectionStateDisconnected, state.Current) + assert.Error(t, state.Reason) + assert.Equal(t, 401, state.Reason.StatusCode) + assert.Equal(t, 40142, int(state.Reason.Code)) + assert.ErrorContains(t, state.Reason, "token expired") + + ablytest.Soon.Recv(t, &state, changes, t.Fatalf) + assert.Equal(t, ably.ConnectionEventConnecting, state.Event) + ablytest.Soon.Recv(t, &state, changes, t.Fatalf) + assert.Equal(t, ably.ConnectionEventConnected, state.Event) + assert.Nil(t, state.Reason) assert.Equal(t, ably.ConnectionStateConnected, realtime.Connection.State()) } - + ablytest.Instantly.NoRecv(t, nil, changes, t.Fatalf) assert.True(t, ablytest.Instantly.IsTrue(recorder.CheckIfReceived(ably.ActionDisconnected, 3))) + tokens := []string{} assert.Len(t, recorder.URLs(), 4) // 4 connect attempts made in total, disconnect received after each one for _, url := range recorder.URLs() { @@ -1833,6 +1843,65 @@ func TestRealtimeConn_RTN15h_Integration_ClientInitiatedAuth(t *testing.T) { assert.ElementsMatch(t, authCallbackTokens, tokens) } +func TestRealtimeConn_RTN22_RTC8_Integration_ServerInitiatedAuth(t *testing.T) { + app, restClient := ablytest.NewREST() + defer safeclose(t, app) + + recorder := NewMessageRecorder() + authCallbackTokens := []string{} + + // Server sends AUTH message 30 seconds before token expiry. + // So sending client token with expiry of 33 seconds, server will send AUTH msg after 3 seconds. + authCallback := func(ctx context.Context, tp ably.TokenParams) (ably.Tokener, error) { + tokenExpiry := 33000 + token, err := restClient.Auth.RequestToken(context.Background(), &ably.TokenParams{TTL: int64(tokenExpiry)}) + authCallbackTokens = append(authCallbackTokens, token.Token) + return token, err + } + + realtime, err := ably.NewRealtime( + ably.WithAutoConnect(false), + ably.WithDial(recorder.Dial), + ably.WithUseBinaryProtocol(false), + ably.WithEnvironment(ablytest.Environment), + ably.WithAuthCallback(authCallback)) + + assert.NoError(t, err) + defer realtime.Close() + + err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) + assert.NoError(t, err) + + changes := make(ably.ConnStateChanges, 2) + off := realtime.Connection.OnAll(changes.Receive) + defer off() + var state ably.ConnectionStateChange + + for i := 0; i < 3; i++ { + // auth msg sent by ably server every 3 seconds, so connection is updated by client + ablytest.Soon.Recv(t, &state, changes, t.Fatalf) + assert.Equal(t, ably.ConnectionEventUpdate, state.Event) + assert.Equal(t, ably.ConnectionStateConnected, state.Previous) + assert.Nil(t, state.Reason) + assert.Equal(t, ably.ConnectionStateConnected, realtime.Connection.State()) + assert.True(t, ablytest.Instantly.IsTrue(recorder.CheckIfReceived(ably.ActionAuth, i+1))) + } + ablytest.Instantly.NoRecv(t, nil, changes, t.Fatalf) + assert.True(t, ablytest.Instantly.IsTrue(recorder.CheckIfReceived(ably.ActionDisconnected, 0))) + + // Only one dial attempt + tokens := []string{} + assert.Len(t, recorder.URLs(), 1) + for _, url := range recorder.URLs() { + tokens = append(tokens, url.Query().Get("access_token")) + } + assert.Len(t, tokens, 1) + + assert.Len(t, authCallbackTokens, 4) + assert.Equal(t, tokens[0], authCallbackTokens[0]) + assertUnique(t, authCallbackTokens) +} + func TestRealtimeConn_RTN15i_OnErrorWhenConnected(t *testing.T) { in := make(chan *ably.ProtocolMessage, 1)