Skip to content

Commit

Permalink
Merge pull request #234 from MaindeckAS/supportPublicGrantsWithIsPubl…
Browse files Browse the repository at this point in the history
…icFunction

Introducing public field to client models
  • Loading branch information
LyricTian authored Feb 3, 2023
2 parents 62bc01d + 4d9fa1e commit 8f0d487
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 21 deletions.
6 changes: 5 additions & 1 deletion manage/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType,
}
return ti, nil
}

// get authorization code data
func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
ti, err := m.tokenStore.GetByCode(ctx, code)
Expand Down Expand Up @@ -296,6 +296,10 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType,
}
}

if gt == oauth2.ClientCredentials && cli.IsPublic() == true {
return nil, errors.ErrInvalidClient
}

if gt == oauth2.AuthorizationCode {
ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type (
GetID() string
GetSecret() string
GetDomain() string
IsPublic() bool
GetUserID() string
}

Expand Down
6 changes: 6 additions & 0 deletions models/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type Client struct {
ID string
Secret string
Domain string
Public bool
UserID string
}

Expand All @@ -23,6 +24,11 @@ func (c *Client) GetDomain() string {
return c.Domain
}

// IsPublic public
func (c *Client) IsPublic() bool {
return c.Public
}

// GetUserID user id
func (c *Client) GetUserID() string {
return c.UserID
Expand Down
49 changes: 29 additions & 20 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/gavv/httpexpect"
Expand All @@ -26,22 +25,30 @@ var (
clientSecret = "11111111"

plainChallenge = "ThisIsAFourtyThreeCharactersLongStringThing"
s256Challenge = "s256test"
// echo s256test | sha256 | base64 | tr '/+' '_-'
s256ChallengeHash = "W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o="
s256Challenge = "s256tests256tests256tests256tests256tests256test"
// sha2562 := sha256.Sum256([]byte(s256Challenge))
// fmt.Printf(base64.URLEncoding.EncodeToString(sha2562[:]))
s256ChallengeHash = "To2Xqv01cm16bC9Sf7KRRS8CO2SFss_HSMQOr3sdCDE="
)

func init() {
manager = manage.NewDefaultManager()
manager.MustTokenStorage(store.NewMemoryTokenStore())
}

func clientStore(domain string) oauth2.ClientStore {
func clientStore(domain string, public bool) oauth2.ClientStore {
clientStore := store.NewClientStore()
var secret string
if public {
secret = ""
} else {
secret = clientSecret
}
clientStore.Set(clientID, &models.Client{
ID: clientID,
Secret: clientSecret,
Secret: secret,
Domain: domain,
Public: public,
})
return clientStore
}
Expand Down Expand Up @@ -95,7 +102,7 @@ func TestAuthorizeCode(t *testing.T) {
}))
defer csrv.Close()

manager.MapClientStorage(clientStore(csrv.URL))
manager.MapClientStorage(clientStore(csrv.URL, true))
srv = server.NewDefaultServer(manager)
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
userID = "000000"
Expand All @@ -107,7 +114,7 @@ func TestAuthorizeCode(t *testing.T) {
WithQuery("client_id", clientID).
WithQuery("scope", "all").
WithQuery("state", "123").
WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")).
WithQuery("redirect_uri", csrv.URL+"/oauth2").
Expect().Status(http.StatusOK)
}

Expand All @@ -134,7 +141,7 @@ func TestAuthorizeCodeWithChallengePlain(t *testing.T) {
WithFormField("grant_type", "authorization_code").
WithFormField("client_id", clientID).
WithFormField("code", code).
WithBasicAuth("code_verifier", "testchallenge").
WithFormField("code_verifier", plainChallenge).
Expect().
Status(http.StatusOK).
JSON().Object()
Expand All @@ -146,19 +153,20 @@ func TestAuthorizeCodeWithChallengePlain(t *testing.T) {
}))
defer csrv.Close()

