From 2e80c4974488cab07054aedaa82b42cb256143fb Mon Sep 17 00:00:00 2001 From: Michael Hadley Date: Thu, 13 Apr 2023 12:24:58 -0700 Subject: [PATCH] Return `code` or `error` as `ErrorCode` when present --- pkg/workos_errors/http.go | 16 ++++++++++++---- pkg/workos_errors/http_test.go | 28 ++++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/pkg/workos_errors/http.go b/pkg/workos_errors/http.go index 6b154c9d..3f5b6508 100644 --- a/pkg/workos_errors/http.go +++ b/pkg/workos_errors/http.go @@ -72,15 +72,23 @@ func getJsonErrorMessage(b []byte, statusCode int) (string, string, []string, [] return string(b), "", nil, nil } + var errorCode string + + if payload.Code != "" { + errorCode = payload.Code + } else { + errorCode = payload.Error + } + if payload.Error != "" && payload.ErrorDescription != "" { - return fmt.Sprintf("%s %s", payload.Error, payload.ErrorDescription), "", nil, nil + return fmt.Sprintf("%s %s", payload.Error, payload.ErrorDescription), errorCode, nil, nil } else if payload.Message != "" && len(payload.Errors) == 0 { - return payload.Message, "", nil, nil + return payload.Message, errorCode, nil, nil } else if payload.Message != "" && len(payload.Errors) > 0 { - return payload.Message, payload.Code, payload.Errors, nil + return payload.Message, errorCode, payload.Errors, nil } - return string(b), "", nil, nil + return string(b), errorCode, nil, nil } // HTTPError represents an http error. diff --git a/pkg/workos_errors/http_test.go b/pkg/workos_errors/http_test.go index 072f99c1..f1ada38e 100644 --- a/pkg/workos_errors/http_test.go +++ b/pkg/workos_errors/http_test.go @@ -13,16 +13,36 @@ func TestGetHTTPErrorWithJSONPayload(t *testing.T) { rec.Header().Set("X-Request-ID", "GOrOXx") rec.Header().Set("Content-Type", "application/json") rec.WriteHeader(http.StatusUnauthorized) - rec.WriteString(`{"message":"unauthorized", "error": "unauthorized error", "error_description": "unauthorized error description"}`) + rec.WriteString(`{"message":"unauthorized", "error": "unauthorized_error", "error_description": "unauthorized error description"}`) err := TryGetHTTPError(rec.Result()) require.Error(t, err) httperr := err.(HTTPError) require.Equal(t, http.StatusUnauthorized, httperr.Code) + require.Equal(t, "unauthorized_error", httperr.ErrorCode) require.Equal(t, "401 Unauthorized", httperr.Status) require.Equal(t, "GOrOXx", httperr.RequestID) - require.Equal(t, "unauthorized error unauthorized error description", httperr.Message) + require.Equal(t, "unauthorized_error unauthorized error description", httperr.Message) + + t.Log(httperr) +} + +func TestGetHTTPErrorWithBothErrorAndCode(t *testing.T) { + rec := httptest.NewRecorder() + rec.Header().Set("X-Request-ID", "GOrOXx") + rec.Header().Set("Content-Type", "application/json") + rec.WriteHeader(http.StatusUnauthorized) + rec.WriteString(`{"message":"unauthorized", "code": "bad_credentials", "error": "unauthorized_error"}`) + + err := TryGetHTTPError(rec.Result()) + require.Error(t, err) + + httperr := err.(HTTPError) + require.Equal(t, http.StatusUnauthorized, httperr.Code) + require.Equal(t, "bad_credentials", httperr.ErrorCode) + require.Equal(t, "GOrOXx", httperr.RequestID) + require.Equal(t, "unauthorized", httperr.Message) t.Log(httperr) } @@ -128,7 +148,7 @@ func TestGetHTTPErrorWithoutRequestID(t *testing.T) { rec := httptest.NewRecorder() rec.Header().Set("Content-Type", "application/json") rec.WriteHeader(http.StatusUnauthorized) - rec.WriteString(`{"message":"unauthorized", "error": "unauthorized error", "error_description": "unauthorized error description"}`) + rec.WriteString(`{"message":"unauthorized", "error": "unauthorized_error", "error_description": "unauthorized error description"}`) err := TryGetHTTPError(rec.Result()) require.Error(t, err) @@ -137,7 +157,7 @@ func TestGetHTTPErrorWithoutRequestID(t *testing.T) { require.Equal(t, http.StatusUnauthorized, httperr.Code) require.Equal(t, "401 Unauthorized", httperr.Status) require.Empty(t, httperr.RequestID) - require.Equal(t, "unauthorized error unauthorized error description", httperr.Message) + require.Equal(t, "unauthorized_error unauthorized error description", httperr.Message) t.Log(httperr) }