diff --git a/fastFM/als.py b/fastFM/als.py index 513a9dd..8ec1db0 100644 --- a/fastFM/als.py +++ b/fastFM/als.py @@ -169,6 +169,10 @@ def fit(self, X_train, y_train): y_train = _validate_class_labels(y_train) self.classes_ = np.unique(y_train) + if len(self.classes_) != 2: + raise ValueError("This solver only supports binary classification" + " but the data contains" + " class: %r" % self.classes_) # fastFM-core expects labels to be in {-1,1} y_train = y_train.copy() diff --git a/fastFM/mcmc.py b/fastFM/mcmc.py index 4ea6bf4..fa841e5 100644 --- a/fastFM/mcmc.py +++ b/fastFM/mcmc.py @@ -211,6 +211,10 @@ def fit_predict_proba(self, X_train, y_train, X_test): self.task = "classification" self.classes_ = np.unique(y_train) + if len(self.classes_) != 2: + raise ValueError("This solver only supports binary classification" + " but the data contains" + " class: %r" % self.classes_) # fastFM-core expects labels to be in {-1,1} y_train = y_train.copy() @@ -218,10 +222,6 @@ def fit_predict_proba(self, X_train, y_train, X_test): y_train[i_class1] = -1 y_train[-i_class1] = 1 - if len(self.classes_) != 2: - raise ValueError("This solver only supports binary classification" - "but the data contains" - " class: %r" % self.classes_[0]) X_train, y_train, X_test = _validate_mcmc_fit_input(X_train, y_train, X_test) diff --git a/fastFM/sgd.py b/fastFM/sgd.py index 07df241..34aca45 100644 --- a/fastFM/sgd.py +++ b/fastFM/sgd.py @@ -158,6 +158,10 @@ def fit(self, X, y): """ y = _validate_class_labels(y) self.classes_ = np.unique(y) + if len(self.classes_) != 2: + raise ValueError("This solver only supports binary classification" + " but the data contains" + " class: %r" % self.classes_) # fastFM-core expects labels to be in {-1,1} y_train = y.copy() @@ -165,11 +169,6 @@ def fit(self, X, y): y_train[i_class1] = -1 y_train[-i_class1] = 1 - if len(self.classes_) != 2: - raise ValueError("This solver only supports binary classification" - "but the data contains" - " class: %r" % self.classes_[0]) - check_consistent_length(X, y) y = y.astype(np.float64) X = X.T