Skip to content

Commit

Permalink
feat: Replaced use on Access token in merlin extractor with Idtoken (#50
Browse files Browse the repository at this point in the history
)

Co-authored-by: Mayur Jagtap <[email protected]>
  • Loading branch information
Mayurjag and solsticemj25 authored Jan 12, 2024
1 parent 4bf1b95 commit d0af130
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
10 changes: 7 additions & 3 deletions plugins/extractors/merlin/internal/merlin/merlin_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ import (

"github.com/goto/meteor/metrics/otelhttpclient"
"github.com/goto/meteor/plugins/internal/urlbuilder"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
)

const (
audience = "sdk.caraml"
)

var authScopes = []string{"https://www.googleapis.com/auth/userinfo.email"}
Expand Down Expand Up @@ -152,12 +156,12 @@ func authenticatedClient(ctx context.Context, serviceAccountJSON []byte, scopes
return google.DefaultClient(ctx, scopes...)
}

creds, err := google.CredentialsFromJSON(ctx, serviceAccountJSON, authScopes...)
client, err := idtoken.NewClient(ctx, audience, idtoken.WithCredentialsJSON(serviceAccountJSON))
if err != nil {
return nil, fmt.Errorf("google credentials from JSON: %w", err)
}

return oauth2.NewClient(ctx, creds.TokenSource), nil
return client, nil
}

