diff --git a/ludwig/automl/auto_tune_config.py b/ludwig/automl/auto_tune_config.py index 4da1c2d751a..009c9ebe138 100644 --- a/ludwig/automl/auto_tune_config.py +++ b/ludwig/automl/auto_tune_config.py @@ -269,7 +269,7 @@ def memory_tune_config(config, dataset, model_category, row_count): modified_config = copy.deepcopy(config) modified_config[HYPEROPT]["parameters"] = modified_hyperparam_search_space - modified_config[HYPEROPT]["sampler"]["num_samples"] = _update_num_samples( - modified_config[HYPEROPT]["sampler"]["num_samples"], modified_hyperparam_search_space + modified_config[HYPEROPT]["executor"]["num_samples"] = _update_num_samples( + modified_config[HYPEROPT]["executor"]["num_samples"], modified_hyperparam_search_space ) return modified_config, fits_in_memory diff --git a/ludwig/automl/automl.py b/ludwig/automl/automl.py index d2377c71116..bd54721badb 100644 --- a/ludwig/automl/automl.py +++ b/ludwig/automl/automl.py @@ -204,7 +204,7 @@ def train_with_config( # TODO (ASN): Decide how we want to proceed if at least one trial has # completed for trial in hyperopt_results.ordered_trials: - if np.isnan(trial.metric_score): + if isinstance(trial.metric_score, str) or np.isnan(trial.metric_score): warnings.warn( "There was an error running the experiment. " "A trial failed to start. " @@ -250,7 +250,7 @@ def _model_select( model_category = TEXT input_feature["encoder"] = AUTOML_DEFAULT_TEXT_ENCODER base_config = merge_dict(base_config, default_configs[TEXT][AUTOML_DEFAULT_TEXT_ENCODER]) - base_config[HYPEROPT]["sampler"]["num_samples"] = 5 # set for small hyperparameter search space + base_config[HYPEROPT]["executor"]["num_samples"] = 5 # set for small hyperparameter search space # TODO (ASN): add image heuristics if input_feature["type"] == IMAGE: diff --git a/ludwig/automl/base_config.py b/ludwig/automl/base_config.py index 5ea3c672c71..ecd207a23c5 100644 --- a/ludwig/automl/base_config.py +++ b/ludwig/automl/base_config.py @@ -21,7 +21,21 @@ from ludwig.automl.data_source import DataframeSource, DataSource from ludwig.automl.utils import _ray_init, FieldConfig, FieldInfo, FieldMetadata, get_available_resources -from ludwig.constants import AUDIO, BINARY, CATEGORY, DATE, IMAGE, NUMBER, TEXT +from ludwig.constants import ( + AUDIO, + BINARY, + CATEGORY, + COMBINER, + DATE, + EXECUTOR, + HYPEROPT, + IMAGE, + NUMBER, + SCHEDULER, + SEARCH_ALG, + TEXT, + TYPE, +) from ludwig.utils import strings_utils from ludwig.utils.data_utils import load_dataset, load_yaml from ludwig.utils.defaults import default_random_seed @@ -121,18 +135,18 @@ def _create_default_config( dataset_info.fields, dataset_info.row_count, resources, target_name ) # create set of all feature types appearing in the dataset - feature_types = [[feat["type"] for feat in features] for features in input_and_output_feature_config.values()] + feature_types = [[feat[TYPE] for feat in features] for features in input_and_output_feature_config.values()] feature_types = set(sum(feature_types, [])) model_configs = {} # read in base config and update with experiment resources base_automl_config = load_yaml(BASE_AUTOML_CONFIG) - base_automl_config["hyperopt"]["executor"].update(experiment_resources) - base_automl_config["hyperopt"]["executor"]["time_budget_s"] = time_limit_s + base_automl_config[HYPEROPT][EXECUTOR].update(experiment_resources) + base_automl_config[HYPEROPT][EXECUTOR]["time_budget_s"] = time_limit_s if time_limit_s is not None: - base_automl_config["hyperopt"]["sampler"]["scheduler"]["max_t"] = time_limit_s - base_automl_config["hyperopt"]["sampler"]["search_alg"]["random_state_seed"] = random_seed + base_automl_config[HYPEROPT][EXECUTOR][SCHEDULER]["max_t"] = time_limit_s + base_automl_config[HYPEROPT][SEARCH_ALG]["random_state_seed"] = random_seed base_automl_config.update(input_and_output_feature_config) model_configs["base_config"] = base_automl_config @@ -146,10 +160,10 @@ def _create_default_config( model_configs[feat_type][encoder_name] = load_yaml(encoder_config_path) # read in all combiner configs - model_configs["combiner"] = {} + model_configs[COMBINER] = {} for combiner_type, default_config in combiner_defaults.items(): combiner_config = load_yaml(default_config) - model_configs["combiner"][combiner_type] = combiner_config + model_configs[COMBINER][combiner_type] = combiner_config return model_configs diff --git a/ludwig/automl/defaults/base_automl_config.yaml b/ludwig/automl/defaults/base_automl_config.yaml index c7fc00fe8ef..e267a9e7f96 100644 --- a/ludwig/automl/defaults/base_automl_config.yaml +++ b/ludwig/automl/defaults/base_automl_config.yaml @@ -4,11 +4,13 @@ trainer: # validation_metric: accuracy hyperopt: - sampler: + search_alg: + # Gives results like default + supports random_state_seed for sample sequence repeatability + type: hyperopt + executor: type: ray - search_alg: - # Gives results like default + supports random_state_seed for sample sequence repeatability - type: hyperopt + num_samples: 10 + time_budget_s: 7200 scheduler: type: async_hyperband time_attr: time_total_s @@ -16,8 +18,3 @@ hyperopt: grace_period: 72 # Increased over default to get more pruning/exploration reduction_factor: 5 - num_samples: 10 - - executor: - type: ray - time_budget_s: 7200 diff --git a/ludwig/automl/utils.py b/ludwig/automl/utils.py index 75bb50cd9d3..1f2c13c1042 100644 --- a/ludwig/automl/utils.py +++ b/ludwig/automl/utils.py @@ -8,7 +8,7 @@ from numpy import nan_to_num from pandas import Series -from ludwig.constants import COMBINER, CONFIG, HYPEROPT, NAME, NUMBER, PARAMETERS, SAMPLER, TRAINER, TYPE +from ludwig.constants import COMBINER, CONFIG, HYPEROPT, NAME, NUMBER, PARAMETERS, SEARCH_ALG, TRAINER, TYPE from ludwig.features.feature_registries import output_type_registry from ludwig.modules.metric_registry import metric_registry from ludwig.utils.defaults import default_combiner_type @@ -127,7 +127,7 @@ def _add_transfer_config(base_config: Dict, ref_configs: Dict) -> Dict: point_to_evaluate = {} _add_option_to_evaluate(point_to_evaluate, min_dataset_config, hyperopt_params, COMBINER) _add_option_to_evaluate(point_to_evaluate, min_dataset_config, hyperopt_params, TRAINER) - base_config[HYPEROPT][SAMPLER]["search_alg"]["points_to_evaluate"] = [point_to_evaluate] + base_config[HYPEROPT][SEARCH_ALG]["points_to_evaluate"] = [point_to_evaluate] return base_config diff --git a/ludwig/constants.py b/ludwig/constants.py index 6b09db2e9c4..04d3c1db440 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -123,6 +123,8 @@ MINIMIZE = "minimize" MAXIMIZE = "maximize" SAMPLER = "sampler" +SEARCH_ALG = "search_alg" +SCHEDULER = "scheduler" PARAMETERS = "parameters" NAME = "name" diff --git a/ludwig/hyperopt/execution.py b/ludwig/hyperopt/execution.py index a357e720834..76b10995c3c 100644 --- a/ludwig/hyperopt/execution.py +++ b/ludwig/hyperopt/execution.py @@ -11,14 +11,14 @@ import uuid from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union from ludwig.api import LudwigModel from ludwig.backend import initialize_backend, RAY from ludwig.callbacks import Callback from ludwig.constants import COLUMN, MAXIMIZE, TEST, TRAINER, TRAINING, TYPE, VALIDATION from ludwig.hyperopt.results import HyperoptResults, RayTuneResults, TrialResults -from ludwig.hyperopt.sampling import HyperoptSampler, RayTuneSampler +from ludwig.hyperopt.sampling import get_search_algorithm, RayTuneSampler from ludwig.hyperopt.utils import load_json_values from ludwig.modules.metric_modules import get_best_function from ludwig.utils import metric_utils @@ -33,28 +33,40 @@ import ray from ray import tune from ray.tune import register_trainable - from ray.tune.suggest import BasicVariantGenerator, ConcurrencyLimiter + from ray.tune.suggest import BasicVariantGenerator, ConcurrencyLimiter, SEARCH_ALG_IMPORT from ray.tune.sync_client import CommandBasedClient from ray.tune.syncer import get_cloud_sync_client from ray.tune.utils import wait_for_gpu from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.util.queue import Queue as RayQueue - from ludwig.backend.ray import RayBackend except ImportError as e: - logger.warn(f"ImportError (execution.py) failed to import ray with error: \n\t{e}") + logger.warning(f"ImportError (execution.py) failed to import ray with error: \n\t{e}") ray = None get_horovod_kwargs = None + +try: + from ludwig.backend.ray import RayBackend + + # TODO: refactor this into an interface + def _is_ray_backend(backend) -> bool: + if isinstance(backend, str): + return backend == RAY + return isinstance(backend, RayBackend) + +except ImportError as e: + logger.warning( + f"ImportError (execution.py) failed to import RayBackend with error: \n\t{e}. " + "The LocalBackend will be used instead. If you want to use the RayBackend, please install ludwig[distributed]." + ) + get_horovod_kwargs = None + class RayBackend: pass - -# TODO: refactor this into an interface -def _is_ray_backend(backend) -> bool: - if isinstance(backend, str): - return backend == RAY - return isinstance(backend, RayBackend) + def _is_ray_backend(backend) -> bool: + return False def _get_relative_checkpoints_dir_parts(path: Path): @@ -63,7 +75,7 @@ def _get_relative_checkpoints_dir_parts(path: Path): class HyperoptExecutor(ABC): def __init__( - self, hyperopt_sampler: Union[dict, HyperoptSampler], output_feature: str, metric: str, split: str + self, hyperopt_sampler: Union[dict, RayTuneSampler], output_feature: str, metric: str, split: str ) -> None: self.hyperopt_sampler = hyperopt_sampler self.output_feature = output_feature @@ -184,121 +196,22 @@ def execute( pass -class SerialExecutor(HyperoptExecutor): - def __init__( - self, hyperopt_sampler: HyperoptSampler, output_feature: str, metric: str, split: str, **kwargs - ) -> None: - HyperoptExecutor.__init__(self, hyperopt_sampler, output_feature, metric, split) - - def execute( - self, - config, - dataset=None, - training_set=None, - validation_set=None, - test_set=None, - training_set_metadata=None, - data_format=None, - experiment_name="hyperopt", - model_name="run", - # model_load_path=None, - # model_resume_path=None, - skip_save_training_description=False, - skip_save_training_statistics=False, - skip_save_model=False, - skip_save_progress=False, - skip_save_log=False, - skip_save_processed_input=True, - skip_save_unprocessed_output=False, - skip_save_predictions=False, - skip_save_eval_stats=False, - output_directory="results", - gpus=None, - gpu_memory_limit=None, - allow_parallel_threads=True, - callbacks=None, - backend=None, - random_seed=default_random_seed, - debug=False, - **kwargs, - ) -> HyperoptResults: - trial_results = [] - trials = 0 - while not self.hyperopt_sampler.finished(): - sampled_parameters = self.hyperopt_sampler.sample_batch() - metric_scores = [] - - for i, parameters in enumerate(sampled_parameters): - modified_config = substitute_parameters(copy.deepcopy(config), parameters) - - trial_id = trials + i - - model = LudwigModel( - config=modified_config, - backend=backend, - gpus=gpus, - gpu_memory_limit=gpu_memory_limit, - allow_parallel_threads=allow_parallel_threads, - callbacks=callbacks, - ) - eval_stats, train_stats, _, _ = model.experiment( - dataset=dataset, - training_set=training_set, - validation_set=validation_set, - test_set=test_set, - training_set_metadata=training_set_metadata, - data_format=data_format, - experiment_name=f"{experiment_name}_{trial_id}", - model_name=model_name, - # model_load_path=model_load_path, - # model_resume_path=model_resume_path, - eval_split=self.split, - skip_save_training_description=skip_save_training_description, - skip_save_training_statistics=skip_save_training_statistics, - skip_save_model=skip_save_model, - skip_save_progress=skip_save_progress, - skip_save_log=skip_save_log, - skip_save_processed_input=skip_save_processed_input, - skip_save_unprocessed_output=skip_save_unprocessed_output, - skip_save_predictions=skip_save_predictions, - skip_save_eval_stats=skip_save_eval_stats, - output_directory=output_directory, - skip_collect_predictions=True, - skip_collect_overall_stats=False, - random_seed=random_seed, - debug=debug, - ) - metric_score = self.get_metric_score(train_stats) - metric_scores.append(metric_score) - - trial_results.append( - TrialResults( - parameters=parameters, - metric_score=metric_score, - training_stats=train_stats, - eval_stats=eval_stats, - ) - ) - trials += len(sampled_parameters) - - self.hyperopt_sampler.update_batch(zip(sampled_parameters, metric_scores)) - - ordered_trials = self.sort_hyperopt_results(trial_results) - return HyperoptResults(ordered_trials=ordered_trials) - - class RayTuneExecutor(HyperoptExecutor): def __init__( self, hyperopt_sampler, output_feature: str, metric: str, + goal: str, split: str, + search_alg: Optional[Dict] = None, cpu_resources_per_trial: int = None, gpu_resources_per_trial: int = None, kubernetes_namespace: str = None, time_budget_s: Union[int, float, datetime.timedelta] = None, max_concurrent_trials: Optional[int] = None, + num_samples: int = 1, + scheduler: Optional[Dict] = None, **kwargs, ) -> None: if ray is None: @@ -316,10 +229,14 @@ def __init__( logger.info("Initializing new Ray cluster...") ray.init(ignore_reinit_error=True) self.search_space = hyperopt_sampler.search_space - self.num_samples = hyperopt_sampler.num_samples - self.goal = hyperopt_sampler.goal - self.search_alg_dict = hyperopt_sampler.search_alg_dict - self.scheduler = hyperopt_sampler.scheduler + self.num_samples = num_samples + self.goal = goal + self.search_algorithm = ( + get_search_algorithm(None)(search_alg) + if search_alg is None + else get_search_algorithm(search_alg.get(TYPE, None))(search_alg) + ) + self.scheduler = None if scheduler is None else tune.create_scheduler(scheduler[TYPE], **scheduler) self.decode_ctx = hyperopt_sampler.decode_ctx self.output_feature = output_feature self.metric = metric @@ -643,6 +560,7 @@ def execute( backend=None, random_seed=default_random_seed, debug=False, + hyperopt_log_verbosity=3, **kwargs, ) -> RayTuneResults: if isinstance(dataset, str) and not has_remote_protocol(dataset) and not os.path.isabs(dataset): @@ -696,13 +614,21 @@ def execute( mode = "min" if self.goal != MAXIMIZE else "max" metric = "metric_score" - if self.search_alg_dict is not None: - if TYPE not in self.search_alg_dict: - logger.warning("WARNING: Kindly set type param for search_alg " "to utilize Tune's Search Algorithms.") + # if random seed not set, use Ludwig seed + self.search_algorithm.check_for_random_seed(random_seed) + if self.search_algorithm.search_alg_dict is not None: + if TYPE not in self.search_algorithm.search_alg_dict: + candiate_search_algs = [search_alg for search_alg in SEARCH_ALG_IMPORT.keys()] + logger.warning( + "WARNING: search_alg type parameter missing, using 'variant_generator' as default. " + f"These are possible values for the type parameter: {candiate_search_algs}." + ) search_alg = None else: - search_alg_type = self.search_alg_dict[TYPE] - search_alg = tune.create_searcher(search_alg_type, metric=metric, mode=mode, **self.search_alg_dict) + search_alg_type = self.search_algorithm.search_alg_dict[TYPE] + search_alg = tune.create_searcher( + search_alg_type, metric=metric, mode=mode, **self.search_algorithm.search_alg_dict + ) else: search_alg = None @@ -781,6 +707,7 @@ def run_experiment_trial(config, local_hyperopt_dict, checkpoint_dir=None): trial_name_creator=lambda trial: f"trial_{trial.trial_id}", trial_dirname_creator=lambda trial: f"trial_{trial.trial_id}", callbacks=tune_callbacks, + verbose=hyperopt_log_verbosity, ) except Exception as e: # Explicitly raise a RuntimeError if an error is encountered during a Ray trial. @@ -842,7 +769,7 @@ def get_build_hyperopt_executor(executor_type): return get_from_registry(executor_type, executor_registry) -executor_registry = {"serial": SerialExecutor, "ray": RayTuneExecutor} +executor_registry = {"ray": RayTuneExecutor} def set_values(model_dict, name, parameters_dict): diff --git a/ludwig/hyperopt/run.py b/ludwig/hyperopt/run.py index 01624569236..14464610d56 100644 --- a/ludwig/hyperopt/run.py +++ b/ludwig/hyperopt/run.py @@ -8,11 +8,11 @@ from ludwig.api import LudwigModel from ludwig.backend import Backend, initialize_backend, LocalBackend from ludwig.callbacks import Callback -from ludwig.constants import COMBINED, EXECUTOR, HYPEROPT, LOSS, MINIMIZE, SAMPLER, TEST, TRAINING, TYPE, VALIDATION +from ludwig.constants import COMBINED, EXECUTOR, HYPEROPT, LOSS, MINIMIZE, RAY, TEST, TRAINING, TYPE, VALIDATION from ludwig.features.feature_registries import output_type_registry from ludwig.hyperopt.execution import executor_registry, get_build_hyperopt_executor, RayTuneExecutor from ludwig.hyperopt.results import HyperoptResults -from ludwig.hyperopt.sampling import get_build_hyperopt_sampler, sampler_registry +from ludwig.hyperopt.sampling import get_build_hyperopt_sampler from ludwig.hyperopt.utils import print_hyperopt_results, save_hyperopt_stats, should_tune_preprocessing from ludwig.utils.defaults import default_random_seed, merge_with_defaults from ludwig.utils.fs_utils import makedirs, open_file @@ -56,6 +56,7 @@ def hyperopt( callbacks: List[Callback] = None, backend: Union[Backend, str] = None, random_seed: int = default_random_seed, + hyperopt_log_verbosity: int = 3, **kwargs, ) -> HyperoptResults: """This method performs an hyperparameter optimization. @@ -148,6 +149,9 @@ def hyperopt( of backend to use to execute preprocessing / training steps. :param random_seed: (int: default: 42) random seed used for weights initialization, splits and any other random function. + :param hyperopt_log_verbosity: (int: default: 3) controls verbosity of + ray tune log messages. Valid values: 0 = silent, 1 = only status updates, + 2 = status and brief trial results, 3 = status and detailed trial results. # Return @@ -175,7 +179,7 @@ def hyperopt( logger.info(pformat(hyperopt_config, indent=4)) logger.info("\n") - sampler = hyperopt_config["sampler"] + search_alg = hyperopt_config["search_alg"] executor = hyperopt_config["executor"] parameters = hyperopt_config["parameters"] split = hyperopt_config["split"] @@ -243,10 +247,10 @@ def hyperopt( ) ) - hyperopt_sampler = get_build_hyperopt_sampler(sampler[TYPE])(goal, parameters, **sampler) + hyperopt_sampler = get_build_hyperopt_sampler(RAY)(parameters) hyperopt_executor = get_build_hyperopt_executor(executor[TYPE])( - hyperopt_sampler, output_feature, metric, split, **executor + hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor ) # Explicitly default to a local backend to avoid picking up Ray or Horovod @@ -324,6 +328,7 @@ def hyperopt( callbacks=callbacks, backend=backend, random_seed=random_seed, + hyperopt_log_verbosity=hyperopt_log_verbosity, **kwargs, ) @@ -351,24 +356,13 @@ def hyperopt( def update_hyperopt_params_with_defaults(hyperopt_params): - set_default_value(hyperopt_params, SAMPLER, {}) set_default_value(hyperopt_params, EXECUTOR, {}) set_default_value(hyperopt_params, "split", VALIDATION) set_default_value(hyperopt_params, "output_feature", COMBINED) set_default_value(hyperopt_params, "metric", LOSS) set_default_value(hyperopt_params, "goal", MINIMIZE) - set_default_values(hyperopt_params[SAMPLER], {TYPE: "random"}) - - sampler = get_from_registry(hyperopt_params[SAMPLER][TYPE], sampler_registry) - sampler_defaults = {k: v for k, v in sampler.__dict__.items() if k in get_class_attributes(sampler)} - set_default_values( - hyperopt_params[SAMPLER], - sampler_defaults, - ) - - set_default_values(hyperopt_params[EXECUTOR], {TYPE: "serial"}) - + set_default_values(hyperopt_params[EXECUTOR], {TYPE: "ray"}) executor = get_from_registry(hyperopt_params[EXECUTOR][TYPE], executor_registry) executor_defaults = {k: v for k, v in executor.__dict__.items() if k in get_class_attributes(executor)} set_default_values( diff --git a/ludwig/hyperopt/sampling.py b/ludwig/hyperopt/sampling.py index 0c05c766531..f9602256cb1 100644 --- a/ludwig/hyperopt/sampling.py +++ b/ludwig/hyperopt/sampling.py @@ -13,21 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import copy -import itertools import json import logging -from abc import ABC, abstractmethod +from abc import ABC +from importlib import import_module from inspect import signature -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Dict -import numpy as np -from bayesmark.builtin_opt.pysot_optimizer import PySOTOptimizer -from bayesmark.space import JointSpace - -from ludwig.constants import CATEGORY, FLOAT, INT, MAXIMIZE, MINIMIZE, SPACE, TYPE from ludwig.utils.misc_utils import get_from_registry -from ludwig.utils.strings_utils import str2bool try: from ray import tune @@ -67,309 +60,23 @@ def ray_resource_allocation_function( logger = logging.getLogger(__name__) -def int_grid_function(low: int, high: int, steps=None, **kwargs): - if steps is None: - steps = high - low + 1 - samples = np.linspace(low, high, num=steps, dtype=int) - return samples.tolist() - - -def float_grid_function(low: float, high: float, steps=None, space="linear", base=None, **kwargs): - if steps is None: - steps = int(high - low + 1) - if space == "linear": - samples = np.linspace(low, high, num=steps) - elif space == "log": - if base: - samples = np.logspace(low, high, num=steps, base=base) - else: - samples = np.geomspace(low, high, num=steps) - else: - raise ValueError( - 'The space parameter of the float grid function is "{}". ' 'Available ones are: {"linear", "log"}' - ) - return samples.tolist() - - -def category_grid_function(values, **kwargs): - return values - - def identity(x): return x -grid_functions_registry = { - "int": int_grid_function, - "float": float_grid_function, - "category": category_grid_function, -} - - -class HyperoptSampler(ABC): - def __init__(self, goal: str, parameters: Dict[str, Any], batch_size: int = 1) -> None: - assert goal in [MINIMIZE, MAXIMIZE] - self.goal = goal # useful for Bayesian strategy - self.parameters = parameters - self.default_batch_size = batch_size - - @abstractmethod - def sample(self) -> Dict[str, Any]: - # Yields a set of parameters names and their values. - # Define `build_hyperopt_strategy` which would take parameters as inputs - pass - - def sample_batch(self, batch_size: int = None) -> List[Dict[str, Any]]: - samples = [] - if batch_size is None: - batch_size = self.default_batch_size - for _ in range(batch_size): - try: - samples.append(self.sample()) - except IndexError: - # Logic: is samples is empty it means that we encountered - # the IndexError the first time we called self.sample() - # so we should raise the exception. If samples is not empty - # we should just return it, even if it will contain - # less samples than the specified batch_size. - # This is fine as from now on finished() will return True. - if not samples: - raise IndexError - return samples - - @abstractmethod - def update(self, sampled_parameters: Dict[str, Any], metric_score: float): - # Given the results of previous computation, it updates - # the strategy (not needed for stateless strategies like "grid" - # and random, but will be needed by Bayesian) - pass - - def update_batch(self, parameters_metric_tuples: Iterable[Tuple[Dict[str, Any], float]]): - for (sampled_parameters, metric_score) in parameters_metric_tuples: - self.update(sampled_parameters, metric_score) - - @abstractmethod - def finished(self) -> bool: - # Should return true when all samples have been sampled - pass - - -class RandomSampler(HyperoptSampler): - num_samples = 10 - - def __init__(self, goal: str, parameters: Dict[str, Any], num_samples=10, **kwargs) -> None: - HyperoptSampler.__init__(self, goal, parameters) - params_for_join_space = copy.deepcopy(parameters) - - cat_params_values_types = {} - for param_name, param_values in params_for_join_space.items(): - if param_values[TYPE] == CATEGORY: - param_values[TYPE] = "cat" - values_str = [] - values_types = {} - for value in param_values["values"]: - value_type = type(value) - if value_type == bool: - value_str = str(value) - value_type = str2bool - elif value_type == str or value_type == int or value_type == float: - value_str = str(value) - else: - value_str = json.dumps(value) - value_type = json.loads - values_str.append(value_str) - values_types[value_str] = value_type - param_values["values"] = values_str - cat_params_values_types[param_name] = values_types - if param_values[TYPE] == FLOAT: - param_values[TYPE] = "real" - if param_values[TYPE] == INT or param_values[TYPE] == "real": - if SPACE not in param_values: - param_values[SPACE] = "linear" - param_values["range"] = (param_values["low"], param_values["high"]) - del param_values["low"] - del param_values["high"] - - self.cat_params_values_types = cat_params_values_types - self.space = JointSpace(params_for_join_space) - self.num_samples = num_samples - self.samples = self._determine_samples() - self.sampled_so_far = 0 - self.default_batch_size = self.num_samples - - def _determine_samples(self): - samples = [] - for _ in range(self.num_samples): - bounds = self.space.get_bounds() - x = bounds[:, 0] + (bounds[:, 1] - bounds[:, 0]) * np.random.rand(1, len(self.space.get_bounds())) - sample = self.space.unwarp(x)[0] - samples.append(sample) - return samples - - def sample(self) -> Dict[str, Any]: - if self.sampled_so_far >= len(self.samples): - raise IndexError() - sample = self.samples[self.sampled_so_far] - for key in sample: - if key in self.cat_params_values_types: - values_types = self.cat_params_values_types[key] - sample[key] = values_types[sample[key]](sample[key]) - self.sampled_so_far += 1 - return sample - - def update(self, sampled_parameters: Dict[str, Any], metric_score: float): - pass - - def finished(self) -> bool: - return self.sampled_so_far >= len(self.samples) - - -class GridSampler(HyperoptSampler): - def __init__(self, goal: str, parameters: Dict[str, Any], **kwargs) -> None: - HyperoptSampler.__init__(self, goal, parameters) - self.search_space = self._create_search_space() - self.samples = self._get_grids() - self.sampled_so_far = 0 - self.default_batch_size = len(self.samples) - - def _create_search_space(self): - search_space = {} - for hp_name, hp_params in self.parameters.items(): - grid_function = get_from_registry(hp_params[TYPE], grid_functions_registry) - search_space[hp_name] = grid_function(**hp_params) - return search_space - - def _get_grids(self): - hp_params = sorted(self.search_space) - grids = [ - dict(zip(hp_params, prod)) - for prod in itertools.product(*(self.search_space[hp_name] for hp_name in hp_params)) - ] - - return grids - - def sample(self) -> Dict[str, Any]: - if self.sampled_so_far >= len(self.samples): - raise IndexError() - sample = self.samples[self.sampled_so_far] - self.sampled_so_far += 1 - return sample - - def update(self, sampled_parameters: Dict[str, Any], statistics: Dict[str, Any]): - # actual implementation ... - pass - - def finished(self) -> bool: - return self.sampled_so_far >= len(self.samples) - - -class PySOTSampler(HyperoptSampler): - """pySOT: Surrogate optimization in Python. - This is a wrapper around the pySOT package (https://github.com/dme65/pySOT): - David Eriksson, David Bindel, Christine Shoemaker - pySOT and POAP: An event-driven asynchronous framework for surrogate optimization - """ - - def __init__(self, goal: str, parameters: Dict[str, Any], num_samples=10, **kwargs) -> None: - HyperoptSampler.__init__(self, goal, parameters) - params_for_join_space = copy.deepcopy(parameters) - - cat_params_values_types = {} - for param_name, param_values in params_for_join_space.items(): - if param_values[TYPE] == CATEGORY: - param_values[TYPE] = "cat" - values_str = [] - values_types = {} - for value in param_values["values"]: - value_type = type(value) - if value_type == bool: - value_str = str(value) - value_type = str2bool - elif value_type == str or value_type == int or value_type == float: - value_str = str(value) - else: - value_str = json.dumps(value) - value_type = json.loads - values_str.append(value_str) - values_types[value_str] = value_type - param_values["values"] = values_str - cat_params_values_types[param_name] = values_types - if param_values[TYPE] == FLOAT: - param_values[TYPE] = "real" - if param_values[TYPE] == INT or param_values[TYPE] == "real": - if SPACE not in param_values: - param_values[SPACE] = "linear" - param_values["range"] = (param_values["low"], param_values["high"]) - del param_values["low"] - del param_values["high"] - - self.cat_params_values_types = cat_params_values_types - self.pysot_optimizer = PySOTOptimizer(params_for_join_space) - self.sampled_so_far = 0 - self.num_samples = num_samples - - def sample(self) -> Dict[str, Any]: - """Suggest one new point to be evaluated.""" - if self.sampled_so_far >= self.num_samples: - raise IndexError() - sample = self.pysot_optimizer.suggest(n_suggestions=1)[0] - for key in sample: - if key in self.cat_params_values_types: - values_types = self.cat_params_values_types[key] - sample[key] = values_types[sample[key]](sample[key]) - self.sampled_so_far += 1 - return sample - - def update(self, sampled_parameters: Dict[str, Any], metric_score: float): - for key in sampled_parameters: - if key in self.cat_params_values_types: - if type(sampled_parameters[key]) not in {bool, int, float, str}: - sampled_parameters[key] = json.dumps(sampled_parameters[key]) - else: - sampled_parameters[key] = str(sampled_parameters[key]) - self.pysot_optimizer.observe([sampled_parameters], [metric_score]) - - def finished(self) -> bool: - return self.sampled_so_far >= self.num_samples - - -class RayTuneSampler(HyperoptSampler): +class RayTuneSampler: def __init__( self, - goal: str, parameters: Dict[str, Any], - search_alg: dict = None, - scheduler: dict = None, - num_samples=1, **kwargs, ) -> None: - HyperoptSampler.__init__(self, goal, parameters) self._check_ray_tune() self.search_space, self.decode_ctx = self._get_search_space(parameters) - self.search_alg_dict = search_alg - self.scheduler = self._create_scheduler(scheduler, parameters) - self.num_samples = num_samples - self.goal = goal def _check_ray_tune(self): if not _HAS_RAY_TUNE: raise ValueError("Requested Ray sampler but Ray Tune is not installed. Run `pip install ray[tune]`") - def _create_scheduler(self, scheduler_config, parameters): - if not scheduler_config: - return None - - dynamic_resource_allocation = scheduler_config.pop("dynamic_resource_allocation", False) - - if scheduler_config.get("type") == "pbt": - scheduler_config.update({"hyperparam_mutations": self.search_space}) - - scheduler = tune.create_scheduler(scheduler_config.get("type"), **scheduler_config) - - if dynamic_resource_allocation: - scheduler = ResourceChangingScheduler(scheduler, ray_resource_allocation_function) - return scheduler - def _get_search_space(self, parameters): config = {} ctx = {} @@ -395,15 +102,6 @@ def _get_search_space(self, parameters): config[param] = param_search_space(**param_search_input_args) return config, ctx - def sample(self) -> Dict[str, Any]: - pass - - def update(self, sampled_parameters: Dict[str, Any], statistics: Dict[str, Any]): - pass - - def finished(self) -> bool: - pass - @staticmethod def encode_values(param, values, ctx): """JSON encodes any search spaces whose values are lists / dicts. @@ -432,4 +130,166 @@ def get_build_hyperopt_sampler(strategy_type): return get_from_registry(strategy_type, sampler_registry) -sampler_registry = {"grid": GridSampler, "random": RandomSampler, "pysot": PySOTSampler, "ray": RayTuneSampler} +sampler_registry = {"ray": RayTuneSampler} + + +# TODO: split to separate module? +def _is_package_installed(package_name: str, search_algo_name: str) -> bool: + try: + import_module(package_name) + return True + except ImportError: + raise ImportError( + f"Search algorithm {search_algo_name} requires package {package_name}, however package is not installed." + " Please refer to Ray Tune documentation for packages required for this search algorithm." + ) + + +class SearchAlgorithm(ABC): + def __init__(self, search_alg_dict: Dict) -> None: + self.search_alg_dict = search_alg_dict + self.random_seed_attribute_name = None + + def check_for_random_seed(self, ludwig_random_seed: int) -> None: + if self.random_seed_attribute_name not in self.search_alg_dict: + self.search_alg_dict[self.random_seed_attribute_name] = ludwig_random_seed + + +class BasicVariantSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "random_state" + + +class HyperoptSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("hyperopt", "hyperopt") + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "random_state_seed" + + +class BOHBSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("hpbandster", "bohb") + _is_package_installed("ConfigSpace", "bohb") + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "seed" + + +class AxSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("sqlalchemy", "ax") + _is_package_installed("ax", "ax") + super().__init__(search_alg_dict) + + # override parent method, this search algorithm does not support + # setting random seed + def check_for_random_seed(self, ludwig_random_seed: int) -> None: + pass + + +class BayesOptSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("bayes_opt", "bayesopt") + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "random_state" + + +class BlendsearchSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("flaml", "blendsearch") + super().__init__(search_alg_dict) + + # override parent method, this search algorithm does not support + # setting random seed + def check_for_random_seed(self, ludwig_random_seed: int) -> None: + pass + + +class CFOSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("flaml", "cfo") + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "seed" + + # override parent method, this search algorithm does not support + # setting random seed + def check_for_random_seed(self, ludwig_random_seed: int) -> None: + pass + + +class DragonflySA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("dragonfly", "dragonfly") + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "random_state_seed" + + +class HEBOSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("hebo", "hebo") + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "random_state_seed" + + +class SkoptSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("skopt", "skopt") + super().__init__(search_alg_dict) + + # override parent method, this search algorithm does not support + # setting random seed + def check_for_random_seed(self, ludwig_random_seed: int) -> None: + pass + + +class NevergradSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("nevergrad", "nevergrad") + super().__init__(search_alg_dict) + + # override parent method, this search algorithm does not support + # setting random seed + def check_for_random_seed(self, ludwig_random_seed: int) -> None: + pass + + +class OptunaSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("optuna", "optuna") + super().__init__(search_alg_dict) + self.random_seed_attribute_name = "seed" + + +class ZooptSA(SearchAlgorithm): + def __init__(self, search_alg_dict: Dict) -> None: + _is_package_installed("zoopt", "zoopt") + super().__init__(search_alg_dict) + + # override parent method, this search algorithm does not support + # setting random seed + def check_for_random_seed(self, ludwig_random_seed: int) -> None: + pass + + +def get_search_algorithm(search_algo): + return get_from_registry(search_algo, search_algo_registry) + + +search_algo_registry = { + None: BasicVariantSA, + "variant_generator": BasicVariantSA, + "random": BasicVariantSA, + "hyperopt": HyperoptSA, + "bohb": BOHBSA, + "ax": AxSA, + "bayesopt": BayesOptSA, + "blendsearch": BlendsearchSA, + "cfo": CFOSA, + "dragonfly": DragonflySA, + "hebo": HEBOSA, + "skopt": SkoptSA, + "nevergrad": NevergradSA, + "optuna": OptunaSA, + "zoopt": ZooptSA, +} diff --git a/ludwig/hyperopt/utils.py b/ludwig/hyperopt/utils.py index 368420f221a..426dbe24933 100644 --- a/ludwig/hyperopt/utils.py +++ b/ludwig/hyperopt/utils.py @@ -1,9 +1,10 @@ +import dataclasses import json import logging import os from ludwig.constants import HYPEROPT, PARAMETERS, PREPROCESSING -from ludwig.hyperopt.results import HyperoptResults +from ludwig.hyperopt.results import HyperoptResults, TrialResults from ludwig.utils.data_utils import save_json from ludwig.utils.print_utils import print_boxed @@ -13,7 +14,8 @@ def print_hyperopt_results(hyperopt_results: HyperoptResults): print_boxed("HYPEROPT RESULTS", print_fun=logger.info) for trial_results in hyperopt_results.ordered_trials: - logger.info(f"score: {trial_results.metric_score:.6f} | parameters: {trial_results.parameters}") + if not isinstance(trial_results.metric_score, str): + logger.info(f"score: {trial_results.metric_score:.6f} | parameters: {trial_results.parameters}") logger.info("") @@ -30,8 +32,16 @@ def load_json_value(v): return v +# define set containing names to return for TrialResults +TRIAL_RESULTS_NAMES_SET = {f.name for f in dataclasses.fields(TrialResults)} + + def load_json_values(d): - return {k: load_json_value(v) for k, v in d.items()} + # ensure metric_score is a string for the json load to eliminate extraneous exception message + d["metric_score"] = str(d["metric_score"]) + + # load only data required for TrialResults + return {k: load_json_value(v) for k, v in d.items() if k in TRIAL_RESULTS_NAMES_SET} def should_tune_preprocessing(config): diff --git a/ludwig/hyperopt_cli.py b/ludwig/hyperopt_cli.py index 2b673fec7fb..fb81311770f 100644 --- a/ludwig/hyperopt_cli.py +++ b/ludwig/hyperopt_cli.py @@ -59,6 +59,7 @@ def hyperopt_cli( callbacks: List[Callback] = None, backend: Union[Backend, str] = None, random_seed: int = default_random_seed, + hyperopt_log_verbosity: int = 3, **kwargs, ): """Searches for optimal hyperparameters. @@ -150,6 +151,9 @@ def hyperopt_cli( of backend to use to execute preprocessing / training steps. :param random_seed: (int: default: 42) random seed used for weights initialization, splits and any other random function. + :param hyperopt_log_verbosity: (int: default: 3) Controls verbosity of ray tune log messages. Valid values: + 0 = silent, 1 = only status updates, 2 = status and brief trial + results, 3 = status and detailed trial results. # Return :return" (`None`) @@ -183,6 +187,7 @@ def hyperopt_cli( callbacks=callbacks, backend=backend, random_seed=random_seed, + hyperopt_log_verbosity=hyperopt_log_verbosity, **kwargs, ) @@ -356,6 +361,16 @@ def cli(sys_argv): "to a random number generator: data splitting, parameter " "initialization and training set shuffling", ) + parser.add_argument( + "-hlv", + "--hyperopt_log_verbosity", + type=int, + default=3, + choices=[0, 1, 2, 3], + help="Controls verbosity of ray tune log messages. Valid values: " + "0 = silent, 1 = only status updates, 2 = status and brief trial " + "results, 3 = status and detailed trial results.", + ) parser.add_argument("-g", "--gpus", nargs="+", type=int, default=None, help="list of gpus to use") parser.add_argument( "-gml", "--gpu_memory_limit", type=int, default=None, help="maximum memory in MB to allocate per GPU device" diff --git a/ludwig/utils/defaults.py b/ludwig/utils/defaults.py index bc3a7606cf7..7cb1941a949 100644 --- a/ludwig/utils/defaults.py +++ b/ludwig/utils/defaults.py @@ -29,12 +29,17 @@ COMBINED, DROP_ROW, EVAL_BATCH_SIZE, + EXECUTOR, HYPEROPT, LOSS, NAME, NUMBER, + PARAMETERS, PREPROCESSING, PROC_COLUMN, + RAY, + SAMPLER, + SEARCH_ALG, TRAINER, TYPE, ) @@ -161,16 +166,75 @@ def _upgrade_deprecated_fields(config: Dict[str, Any]): ) feature[TYPE] = NUMBER - if HYPEROPT in config and "parameters" in config[HYPEROPT]: - hparams = config[HYPEROPT]["parameters"] - for k, v in list(hparams.items()): - substr = "training." - if k.startswith(substr): + if HYPEROPT in config: + # check for use of legacy "training" reference, if any found convert to "trainer" + if PARAMETERS in config[HYPEROPT]: + hparams = config[HYPEROPT][PARAMETERS] + for k, v in list(hparams.items()): + substr = "training." + if k.startswith(substr): + warnings.warn( + 'Config section "training" renamed to "trainer" and will be removed in v0.6', DeprecationWarning + ) + hparams["trainer." + k[len(substr) :]] = v + del hparams[k] + + # check for legacy parameters in "executor" + if EXECUTOR in config[HYPEROPT]: + hpexecutor = config[HYPEROPT][EXECUTOR] + executor_type = hpexecutor[TYPE] + if executor_type != RAY: warnings.warn( - 'Config section "training" renamed to "trainer" and will be removed in v0.6', DeprecationWarning + f'executor type "{executor_type}" not supported, converted to "ray" will be flagged as error ' + "in v0.6", + DeprecationWarning, ) - hparams["trainer." + k[len(substr) :]] = v - del hparams[k] + hpexecutor[TYPE] = RAY + + # if search_alg not at top level and is present in executor, promote to top level + if SEARCH_ALG in hpexecutor: + # promote only if not in top-level, otherwise use current top-level + if SEARCH_ALG not in config[HYPEROPT]: + config[HYPEROPT][SEARCH_ALG] = hpexecutor[SEARCH_ALG] + del hpexecutor[SEARCH_ALG] + else: + warnings.warn( + 'Missing "executor" section, adding "ray" executor will be flagged as error in v0.6', DeprecationWarning + ) + config[HYPEROPT][EXECUTOR] = {TYPE: RAY} + + # check for legacy "sampler" section + if SAMPLER in config[HYPEROPT]: + warnings.warn( + f'"{SAMPLER}" is no longer supported, converted to "{SEARCH_ALG}". "{SAMPLER}" will be flagged as ' + "error in v0.6", + DeprecationWarning, + ) + if SEARCH_ALG in config[HYPEROPT][SAMPLER]: + if SEARCH_ALG not in config[HYPEROPT]: + config[HYPEROPT][SEARCH_ALG] = config[HYPEROPT][SAMPLER][SEARCH_ALG] + warnings.warn('Moved "search_alg" to hyperopt config top-level', DeprecationWarning) + + # if num_samples or scheduler exist in SAMPLER move to EXECUTOR Section + if "num_samples" in config[HYPEROPT][SAMPLER] and "num_samples" not in config[HYPEROPT][EXECUTOR]: + config[HYPEROPT][EXECUTOR]["num_samples"] = config[HYPEROPT][SAMPLER]["num_samples"] + warnings.warn('Moved "num_samples" from "sampler" to "executor"', DeprecationWarning) + + if "scheduler" in config[HYPEROPT][SAMPLER] and "scheduler" not in config[HYPEROPT][EXECUTOR]: + config[HYPEROPT][EXECUTOR]["scheduler"] = config[HYPEROPT][SAMPLER]["scheduler"] + warnings.warn('Moved "scheduler" from "sampler" to "executor"', DeprecationWarning) + + # remove legacy section + del config[HYPEROPT][SAMPLER] + + if SEARCH_ALG not in config[HYPEROPT]: + # make top-level as search_alg, if missing put in default value + config[HYPEROPT][SEARCH_ALG] = {TYPE: "variant_generator"} + warnings.warn( + 'Missing "search_alg" at hyperopt top-level, adding in default value, will be flagged as error ' + "in v0.6", + DeprecationWarning, + ) if TRAINER in config: trainer = config[TRAINER] @@ -237,7 +301,7 @@ def _merge_hyperopt_with_trainer(config: dict) -> None: if "hyperopt" not in config: return - scheduler = config["hyperopt"].get("sampler", {}).get("scheduler") + scheduler = config["hyperopt"].get("executor", {}).get("scheduler") if not scheduler: return diff --git a/ludwig/utils/visualization_utils.py b/ludwig/utils/visualization_utils.py index 1c6497033e7..4a104d55540 100644 --- a/ludwig/utils/visualization_utils.py +++ b/ludwig/utils/visualization_utils.py @@ -23,7 +23,7 @@ import numpy as np import pandas as pd -from ludwig.constants import TRAINING, TYPE, VALIDATION +from ludwig.constants import SPACE, TRAINING, VALIDATION logger = logging.getLogger(__name__) @@ -51,6 +51,10 @@ INT_QUANTILES = 10 FLOAT_QUANTILES = 10 +# mapping from RayTune search space to Ludwig types (float, int, category) for hyperopt visualizations +RAY_TUNE_FLOAT_SPACES = {"uniform", "quniform", "loguniform", "qloguniform", "randn", "qrandn"} +RAY_TUNE_INT_SPACES = {"randint", "qrandint", "lograndint", "qlograndint"} +RAY_TUNE_CATEGORY_SPACES = {"choice", "grid_search"} _matplotlib_34 = LooseVersion(mpl.__version__) >= LooseVersion("3.4") @@ -1270,7 +1274,7 @@ def bar_plot( def hyperopt_report(hyperparameters, hyperopt_results_df, metric, filename_template, float_precision=3): title = "Hyperopt Report: {}" for hp_name, hp_params in hyperparameters.items(): - if hp_params[TYPE] == "int": + if hp_params[SPACE] in RAY_TUNE_INT_SPACES: hyperopt_int_plot( hyperopt_results_df, hp_name, @@ -1278,7 +1282,7 @@ def hyperopt_report(hyperparameters, hyperopt_results_df, metric, filename_templ title.format(hp_name), filename_template.format(hp_name) if filename_template else None, ) - elif hp_params[TYPE] == "float": + elif hp_params[SPACE] in RAY_TUNE_FLOAT_SPACES: hyperopt_float_plot( hyperopt_results_df, hp_name, @@ -1287,7 +1291,7 @@ def hyperopt_report(hyperparameters, hyperopt_results_df, metric, filename_templ filename_template.format(hp_name) if filename_template else None, log_scale_x=hp_params["scale"] == "log" if "scale" in hp_params else False, ) - elif hp_params[TYPE] == "category": + elif hp_params[SPACE] in RAY_TUNE_CATEGORY_SPACES: hyperopt_category_plot( hyperopt_results_df, hp_name, @@ -1295,14 +1299,20 @@ def hyperopt_report(hyperparameters, hyperopt_results_df, metric, filename_templ title.format(hp_name), filename_template.format(hp_name) if filename_template else None, ) + else: + # TODO: more research needed on how to handle RayTune "sample_from" search space + raise ValueError( + f"{hp_params[SPACE]} search space not supported in Ludwig. " + f"Supported values are {RAY_TUNE_FLOAT_SPACES | RAY_TUNE_INT_SPACES | RAY_TUNE_CATEGORY_SPACES}." + ) # quantize float and int columns for hp_name, hp_params in hyperparameters.items(): - if hp_params[TYPE] == "int": + if hp_params[SPACE] in RAY_TUNE_INT_SPACES: num_distinct_values = len(hyperopt_results_df[hp_name].unique()) if num_distinct_values > INT_QUANTILES: hyperopt_results_df[hp_name] = pd.qcut(hyperopt_results_df[hp_name], q=INT_QUANTILES, precision=0) - elif hp_params[TYPE] == "float": + elif hp_params[SPACE] in RAY_TUNE_FLOAT_SPACES: hyperopt_results_df[hp_name] = pd.qcut( hyperopt_results_df[hp_name], q=FLOAT_QUANTILES, diff --git a/ludwig/visualize.py b/ludwig/visualize.py index 187bd9d64c9..18a75bfc34a 100644 --- a/ludwig/visualize.py +++ b/ludwig/visualize.py @@ -29,7 +29,7 @@ from ludwig.backend import LOCAL_BACKEND from ludwig.callbacks import Callback -from ludwig.constants import ACCURACY, EDIT_DISTANCE, HITS_AT_K, LOSS, PREDICTIONS, SPLIT, TRAINING, TYPE, VALIDATION +from ludwig.constants import ACCURACY, EDIT_DISTANCE, HITS_AT_K, LOSS, PREDICTIONS, SPACE, SPLIT, TRAINING, VALIDATION from ludwig.contrib import add_contrib_callback_args from ludwig.utils import visualization_utils from ludwig.utils.data_utils import ( @@ -3708,9 +3708,20 @@ def hyperopt_hiplot(hyperopt_stats_path, output_directory=None, **kwargs): ) +def _convert_space_to_dtype(space: str) -> str: + if space in visualization_utils.RAY_TUNE_FLOAT_SPACES: + return "float" + elif space in visualization_utils.RAY_TUNE_INT_SPACES: + return "int" + else: + return "object" + + def hyperopt_results_to_dataframe(hyperopt_results, hyperopt_parameters, metric): df = pd.DataFrame([{metric: res["metric_score"], **res["parameters"]} for res in hyperopt_results]) - df = df.astype({hp_name: hp_params[TYPE] for hp_name, hp_params in hyperopt_parameters.items()}) + df = df.astype( + {hp_name: _convert_space_to_dtype(hp_params[SPACE]) for hp_name, hp_params in hyperopt_parameters.items()} + ) return df diff --git a/requirements_distributed.txt b/requirements_distributed.txt index 077a09e8500..ea807b6e024 100644 --- a/requirements_distributed.txt +++ b/requirements_distributed.txt @@ -1,6 +1,7 @@ # requirements for dask dask[dataframe]>2021.3.1,<2022.1.1 pyarrow==6.0.1 # https://github.com/ray-project/ray/issues/22310 + # requirements for horovod horovod[pytorch]>=0.24.0 # requirements for ray diff --git a/requirements_hyperopt.txt b/requirements_hyperopt.txt index e9e56fa42db..12611f19c04 100644 --- a/requirements_hyperopt.txt +++ b/requirements_hyperopt.txt @@ -1,4 +1,8 @@ -bayesmark>=0.0.7 -pySOT + +# ray[default,tune]>=1.9.2,!=1.10 # TODO: remove +ray[default,tune]>=1.11.0 + + +# required for Ray Tune Search Algorithm support for AutoML +#search_alg: hyperopt hyperopt -ray[default,tune]>=1.9.2,!=1.10 diff --git a/requirements_test.txt b/requirements_test.txt index 1f64e927741..25f3ac7ea6e 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -6,6 +6,35 @@ comet_ml mlflow whylogs -# For testing BOHB with Ray Tune +# For testing optional Ray Tune Search Algorithms +# search_alg: bohb hpbandster ConfigSpace + +# search_alg: ax +ax-platform +sqlalchemy + +# search_alg: bayesopt +bayesian-optimization + +# search_alg: cfo and blendsearch +flaml[blendsearch] + +# search_alg: dragonfly +dragonfly-opt + +# search_alg: hebo +HEBO + +# search_alg: nevergrad +nevergrad + +# search_alg: optuna +optuna + +# search_alg: skopt +scikit-optimize + +# search_alg: zoopt +zoopt diff --git a/tests/conftest.py b/tests/conftest.py index ec710450a08..face73aea6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,20 +64,23 @@ def hyperopt_results(): hyperopt_configs = { "parameters": { "trainer.learning_rate": { - "type": "float", - "low": 0.0001, - "high": 0.01, - "space": "log", - "steps": 3, + "space": "loguniform", + "lower": 0.0001, + "upper": 0.01, }, - output_feature_name + ".output_size": {"type": "int", "low": 32, "high": 256, "steps": 5}, - output_feature_name + ".num_fc_layers": {"type": "int", "low": 1, "high": 5, "space": "linear", "steps": 4}, + output_feature_name + ".output_size": {"space": "choice", "categories": [32, 64, 128, 256]}, + output_feature_name + ".num_fc_layers": {"space": "randint", "lower": 1, "upper": 6}, }, "goal": "minimize", "output_feature": output_feature_name, "validation_metrics": "loss", - "executor": {"type": "serial"}, - "sampler": {"type": "random", "num_samples": 2}, + "executor": { + "type": "ray", + "num_samples": 2, + }, + "search_alg": { + "type": "variant_generator", + }, } # add hyperopt parameter space to the config diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index 89e651b1d8d..7ce74d32589 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -87,18 +87,21 @@ def _prepare_hyperopt_data(csv_filename, config_filename): "hyperopt": { "parameters": { "trainer.learning_rate": { - "type": "float", - "low": 0.0001, - "high": 0.01, - "space": "log", - "steps": 3, + "space": "loguniform", + "lower": 0.0001, + "upper": 0.01, } }, "goal": "minimize", "output_feature": output_features[0]["name"], "validation_metrics": "loss", - "executor": {"type": "serial"}, - "sampler": {"type": "random", "num_samples": 2}, + "executor": { + "type": "ray", + "num_samples": 2, + }, + "search_alg": { + "type": "variant_generator", + }, }, } diff --git a/tests/integration_tests/test_hyperopt.py b/tests/integration_tests/test_hyperopt.py index a8685a8cad8..667fe39efe2 100644 --- a/tests/integration_tests/test_hyperopt.py +++ b/tests/integration_tests/test_hyperopt.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import contextlib import logging import os.path +from typing import Dict, Optional, Tuple import pytest +import ray import torch -from ludwig.constants import ACCURACY, TRAINER +from ludwig.constants import ACCURACY, RAY, TRAINER from ludwig.hyperopt.execution import get_build_hyperopt_executor -from ludwig.hyperopt.results import HyperoptResults +from ludwig.hyperopt.results import HyperoptResults, RayTuneResults from ludwig.hyperopt.run import hyperopt, update_hyperopt_params_with_defaults from ludwig.hyperopt.sampling import get_build_hyperopt_sampler from ludwig.utils.defaults import merge_with_defaults @@ -30,47 +33,53 @@ logger.setLevel(logging.INFO) logging.getLogger("ludwig").setLevel(logging.INFO) +RANDOM_SEARCH_SIZE = 4 + HYPEROPT_CONFIG = { "parameters": { + # using only float parameter as common in all search algorithms "trainer.learning_rate": { - "type": "float", - "low": 0.0001, - "high": 0.1, - "space": "log", - "steps": 3, - }, - "combiner.num_fc_layers": { - "type": "int", - "low": 1, - "high": 4, - "space": "linear", - "steps": 3, + "space": "loguniform", + "lower": 0.001, + "upper": 0.1, }, - "combiner.fc_layers": { - "type": "category", - "values": [[{"output_size": 64}, {"output_size": 32}], [{"output_size": 64}], [{"output_size": 32}]], - }, - "utterance.cell_type": {"type": "category", "values": ["rnn", "gru"]}, - "utterance.bidirectional": {"type": "category", "values": [True, False]}, }, "goal": "minimize", + "executor": {"type": "ray", "num_samples": 2, "scheduler": {"type": "fifo"}}, + "search_alg": {"type": "variant_generator"}, } -SAMPLERS = [ - {"type": "grid"}, - {"type": "random", "num_samples": 5}, - {"type": "pysot", "num_samples": 5}, +SEARCH_ALGS = [ + None, + "variant_generator", + "random", + "hyperopt", + "bohb", + "ax", + "bayesopt", + "blendsearch", + "cfo", + "dragonfly", + "hebo", + "skopt", + "optuna", ] -EXECUTORS = [ - {"type": "serial"}, +SCHEDULERS = [ + "fifo", + "asynchyperband", + "async_hyperband", + "median_stopping_rule", + "medianstopping", + "hyperband", + "hb_bohb", + "pbt", + # "pb2", commented out for now: https://github.com/ray-project/ray/issues/24815 + "resource_changing", ] -@pytest.mark.distributed -@pytest.mark.parametrize("sampler", SAMPLERS) -@pytest.mark.parametrize("executor", EXECUTORS) -def test_hyperopt_executor(sampler, executor, csv_filename, validate_output_feature=False, validation_metric=None): +def _setup_ludwig_config(dataset_fp: str) -> Tuple[Dict, str]: input_features = [ text_feature(name="utterance", cell_type="lstm", reduce_output="sum"), category_feature(vocab_size=2, reduce_input="sum"), @@ -78,7 +87,7 @@ def test_hyperopt_executor(sampler, executor, csv_filename, validate_output_feat output_features = [category_feature(vocab_size=2, reduce_input="sum")] - rel_path = generate_data(input_features, output_features, csv_filename) + rel_path = generate_data(input_features, output_features, dataset_fp) config = { "input_features": input_features, @@ -89,10 +98,45 @@ def test_hyperopt_executor(sampler, executor, csv_filename, validate_output_feat config = merge_with_defaults(config) + return config, rel_path + + +@contextlib.contextmanager +def ray_start(num_cpus: Optional[int] = None, num_gpus: Optional[int] = None): + res = ray.init( + num_cpus=num_cpus, + num_gpus=num_gpus, + include_dashboard=False, + object_store_memory=150 * 1024 * 1024, + ) + try: + yield res + finally: + ray.shutdown() + + +@pytest.mark.distributed +@pytest.mark.parametrize("search_alg", SEARCH_ALGS) +def test_hyperopt_search_alg(search_alg, csv_filename, validate_output_feature=False, validation_metric=None): + config, rel_path = _setup_ludwig_config(csv_filename) + hyperopt_config = HYPEROPT_CONFIG.copy() + # finalize hyperopt config settings + if search_alg == "dragonfly": + hyperopt_config["search_alg"] = { + "type": search_alg, + "domain": "euclidean", + "optimizer": "random", + } + elif search_alg is None: + hyperopt_config["search_alg"] = {} + else: + hyperopt_config["search_alg"] = { + "type": search_alg, + } if validate_output_feature: - hyperopt_config["output_feature"] = output_features[0]["name"] + hyperopt_config["output_feature"] = config["output_features"][0]["name"] if validation_metric: hyperopt_config["validation_metric"] = validation_metric @@ -103,22 +147,24 @@ def test_hyperopt_executor(sampler, executor, csv_filename, validate_output_feat output_feature = hyperopt_config["output_feature"] metric = hyperopt_config["metric"] goal = hyperopt_config["goal"] + executor = hyperopt_config["executor"] + search_alg = hyperopt_config["search_alg"] - hyperopt_sampler = get_build_hyperopt_sampler(sampler["type"])(goal, parameters, **sampler) - - hyperopt_executor = get_build_hyperopt_executor(executor["type"])( - hyperopt_sampler, output_feature, metric, split, **executor - ) + hyperopt_sampler = get_build_hyperopt_sampler(RAY)(parameters) gpus = [i for i in range(torch.cuda.device_count())] - hyperopt_executor.execute(config, dataset=rel_path, gpus=gpus) + with ray_start(num_gpus=len(gpus)): + hyperopt_executor = get_build_hyperopt_executor(RAY)( + hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor + ) + raytune_results = hyperopt_executor.execute(config, dataset=rel_path) + assert isinstance(raytune_results, RayTuneResults) @pytest.mark.distributed def test_hyperopt_executor_with_metric(csv_filename): - test_hyperopt_executor( - {"type": "random", "num_samples": 2}, - {"type": "serial"}, + test_hyperopt_search_alg( + "variant_generator", csv_filename, validate_output_feature=True, validation_metric=ACCURACY, @@ -126,8 +172,64 @@ def test_hyperopt_executor_with_metric(csv_filename): @pytest.mark.distributed -@pytest.mark.parametrize("samplers", SAMPLERS) -def test_hyperopt_run_hyperopt(csv_filename, samplers): +@pytest.mark.parametrize("scheduler", SCHEDULERS) +def test_hyperopt_scheduler(scheduler, csv_filename, validate_output_feature=False, validation_metric=None): + config, rel_path = _setup_ludwig_config(csv_filename) + + hyperopt_config = HYPEROPT_CONFIG.copy() + # finalize hyperopt config settings + if scheduler == "pb2": + # setup scheduler hyperparam_bounds parameter + min = hyperopt_config["parameters"]["trainer.learning_rate"]["lower"] + max = hyperopt_config["parameters"]["trainer.learning_rate"]["upper"] + hyperparam_bounds = { + "trainer.learning_rate": [min, max], + } + hyperopt_config["executor"]["scheduler"] = { + "type": scheduler, + "hyperparam_bounds": hyperparam_bounds, + } + else: + hyperopt_config["executor"]["scheduler"] = { + "type": scheduler, + } + + if validate_output_feature: + hyperopt_config["output_feature"] = config["output_features"][0]["name"] + if validation_metric: + hyperopt_config["validation_metric"] = validation_metric + + update_hyperopt_params_with_defaults(hyperopt_config) + + parameters = hyperopt_config["parameters"] + split = hyperopt_config["split"] + output_feature = hyperopt_config["output_feature"] + metric = hyperopt_config["metric"] + goal = hyperopt_config["goal"] + executor = hyperopt_config["executor"] + search_alg = hyperopt_config["search_alg"] + + hyperopt_sampler = get_build_hyperopt_sampler(RAY)(parameters) + + gpus = [i for i in range(torch.cuda.device_count())] + with ray_start(num_gpus=len(gpus)): + # TODO: Determine if we still need this if-then-else construct + if search_alg["type"] in {""}: + with pytest.raises(ImportError): + get_build_hyperopt_executor(RAY)( + hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor + ) + else: + hyperopt_executor = get_build_hyperopt_executor(RAY)( + hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor + ) + raytune_results = hyperopt_executor.execute(config, dataset=rel_path) + assert isinstance(raytune_results, RayTuneResults) + + +@pytest.mark.distributed +@pytest.mark.parametrize("search_space", ["random", "grid"]) +def test_hyperopt_run_hyperopt(csv_filename, search_space): input_features = [ text_feature(name="utterance", cell_type="lstm", reduce_output="sum"), category_feature(vocab_size=2, reduce_input="sum"), @@ -146,38 +248,57 @@ def test_hyperopt_run_hyperopt(csv_filename, samplers): output_feature_name = output_features[0]["name"] - hyperopt_configs = { - "parameters": { + if search_space == "random": + # random search will be size of num_samples + search_parameters = { "trainer.learning_rate": { - "type": "float", - "low": 0.0001, - "high": 0.01, - "space": "log", - "steps": 3, + "lower": 0.0001, + "upper": 0.01, + "space": "loguniform", }, output_feature_name + ".fc_layers": { - "type": "category", - "values": [ + "space": "choice", + "categories": [ [{"output_size": 64}, {"output_size": 32}], [{"output_size": 64}], [{"output_size": 32}], ], }, - output_feature_name + ".output_size": {"type": "int", "low": 16, "high": 36, "steps": 5}, - output_feature_name + ".num_fc_layers": {"type": "int", "low": 1, "high": 5, "space": "linear", "steps": 4}, - }, + output_feature_name + ".output_size": {"space": "choice", "categories": [16, 21, 26, 31, 36]}, + output_feature_name + ".num_fc_layers": {"space": "randint", "lower": 1, "upper": 6}, + } + else: + # grid search space will be product each parameter size + search_parameters = { + "trainer.learning_rate": {"space": "grid_search", "values": [0.001, 0.005, 0.01]}, + output_feature_name + ".output_size": {"space": "grid_search", "values": [16, 21, 36]}, + output_feature_name + ".num_fc_layers": {"space": "grid_search", "values": [1, 3, 6]}, + } + + hyperopt_configs = { + "parameters": search_parameters, "goal": "minimize", "output_feature": output_feature_name, "validation_metrics": "loss", - "executor": {"type": "serial"}, - "sampler": {"type": samplers["type"], "num_samples": 2}, + "executor": {"type": "ray", "num_samples": 1 if search_space == "grid" else RANDOM_SEARCH_SIZE}, + "search_alg": {"type": "variant_generator"}, } # add hyperopt parameter space to the config config["hyperopt"] = hyperopt_configs - hyperopt_results = hyperopt(config, dataset=rel_path, output_directory="results_hyperopt") + with ray_start(): + hyperopt_results = hyperopt(config, dataset=rel_path, output_directory="results_hyperopt") + + if search_space == "random": + assert hyperopt_results.experiment_analysis.results_df.shape[0] == RANDOM_SEARCH_SIZE + else: + # compute size of search space for grid search + grid_search_size = 1 + for k, v in search_parameters.items(): + grid_search_size *= len(v["values"]) + assert hyperopt_results.experiment_analysis.results_df.shape[0] == grid_search_size # check for return results assert isinstance(hyperopt_results, HyperoptResults) @@ -187,39 +308,3 @@ def test_hyperopt_run_hyperopt(csv_filename, samplers): if os.path.isfile(os.path.join("results_hyperopt", "hyperopt_statistics.json")): os.remove(os.path.join("results_hyperopt", "hyperopt_statistics.json")) - - -@pytest.mark.distributed -def test_hyperopt_executor_get_metric_score(): - executor = EXECUTORS[0] - output_feature = "of_name" - split = "validation" - - train_stats = { - "training": { - output_feature: { - "loss": [0.58760345, 1.5066891], - "accuracy": [0.6666667, 0.33333334], - "hits_at_k": [1.0, 1.0], - }, - "combined": {"loss": [0.58760345, 1.5066891]}, - }, - "validation": { - output_feature: {"loss": [0.30233705, 2.6505466], "accuracy": [1.0, 0.0], "hits_at_k": [1.0, 1.0]}, - "combined": {"loss": [0.30233705, 2.6505466]}, - }, - "test": { - output_feature: {"loss": [1.0876318, 1.4353828], "accuracy": [0.7, 0.5], "hits_at_k": [1.0, 1.0]}, - "combined": {"loss": [1.0876318, 1.4353828]}, - }, - } - - metric = "loss" - hyperopt_executor = get_build_hyperopt_executor(executor["type"])(None, output_feature, metric, split, **executor) - score = hyperopt_executor.get_metric_score(train_stats) - assert score == 0.30233705 - - metric = "accuracy" - hyperopt_executor = get_build_hyperopt_executor(executor["type"])(None, output_feature, metric, split, **executor) - score = hyperopt_executor.get_metric_score(train_stats) - assert score == 1.0 diff --git a/tests/integration_tests/test_hyperopt_ray.py b/tests/integration_tests/test_hyperopt_ray.py index d29d9ea2dae..460c1101d7c 100644 --- a/tests/integration_tests/test_hyperopt_ray.py +++ b/tests/integration_tests/test_hyperopt_ray.py @@ -58,27 +58,25 @@ } -SAMPLERS = [ - {"type": "ray"}, - {"type": "ray", "num_samples": 2}, +SCENARIOS = [ + {"executor": {"type": "ray"}, "search_alg": {"type": "variant_generator"}}, + {"executor": {"type": "ray", "num_samples": 2}, "search_alg": {"type": "variant_generator"}}, { - "type": "ray", - "search_alg": {"type": "bohb"}, - "scheduler": { - "type": "hb_bohb", - "time_attr": "training_iteration", - "reduction_factor": 4, + "executor": { + "type": "ray", + "num_samples": 3, + "scheduler": { + "type": "hb_bohb", + "time_attr": "training_iteration", + "reduction_factor": 4, + }, }, - "num_samples": 3, + "search_alg": {"type": "bohb"}, }, ] -EXECUTORS = [ - {"type": "ray"}, -] - -def _get_config(sampler, executor): +def _get_config(search_alg, executor): input_features = [ text_feature(name="utterance", cell_type="lstm", reduce_output="sum"), category_feature(vocab_size=2, reduce_input="sum"), @@ -94,7 +92,7 @@ def _get_config(sampler, executor): "hyperopt": { **HYPEROPT_CONFIG, "executor": executor, - "sampler": sampler, + "search_alg": search_alg, }, } @@ -114,14 +112,14 @@ def ray_start_4_cpus(): @spawn def run_hyperopt_executor( - sampler, + search_alg, executor, csv_filename, validate_output_feature=False, validation_metric=None, use_split=True, ): - config = _get_config(sampler, executor) + config = _get_config(search_alg, executor) rel_path = generate_data(config["input_features"], config["output_features"], csv_filename) if not use_split: @@ -141,19 +139,21 @@ def run_hyperopt_executor( update_hyperopt_params_with_defaults(hyperopt_config) parameters = hyperopt_config["parameters"] - if sampler.get("search_alg", {}).get("type", "") == "bohb": + if search_alg.get("type", "") == "bohb": # bohb does not support grid_search search space del parameters["utterance.cell_type"] + hyperopt_config["parameters"] = parameters split = hyperopt_config["split"] output_feature = hyperopt_config["output_feature"] metric = hyperopt_config["metric"] goal = hyperopt_config["goal"] + search_alg = hyperopt_config["search_alg"] - hyperopt_sampler = get_build_hyperopt_sampler(sampler["type"])(goal, parameters, **sampler) + hyperopt_sampler = get_build_hyperopt_sampler("ray")(parameters) hyperopt_executor = get_build_hyperopt_executor(executor["type"])( - hyperopt_sampler, output_feature, metric, split, **executor + hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor ) hyperopt_executor.execute( @@ -164,11 +164,12 @@ def run_hyperopt_executor( @pytest.mark.distributed -@pytest.mark.parametrize("sampler", SAMPLERS) -@pytest.mark.parametrize("executor", EXECUTORS) -def test_hyperopt_executor(sampler, executor, csv_filename): +@pytest.mark.parametrize("scenario", SCENARIOS) +def test_hyperopt_executor(scenario, csv_filename): + search_alg = scenario["search_alg"] + executor = scenario["executor"] with ray_start_4_cpus(): - run_hyperopt_executor(sampler, executor, csv_filename) + run_hyperopt_executor(search_alg, executor, csv_filename) @pytest.mark.distributed @@ -176,8 +177,8 @@ def test_hyperopt_executor(sampler, executor, csv_filename): def test_hyperopt_executor_with_metric(use_split, csv_filename): with ray_start_4_cpus(): run_hyperopt_executor( - {"type": "ray", "num_samples": 2}, - {"type": "ray"}, + {"type": "variant_generator"}, # search_alg + {"type": "ray", "num_samples": 2}, # executor csv_filename, validate_output_feature=True, validation_metric=ACCURACY, @@ -219,8 +220,8 @@ def test_hyperopt_run_hyperopt(csv_filename): "goal": "minimize", "output_feature": output_feature_name, "validation_metrics": "loss", - "executor": {"type": "ray"}, - "sampler": {"type": "ray", "num_samples": 2}, + "executor": {"type": "ray", "num_samples": 2}, + "search_alg": {"type": "variant_generator"}, } # add hyperopt parameter space to the config @@ -236,7 +237,9 @@ def test_hyperopt_ray_mlflow(csv_filename, tmpdir): client = MlflowClient(tracking_uri=mlflow_uri) num_samples = 2 - config = _get_config({"type": "ray", "num_samples": num_samples}, {"type": "ray"}) + config = _get_config( + {"type": "variant_generator"}, {"type": "ray", "num_samples": num_samples} # search_alg # executor + ) rel_path = generate_data(config["input_features"], config["output_features"], csv_filename) diff --git a/tests/ludwig/utils/test_defaults.py b/tests/ludwig/utils/test_defaults.py index 6604a69fb1f..04048557dbb 100644 --- a/tests/ludwig/utils/test_defaults.py +++ b/tests/ludwig/utils/test_defaults.py @@ -43,7 +43,7 @@ ], }, }, - "sampler": {"type": "ray"}, + "search_alg": {"type": "hyperopt"}, "executor": {"type": "ray"}, "goal": "minimize", } @@ -88,7 +88,7 @@ def test_merge_with_defaults_early_stop(use_train, use_hyperopt_scheduler): if use_hyperopt_scheduler: # hyperopt scheduler cannot be used with early stopping - config[HYPEROPT]["sampler"]["scheduler"] = SCHEDULER + config[HYPEROPT]["executor"]["scheduler"] = SCHEDULER merged_config = merge_with_defaults(config) @@ -136,6 +136,11 @@ def test_deprecated_field_aliases(): }, }, "goal": "minimize", + "sampler": {"type": "grid", "num_samples": 2, "scheduler": {"type": "fifo"}}, + "executor": { + "type": "grid", + "search_alg": "bohb", + }, }, } @@ -151,3 +156,9 @@ def test_deprecated_field_aliases(): hparams = merged_config[HYPEROPT]["parameters"] assert "training.learning_rate" not in hparams assert "trainer.learning_rate" in hparams + + assert "sampler" not in merged_config[HYPEROPT] + + assert merged_config[HYPEROPT]["executor"]["type"] == "ray" + assert "num_samples" in merged_config[HYPEROPT]["executor"] + assert "scheduler" in merged_config[HYPEROPT]["executor"] diff --git a/tests/ludwig/utils/test_hyperopt_ray_utils.py b/tests/ludwig/utils/test_hyperopt_ray_utils.py index 94559f51086..fa2db5727fe 100644 --- a/tests/ludwig/utils/test_hyperopt_ray_utils.py +++ b/tests/ludwig/utils/test_hyperopt_ray_utils.py @@ -30,8 +30,6 @@ "combiner.num_fc_layers": {"space": "qrandint", "lower": 3, "upper": 6, "q": 3}, "utterance.cell_type": {"space": "grid_search", "values": ["rnn", "gru", "lstm"]}, }, - "goal": "minimize", - "num_samples": 3, }, "test_2": { "parameters": { @@ -44,8 +42,6 @@ "combiner.num_fc_layers": {"space": "randint", "lower": 2, "upper": 6}, "utterance.cell_type": {"space": "choice", "categories": ["rnn", "gru", "lstm"]}, }, - "goal": "maximize", - "num_samples": 4, }, } @@ -71,11 +67,9 @@ def test_grid_strategy(key): hyperopt_test_params = HYPEROPT_PARAMS[key] expected_search_space = EXPECTED_SEARCH_SPACE[key] - goal = hyperopt_test_params["goal"] - num_samples = hyperopt_test_params["num_samples"] tune_sampler_params = hyperopt_test_params["parameters"] - tune_sampler = RayTuneSampler(goal=goal, parameters=tune_sampler_params, num_samples=num_samples) + tune_sampler = RayTuneSampler(parameters=tune_sampler_params) search_space = tune_sampler.search_space actual_params_keys = search_space.keys() @@ -85,4 +79,3 @@ def test_grid_strategy(key): assert isinstance(search_space[param], type(expected_search_space[param])) assert actual_params_keys == expected_params_keys - assert tune_sampler.num_samples == num_samples diff --git a/tests/ludwig/utils/test_hyperopt_utils.py b/tests/ludwig/utils/test_hyperopt_utils.py deleted file mode 100644 index 64167aad35b..00000000000 --- a/tests/ludwig/utils/test_hyperopt_utils.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) 2019 Uber Technologies, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import pytest - -from ludwig.hyperopt.sampling import GridSampler, PySOTSampler, RandomSampler - -HYPEROPT_PARAMS = { - "test_1": { - "parameters": { - "trainer.learning_rate": {"type": "float", "low": 0.0001, "high": 0.1, "steps": 4, "space": "log"}, - "combiner.num_fc_layers": {"type": "int", "low": 1, "high": 4}, - "utterance.cell_type": {"type": "category", "values": ["rnn", "gru", "lstm"]}, - }, - "expected_search_space": { - "trainer.learning_rate": [0.0001, 0.001, 0.01, 0.1], - "combiner.num_fc_layers": [1, 2, 3, 4], - "utterance.cell_type": ["rnn", "gru", "lstm"], - }, - "goal": "minimize", - "expected_len_grids": 48, - "num_samples": 10, - }, - "test_2": { - "parameters": { - "trainer.learning_rate": {"type": "float", "low": 0.001, "high": 0.1, "steps": 4, "space": "linear"}, - "combiner.num_fc_layers": {"type": "int", "low": 2, "high": 6, "steps": 3}, - }, - "expected_search_space": { - "trainer.learning_rate": [0.001, 0.034, 0.067, 0.1], - "combiner.num_fc_layers": [2, 4, 6], - }, - "goal": "maximize", - "expected_len_grids": 12, - "num_samples": 5, - }, -} - - -@pytest.mark.parametrize("key", ["test_1", "test_2"]) -def test_grid_strategy(key): - hyperopt_test_params = HYPEROPT_PARAMS[key] - goal = hyperopt_test_params["goal"] - grid_sampler_params = hyperopt_test_params["parameters"] - - grid_sampler = GridSampler(goal=goal, parameters=grid_sampler_params) - - actual_params_keys = grid_sampler.sample().keys() - expected_params_keys = grid_sampler_params.keys() - - for sample in grid_sampler.samples: - for param in actual_params_keys: - value = sample[param] - param_type = grid_sampler_params[param]["type"] - if param_type == "int" or param_type == "float": - low = grid_sampler_params[param]["low"] - high = grid_sampler_params[param]["high"] - assert value >= low and value <= high - else: - assert value in set(grid_sampler_params[param]["values"]) - - assert actual_params_keys == expected_params_keys - assert grid_sampler.search_space == hyperopt_test_params["expected_search_space"] - assert len(grid_sampler.samples) == hyperopt_test_params["expected_len_grids"] - - -@pytest.mark.parametrize("key", ["test_1", "test_2"]) -def test_random_sampler(key): - hyperopt_test_params = HYPEROPT_PARAMS[key] - goal = hyperopt_test_params["goal"] - random_sampler_params = hyperopt_test_params["parameters"] - num_samples = hyperopt_test_params["num_samples"] - - random_sampler = RandomSampler(goal=goal, parameters=random_sampler_params, num_samples=num_samples) - - actual_params_keys = random_sampler.sample().keys() - expected_params_keys = random_sampler_params.keys() - - for sample in random_sampler.samples: - for param in actual_params_keys: - value = sample[param] - param_type = random_sampler_params[param]["type"] - if param_type == "int" or param_type == "float": - low = random_sampler_params[param]["low"] - high = random_sampler_params[param]["high"] - assert value >= low and value <= high - else: - assert value in set(random_sampler_params[param]["values"]) - - assert actual_params_keys == expected_params_keys - assert len(random_sampler.samples) == num_samples - - -@pytest.mark.parametrize("key", ["test_1", "test_2"]) -def test_pysot_sampler(key): - hyperopt_test_params = HYPEROPT_PARAMS[key] - goal = hyperopt_test_params["goal"] - pysot_sampler_params = hyperopt_test_params["parameters"] - num_samples = hyperopt_test_params["num_samples"] - - pysot_sampler = PySOTSampler(goal=goal, parameters=pysot_sampler_params, num_samples=num_samples) - - actual_params_keys = pysot_sampler.sample().keys() - expected_params_keys = pysot_sampler_params.keys() - - pysot_sampler_samples = 1 - - for _ in range(num_samples - 1): - sample = pysot_sampler.sample() - for param in actual_params_keys: - value = sample[param] - param_type = pysot_sampler_params[param]["type"] - if param_type == "int" or param_type == "float": - low = pysot_sampler_params[param]["low"] - high = pysot_sampler_params[param]["high"] - assert value >= low and value <= high - else: - assert value in set(pysot_sampler_params[param]["values"]) - pysot_sampler_samples += 1 - - assert actual_params_keys == expected_params_keys - assert pysot_sampler_samples == num_samples