Skip to content

Commit

Permalink
Relax hf hub pin (#1314)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jul 1, 2024
1 parent cece2b5 commit 62165de
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
4 changes: 3 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
),
)
SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet']
HUGGINGFACE_FOLDER_EXTENSIONS = ['.lock', '.metadata']

PromptResponseDict = Mapping[str, str]
ChatFormattedDict = Mapping[str, List[Dict[str, str]]]
Expand Down Expand Up @@ -886,7 +887,8 @@ def build_from_hf(
f for _, _, files in os.walk(dataset_name) for f in files
]
if not all(
Path(f).suffix in SUPPORTED_EXTENSIONS
Path(f).suffix in SUPPORTED_EXTENSIONS +
HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore'
for f in dataset_files
):
raise InvalidFileExtensionError(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
'onnx==1.14.0',
'onnxruntime==1.15.1',
'boto3>=1.21.45,<2',
'huggingface-hub>=0.19.0,<0.23',
'huggingface-hub>=0.19.0,<0.24',
'beautifulsoup4>=4.12.2,<5', # required for model download utils
'tenacity>=8.2.3,<9',
'catalogue>=2,<3',
Expand Down
6 changes: 4 additions & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from llmfoundry.data.finetuning.tasks import (
DOWNLOADED_FT_DATASETS_DIRPATH,
HUGGINGFACE_FOLDER_EXTENSIONS,
SUPPORTED_EXTENSIONS,
dataset_constructor,
is_valid_ift_example,
Expand Down Expand Up @@ -471,14 +472,15 @@ def test_finetuning_dataloader_safe_load(
)

# If no raised errors, we should expect downloaded files with only safe file types.
if expectation == does_not_raise():
if isinstance(expectation, does_not_raise):
download_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_name)
downloaded_files = [
file for _, _, files in os.walk(download_dir) for file in files
]
assert len(downloaded_files) > 0
assert all(
Path(file).suffix in SUPPORTED_EXTENSIONS
Path(file).suffix in SUPPORTED_EXTENSIONS +
HUGGINGFACE_FOLDER_EXTENSIONS or file == '.gitignore'
for file in downloaded_files
)

Expand Down

0 comments on commit 62165de

Please sign in to comment.