Skip to content

Commit

Permalink
New /completion endpoint (#37)
Browse files Browse the repository at this point in the history
* feat: introducing completion endpoint

* feat: completion route

* feat: first release + postman update

* feat: add token_usage_callback

* chore: fix comments

* feat: adding test case

* chore: style fix
  • Loading branch information
nikazzio authored Jan 15, 2024
1 parent 22fc4a7 commit 57f6787
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 3 deletions.
41 changes: 41 additions & 0 deletions brevia/completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Simple completion functions"""
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts.loading import load_prompt_from_config
from pydantic import BaseModel
from brevia.models import load_chatmodel
from brevia.settings import get_settings


class CompletionParams(BaseModel):
""" Q&A basic conversation chain params """
prompt: dict | None = None


def load_custom_prompt(prompt: dict | None):
""" Load custom prompt """
return load_prompt_from_config(prompt)


def simple_completion_chain(
completion_params: CompletionParams,
) -> Chain:
"""
Return simple completion chain for a generic input text
completion_params: basic completion params including:
prompt: custom prompt for execute simple completion commands
"""

settings = get_settings()
llm_conf = settings.qa_completion_llm
comp_llm = load_chatmodel(llm_conf)
verbose = settings.verbose_mode
# Create chain for follow-up question using chat history (if present)
completion_llm = LLMChain(
llm=comp_llm,
prompt=load_custom_prompt(completion_params.prompt),
verbose=verbose,
)

return completion_llm
48 changes: 45 additions & 3 deletions brevia/postman/Brevia API.postman_collection.json
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
{
"info": {
"_postman_id": "8bcda8f4-e21c-4bf9-95ba-45f4405138a4",
"_postman_id": "484c3a34-cbf0-48a1-8d51-72bc8f460502",
"name": "Brevia API",
"schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json",
"_exporter_id": "234034"
"_exporter_id": "4185449"
},
"item": [
{
"name": "Chat",
"item": [
{
"name": "chat - Question/Answer flow basic",
"name": "chat - Simple Completion",
"request": {
"auth": {
"type": "bearer",
Expand Down Expand Up @@ -166,6 +166,48 @@
}
},
"response": []
},
{
"name": "completion - Completion with custom prompt",
"request": {
"auth": {
"type": "bearer",
"bearer": [
{
"key": "token",
"value": "{{access_token}}",
"type": "string"
}
]
},
"method": "POST",
"header": [
{
"key": "Content-Type",
"value": "application/json",
"type": "text"
},
{
"key": "X-Chat-Session",
"value": "{{session_id}}",
"type": "text"
}
],
"body": {
"mode": "raw",
"raw": "{\n \"text\": \"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum\",\n \"prompt\": {\n \"_type\": \"prompt\",\n \"input_variables\": [\n \"text\"\n ],\n \"template\": \"formatta il seguente testo in formato EMAIL :\\n\\n{text}\\n\\n SCRIVI L'EMAIL:\"\n },\n \"token_data\": true\n}"
},
"url": {
"raw": "{{baseUrl}}/completion",
"host": [
"{{baseUrl}}"
],
"path": [
"completion"
]
}
},
"response": []
}
]
},
Expand Down
2 changes: 2 additions & 0 deletions brevia/routers/app_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
jobs_router,
qa_router,
status_router,
completion_router
)


Expand All @@ -23,5 +24,6 @@ def add_routers(app: FastAPI) -> None:
app.include_router(chat_history_router.router)
app.include_router(qa_router.router)
app.include_router(status_router.router)
app.include_router(completion_router.router)

index.init_index()
61 changes: 61 additions & 0 deletions brevia/routers/completion_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Completion API endpoints"""
from langchain.chains.base import Chain
from fastapi import APIRouter
from brevia.dependencies import (
get_dependencies,
)
from brevia.completions import CompletionParams, simple_completion_chain
from brevia.callback import token_usage_callback, token_usage, TokensCallbackHandler

router = APIRouter()


class CompletionBody(CompletionParams):
""" /completion request body """
text: str
prompt: dict | None = None
token_data: bool = False


@router.post('/completion', dependencies=get_dependencies())
async def completion_action(
completion_body: CompletionBody,
):
""" /completion endpoint, send a text with a custom prompt and get a completion """

chain = simple_completion_chain(
completion_params=CompletionParams(**completion_body.model_dump()),
)

return await run_chain(
chain=chain,
completion_body=completion_body,
)


async def run_chain(
chain: Chain,
completion_body: CompletionBody,
):
"""Run chain usign async methods and return result"""
with token_usage_callback() as callb:
result = await chain.acall({
'text': completion_body.text
})
return completion_result(
result=result,
callb=callb,
)


def completion_result(
result: dict,
callb: TokensCallbackHandler,
) -> dict:
""" Handle chat result: save chat history and return answer """
answer = result['text'].strip(" \n")

return {
'completion': answer,
'usage': token_usage(callb),
}
28 changes: 28 additions & 0 deletions tests/routers/test_completion_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Complerions router tests"""
from fastapi.testclient import TestClient
from fastapi import FastAPI
from brevia.routers.completion_router import router


app = FastAPI()
app.include_router(router)
client = TestClient(app)


def test_completion():
"""Test POST /completion endpoint"""
response = client.post(
'/completion',
headers={'Content-Type': 'application/json'},
content='''{
"text": "test",
"prompt": {
"_type": "prompt",
"input_variables": ["text"],
"template": "test"}
}
''',
)
assert response.status_code == 200
data = response.json()
assert data is not None
18 changes: 18 additions & 0 deletions tests/test_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Query module tests"""
from langchain.chains.base import Chain
from brevia.completions import simple_completion_chain, CompletionParams


fake_prompt = CompletionParams()
fake_prompt.prompt = {
'_type': 'prompt',
'input_variables': ['text'],
'template': 'Fake',
}


def test_simple_completion_chain():
"""Test simple_completion_chain method"""
result = simple_completion_chain(fake_prompt)
assert result is not None
assert isinstance(result, Chain)

0 comments on commit 57f6787

Please sign in to comment.