diff --git a/jupyverse_api/jupyverse_api/app/__init__.py b/jupyverse_api/jupyverse_api/app/__init__.py index 9eb61145..71f3f203 100644 --- a/jupyverse_api/jupyverse_api/app/__init__.py +++ b/jupyverse_api/jupyverse_api/app/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import defaultdict from typing import Dict, List @@ -16,8 +18,13 @@ class App: _app: FastAPI _router_paths: Dict[str, List[str]] - def __init__(self, app: FastAPI): - self._app = app + def __init__(self, app: FastAPI, mount_path: str | None = None): + if mount_path is None: + self._app = app + else: + subapi = FastAPI() + app.mount(mount_path, subapi) + self._app = subapi app.add_exception_handler(RedirectException, _redirect_exception_handler) self._router_paths = defaultdict(list) diff --git a/jupyverse_api/jupyverse_api/main/__init__.py b/jupyverse_api/jupyverse_api/main/__init__.py index 79db98cb..477aabbe 100644 --- a/jupyverse_api/jupyverse_api/main/__init__.py +++ b/jupyverse_api/jupyverse_api/main/__init__.py @@ -13,13 +13,21 @@ class AppComponent(Component): + def __init__( + self, + *, + mount_path: str | None = None, + ) -> None: + super().__init__() + self.mount_path = mount_path + async def start( self, ctx: Context, ) -> None: app = await ctx.request_resource(FastAPI) - _app = App(app) + _app = App(app, mount_path=self.mount_path) ctx.add_resource(_app) diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 00000000..d2516a5a --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,46 @@ +import pytest +from asphalt.core import Context +from fastapi import APIRouter +from httpx import AsyncClient +from jupyverse_api import Router +from jupyverse_api.app import App +from jupyverse_api.main import JupyverseComponent + +from utils import configure + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mount_path", + ( + None, + "/foo", + ), +) +async def test_mount_path(mount_path, unused_tcp_port): + components = configure({"app": {"type": "app"}}, {"app": {"mount_path": mount_path}}) + + async with Context() as ctx, AsyncClient() as http: + await JupyverseComponent( + components=components, + port=unused_tcp_port, + ).start(ctx) + + app = await ctx.request_resource(App) + router = APIRouter() + + @router.get("/") + async def get(): + pass + + Router(app).include_router(router) + + response = await http.get(f"http://127.0.0.1:{unused_tcp_port}") + expected = 200 if mount_path is None else 404 + assert response.status_code == expected + + response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/bar") + assert response.status_code == 404 + + response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/foo") + expected = 404 if mount_path is None else 200