Skip to content

Commit

Permalink
FIX: fixing error on temporal decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
robbisg committed Jul 26, 2024
1 parent aa6c740 commit 1093936
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 3 additions & 1 deletion sekupy/analysis/tests/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def test_temporal_decoding(fetch_ds):

roi_result = scores['mask-brain_value-2.0']
assert len(roi_result) == n_permutation + 1
assert roi_result[0]['test_score'].shape == (n_splits, 3, 3)

test_results = np.array(roi_result[0]['test_score'])
assert test_results.shape == (n_splits, 3, 3)
assert np.max(roi_result[0]['test_score']) <= 1.


Expand Down
5 changes: 2 additions & 3 deletions sekupy/ext/sklearn/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def cross_validate(
cv = check_cv(cv, y, classifier=is_classifier(estimator))

scorers = check_scoring(
estimator, scoring=scoring, #raise_exc=(error_score == "raise")
estimator, scoring=scoring, allow_none=(error_score == "raise")
)

if _routing_enabled():
Expand Down Expand Up @@ -1048,8 +1048,7 @@ def _score(estimator, X_test, y_test, scorer, score_params, error_score="raise")
with suppress(ValueError):
# e.g. unwrap memmapped scalars
scores = scores.item()
if not isinstance(scores, numbers.Number):
raise ValueError(error_msg % (scores, type(scores), scorer))

return scores


Expand Down

0 comments on commit 1093936

Please sign in to comment.