Skip to content

Commit

Permalink
better n_class check
Browse files Browse the repository at this point in the history
  • Loading branch information
Immanuel Bayer committed Apr 8, 2015
1 parent 490ac6a commit 518b947
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
4 changes: 4 additions & 0 deletions fastFM/als.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def fit(self, X_train, y_train):
y_train = _validate_class_labels(y_train)

self.classes_ = np.unique(y_train)
if len(self.classes_) != 2:
raise ValueError("This solver only supports binary classification"
" but the data contains"
" class: %r" % self.classes_)

# fastFM-core expects labels to be in {-1,1}
y_train = y_train.copy()
Expand Down
8 changes: 4 additions & 4 deletions fastFM/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,17 @@ def fit_predict_proba(self, X_train, y_train, X_test):
self.task = "classification"

self.classes_ = np.unique(y_train)
if len(self.classes_) != 2:
raise ValueError("This solver only supports binary classification"
" but the data contains"
" class: %r" % self.classes_)

# fastFM-core expects labels to be in {-1,1}
y_train = y_train.copy()
i_class1 = (y_train == self.classes_[0])
y_train[i_class1] = -1
y_train[-i_class1] = 1

if len(self.classes_) != 2:
raise ValueError("This solver only supports binary classification"
"but the data contains"
" class: %r" % self.classes_[0])

X_train, y_train, X_test = _validate_mcmc_fit_input(X_train, y_train,
X_test)
Expand Down
9 changes: 4 additions & 5 deletions fastFM/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,17 @@ def fit(self, X, y):
"""
y = _validate_class_labels(y)
self.classes_ = np.unique(y)
if len(self.classes_) != 2:
raise ValueError("This solver only supports binary classification"
" but the data contains"
" class: %r" % self.classes_)

# fastFM-core expects labels to be in {-1,1}
y_train = y.copy()
i_class1 = (y_train == self.classes_[0])
y_train[i_class1] = -1
y_train[-i_class1] = 1

if len(self.classes_) != 2:
raise ValueError("This solver only supports binary classification"
"but the data contains"
" class: %r" % self.classes_[0])

check_consistent_length(X, y)
y = y.astype(np.float64)
X = X.T
Expand Down

0 comments on commit 518b947

Please sign in to comment.