-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: introducing completion endpoint * feat: completion route * feat: first release + postman update * feat: add token_usage_callback * chore: fix comments * feat: adding test case * chore: style fix
- Loading branch information
Showing
6 changed files
with
195 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Simple completion functions""" | ||
from langchain.chains.base import Chain | ||
from langchain.chains.llm import LLMChain | ||
from langchain.prompts.loading import load_prompt_from_config | ||
from pydantic import BaseModel | ||
from brevia.models import load_chatmodel | ||
from brevia.settings import get_settings | ||
|
||
|
||
class CompletionParams(BaseModel): | ||
""" Q&A basic conversation chain params """ | ||
prompt: dict | None = None | ||
|
||
|
||
def load_custom_prompt(prompt: dict | None): | ||
""" Load custom prompt """ | ||
return load_prompt_from_config(prompt) | ||
|
||
|
||
def simple_completion_chain( | ||
completion_params: CompletionParams, | ||
) -> Chain: | ||
""" | ||
Return simple completion chain for a generic input text | ||
completion_params: basic completion params including: | ||
prompt: custom prompt for execute simple completion commands | ||
""" | ||
|
||
settings = get_settings() | ||
llm_conf = settings.qa_completion_llm | ||
comp_llm = load_chatmodel(llm_conf) | ||
verbose = settings.verbose_mode | ||
# Create chain for follow-up question using chat history (if present) | ||
completion_llm = LLMChain( | ||
llm=comp_llm, | ||
prompt=load_custom_prompt(completion_params.prompt), | ||
verbose=verbose, | ||
) | ||
|
||
return completion_llm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Completion API endpoints""" | ||
from langchain.chains.base import Chain | ||
from fastapi import APIRouter | ||
from brevia.dependencies import ( | ||
get_dependencies, | ||
) | ||
from brevia.completions import CompletionParams, simple_completion_chain | ||
from brevia.callback import token_usage_callback, token_usage, TokensCallbackHandler | ||
|
||
router = APIRouter() | ||
|
||
|
||
class CompletionBody(CompletionParams): | ||
""" /completion request body """ | ||
text: str | ||
prompt: dict | None = None | ||
token_data: bool = False | ||
|
||
|
||
@router.post('/completion', dependencies=get_dependencies()) | ||
async def completion_action( | ||
completion_body: CompletionBody, | ||
): | ||
""" /completion endpoint, send a text with a custom prompt and get a completion """ | ||
|
||
chain = simple_completion_chain( | ||
completion_params=CompletionParams(**completion_body.model_dump()), | ||
) | ||
|
||
return await run_chain( | ||
chain=chain, | ||
completion_body=completion_body, | ||
) | ||
|
||
|
||
async def run_chain( | ||
chain: Chain, | ||
completion_body: CompletionBody, | ||
): | ||
"""Run chain usign async methods and return result""" | ||
with token_usage_callback() as callb: | ||
result = await chain.acall({ | ||
'text': completion_body.text | ||
}) | ||
return completion_result( | ||
result=result, | ||
callb=callb, | ||
) | ||
|
||
|
||
def completion_result( | ||
result: dict, | ||
callb: TokensCallbackHandler, | ||
) -> dict: | ||
""" Handle chat result: save chat history and return answer """ | ||
answer = result['text'].strip(" \n") | ||
|
||
return { | ||
'completion': answer, | ||
'usage': token_usage(callb), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Complerions router tests""" | ||
from fastapi.testclient import TestClient | ||
from fastapi import FastAPI | ||
from brevia.routers.completion_router import router | ||
|
||
|
||
app = FastAPI() | ||
app.include_router(router) | ||
client = TestClient(app) | ||
|
||
|
||
def test_completion(): | ||
"""Test POST /completion endpoint""" | ||
response = client.post( | ||
'/completion', | ||
headers={'Content-Type': 'application/json'}, | ||
content='''{ | ||
"text": "test", | ||
"prompt": { | ||
"_type": "prompt", | ||
"input_variables": ["text"], | ||
"template": "test"} | ||
} | ||
''', | ||
) | ||
assert response.status_code == 200 | ||
data = response.json() | ||
assert data is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Query module tests""" | ||
from langchain.chains.base import Chain | ||
from brevia.completions import simple_completion_chain, CompletionParams | ||
|
||
|
||
fake_prompt = CompletionParams() | ||
fake_prompt.prompt = { | ||
'_type': 'prompt', | ||
'input_variables': ['text'], | ||
'template': 'Fake', | ||
} | ||
|
||
|
||
def test_simple_completion_chain(): | ||
"""Test simple_completion_chain method""" | ||
result = simple_completion_chain(fake_prompt) | ||
assert result is not None | ||
assert isinstance(result, Chain) |