Skip to content

Commit

Permalink
option to select the classifiers for the ml predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Oct 26, 2024
1 parent 44c4b4e commit b581787
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions matsim/scenariogen/ml/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .models import create_regressor, model_to_java, model_to_py

classifier = {
CLASSIFIER = {
'mean',
'XGBRFRegressor',
'XGBRegressor',
Expand Down Expand Up @@ -42,11 +42,12 @@
class MLRegressor:
""" General class for machine learning regression models """

def __init__(self, n_trials=100, error="mae", fold=None, bounds=None):
def __init__(self, n_trials=100, error="mae", fold=None, bounds=None, classifier=None):
self.n_trials = n_trials
self.fold = fold if fold else KFold(n_splits=5, shuffle=False)
self.bounds = bounds
self.error = mean_absolute_error
self.classifier = classifier if classifier else CLASSIFIER
self.models = {}
self.df = None
self.exclude = None
Expand Down Expand Up @@ -152,8 +153,8 @@ def _fn(trial):

optuna.logging.set_verbosity(optuna.logging.WARNING)

with tqdm(total=len(classifier), position=0, leave=True) as pbar:
for m in classifier:
with tqdm(total=len(self.classifier), position=0, leave=True) as pbar:
for m in self.classifier:
pbar.set_description(f"Training model {m}")

with tqdm(total=self.n_trials, desc="Iteration", position=1, leave=True) as self.pb:
Expand Down

0 comments on commit b581787

Please sign in to comment.