manager.MapClientStorage(clientStore(csrv.URL))
manager.MapClientStorage(clientStore(csrv.URL, true))
srv = server.NewDefaultServer(manager)
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
userID = "000000"
return
})
srv.SetClientInfoHandler(server.ClientFormHandler)

e.GET("/authorize").
WithQuery("response_type", "code").
WithQuery("client_id", clientID).
WithQuery("scope", "all").
WithQuery("state", "123").
WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")).
WithQuery("redirect_uri", csrv.URL+"/oauth2").
WithQuery("code_challenge", plainChallenge).
Expect().Status(http.StatusOK)
}
Expand Down Expand Up @@ -186,7 +194,7 @@ func TestAuthorizeCodeWithChallengeS256(t *testing.T) {
WithFormField("grant_type", "authorization_code").
WithFormField("client_id", clientID).
WithFormField("code", code).
WithBasicAuth("code_verifier", s256Challenge).
WithFormField("code_verifier", s256Challenge).
Expect().
Status(http.StatusOK).
JSON().Object()
Expand All @@ -198,19 +206,20 @@ func TestAuthorizeCodeWithChallengeS256(t *testing.T) {
}))
defer csrv.Close()

manager.MapClientStorage(clientStore(csrv.URL))
manager.MapClientStorage(clientStore(csrv.URL, true))
srv = server.NewDefaultServer(manager)
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
userID = "000000"
return
})
srv.SetClientInfoHandler(server.ClientFormHandler)

e.GET("/authorize").
WithQuery("response_type", "code").
WithQuery("client_id", clientID).
WithQuery("scope", "all").
WithQuery("state", "123").
WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")).
WithQuery("redirect_uri", csrv.URL+"/oauth2").
WithQuery("code_challenge", s256ChallengeHash).
WithQuery("code_challenge_method", "S256").
Expect().Status(http.StatusOK)
Expand All @@ -226,7 +235,7 @@ func TestImplicit(t *testing.T) {
csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer csrv.Close()

manager.MapClientStorage(clientStore(csrv.URL))
manager.MapClientStorage(clientStore(csrv.URL, false))
srv = server.NewDefaultServer(manager)
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
userID = "000000"
Expand All @@ -238,7 +247,7 @@ func TestImplicit(t *testing.T) {
WithQuery("client_id", clientID).
WithQuery("scope", "all").
WithQuery("state", "123").
WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")).
WithQuery("redirect_uri", csrv.URL+"/oauth2").
Expect().Status(http.StatusOK)
}

Expand All @@ -249,7 +258,7 @@ func TestPasswordCredentials(t *testing.T) {
defer tsrv.Close()
e := httpexpect.New(t, tsrv.URL)

manager.MapClientStorage(clientStore(""))
manager.MapClientStorage(clientStore("", false))
srv = server.NewDefaultServer(manager)
srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) {
if username == "admin" && password == "123456" {
Expand Down Expand Up @@ -282,7 +291,7 @@ func TestClientCredentials(t *testing.T) {
defer tsrv.Close()
e := httpexpect.New(t, tsrv.URL)

manager.MapClientStorage(clientStore(""))
manager.MapClientStorage(clientStore("", false))

srv = server.NewDefaultServer(manager)
srv.SetClientInfoHandler(server.ClientFormHandler)
Expand Down Expand Up @@ -372,7 +381,7 @@ func TestRefreshing(t *testing.T) {
}))
defer csrv.Close()

manager.MapClientStorage(clientStore(csrv.URL))
manager.MapClientStorage(clientStore(csrv.URL, true))
srv = server.NewDefaultServer(manager)
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
userID = "000000"
Expand All @@ -384,7 +393,7 @@ func TestRefreshing(t *testing.T) {
WithQuery("client_id", clientID).
WithQuery("scope", "all").
WithQuery("state", "123").
WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")).
WithQuery("redirect_uri", csrv.URL+"/oauth2").
Expect().Status(http.StatusOK)
}

Expand Down

0 comments on commit 8f0d487

Please sign in to comment.