Skip to content

Commit

Permalink
Merge pull request #310 from mdekstrand/fix/309-ndcg-truncate
Browse files Browse the repository at this point in the history
Fix nDCG truncation bug (#309)
  • Loading branch information
mdekstrand authored Mar 11, 2022
2 parents 19f03bb + 690818d commit 2646e79
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
4 changes: 2 additions & 2 deletions lenskit/metrics/topn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,11 @@ def ndcg(recs, truth, discount=np.log2, k=None):
The maximum list length.
"""

tpos = truth.index.get_indexer(recs['item'])

if k is not None:
recs = recs.iloc[:k]

tpos = truth.index.get_indexer(recs['item'])

if 'rating' in truth.columns:
i_rates = np.sort(truth.rating.values)[::-1]
if k is not None:
Expand Down
2 changes: 1 addition & 1 deletion lkbuild/boot-env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ dependencies:
- invoke=1
- requests=2
- pip
- conda-lock
- conda-lock=0.13
- mamba
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dev = [
"flake8 >= 3",
"coverage >= 5",
"pytest-cov >= 2.12",
"ipython >= 7",
"ipython == 7.*",
"docopt >= 0.6",
"tqdm >= 4",
"sphinx-autobuild >= 2021",
Expand All @@ -63,6 +63,8 @@ doc = [
"sphinxcontrib-bibtex >= 2.0",
"sphinx_rtd_theme >= 0.5",
"nbsphinx >= 0.8",
"ipython == 7.*",
"notebook >=6",
]
demo = [
"notebook >= 6",
Expand Down
8 changes: 8 additions & 0 deletions tests/test_topn_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def test_ndcg_perfect():
assert ndcg(recs, truth) == approx(1.0)


def test_ndcg_perfect_k_short():
recs = pd.DataFrame({'item': [2, 3, 1]})
truth = pd.DataFrame({'item': [1, 2, 3], 'rating': [3.0, 5.0, 4.0]})
truth = truth.set_index('item')
assert ndcg(recs, truth, k=2) == approx(1.0)
assert ndcg(recs[:2], truth, k=2) == approx(1.0)


def test_ndcg_wrong():
recs = pd.DataFrame({'item': [1, 2]})
truth = pd.DataFrame({'item': [1, 2, 3], 'rating': [3.0, 5.0, 4.0]})
Expand Down

0 comments on commit 2646e79

Please sign in to comment.