From 8fb6f40e160f716ada3c8cdfdc2328abf4097c9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0lker=20SI=C4=9EIRCI?= Date: Fri, 9 Aug 2024 13:39:33 +0300 Subject: [PATCH] score all endpoint returns lists instead of one element --- notebooks/playground.ipynb | 68 +++++++++++++++++++ .../routers/score_calculation.py | 11 ++- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/notebooks/playground.ipynb b/notebooks/playground.ipynb index d316332..a4b0081 100644 --- a/notebooks/playground.ipynb +++ b/notebooks/playground.ipynb @@ -22,6 +22,13 @@ "load_dotenv()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Downloaded Model" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -77,6 +84,67 @@ "\n", "scorer = BERTScorer(model_type=model_name, lang=\"tr\", rescale_with_baseline=True)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test the API" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import httpx\n", + "\n", + "\n", + "def get_bert_score(candidates: list[str], references: list[str]) -> dict:\n", + " url = \"http://localhost:8888/score_calculation/all\"\n", + "\n", + " data = {\n", + " \"candidate\": candidates,\n", + " \"reference\": references,\n", + " }\n", + "\n", + " response = httpx.post(url, json=data)\n", + "\n", + " return response.json()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'precision': [0.8399680852890015, 0.5356022715568542],\n", + " 'recall': [0.8359646797180176, 0.5197705626487732],\n", + " 'f1': [0.8379616141319275, 0.5275676846504211]}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "candidates = [\"The quick brown fox jumps over the lazy dog\", \"This is a sample test\"]\n", + "references = [\"A fast brown fox leaps over a lazy dog\", \"For example purposes\"]\n", + "\n", + "get_bert_score(candidates, references)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/bert_score_api/routers/score_calculation.py b/src/bert_score_api/routers/score_calculation.py index f14e045..00a6114 100644 --- a/src/bert_score_api/routers/score_calculation.py +++ b/src/bert_score_api/routers/score_calculation.py @@ -1,3 +1,4 @@ +import logging from typing import Annotated from fastapi import APIRouter, Depends, HTTPException @@ -5,6 +6,7 @@ from bert_score_api.deps import get_bert_scorer from bert_score_api.schemes import TextPair +logger = logging.getLogger(__name__) router = APIRouter(prefix="/score_calculation", tags=["calculate"]) @@ -24,10 +26,13 @@ async def calculate_bert_score( text_pair.reference, verbose=True, ) + + logger.debug(f"Precision type: {type(P)}") + return { - "precision": P.mean().item(), - "recall": R.mean().item(), - "f1": F1.mean().item(), + "precision": P.tolist(), + "recall": R.tolist(), + "f1": F1.tolist(), } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e