Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: strongly type get_configure_view #49

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions oidc/provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

from collections.abc import Callable

from django.http import HttpRequest

import time

import requests
from sentry.auth.provider import MigratingIdentityId
from sentry.auth.providers.oauth2 import OAuth2Callback, OAuth2Login, OAuth2Provider
from sentry.auth.services.auth.model import RpcAuthProvider
from sentry.organizations.services.organization.model import RpcOrganization
from sentry.plugins.base.response import DeferredResponse

from .constants import (
AUTHORIZATION_ENDPOINT,
Expand All @@ -14,7 +23,7 @@
TOKEN_ENDPOINT,
USERINFO_ENDPOINT,
)
from .views import FetchUser, OIDCConfigureView
from .views import FetchUser, oidc_configure_view


class OIDCLogin(OAuth2Login):
Expand All @@ -37,7 +46,7 @@ def get_authorize_params(self, state, redirect_uri):


class OIDCProvider(OAuth2Provider):
name = ISSUER
name = ISSUER if ISSUER else "oidc"
reneluria marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, domain=None, domains=None, version=None, **config):
if domain:
Expand All @@ -63,8 +72,10 @@ def get_client_id(self):
def get_client_secret(self):
return CLIENT_SECRET

def get_configure_view(self):
return OIDCConfigureView.as_view()
def get_configure_view(
self,
) -> Callable[[HttpRequest, RpcOrganization, RpcAuthProvider], DeferredResponse]:
return oidc_configure_view

def get_auth_pipeline(self):
return [
Expand Down
35 changes: 22 additions & 13 deletions oidc/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

import logging

from sentry.auth.view import AuthView, ConfigureView
from django.http import HttpRequest
from rest_framework.response import Response

from sentry.auth.services.auth.model import RpcAuthProvider
from sentry.auth.view import AuthView
from sentry.utils import json
from sentry.organizations.services.organization.model import RpcOrganization
from sentry.plugins.base.response import DeferredResponse
max-wittig marked this conversation as resolved.
Show resolved Hide resolved
from sentry.utils.signing import urlsafe_b64decode

from .constants import ERR_INVALID_RESPONSE, ISSUER
Expand All @@ -15,7 +23,7 @@ def __init__(self, domains, version, *args, **kwargs):
self.version = version
super().__init__(*args, **kwargs)

def dispatch(self, request, helper):
def dispatch(self, request: HttpRequest, helper) -> Response: # type: ignore
data = helper.fetch_state("data")

try:
Expand Down Expand Up @@ -52,17 +60,18 @@ def dispatch(self, request, helper):
return helper.next_step()


class OIDCConfigureView(ConfigureView):
def dispatch(self, request, organization, auth_provider):
config = auth_provider.config
if config.get("domain"):
domains = [config["domain"]]
else:
domains = config.get("domains")
return self.render(
"oidc/configure.html",
{"provider_name": ISSUER or "", "domains": domains or []},
)
def oidc_configure_view(
request: HttpRequest, organization: RpcOrganization, auth_provider: RpcAuthProvider
) -> DeferredResponse:
config = auth_provider.config
if config.get("domain"):
domains: list[str] | None
domains = [config["domain"]]
else:
domains = config.get("domains")
return DeferredResponse(
"oidc/configure.html", {"provider_name": ISSUER or "", "domains": domains or []}
reneluria marked this conversation as resolved.
Show resolved Hide resolved
)


def extract_domain(email):
Expand Down
Loading