diff --git a/lenskit/metrics/topn.py b/lenskit/metrics/topn.py index ae1b70f3c..7f616c53d 100644 --- a/lenskit/metrics/topn.py +++ b/lenskit/metrics/topn.py @@ -300,11 +300,11 @@ def ndcg(recs, truth, discount=np.log2, k=None): The maximum list length. """ - tpos = truth.index.get_indexer(recs['item']) - if k is not None: recs = recs.iloc[:k] + tpos = truth.index.get_indexer(recs['item']) + if 'rating' in truth.columns: i_rates = np.sort(truth.rating.values)[::-1] if k is not None: diff --git a/lkbuild/boot-env.yml b/lkbuild/boot-env.yml index 892591fbf..e8d9d7fae 100644 --- a/lkbuild/boot-env.yml +++ b/lkbuild/boot-env.yml @@ -7,5 +7,5 @@ dependencies: - invoke=1 - requests=2 - pip -- conda-lock +- conda-lock=0.13 - mamba diff --git a/pyproject.toml b/pyproject.toml index 68b0f0bfc..8c625d2ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ dev = [ "flake8 >= 3", "coverage >= 5", "pytest-cov >= 2.12", - "ipython >= 7", + "ipython == 7.*", "docopt >= 0.6", "tqdm >= 4", "sphinx-autobuild >= 2021", @@ -63,6 +63,8 @@ doc = [ "sphinxcontrib-bibtex >= 2.0", "sphinx_rtd_theme >= 0.5", "nbsphinx >= 0.8", + "ipython == 7.*", + "notebook >=6", ] demo = [ "notebook >= 6", diff --git a/tests/test_topn_ndcg.py b/tests/test_topn_ndcg.py index 5d1601d9a..2fdf9e573 100644 --- a/tests/test_topn_ndcg.py +++ b/tests/test_topn_ndcg.py @@ -87,6 +87,14 @@ def test_ndcg_perfect(): assert ndcg(recs, truth) == approx(1.0) +def test_ndcg_perfect_k_short(): + recs = pd.DataFrame({'item': [2, 3, 1]}) + truth = pd.DataFrame({'item': [1, 2, 3], 'rating': [3.0, 5.0, 4.0]}) + truth = truth.set_index('item') + assert ndcg(recs, truth, k=2) == approx(1.0) + assert ndcg(recs[:2], truth, k=2) == approx(1.0) + + def test_ndcg_wrong(): recs = pd.DataFrame({'item': [1, 2]}) truth = pd.DataFrame({'item': [1, 2, 3], 'rating': [3.0, 5.0, 4.0]})