Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

azcontainerregistry: delegate token caching #23272

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 8 additions & 118 deletions sdk/containers/azcontainerregistry/authentication_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@
package azcontainerregistry

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync/atomic"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/internal/temporal"
)

const (
Expand All @@ -28,6 +22,7 @@ const (
)

type authenticationPolicyOptions struct {
*authenticationTokenCacheOptions
}

// authenticationPolicy is a policy to do the challenge-based authentication for container registry service. The authorization flow is as follows:
Expand All @@ -45,19 +40,15 @@ type authenticationPolicyOptions struct {
// Each registry service shares one refresh token, it will be cached in refreshTokenCache until expire time.
// Since the scope will be different for different API/repository/artifact, accessTokenCache will only work when continuously calling same API.
type authenticationPolicy struct {
refreshTokenCache *temporal.Resource[azcore.AccessToken, acquiringResourceState]
accessTokenCache atomic.Value
cred azcore.TokenCredential
aadScopes []string
authClient *authenticationClient
accessTokenCache *authenticationTokenCache
}

func newAuthenticationPolicy(cred azcore.TokenCredential, scopes []string, authClient *authenticationClient, opts *authenticationPolicyOptions) *authenticationPolicy {
if opts == nil {
opts = &authenticationPolicyOptions{}
}
return &authenticationPolicy{
cred: cred,
aadScopes: scopes,
authClient: authClient,
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
accessTokenCache: newAuthenticationTokenCache(cred, scopes, authClient, opts.authenticationTokenCacheOptions),
}
}

Expand All @@ -67,7 +58,7 @@ func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
if req.Raw().Header.Get(headerAuthorization) != "" {
// retry request could do the request with existed token directly
resp, err = req.Next()
} else if accessToken := p.accessTokenCache.Load(); accessToken != nil && accessToken != "" {
} else if accessToken := p.accessTokenCache.Load(); accessToken != "" {
// if there is a previous access token, then we try to use this token to do the request
req.Raw().Header.Set(
headerAuthorization,
Expand All @@ -93,10 +84,9 @@ func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
if service, scope, err = findServiceAndScope(resp); err != nil {
return nil, err
}
if accessToken, err = p.getAccessToken(req, service, scope); err != nil {
if accessToken, err = p.accessTokenCache.AcquireAccessToken(req.Raw().Context(), service, scope); err != nil {
return nil, err
}
p.accessTokenCache.Store(accessToken)
req.Raw().Header.Set(
headerAuthorization,
fmt.Sprintf("%s%s", bearerHeader, accessToken),
Expand All @@ -111,35 +101,6 @@ func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
return resp, nil
}

func (p *authenticationPolicy) getAccessToken(req *policy.Request, service, scope string) (string, error) {
// anonymous access
if p.cred == nil {
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(req.Raw().Context(), service, scope, "", &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypePassword)})
if err != nil {
return "", err
}
return *resp.acrAccessToken.AccessToken, nil
}

// access with token
// get refresh token from cache/request
refreshToken, err := p.refreshTokenCache.Get(acquiringResourceState{
policy: p,
req: req,
service: service,
})
if err != nil {
return "", err
}

// get access token from request
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(req.Raw().Context(), service, scope, refreshToken.Token, &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypeRefreshToken)})
if err != nil {
return "", err
}
return *resp.acrAccessToken.AccessToken, nil
}

func findServiceAndScope(resp *http.Response) (string, string, error) {
authHeader := resp.Header.Get("WWW-Authenticate")
if authHeader == "" {
Expand Down Expand Up @@ -176,74 +137,3 @@ func getChallengeRequest(oriReq policy.Request) (*policy.Request, error) {
copied.Raw().Header.Del("Content-Type")
return copied, nil
}

type acquiringResourceState struct {
req *policy.Request
policy *authenticationPolicy
service string
}

// acquireRefreshToken acquires or updates the refresh token of ACR service; only one thread/goroutine at a time ever calls this function
func acquireRefreshToken(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
// get AAD token from credential
aadToken, err := state.policy.cred.GetToken(
state.req.Raw().Context(),
policy.TokenRequestOptions{
Scopes: state.policy.aadScopes,
},
)
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

// exchange refresh token with AAD token
refreshResp, err := state.policy.authClient.ExchangeAADAccessTokenForACRRefreshToken(state.req.Raw().Context(), postContentSchemaGrantTypeAccessToken, state.service, &authenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: &aadToken.Token,
})
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

refreshToken := azcore.AccessToken{
Token: *refreshResp.acrRefreshToken.RefreshToken,
}

// get refresh token expire time
refreshToken.ExpiresOn, err = getJWTExpireTime(*refreshResp.acrRefreshToken.RefreshToken)
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

// return refresh token
return refreshToken, refreshToken.ExpiresOn, nil
}

func getJWTExpireTime(token string) (time.Time, error) {
values := strings.Split(token, ".")
if len(values) > 2 {
value := values[1]
padding := len(value) % 4
if padding > 0 {
for i := 0; i < 4-padding; i++ {
value += "="
}
}
parsedValue, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return time.Time{}, err
}

var jsonValue *jwtOnlyWithExp
err = json.Unmarshal(parsedValue, &jsonValue)
if err != nil {
return time.Time{}, err
}
return time.Unix(jsonValue.Exp, 0), nil
}

return time.Time{}, errors.New("could not parse refresh token expire time")
}

type jwtOnlyWithExp struct {
Exp int64 `json:"exp"`
}
56 changes: 32 additions & 24 deletions sdk/containers/azcontainerregistry/authentication_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ func Test_authenticationPolicy_getAccessToken_live(t *testing.T) {
authClient, err := newAuthenticationClient(endpoint, &authenticationClientOptions{options})
require.NoError(t, err)
p := &authenticationPolicy{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
cred,
[]string{options.Cloud.Services[ServiceName].Audience + "/.default"},
authClient,
accessTokenCache: &authenticationTokenCache{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
cred,
[]string{options.Cloud.Services[ServiceName].Audience + "/.default"},
authClient,
},
}
request, err := runtime.NewRequest(context.Background(), http.MethodGet, "https://test.com")
require.NoError(t, err)
token, err := p.getAccessToken(request, strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
token, err := p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
require.NoError(t, err)
require.NotEmpty(t, token)
}
Expand All @@ -161,22 +163,24 @@ func Test_authenticationPolicy_getAccessToken_error(t *testing.T) {
require.NoError(t, err)

p := &authenticationPolicy{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
accessTokenCache: &authenticationTokenCache{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
},
}
request, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL())
require.NoError(t, err)
_, err = p.getAccessToken(request, "service", "scope")
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
_, err = p.getAccessToken(request, "service", "scope")
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
_, err = p.getAccessToken(request, "service", "scope")
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
p.cred = nil
_, err = p.getAccessToken(request, "service", "scope")
p.accessTokenCache.cred = nil
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
}

Expand All @@ -186,12 +190,14 @@ func Test_authenticationPolicy_getAccessToken_live_anonymous(t *testing.T) {
authClient, err := newAuthenticationClient(endpoint, &authenticationClientOptions{options})
require.NoError(t, err)
p := &authenticationPolicy{
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
authClient: authClient,
accessTokenCache: &authenticationTokenCache{
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
authClient: authClient,
},
}
request, err := runtime.NewRequest(context.Background(), http.MethodGet, "https://test.com")
require.NoError(t, err)
token, err := p.getAccessToken(request, strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
token, err := p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
require.NoError(t, err)
require.NotEmpty(t, token)
}
Expand Down Expand Up @@ -244,11 +250,13 @@ func Test_authenticationPolicy(t *testing.T) {
authClient, err := newAuthenticationClient(srv.URL(), &authenticationClientOptions{ClientOptions: azcore.ClientOptions{Transport: srv}})
require.NoError(t, err)
authPolicy := &authenticationPolicy{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
accessTokenCache: &authenticationTokenCache{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
},
}
pl := runtime.NewPipeline("testmodule", "v0.1.0", runtime.PipelineOptions{PerRetry: []policy.Policy{authPolicy}}, &policy.ClientOptions{Transport: srv})

Expand Down
Loading