Skip to content

Commit

Permalink
Fix auto eval with preemptible jobs (#353)
Browse files Browse the repository at this point in the history
* push changes

* fix auto eval 7

* quick fix

* revert changes
  • Loading branch information
vwxyzjn authored Sep 18, 2024
1 parent 0673f21 commit e26c023
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
28 changes: 24 additions & 4 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Any, List, NewType, Optional, Tuple, Union

import requests
Expand Down Expand Up @@ -708,12 +709,20 @@ def beaker_experiment_succeeded(experiment_id: str) -> bool:
)


def get_beaker_dataset_ids(experiment_id: str) -> Optional[List[str]]:
@dataclass
class DatasetInfo:
id: str
committed: Any
non_empty: bool


def get_beaker_dataset_ids(experiment_id: str, sort=False) -> Optional[List[str]]:
"""if sort is True, the non-empty latest dataset will be availble at the end of the list"""
experiment = get_beaker_experiment_info(experiment_id)
if not experiment:
return None
result_ids = [job["result"]["beaker"] for job in experiment["jobs"]]
dataset_ids = []
dataset_infos = []
for result_id in result_ids:
get_dataset_command = f"beaker dataset get {result_id} --format json"
process = subprocess.Popen(["bash", "-c", get_dataset_command], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Expand All @@ -722,8 +731,19 @@ def get_beaker_dataset_ids(experiment_id: str) -> Optional[List[str]]:
print(f"Failed to get Beaker dataset: {stderr}")
return None
datasets = json.loads(stdout)
dataset_ids.extend([dataset["id"] for dataset in datasets])
return dataset_ids
dataset_infos.extend(
[
DatasetInfo(
id=dataset["id"], committed=dataset["committed"], non_empty=dataset["storage"]["totalSize"] > 0
)
for dataset in datasets
]
)
if sort:
# sort based on empty, then commited
dataset_infos.sort(key=lambda x: (x.non_empty, datetime.strptime(x.committed, "%Y-%m-%dT%H:%M:%S.%fZ")))
print(dataset_infos)
return [dataset.id for dataset in dataset_infos]


def get_beaker_whoami() -> Optional[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Args:

def main(args: Args, beaker_runtime_config: BeakerRuntimeConfig):
print(args)
beaker_dataset_ids = get_beaker_dataset_ids(beaker_runtime_config.beaker_workload_id)
beaker_dataset_ids = get_beaker_dataset_ids(beaker_runtime_config.beaker_workload_id, sort=True)
print(beaker_experiment_succeeded(beaker_runtime_config.beaker_workload_id))

start_time = time.time()
Expand All @@ -35,9 +35,10 @@ def main(args: Args, beaker_runtime_config: BeakerRuntimeConfig):
# NOTE: we are assuming the first beaker dataset has the model
# I have checked a couple of beaker jobs and found the first dataset is the model
# but we should check this assumption
beaker_dataset_ids = get_beaker_dataset_ids(beaker_runtime_config.beaker_workload_id, sort=True)
submit_beaker_eval_jobs(
model_name=args.model_name,
location=beaker_dataset_ids[0],
location=beaker_dataset_ids[-1],
run_oe_eval_experiments=True,
run_safety_evaluations=True,
skip_oi_evals=True,
Expand Down

0 comments on commit e26c023

Please sign in to comment.