diff --git a/brevia/middleware.py b/brevia/middleware.py new file mode 100644 index 0000000..3889c27 --- /dev/null +++ b/brevia/middleware.py @@ -0,0 +1,34 @@ +"""Brevia Middleware classes""" +from importlib.metadata import version +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp + + +class VersionHeaderMiddleware(BaseHTTPMiddleware): + """Middleware to add version headers to response""" + def __init__( + self, + app: ASGIApp, + api_version: str = '', + api_name: str = '', + ) -> None: + super().__init__(app) + self.api_version = api_version + self.api_name = api_name + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + """Add version header to response""" + response = await call_next(request) + response.headers['X-Brevia-Version'] = version('Brevia') + # Add API version and name headers + # about the custom API created with Brevia + if self.api_name: + response.headers['X-API-Name'] = self.api_name + if self.api_version: + response.headers['X-API-Version'] = self.api_version + + return response diff --git a/main.py b/main.py index 555dfe1..ce346aa 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import uvicorn +from brevia.middleware import VersionHeaderMiddleware from brevia.routers.app_routers import add_routers from brevia.utilities.openapi import metadata @@ -15,6 +16,7 @@ allow_methods=["POST"], allow_headers=["*"], ) +app.add_middleware(VersionHeaderMiddleware) add_routers(app) diff --git a/poetry.lock b/poetry.lock index 5a392da..d725f5c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3867,13 +3867,13 @@ image = ["Pillow (>=8.0.0)"] [[package]] name = "pytest" -version = "7.4.4" +version = "8.3.3" description = "pytest: simple powerful testing with Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, - {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, + {file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"}, + {file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"}, ] [package.dependencies] @@ -3881,11 +3881,29 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} +pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" [package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] name = "pytest-cov" @@ -5702,4 +5720,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "b8d340f5760d3dc5e56727b77f0abca308c057579fe26fa4b40e8f2af2d799d4" +content-hash = "7ab411550d629de889801e36610e7f029662ebacf23e412cc63508e90872037f" diff --git a/pyproject.toml b/pyproject.toml index 7682ee5..9abb242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,9 +60,10 @@ repository = "https://github.com/brevia-ai/brevia" [tool.poetry.group.dev.dependencies] flake8 = "^6.1.0" pylint = "^3.0.1" -pytest = "^7.4.2" +pytest = "^8.2.0" pytest-cov = "^4.1.0" httpx = "^0.27.2" +pytest-asyncio = "^0.24.0" [tool.pylint.main] fail-under = 9.5 diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..3dbd3e0 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,34 @@ +""" Middleware module tests. """ +from importlib.metadata import version +import pytest +from fastapi import FastAPI, Response +from brevia.middleware import VersionHeaderMiddleware + +pytest_plugins = ('pytest_asyncio',) + + +@pytest.mark.asyncio +async def test_dispatch(): + """Test dispatch method of VersionHeaderMiddleware class.""" + middleware = VersionHeaderMiddleware(app=FastAPI()) + + async def call_next_test(request): + return Response() + + response = await middleware.dispatch(request=None, call_next=call_next_test) + + assert response.headers.get('X-Brevia-Version') == version('Brevia') + assert response.headers.get('X-API-Version') is None + assert response.headers.get('X-API-Name') is None + + middleware = VersionHeaderMiddleware( + app=FastAPI(), + api_version='1.0', + api_name='Test API', + ) + + response = await middleware.dispatch(request=None, call_next=call_next_test) + + assert response.headers.get('X-Brevia-Version') == version('Brevia') + assert response.headers.get('X-API-Version') == '1.0' + assert response.headers.get('X-API-Name') == 'Test API'