// drainBody drains and closes the response body to avoid the following
Expand Down
47 changes: 23 additions & 24 deletions plugins/extractors/merlin/internal/merlin/merlin_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ import (
var (
ctx = context.Background()
credsJSON = []byte(`{"type":"service_account","project_id":"company-data-platform","private_key_id":"698vxv308w3i68p938040bz817r95b1e0k4kmvqs","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEoQIBAAKCAQBVzQ0WPuaqdwMNapCGKdKUR/MOgWNByruT60SJwd5lY/2Sjx1w\nQ4sJ6xk/+Tz7bT3CgNBAPQ+rZfLD2fdQJIBeYElRcHw6a2PA/6TaX2e4qq0+5xk3\ngnItlqZm0hQElZd76LNlMcItHNmneLICowTOdzl0hUd2IgrqLB545v3KOGfwoEAp\nz3mPm/iF1+zTPWy041w7ajvWK2N3mRygKoP79ne2gDuN2+QHmW8wPFzQ3pdQZU65\n1npgP9N4wRHQT8vowTUSYdSRZG1p1MPKCXsrqhMUv7yLdrOcacAvcZqYOcMhJQ1p\nbYpsinDR65ARDduMeKoEUkFb3hf2zPUcdYNhAgMBAAECggEAQldOxCGUlr94o7n+\nz02tHavYGiIfDfLkQIYLs3wsKjc7DEQOHgyLh/q4xkc/SKR5uVeCLflIkV09bQOu\nftAKVW6bohWYaE86jTLdU1+rQhTt6ZIkZFA/WlJ+jUfn5HeJ7mvJsffcTKde/2eK\nNBG6GK4Exbx7ubKuv8unMBJiryUycioPykWZEVYl72+0IBsKCQOX39Fd/pgJF9jL\nFPelgCsrvPA/3lodgQu3m8VENlu4G6z3kPQghAvI37xC9NlUNVvx1yxCukQhf0zQ\nQ55kUTwgZ9sIGGcI/2K6H1YHv+m3vnM5D5iL9eTHn1HnlGtplQJhmhKjCxXIpbHx\nQToOwQKBgQCcEZP6H3nq3eH7d5ro1fvA6YEoERfzIzaU4Kk3Sb9e1tXjYSz8ccNv\nK3gZsHV2YZy3q9mCYnc0oPwwx5dSwhzpOrBrwvyopPbkKpD9WCXtZtRkwRTN7CXR\nE+2eSSpu2y14SKysPQoDZmyJo8bs7rseLQTiZeUPlYdlP6adOGSX+QKBgQCMvVqE\n6nbX41DcLJuUxT026T9zncnpRu3gkfyY0O5QF8/Vcq6y5LxdQtyMNbcbkDY8isAM\nwTP4KaXPul38TOCjfG3MODDbzmeQ27qKL/9Ueyi812BN4XIrpguoPKgFtlyi1JNH\nZiUtimedOoNG4LuuDEqeNyW1Qm/WlQu5fqKwqQKBgGscuVW6Ep+6RuWisePJMO62\nk9ke2jQZ39UP17NFXx1FDyjuQcTEg2AiElx3OjbUSY3ZWP/eenfZYRxNb7Lx3IvJ\nptleyq8oAPaZrEbkH6uunmjEB3ZI869qIPQ4vPG2ZZ+fKTtQ7TVmL2nLyLRGKJBO\nT4LecfZfJry7katnz8ppAoGAI0FXyI33YVNHMTBXdOgH0paRV4QCTVaARk4rqZhE\n6nlcjcqhqpyT9wTFvLXD/bqda4MSYt+PBi5go+26l3Ymm62Sz6KP0rAcz3PLgcxO\nOLp1VQDa1geQkxCQQP+Y032ALSX1EuCqlYLjO8aplfq76PiZRJLp9kMDQwypGDl5\nxakCgYAm7pO0LA/hTvdrZ7zGUIfTTZxf1qD+W0iUh2MtyaZM9uQhDoaahf7f2TT/\nt2+wlyIlHMdUxfDYf8U5owl9IysqaPMZsQmYNgYmXpW8/AhNcKFnslyrtd57Of3C\nlFHpNwfjNlxDTsql2kWbcwJbY0EblPRItplE7gDlUvfgSNTj+g==\n-----END PRIVATE KEY-----\n","client_email":"[email protected]","client_id":"043161688880430795893","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_x509_cert_url":"https://www.googleapis.com/robot/v1/metadata/x509/systems-meteor%40company-data-platform.iam.gserviceaccount.com"}`)
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
)

func TestNewClient(t *testing.T) {
cases := []struct {
name string
params ClientParams
errStr string
ctx context.Context
}{
{
name: "Valid",
params: ClientParams{
BaseURL: "http://company.com/api/merlin/",
ServiceAccountJSON: credsJSON,
},
ctx: ctxWithClientWithIDToken(t, token),
},
{
name: "WithoutCredentials",
Expand All @@ -55,11 +58,12 @@ func TestNewClient(t *testing.T) {
ServiceAccountJSON: credsJSON,
},
errStr: `invalid input: parse "http://Gintama - Yorozuya Gin-chan": invalid character " " in host name`,
ctx: ctxWithClientWithIDToken(t, token),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewClient(ctx, tc.params)
_, err := NewClient(tc.ctx, tc.params)
if tc.errStr != "" {
assert.ErrorContains(t, err, tc.errStr)
} else {
Expand Down Expand Up @@ -147,7 +151,6 @@ func TestProjects(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
token := "MyIncrediblyPowerfulAccessToken"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.URL.Path, "/api/merlin/v1/projects")
Expand All @@ -157,7 +160,7 @@ func TestProjects(t *testing.T) {
}))
defer srv.Close()

c, err := NewClient(ctxWithClient(t, token), ClientParams{
c, err := NewClient(ctxWithClientWithIDToken(t, token), ClientParams{
BaseURL: srv.URL + "/api/merlin",
ServiceAccountJSON: credsJSON,
Timeout: 1 * time.Second,
Expand All @@ -183,7 +186,7 @@ func TestProjects(t *testing.T) {
}))
defer srv.Close()

c, err := NewClient(ctxWithClient(t, "MyIncrediblyPowerfulAccessToken"), ClientParams{
c, err := NewClient(ctxWithClientWithIDToken(t, token), ClientParams{
BaseURL: srv.URL + "/api/merlin",
ServiceAccountJSON: credsJSON,
Timeout: timeout,
Expand Down Expand Up @@ -346,7 +349,6 @@ func TestModels(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
token := "MyIncrediblyPowerfulAccessToken"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.URL.Path, fmt.Sprintf("/api/merlin/v1/projects/%d/models", tc.projectID))
Expand All @@ -356,7 +358,7 @@ func TestModels(t *testing.T) {
}))
defer srv.Close()

c, err := NewClient(ctxWithClient(t, token), ClientParams{
c, err := NewClient(ctxWithClientWithIDToken(t, token), ClientParams{
BaseURL: srv.URL + "/api/merlin",
ServiceAccountJSON: credsJSON,
Timeout: 1 * time.Second,
Expand All @@ -382,7 +384,7 @@ func TestModels(t *testing.T) {
}))
defer srv.Close()

c, err := NewClient(ctxWithClient(t, "MyIncrediblyPowerfulAccessToken"), ClientParams{
c, err := NewClient(ctxWithClientWithIDToken(t, token), ClientParams{
BaseURL: srv.URL + "/api/merlin",
ServiceAccountJSON: credsJSON,
Timeout: timeout,
Expand Down Expand Up @@ -467,7 +469,6 @@ func TestModelVersion(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
token := "MyIncrediblyPowerfulAccessToken"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.URL.Path, fmt.Sprintf(
Expand All @@ -479,7 +480,7 @@ func TestModelVersion(t *testing.T) {
}))
defer srv.Close()

c, err := NewClient(ctxWithClient(t, token), ClientParams{
c, err := NewClient(ctxWithClientWithIDToken(t, token), ClientParams{
BaseURL: srv.URL + "/api/merlin",
ServiceAccountJSON: credsJSON,
Timeout: 1 * time.Second,
Expand All @@ -505,7 +506,7 @@ func TestModelVersion(t *testing.T) {
}))
defer srv.Close()

c, err := NewClient(ctxWithClient(t, "MyIncrediblyPowerfulAccessToken"), ClientParams{
c, err := NewClient(ctxWithClientWithIDToken(t, token), ClientParams{
BaseURL: srv.URL + "/api/merlin",
ServiceAccountJSON: credsJSON,
Timeout: timeout,
Expand All @@ -517,23 +518,23 @@ func TestModelVersion(t *testing.T) {
})
}

func ctxWithClient(t *testing.T, token string) context.Context {
func ctxWithClientWithIDToken(t *testing.T, token string) context.Context {
return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
Transport: mockOauthRoundTripper{
T: t,
AccessToken: token,
Base: http.DefaultTransport,
Transport: mockIDTokenRoundTripper{
T: t,
IDToken: token,
Base: http.DefaultTransport,
},
})
}

type mockOauthRoundTripper struct {
T *testing.T
AccessToken string
Base http.RoundTripper
type mockIDTokenRoundTripper struct {
T *testing.T
IDToken string
Base http.RoundTripper
}

func (m mockOauthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func (m mockIDTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if m.match(req) {
return &http.Response{
Status: http.StatusText(http.StatusOK),
Expand All @@ -543,9 +544,7 @@ func (m mockOauthRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
ProtoMinor: 1,
Header: make(http.Header),
Body: testutils.ValueAsJSONReader(m.T, map[string]interface{}{
"access_token": m.AccessToken,
"expires_in": 3599,
"token_type": "Bearer",
"id_token": m.IDToken,
}),
Uncompressed: true,
}, nil
Expand All @@ -554,7 +553,7 @@ func (m mockOauthRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
return m.Base.RoundTrip(req)
}

func (m mockOauthRoundTripper) match(r *http.Request) bool {
func (m mockIDTokenRoundTripper) match(r *http.Request) bool {
return r.Method == http.MethodPost &&
r.URL.Host == "oauth2.googleapis.com" &&
r.URL.Path == "/token"
Expand Down

0 comments on commit d0af130

Please sign in to comment.