diff --git a/CHANGELOG.md b/CHANGELOG.md index b73eeae96..8e2de4534 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`. - Added support for flash attention and gradient checkpointing to `hf_olmo`. +- Added an eval-only script that evaluates existing checkpoints on specified tasks. - Added `effective_n_kv_heads` to OLMoConfig for hacky VLLM support. + ## [v0.5.0](https://github.com/allenai/OLMo/releases/tag/v0.5.0) - 2024-08-26 - Fixed conversion to HuggingFace model for DDP-trained models. @@ -45,7 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Swapped in correct flan data mix. - Fix bug where the attention norm, when applied before the attention block, was modifying the residual stream. - Fixed `OLMo.from_checkpoint()` so that it correctly loads `olmo_core` and `torch_new` style checkpoints. -- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout +- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout ## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11 diff --git a/configs/peteish1-weka.yaml b/configs/peteish1-weka.yaml index 071c5399b..896500244 100644 --- a/configs/peteish1-weka.yaml +++ b/configs/peteish1-weka.yaml @@ -108,35 +108,35 @@ eval_interval: 1000 eval_subset_num_batches: -1 device_eval_batch_size: ${device_train_microbatch_size} evaluators: - # - label: all-small-ppl-validation - # data: - # num_workers: 0 - # drop_last: true - # # generate_doc_lengths: true - # memmap_dtype: uint32 - # datasets: - # c4_en-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy - # dolma_books-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_books/val/part-0-00000.npy - # dolma_common-crawl-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_common-crawl/val/part-0-00000.npy - # dolma_pes2o-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_pes2o/val/part-0-00000.npy - # dolma_reddit-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_reddit/val/part-0-00000.npy - # dolma_stack-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_stack/val/part-0-00000.npy - # dolma_wiki-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_wiki/val/part-0-00000.npy - # ice-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/ice/val/part-0-00000.npy - # m2d2_s2orc-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/m2d2_s2orc/val/part-0-00000.npy - # pile-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/pile/val/part-0-00000.npy - # wikitext_103-validation: - # - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/wikitext_103/val/part-0-00000.npy + - label: all-small-ppl-validation + data: + num_workers: 0 + drop_last: true + # generate_doc_lengths: true + memmap_dtype: uint32 + datasets: + c4_en-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy + dolma_books-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_books/val/part-0-00000.npy + dolma_common-crawl-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_common-crawl/val/part-0-00000.npy + dolma_pes2o-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_pes2o/val/part-0-00000.npy + dolma_reddit-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_reddit/val/part-0-00000.npy + dolma_stack-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_stack/val/part-0-00000.npy + dolma_wiki-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/dolma_wiki/val/part-0-00000.npy + ice-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/ice/val/part-0-00000.npy + m2d2_s2orc-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/m2d2_s2orc/val/part-0-00000.npy + pile-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/pile/val/part-0-00000.npy + wikitext_103-validation: + - /weka/oe-training-default/ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/wikitext_103/val/part-0-00000.npy ########################## # Downstream evaluations # @@ -155,7 +155,7 @@ evaluators: - label: boolq type: downstream - + - label: sciq type: downstream @@ -231,6 +231,228 @@ evaluators: - label: arc_easy_ppl type: downstream + - label: piqa_rc_0shot + type: downstream + + - label: piqa_rc_0shot_bpb + type: downstream + + - label: piqa_rc_5shot + type: downstream + + - label: piqa_rc_5shot_bpb + type: downstream + + - label: piqa_mc_5shot + type: downstream + + - label: piqa_mc_5shot_bpb + type: downstream + + - label: hellaswag_rc_0shot + type: downstream + + - label: hellaswag_rc_0shot_bpb + type: downstream + + - label: hellaswag_rc_5shot + type: downstream + + - label: hellaswag_rc_5shot_bpb + type: downstream + + - label: hellaswag_mc_5shot + type: downstream + + - label: hellaswag_mc_5shot_bpb + type: downstream + + - label: winogrande_rc_0shot + type: downstream + + - label: winogrande_rc_0shot_bpb + type: downstream + + - label: winogrande_rc_5shot + type: downstream + + - label: winogrande_rc_5shot_bpb + type: downstream + + - label: winogrande_mc_5shot + type: downstream + + - label: winogrande_mc_5shot_bpb + type: downstream + + - label: openbookqa_rc_0shot + type: downstream + + - label: openbookqa_rc_0shot_bpb + type: downstream + + - label: openbookqa_rc_5shot + type: downstream + + - label: openbookqa_rc_5shot_bpb + type: downstream + + - label: openbookqa_mc_5shot + type: downstream + + - label: openbookqa_mc_5shot_bpb + type: downstream + + - label: boolq_rc_0shot + type: downstream + + - label: boolq_rc_0shot_bpb + type: downstream + + - label: boolq_rc_5shot + type: downstream + + - label: boolq_rc_5shot_bpb + type: downstream + + - label: boolq_mc_5shot + type: downstream + + - label: boolq_mc_5shot_bpb + type: downstream + + - label: sciq_rc_0shot + type: downstream + + - label: sciq_rc_0shot_bpb + type: downstream + + # - label: sciq_rc_5shot + # type: downstream + + # - label: sciq_rc_5shot_bpb + # type: downstream + + # - label: sciq_mc_5shot + # type: downstream + + # - label: sciq_mc_5shot_bpb + # type: downstream + + - label: arc_easy_rc_0shot + type: downstream + + - label: arc_easy_rc_0shot_bpb + type: downstream + + - label: arc_easy_rc_5shot + type: downstream + + - label: arc_easy_rc_5shot_bpb + type: downstream + + - label: arc_easy_mc_5shot + type: downstream + + - label: arc_easy_mc_5shot_bpb + type: downstream + + - label: arc_challenge_rc_0shot + type: downstream + + - label: arc_challenge_rc_0shot_bpb + type: downstream + + - label: arc_challenge_rc_5shot + type: downstream + + - label: arc_challenge_rc_5shot_bpb + type: downstream + + - label: arc_challenge_mc_5shot + type: downstream + + - label: arc_challenge_mc_5shot_bpb + type: downstream + + - label: copa_rc_0shot + type: downstream + + - label: copa_rc_0shot_bpb + type: downstream + + # - label: copa_rc_5shot + # type: downstream + + # - label: copa_rc_5shot_bpb + # type: downstream + + # - label: copa_mc_5shot + # type: downstream + + # - label: copa_mc_5shot_bpb + # type: downstream + + - label: csqa_rc_0shot + type: downstream + + - label: csqa_rc_0shot_bpb + type: downstream + + - label: csqa_rc_5shot + type: downstream + + - label: csqa_rc_5shot_bpb + type: downstream + + - label: csqa_mc_5shot + type: downstream + + - label: csqa_mc_5shot_bpb + type: downstream + + - label: socialiqa_rc_0shot + type: downstream + + - label: socialiqa_rc_0shot_bpb + type: downstream + + - label: socialiqa_rc_5shot + type: downstream + + - label: socialiqa_rc_5shot_bpb + type: downstream + + - label: socialiqa_mc_5shot + type: downstream + + - label: socialiqa_mc_5shot_bpb + type: downstream + + - label: mmlu_stem_var_bpb + type: downstream + + - label: mmlu_humanities_var_bpb + type: downstream + + - label: mmlu_social_sciences_var_bpb + type: downstream + + - label: mmlu_other_var_bpb + type: downstream + + - label: mmlu_stem_bpb + type: downstream + + - label: mmlu_humanities_bpb + type: downstream + + - label: mmlu_social_sciences_bpb + type: downstream + + - label: mmlu_other_bpb + type: downstream + data: pad_direction: right # generate_doc_lengths: true diff --git a/configs/peteish7-weka.yaml b/configs/peteish7-weka.yaml index a7dc9d66c..5980de319 100644 --- a/configs/peteish7-weka.yaml +++ b/configs/peteish7-weka.yaml @@ -154,7 +154,7 @@ evaluators: - label: boolq type: downstream - + - label: sciq type: downstream diff --git a/olmo/train.py b/olmo/train.py index 341055003..b7f778bfb 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1368,8 +1368,8 @@ def close(self, exit_code: int = 0) -> None: gc.enable() else: gc.disable() - if wandb.run is not None: - wandb.finish(exit_code=exit_code, quiet=True) + # if wandb.run is not None: + # wandb.finish(exit_code=exit_code, quiet=True) def __enter__(self) -> Trainer: return self diff --git a/scripts/beaker/peteish/peteish1-eval-launch.sh b/scripts/beaker/peteish/peteish1-eval-launch.sh new file mode 100644 index 000000000..09ee396cc --- /dev/null +++ b/scripts/beaker/peteish/peteish1-eval-launch.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -ex + +NUM_NODES=16 + +gantry run \ + --allow-dirty \ + --workspace ai2/OLMo-pretraining-stability \ + --task-name peteish1-eval \ + --description "Pete-ish 1B eval" \ + --priority high \ + --preemptible \ + --beaker-image petew/olmo-torch23-gantry \ + --cluster ai2/jupiter-cirrascale-2 \ + --gpus 8 \ + --replicas "${NUM_NODES}" \ + --leader-selection \ + --host-networking \ + --propagate-failure \ + --propagate-preemption \ + --synchronized-start-timeout 90m \ + --budget ai2/oe-training \ + --no-nfs \ + --weka oe-training-default:/weka/oe-training-default \ + --no-python \ + --env LOG_FILTER_TYPE=local_rank0_only \ + --env OMP_NUM_THREADS=8 \ + --env OLMO_TASK=model \ + --env R2_PROFILE=R2 \ + --env S3_PROFILE=S3 \ + --env WEKA_PROFILE=WEKA \ + --env-secret AWS_CONFIG=PETEW_AWS_CONFIG \ + --env-secret AWS_CREDENTIALS=PETEW_AWS_CREDENTIALS \ + --env-secret R2_ENDPOINT_URL=R2_ENDPOINT_URL \ + --env-secret WEKA_ENDPOINT_URL=WEKA_ENDPOINT_URL \ + --env-secret WANDB_API_KEY=JIACHENGL_WANDB_API_KEY \ + --shared-memory 10GiB \ + --yes \ + --timeout=-1 \ + -- /bin/bash -c "scripts/beaker/peteish/peteish1-eval.sh \$BEAKER_LEADER_REPLICA_HOSTNAME ${NUM_NODES} \$BEAKER_REPLICA_RANK" \ No newline at end of file diff --git a/scripts/beaker/peteish/peteish1-eval.sh b/scripts/beaker/peteish/peteish1-eval.sh new file mode 100755 index 000000000..954043b0d --- /dev/null +++ b/scripts/beaker/peteish/peteish1-eval.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +set -exuo pipefail +IFS=$'\n\t' + +BEAKER_LEADER_REPLICA_HOSTNAME=$1 +shift + +NUM_NODES=$1 +shift + +BEAKER_REPLICA_RANK=$1 +shift + +# Setup Python environment. +conda shell.bash activate base + +# Install flash-attn +#conda install -y -c nvidia cuda-python +pip install packaging ninja +export FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE +pip install flash-attn==2.5.9.post1 --no-build-isolation +# pip install awscli +pip install '.[train]' +pip freeze + +# Move AWS credentials from env to relevant files +mkdir -p ~/.aws +printenv AWS_CONFIG > ~/.aws/config +printenv AWS_CREDENTIALS > ~/.aws/credentials + +# Force processes to synchronize at init_process_group +export TORCH_DIST_INIT_BARRIER=1 + +# Tell OLMo all ranks share the same filesystem for checkpoints. +export OLMO_SHARED_FS=1 + +export NCCL_DEBUG=INFO +export NCCL_IB_HCA="^=mlx5_bond_0" +export NCCL_SOCKET_IFNAME=ib +# export NCCL_IB_GID_INDEX=0 + +torchrun \ + --nnodes "${NUM_NODES}:${NUM_NODES}" \ + --nproc-per-node 8 \ + --rdzv_id 12347 \ + --rdzv_backend static \ + --rdzv_endpoint "${BEAKER_LEADER_REPLICA_HOSTNAME}:29400" \ + --node_rank "${BEAKER_REPLICA_RANK}" \ + --rdzv_conf 'read_timeout=420' \ + scripts/eval.py \ + configs/peteish1-weka.yaml \ + --run_name="${GANTRY_TASK_NAME}" \ + --save_interval_ephemeral=null \ + --save_overwrite \ + --wandb.group="peteish1" \ + --load_path="/weka/oe-training-default/ai2-llm/checkpoints/OLMo-small/peteish1" diff --git a/scripts/eval.py b/scripts/eval.py new file mode 100644 index 000000000..39990ed67 --- /dev/null +++ b/scripts/eval.py @@ -0,0 +1,286 @@ +"""Run this script with 'torchrun'.""" + +import logging +import sys +from datetime import timedelta +from pathlib import Path +from typing import Optional, TextIO +import glob + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import wandb +from packaging import version +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.nn.parallel import DistributedDataParallel as DDP + +from olmo.config import ( + DDPGradSyncMode, + DistributedStrategy, + TrainConfig, +) +from olmo.eval import build_evaluators +from olmo.exceptions import OLMoCliError, OLMoConfigurationError +from olmo.model import OLMo +from olmo.optim import build_optimizer, build_scheduler +from olmo.torch_util import ( + barrier, + get_default_device, + get_global_rank, + get_local_rank, + get_local_world_size, + get_world_size, + peak_gpu_memory, + seed_all, +) +from olmo.train import Trainer +from olmo.util import ( + add_cached_path_clients, + clean_opt, + log_extra_field, + prepare_cli_environment, +) + +log = logging.getLogger("train") + + +def main(cfg: TrainConfig) -> None: + # Ensure run name set. + if cfg.run_name is None: + raise OLMoConfigurationError("--run_name is required") + log_extra_field("run_name", cfg.run_name) + + # Sanity check + if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None: + log.warning( + "You want to reset the optimizer or trainer state, but we're not loading from the checkpoint. The" + "setting has no effect." + ) + + barrier() + + device = torch.device("cuda") + + # Fill some configuration options. + cfg.model.precision = cfg.precision + cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size() + assert cfg.device_train_batch_size is not None # for mypy + cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size + if cfg.optimizer.no_decay_norm_and_bias is not None: + log.warning( + "You set the deprecated config option `no_decay_norm_and_bias`. For compatibility, this" + "setting will take precedence over all other weight decay configurations. Please change" + "your config to use `decay_norm_and_bias` and `decay_embeddings` instead." + ) + cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias + cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias + cfg.optimizer.no_decay_norm_and_bias = None # So nobody uses this by accident. + + # Display and save configuration. + if get_global_rank() == 0: + if cfg.data.paths is not None and len(cfg.data.paths) < 50: + log.info("Configuration:") + log.info(cfg) + if not cfg.dry_run and (cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)): + # Save config. + save_path = Path(cfg.save_folder) / "config.yaml" + if save_path.is_file() and not cfg.save_overwrite: + raise OLMoConfigurationError(f"{save_path} already exists, use --save_overwrite to overwrite") + else: + log.info(f"Saving config to {save_path}") + save_path.parent.mkdir(exist_ok=True, parents=True) + cfg.save(save_path) + del save_path + + barrier() + + # Maybe start W&B run. + if cfg.wandb is not None and (get_global_rank() == 0 or not cfg.wandb.rank_zero_only): + wandb_dir = Path(cfg.save_folder) / "wandb" + wandb_dir.mkdir(parents=True, exist_ok=True) + wandb.init( + dir=wandb_dir, + project=cfg.wandb.project, + entity=cfg.wandb.entity, + group=cfg.wandb.group, + name=cfg.wandb.name, + tags=cfg.wandb.tags, + config=cfg.asdict(exclude=["wandb"]), + ) + + barrier() + + # Set seed. + seed_all(cfg.seed) + + # # Construct data loader. + # train_loader = build_train_dataloader(cfg) + train_loader = None + + # Construct evaluators. + evaluators = build_evaluators(cfg, device) + barrier() + + if cfg.load_path is None: + raise OLMoConfigurationError("To run eval you must provide a load_path") + if 'step' in cfg.load_path.split('/')[-1]: + load_paths = [cfg.load_path] + else: + # This globbing does not work with remote paths. + load_paths = list(sorted(glob.glob(f"{cfg.load_path}/step*"), key=lambda x: int(x.split('/')[-1].split('step')[-1]))) + + for load_path in load_paths: + step = int(load_path.split('/')[-1].split('step')[-1]) + + # Initialize the model. + log.info("Building model...") + olmo_model = OLMo(cfg.model) + log.info(f"Total number of parameters: {olmo_model.num_params():,d}") + log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}") + log.info(f"Peak GPU Memory (MB) before {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}") + + olmo_model.set_activation_checkpointing(cfg.activation_checkpointing) + + if cfg.distributed_strategy == DistributedStrategy.ddp: + log.info("Wrapping model with DDP...") + assert cfg.ddp is not None, "DistributedStrategy ddp needs cfg.ddp to be set!" + + if cfg.model.init_device != "cuda": + raise OLMoConfigurationError("DDP does not work with init_device set to anything other than `cuda`.") + + if cfg.ddp.find_unused_params is True and cfg.ddp.grad_sync_mode != DDPGradSyncMode.micro_batch: + raise OLMoConfigurationError( + "`find_unused_params` is set to True. DDP needs to synchronize gradients for every micro-batch to avoid errors. Set `grad_sync_mode` to `micro_batch`." + ) + + param_init_fn = None + + # move to cuda before calling ddp + dist_model = DDP(olmo_model.to(device), find_unused_parameters=cfg.ddp.find_unused_params) + elif cfg.distributed_strategy == DistributedStrategy.fsdp: + # Wrap the model in FSDP. + log.info("Wrapping model with FSDP...") + assert cfg.fsdp is not None, "DistributedStrategy fsdp needs cfg.fsdp to be set!" + wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy) + + if version.parse(torch.__version__) >= version.parse("2.1.0"): + # This prevents any parameters from being initialized twice + def dummy_init_fn(module: torch.nn.Module) -> None: + module.to_empty(device=get_default_device()) + + param_init_fn = dummy_init_fn + else: + param_init_fn = None + + # Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica + device_mesh = None + hybrid_sharding_fsdp_kwargs = {} + if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2): + if version.parse(torch.__version__) < version.parse("2.2.0"): + # Device mesh was not added to PyTorch until v2.2.0 + raise OLMoConfigurationError( + "OLMo training does not correctly support hybrid sharding before torch 2.2.0" + ) + + from torch.distributed.device_mesh import init_device_mesh + + num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or ( + get_world_size() // get_local_world_size() + ) + + if num_model_replicas <= 0: + raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer") + + if get_world_size() % num_model_replicas != 0: + raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide world size") + + device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas)) + hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh + + dist_model = FSDP( + olmo_model, + sharding_strategy=cfg.fsdp.sharding_strategy, + mixed_precision=cfg.fsdp_precision, + auto_wrap_policy=wrap_policy, + use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics + limit_all_gathers=True, + device_id=get_local_rank(), + param_init_fn=param_init_fn, + **hybrid_sharding_fsdp_kwargs, + ) + elif cfg.distributed_strategy is None: + raise NotImplementedError("Single accelerator training not implemented yet!") + + # when param_init_fn is None, FSDP will call reset_parameters() automatically + if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp: + olmo_model.reset_parameters() + + log.info(f"Peak GPU Memory (MB) after {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}") + log.info("Model:") + log.info(dist_model) + + # Construct optimizer and learning rate scheduler. + optim = build_optimizer(cfg, dist_model) + scheduler = build_scheduler(cfg) + + # Data indices file. + indices_file: Optional[TextIO] = None + + # Consolidate components into `Trainer` object. + with Trainer( + cfg=cfg, + epoch=cfg.epoch, + model=olmo_model, + dist_model=dist_model, + optim=optim, + scheduler=scheduler, + train_loader=train_loader, + device=device, + evaluators=evaluators, + indices_file=indices_file, + ) as trainer: + + log.info(f"Loading checkpoint from {load_path}...") + trainer.restore_checkpoint( + load_path, + load_optimizer_state=False, + load_trainer_state=False, + sharded_checkpointer=cfg.load_path_sharded_checkpointer, + ) + log.info("Checkpoint successfully loaded") + + log.info("Starting evaluating...") + eval_metrics = trainer.eval() + if wandb.run is not None: + wandb.log(eval_metrics, step=step) + log.info("Evaluating complete") + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn", force=True) + except RuntimeError as e: + print(f"failed to set multiprocessing start method: {e}") + log.info(f"Multiprocessing start method set to '{mp.get_start_method()}'") + + # Set CUDA device. + torch.cuda.set_device(f"cuda:{get_local_rank()}") + + # Initialize process group. + dist.init_process_group(backend="nccl", timeout=timedelta(minutes=30)) + log.info("Process group initialized") + + prepare_cli_environment() + log.info("CLI environment prepared") + + add_cached_path_clients() + + try: + yaml_path, args_list = sys.argv[1], sys.argv[2:] + except IndexError: + raise OLMoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]") + + cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list]) + main(cfg)