diff --git a/canaille/oidc/endpoints.py b/canaille/oidc/endpoints.py index 00c4f861..a9d81711 100644 --- a/canaille/oidc/endpoints.py +++ b/canaille/oidc/endpoints.py @@ -2,7 +2,6 @@ import uuid from authlib.integrations.flask_oauth2 import current_token -from authlib.jose import JsonWebKey from authlib.jose import jwt from authlib.oauth2 import OAuth2Error from canaille import csrf @@ -28,10 +27,9 @@ from .oauth import authorization from .oauth import ClientConfigurationEndpoint from .oauth import ClientRegistrationEndpoint -from .oauth import DEFAULT_JWT_ALG -from .oauth import DEFAULT_JWT_KTY from .oauth import generate_user_info from .oauth import get_issuer +from .oauth import get_jwks from .oauth import IntrospectionEndpoint from .oauth import require_oauth from .oauth import RevocationEndpoint @@ -211,22 +209,7 @@ def client_registration_management(client_id): @bp.route("/jwks.json") def jwks(): - kty = current_app.config["OIDC"]["JWT"].get("KTY", DEFAULT_JWT_KTY) - alg = current_app.config["OIDC"]["JWT"].get("ALG", DEFAULT_JWT_ALG) - jwk = JsonWebKey.import_key( - current_app.config["OIDC"]["JWT"]["PUBLIC_KEY"], {"kty": kty} - ) - return jsonify( - { - "keys": [ - { - "use": "sig", - "alg": alg, - **jwk, - } - ] - } - ) + return jsonify(get_jwks()) @bp.route("/userinfo") diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index 9a6a1063..8b781f9c 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -2,6 +2,7 @@ from authlib.integrations.flask_oauth2 import AuthorizationServer from authlib.integrations.flask_oauth2 import ResourceProtector +from authlib.jose import JsonWebKey from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, ) @@ -66,7 +67,7 @@ def get_issuer(): return request.url_root -def get_jwt_config(grant): +def get_jwt_config(grant=None): return { "key": current_app.config["OIDC"]["JWT"]["PRIVATE_KEY"], "alg": current_app.config["OIDC"]["JWT"].get("ALG", DEFAULT_JWT_ALG), @@ -75,6 +76,23 @@ def get_jwt_config(grant): } +def get_jwks(): + kty = current_app.config["OIDC"]["JWT"].get("KTY", DEFAULT_JWT_KTY) + alg = current_app.config["OIDC"]["JWT"].get("ALG", DEFAULT_JWT_ALG) + jwk = JsonWebKey.import_key( + current_app.config["OIDC"]["JWT"]["PUBLIC_KEY"], {"kty": kty} + ) + return { + "keys": [ + { + "use": "sig", + "alg": alg, + **jwk, + } + ] + } + + def claims_from_scope(scope): claims = {"sub"} if "profile" in scope: