diff --git a/manage/manager.go b/manage/manager.go index c11f391..ad09790 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -218,7 +218,7 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, } return ti, nil } - + // get authorization code data func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { ti, err := m.tokenStore.GetByCode(ctx, code) @@ -296,6 +296,10 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, } } + if gt == oauth2.ClientCredentials && cli.IsPublic() == true { + return nil, errors.ErrInvalidClient + } + if gt == oauth2.AuthorizationCode { ti, err := m.getAndDelAuthorizationCode(ctx, tgr) if err != nil { diff --git a/model.go b/model.go index 121a42d..ec52244 100644 --- a/model.go +++ b/model.go @@ -10,6 +10,7 @@ type ( GetID() string GetSecret() string GetDomain() string + IsPublic() bool GetUserID() string } diff --git a/models/client.go b/models/client.go index e7ad7f5..c31ad7f 100644 --- a/models/client.go +++ b/models/client.go @@ -5,6 +5,7 @@ type Client struct { ID string Secret string Domain string + Public bool UserID string } @@ -23,6 +24,11 @@ func (c *Client) GetDomain() string { return c.Domain } +// IsPublic public +func (c *Client) IsPublic() bool { + return c.Public +} + // GetUserID user id func (c *Client) GetUserID() string { return c.UserID diff --git a/server/server_test.go b/server/server_test.go index 0fd00d5..1eb0bcd 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "testing" "github.com/gavv/httpexpect" @@ -26,9 +25,10 @@ var ( clientSecret = "11111111" plainChallenge = "ThisIsAFourtyThreeCharactersLongStringThing" - s256Challenge = "s256test" - // echo s256test | sha256 | base64 | tr '/+' '_-' - s256ChallengeHash = "W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=" + s256Challenge = "s256tests256tests256tests256tests256tests256test" + // sha2562 := sha256.Sum256([]byte(s256Challenge)) + // fmt.Printf(base64.URLEncoding.EncodeToString(sha2562[:])) + s256ChallengeHash = "To2Xqv01cm16bC9Sf7KRRS8CO2SFss_HSMQOr3sdCDE=" ) func init() { @@ -36,12 +36,19 @@ func init() { manager.MustTokenStorage(store.NewMemoryTokenStore()) } -func clientStore(domain string) oauth2.ClientStore { +func clientStore(domain string, public bool) oauth2.ClientStore { clientStore := store.NewClientStore() + var secret string + if public { + secret = "" + } else { + secret = clientSecret + } clientStore.Set(clientID, &models.Client{ ID: clientID, - Secret: clientSecret, + Secret: secret, Domain: domain, + Public: public, }) return clientStore } @@ -95,7 +102,7 @@ func TestAuthorizeCode(t *testing.T) { })) defer csrv.Close() - manager.MapClientStorage(clientStore(csrv.URL)) + manager.MapClientStorage(clientStore(csrv.URL, true)) srv = server.NewDefaultServer(manager) srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { userID = "000000" @@ -107,7 +114,7 @@ func TestAuthorizeCode(t *testing.T) { WithQuery("client_id", clientID). WithQuery("scope", "all"). WithQuery("state", "123"). - WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + WithQuery("redirect_uri", csrv.URL+"/oauth2"). Expect().Status(http.StatusOK) } @@ -134,7 +141,7 @@ func TestAuthorizeCodeWithChallengePlain(t *testing.T) { WithFormField("grant_type", "authorization_code"). WithFormField("client_id", clientID). WithFormField("code", code). - WithBasicAuth("code_verifier", "testchallenge"). + WithFormField("code_verifier", plainChallenge). Expect(). Status(http.StatusOK). JSON().Object() @@ -146,19 +153,20 @@ func TestAuthorizeCodeWithChallengePlain(t *testing.T) { })) defer csrv.Close() - manager.MapClientStorage(clientStore(csrv.URL)) + manager.MapClientStorage(clientStore(csrv.URL, true)) srv = server.NewDefaultServer(manager) srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { userID = "000000" return }) + srv.SetClientInfoHandler(server.ClientFormHandler) e.GET("/authorize"). WithQuery("response_type", "code"). WithQuery("client_id", clientID). WithQuery("scope", "all"). WithQuery("state", "123"). - WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + WithQuery("redirect_uri", csrv.URL+"/oauth2"). WithQuery("code_challenge", plainChallenge). Expect().Status(http.StatusOK) } @@ -186,7 +194,7 @@ func TestAuthorizeCodeWithChallengeS256(t *testing.T) { WithFormField("grant_type", "authorization_code"). WithFormField("client_id", clientID). WithFormField("code", code). - WithBasicAuth("code_verifier", s256Challenge). + WithFormField("code_verifier", s256Challenge). Expect(). Status(http.StatusOK). JSON().Object() @@ -198,19 +206,20 @@ func TestAuthorizeCodeWithChallengeS256(t *testing.T) { })) defer csrv.Close() - manager.MapClientStorage(clientStore(csrv.URL)) + manager.MapClientStorage(clientStore(csrv.URL, true)) srv = server.NewDefaultServer(manager) srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { userID = "000000" return }) + srv.SetClientInfoHandler(server.ClientFormHandler) e.GET("/authorize"). WithQuery("response_type", "code"). WithQuery("client_id", clientID). WithQuery("scope", "all"). WithQuery("state", "123"). - WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + WithQuery("redirect_uri", csrv.URL+"/oauth2"). WithQuery("code_challenge", s256ChallengeHash). WithQuery("code_challenge_method", "S256"). Expect().Status(http.StatusOK) @@ -226,7 +235,7 @@ func TestImplicit(t *testing.T) { csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer csrv.Close() - manager.MapClientStorage(clientStore(csrv.URL)) + manager.MapClientStorage(clientStore(csrv.URL, false)) srv = server.NewDefaultServer(manager) srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { userID = "000000" @@ -238,7 +247,7 @@ func TestImplicit(t *testing.T) { WithQuery("client_id", clientID). WithQuery("scope", "all"). WithQuery("state", "123"). - WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + WithQuery("redirect_uri", csrv.URL+"/oauth2"). Expect().Status(http.StatusOK) } @@ -249,7 +258,7 @@ func TestPasswordCredentials(t *testing.T) { defer tsrv.Close() e := httpexpect.New(t, tsrv.URL) - manager.MapClientStorage(clientStore("")) + manager.MapClientStorage(clientStore("", false)) srv = server.NewDefaultServer(manager) srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) { if username == "admin" && password == "123456" { @@ -282,7 +291,7 @@ func TestClientCredentials(t *testing.T) { defer tsrv.Close() e := httpexpect.New(t, tsrv.URL) - manager.MapClientStorage(clientStore("")) + manager.MapClientStorage(clientStore("", false)) srv = server.NewDefaultServer(manager) srv.SetClientInfoHandler(server.ClientFormHandler) @@ -372,7 +381,7 @@ func TestRefreshing(t *testing.T) { })) defer csrv.Close() - manager.MapClientStorage(clientStore(csrv.URL)) + manager.MapClientStorage(clientStore(csrv.URL, true)) srv = server.NewDefaultServer(manager) srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { userID = "000000" @@ -384,7 +393,7 @@ func TestRefreshing(t *testing.T) { WithQuery("client_id", clientID). WithQuery("scope", "all"). WithQuery("state", "123"). - WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + WithQuery("redirect_uri", csrv.URL+"/oauth2"). Expect().Status(http.StatusOK) }