Skip to content

Commit

Permalink
score all endpoint returns lists instead of one element
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkersigirci committed Aug 9, 2024
1 parent f4e65bd commit 8fb6f40
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
68 changes: 68 additions & 0 deletions notebooks/playground.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
"load_dotenv()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Downloaded Model"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down Expand Up @@ -77,6 +84,67 @@
"\n",
"scorer = BERTScorer(model_type=model_name, lang=\"tr\", rescale_with_baseline=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the API"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import httpx\n",
"\n",
"\n",
"def get_bert_score(candidates: list[str], references: list[str]) -> dict:\n",
" url = \"http://localhost:8888/score_calculation/all\"\n",
"\n",
" data = {\n",
" \"candidate\": candidates,\n",
" \"reference\": references,\n",
" }\n",
"\n",
" response = httpx.post(url, json=data)\n",
"\n",
" return response.json()\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': [0.8399680852890015, 0.5356022715568542],\n",
" 'recall': [0.8359646797180176, 0.5197705626487732],\n",
" 'f1': [0.8379616141319275, 0.5275676846504211]}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"candidates = [\"The quick brown fox jumps over the lazy dog\", \"This is a sample test\"]\n",
"references = [\"A fast brown fox leaps over a lazy dog\", \"For example purposes\"]\n",
"\n",
"get_bert_score(candidates, references)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
11 changes: 8 additions & 3 deletions src/bert_score_api/routers/score_calculation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException

from bert_score_api.deps import get_bert_scorer
from bert_score_api.schemes import TextPair

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/score_calculation", tags=["calculate"])


Expand All @@ -24,10 +26,13 @@ async def calculate_bert_score(
text_pair.reference,
verbose=True,
)

logger.debug(f"Precision type: {type(P)}")

return {
"precision": P.mean().item(),
"recall": R.mean().item(),
"f1": F1.mean().item(),
"precision": P.tolist(),
"recall": R.tolist(),
"f1": F1.tolist(),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

0 comments on commit 8fb6f40

Please sign in to comment.