Skip to content

Commit

Permalink
better docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
gonlairo committed Sep 18, 2023
1 parent f0d33e6 commit 12e8833
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
14 changes: 8 additions & 6 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,20 +279,22 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
"""Loads a CEBRA model with a Sklearn backend.
Args:
cebra_info: A dictionary containing information about the CEBRA model.
cebra_info: A dictionary containing information about the CEBRA object,
including the arguments, the state of the object and the state
dictionary of the model.
Returns:
The loaded CEBRA object.
Raises:
ValueError: If the loaded CEBRA model is not fitted, indicating that loading it is not supported.
ValueError: If the loaded CEBRA model was not already fit, indicating that loading it is not supported.
"""
required_keys = ['args', 'state', 'state_dict']
missing_keys = [key for key in required_keys if key not in cebra_info]
if missing_keys:
raise ValueError(
f"Missing keys in data dictionary: {', '.join(missing_keys)}. "
f"You can try loading the CEBRA model with a different backend.")
f"You can try loading the CEBRA model with the torch backend.")

args, state, state_dict = cebra_info['args'], cebra_info[
'state'], cebra_info['state_dict']
Expand All @@ -305,7 +307,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":

if not sklearn_utils.check_fitted(cebra_):
raise ValueError(
"CEBRA model is not fitted. Loading it is not supported.")
"CEBRA model was not already fit. Loading it is not supported.")

if cebra_.num_sessions_ is None:
model = cebra.models.init(
Expand Down Expand Up @@ -1388,7 +1390,7 @@ def load(cls,
Experimental functionality. Do not expect the save/load functionalities to be
backward compatible yet between CEBRA versions!
For information about the file format we refer to :py:meth:`cebra.CEBRA.save`.
For information about the file format please refer to :py:meth:`cebra.CEBRA.save`.
Example:
Expand Down Expand Up @@ -1416,7 +1418,7 @@ def load(cls,
f"Cannot use 'torch' backend with a dictionary-based checkpoint. "
f"Please try a different backend.")
if not isinstance(checkpoint, dict) and backend == "sklearn":
raise RuntimeError(f"Cannot use 'sklearn' backend. "
raise RuntimeError(f"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
f"Please try a different backend.")

if backend == "sklearn":
Expand Down
14 changes: 7 additions & 7 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,23 +808,23 @@ def _assert_same_state_dict(first, second):
assert first[key] == second[key]


def check_fitted(model):
"""Check if a model is fitted.
def check_if_fit(model):
"""Check if a model was already fit.
Args:
model: The model to assess.
model: The model to check.
Returns:
True if fitted.
True if the model was already fit.
"""
return hasattr(model, "n_features_")


def _assert_equal(original_model, loaded_model):
assert original_model.get_params() == loaded_model.get_params()
assert check_fitted(loaded_model) == check_fitted(original_model)
assert check_if_fit(loaded_model) == check_if_fit(original_model)

if check_fitted(loaded_model):
if check_if_fit(loaded_model):
_assert_same_state_dict(original_model.state_dict_,
loaded_model.state_dict_)
X = np.random.normal(0, 1, (100, 1))
Expand Down Expand Up @@ -884,7 +884,7 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture,

original_model = action(original_model)
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
if not check_fitted(original_model):
if not check_if_fit(original_model):
with pytest.raises(ValueError):
original_model.save(savefile.name, backend=backend_save)
else:
Expand Down

0 comments on commit 12e8833

Please sign in to comment.