Skip to content

Commit

Permalink
Merge branch 'main' into feat/ollama-cohere-models
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanorosanelli committed Nov 4, 2024
2 parents 9b00c98 + 6143a86 commit 30a7d75
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 9 deletions.
34 changes: 34 additions & 0 deletions brevia/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Brevia Middleware classes"""
from importlib.metadata import version
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp


class VersionHeaderMiddleware(BaseHTTPMiddleware):
"""Middleware to add version headers to response"""
def __init__(
self,
app: ASGIApp,
api_version: str = '',
api_name: str = '',
) -> None:
super().__init__(app)
self.api_version = api_version
self.api_name = api_name

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
"""Add version header to response"""
response = await call_next(request)
response.headers['X-Brevia-Version'] = version('Brevia')
# Add API version and name headers
# about the custom API created with Brevia
if self.api_name:
response.headers['X-API-Name'] = self.api_name
if self.api_version:
response.headers['X-API-Version'] = self.api_version

return response
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from brevia.middleware import VersionHeaderMiddleware
from brevia.routers.app_routers import add_routers
from brevia.utilities.openapi import metadata

Expand All @@ -15,6 +16,7 @@
allow_methods=["POST"],
allow_headers=["*"],
)
app.add_middleware(VersionHeaderMiddleware)
add_routers(app)


Expand Down
34 changes: 26 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ repository = "https://github.com/brevia-ai/brevia"
[tool.poetry.group.dev.dependencies]
flake8 = "^6.1.0"
pylint = "^3.0.1"
pytest = "^7.4.2"
pytest = "^8.2.0"
pytest-cov = "^4.1.0"
httpx = "^0.27.2"
pytest-asyncio = "^0.24.0"

[tool.pylint.main]
fail-under = 9.5
Expand Down
34 changes: 34 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
""" Middleware module tests. """
from importlib.metadata import version
import pytest
from fastapi import FastAPI, Response
from brevia.middleware import VersionHeaderMiddleware

pytest_plugins = ('pytest_asyncio',)


@pytest.mark.asyncio
async def test_dispatch():
"""Test dispatch method of VersionHeaderMiddleware class."""
middleware = VersionHeaderMiddleware(app=FastAPI())

async def call_next_test(request):
return Response()

response = await middleware.dispatch(request=None, call_next=call_next_test)

assert response.headers.get('X-Brevia-Version') == version('Brevia')
assert response.headers.get('X-API-Version') is None
assert response.headers.get('X-API-Name') is None

middleware = VersionHeaderMiddleware(
app=FastAPI(),
api_version='1.0',
api_name='Test API',
)

response = await middleware.dispatch(request=None, call_next=call_next_test)

assert response.headers.get('X-Brevia-Version') == version('Brevia')
assert response.headers.get('X-API-Version') == '1.0'
assert response.headers.get('X-API-Name') == 'Test API'

0 comments on commit 30a7d75

Please sign in to comment.