From 97c9963579129a8480a50e84cd8a33a137a52ecb Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 18 Sep 2024 17:18:29 -0400 Subject: [PATCH] Fix auto eval parsing (#358) * Fix auto eval parsing * Add some test cases --- open_instruct/test_utils.py | 6 ++++++ open_instruct/utils.py | 10 +++++++--- ..._beaker_dataset_model_upload_then_evaluate_model.py | 2 -- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 49f7a3996..291ac701d 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -16,6 +16,7 @@ import unittest import pytest +from dateutil import parser from open_instruct.utils import get_datasets @@ -76,6 +77,11 @@ def test_loading_preference_data(self): pref_datasets = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["chosen", "rejected"]) self.assertEqual(len(pref_datasets["train"]), 2000) + def test_time_parser_used_in_get_beaker_dataset_ids(self): + # two special cases which beaker uses + self.assertTrue(parser.parse("2024-09-16T19:03:02.31502Z")) + self.assertTrue(parser.parse("0001-01-01T00:00:00Z")) + # useful for checking if public datasets are still available # class CheckTuluDatasetsTest(unittest.TestCase): diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 8c608463e..72083d838 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -22,13 +22,13 @@ import sys import time from dataclasses import dataclass -from datetime import datetime from typing import Any, List, NewType, Optional, Tuple, Union import requests from accelerate.logging import get_logger from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from datasets.builder import DatasetGenerationError +from dateutil import parser from huggingface_hub import HfApi from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser @@ -734,14 +734,18 @@ def get_beaker_dataset_ids(experiment_id: str, sort=False) -> Optional[List[str] dataset_infos.extend( [ DatasetInfo( - id=dataset["id"], committed=dataset["committed"], non_empty=False if dataset["storage"]["totalSize"] is None else dataset["storage"]["totalSize"] > 0 + id=dataset["id"], + committed=dataset["committed"], + non_empty=( + False if dataset["storage"]["totalSize"] is None else dataset["storage"]["totalSize"] > 0 + ), ) for dataset in datasets ] ) if sort: # sort based on empty, then commited - dataset_infos.sort(key=lambda x: (x.non_empty, datetime.strptime(x.committed, "%Y-%m-%dT%H:%M:%S.%fZ"))) + dataset_infos.sort(key=lambda x: (x.non_empty, parser.parse(x.committed))) print(dataset_infos) return [dataset.id for dataset in dataset_infos] diff --git a/scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py b/scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py index fadd408ea..c29ff28e5 100644 --- a/scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py +++ b/scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py @@ -25,8 +25,6 @@ class Args: def main(args: Args, beaker_runtime_config: BeakerRuntimeConfig): print(args) - beaker_dataset_ids = get_beaker_dataset_ids(beaker_runtime_config.beaker_workload_id, sort=True) - print(beaker_experiment_succeeded(beaker_runtime_config.beaker_workload_id)) start_time = time.time() while time.time() - start_time < args.max_wait_time_for_beaker_dataset_upload_seconds: