diff --git a/.env_example b/.env_example index 41962da..aa3004f 100644 --- a/.env_example +++ b/.env_example @@ -4,4 +4,4 @@ HF_HOME=TO_BE_FILLED HF_HUB_ENABLE_HF_TRANSFER=1 LIBRARY_BASE_PATH=TO_BE_FILLED LANGUAGE=tr -RESCALE_WITH_BASELINE=False +RESCALE_WITH_BASELINE=0 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 6860243..e95c815 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -9,6 +9,7 @@ on: - main jobs: + # NOTE: python -m pytest -n auto doesn't work and not correctly set environment variables test_matrix: strategy: matrix: @@ -33,6 +34,6 @@ jobs: rye sync --no-lock - name: Test the Project run: | - python -m pytest -n auto + python -m pytest # - name: Publish code coverage # uses: codecov/codecov-action@v3 diff --git a/notebooks/playground.ipynb b/notebooks/playground.ipynb index a4b0081..db973a7 100644 --- a/notebooks/playground.ipynb +++ b/notebooks/playground.ipynb @@ -85,6 +85,32 @@ "scorer = BERTScorer(model_type=model_name, lang=\"tr\", rescale_with_baseline=True)" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "Baseline not Found for dbmdz/bert-base-turkish-cased on tr at /home/ilker/Documents/MyRepos/bert-score-api/.venv/lib/python3.11/site-packages/bert_score/rescale_baseline/tr/dbmdz/bert-base-turkish-cased.tsv", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m candidates \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe quick brown fox jumps over the lazy dog\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis is a sample test\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 2\u001b[0m references \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mA fast brown fox leaps over a lazy dog\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFor example purposes\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m----> 4\u001b[0m \u001b[43mscorer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscore\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcandidates\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreferences\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/MyRepos/bert-score-api/.venv/lib/python3.11/site-packages/bert_score/scorer.py:239\u001b[0m, in \u001b[0;36mBERTScorer.score\u001b[0;34m(self, cands, refs, verbose, batch_size, return_hash)\u001b[0m\n\u001b[1;32m 236\u001b[0m all_preds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(max_preds, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 238\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrescale_with_baseline:\n\u001b[0;32m--> 239\u001b[0m all_preds \u001b[38;5;241m=\u001b[39m (all_preds \u001b[38;5;241m-\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbaseline_vals\u001b[49m) \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbaseline_vals)\n\u001b[1;32m 241\u001b[0m out \u001b[38;5;241m=\u001b[39m all_preds[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m0\u001b[39m], all_preds[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m1\u001b[39m], all_preds[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m2\u001b[39m] \u001b[38;5;66;03m# P, R, F\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m verbose:\n", + "File \u001b[0;32m~/Documents/MyRepos/bert-score-api/.venv/lib/python3.11/site-packages/bert_score/scorer.py:151\u001b[0m, in \u001b[0;36mBERTScorer.baseline_vals\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_baseline_vals \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 144\u001b[0m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(pd\u001b[38;5;241m.\u001b[39mread_csv(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbaseline_path)\u001b[38;5;241m.\u001b[39mto_numpy())[\n\u001b[1;32m 145\u001b[0m :, \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 149\u001b[0m )\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 151\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 152\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBaseline not Found for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m on \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlang\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbaseline_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 153\u001b[0m )\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_baseline_vals\n", + "\u001b[0;31mValueError\u001b[0m: Baseline not Found for dbmdz/bert-base-turkish-cased on tr at /home/ilker/Documents/MyRepos/bert-score-api/.venv/lib/python3.11/site-packages/bert_score/rescale_baseline/tr/dbmdz/bert-base-turkish-cased.tsv" + ] + } + ], + "source": [ + "candidates = [\"The quick brown fox jumps over the lazy dog\", \"This is a sample test\"]\n", + "references = [\"A fast brown fox leaps over a lazy dog\", \"For example purposes\"]\n", + "\n", + "scorer.score(candidates, references)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/bert_score_api/deps.py b/src/bert_score_api/deps.py index 0c0c542..2830ed7 100644 --- a/src/bert_score_api/deps.py +++ b/src/bert_score_api/deps.py @@ -2,22 +2,24 @@ from bert_score import BERTScorer -from bert_score_api.utils import check_env_vars, get_language_model_map +from bert_score_api.utils import ( + check_env_vars, + get_language_model_map, + is_valid_language, +) def get_bert_scorer() -> BERTScorer: check_env_vars(["LANGUAGE", "RESCALE_WITH_BASELINE"]) LANGUAGE = os.environ["LANGUAGE"].lower() - RESCALE_WITH_BASELINE = bool(os.environ["RESCALE_WITH_BASELINE"]) + RESCALE_WITH_BASELINE = bool(int(os.environ["RESCALE_WITH_BASELINE"])) - model_type = get_language_model_map().get(LANGUAGE, None) - - if model_type is None: + if is_valid_language(LANGUAGE) is False: raise ValueError(f"Language {LANGUAGE} is not supported.") return BERTScorer( - model_type=model_type, + model_type=get_language_model_map()[LANGUAGE], lang=LANGUAGE, rescale_with_baseline=RESCALE_WITH_BASELINE, ) diff --git a/tests/.env.test b/tests/.env.test index 24e74e4..f885e2e 100644 --- a/tests/.env.test +++ b/tests/.env.test @@ -4,5 +4,5 @@ HF_HOME=$HOME/.cache/huggingface HF_HUB_ENABLE_HF_TRANSFER=1 LIBRARY_BASE_PATH=NOT_NEEDED LANGUAGE=tr -RESCALE_WITH_BASELINE=False +RESCALE_WITH_BASELINE=0 TEST_ENV_KEY=TEST_ENV_VALUE diff --git a/tests/conftest.py b/tests/conftest.py index 2bb93de..34181a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,8 +24,7 @@ def anyio_backend(): @pytest_asyncio.fixture() async def async_client() -> AsyncGenerator[AsyncClient, Any]: - #FIXME: This doesn't take env vars into account - # Find a way to load env vars asynchronusly + #FIXME: This doesn't take some env vars into account. Specifically, HF_HOME app = create_app()