diff --git a/lenskit/util/__init__.py b/lenskit/util/__init__.py index 0ebe9619e..03a4d2333 100644 --- a/lenskit/util/__init__.py +++ b/lenskit/util/__init__.py @@ -60,8 +60,10 @@ def clone(algo): sps = dict([(k, clone(v)) for (k, v) in params.items()]) return algo.__class__(**sps) - elif isinstance(algo, list) or isinstance(algo, tuple): + elif isinstance(algo, list): return [clone(a) for a in algo] + elif isinstance(algo, tuple): + return tuple(clone(a) for a in algo) else: return deepcopy(algo) diff --git a/tests/test_bias.py b/tests/test_bias.py index 50af8106a..cde6084f4 100644 --- a/tests/test_bias.py +++ b/tests/test_bias.py @@ -69,6 +69,22 @@ def test_bias_clone(): assert getattr(a2, "user_offsets_", None) is None +def test_bias_clone_damping(): + algo = Bias(damping=(10, 5)) + algo.fit(simple_df) + + params = algo.get_params() + assert sorted(params.keys()) == ["damping", "items", "users"] + + a2 = lku.clone(algo) + assert a2 is not algo + assert getattr(a2, "item_damping", None) == 5 + assert getattr(a2, "user_damping", None) == 10 + assert getattr(a2, "mean_", None) is None + assert getattr(a2, "item_offsets_", None) is None + assert getattr(a2, "user_offsets_", None) is None + + def test_bias_global_only(): algo = Bias(users=False, items=False) algo.fit(simple_df) diff --git a/tests/test_util.py b/tests/test_util.py index 18bf34630..34d408998 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,12 +4,11 @@ # Licensed under the MIT license, see LICENSE.md for details. # SPDX-License-Identifier: MIT -import time import re -import pathlib +import time -import numpy as np -import pandas as pd +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st from lenskit import util as lku @@ -50,6 +49,7 @@ def test_stopwatch_long_str(): def test_stopwatch_minutes(): w = lku.Stopwatch() w.stop() + assert w.stop_time is not None w.start_time = w.stop_time - 62 s = str(w) p = re.compile(r"1m2.\d\ds") @@ -59,6 +59,7 @@ def test_stopwatch_minutes(): def test_stopwatch_hours(): w = lku.Stopwatch() w.stop() + assert w.stop_time is not None w.start_time = w.stop_time - 3663 s = str(w) p = re.compile(r"1h1m3.\d\ds") @@ -80,3 +81,23 @@ def func(foo): assert len(history) == 1 cache("bar") assert len(history) == 2 + + +@settings(suppress_health_check=[HealthCheck.too_slow]) +@given( + st.one_of( + st.integers(), + st.floats(allow_nan=False), + st.lists(st.floats(allow_nan=False), max_size=100), + st.tuples(st.floats(allow_nan=False)), + st.tuples(st.floats(allow_nan=False), st.floats(allow_nan=False)), + st.tuples( + st.floats(allow_nan=False), st.floats(allow_nan=False), st.floats(allow_nan=False) + ), + st.emails(), + ) +) +def test_clone_core_obj(obj): + o2 = lku.clone(obj) + assert o2 == obj + assert type(o2) == type(obj)