Skip to content

Commit

Permalink
Add a function to instantiate a DatasetReference from a dataset dir
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686902586
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Oct 17, 2024
1 parent a15ea31 commit c388d5f
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 1 deletion.
41 changes: 40 additions & 1 deletion tensorflow_datasets/core/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def from_tfds_name(
) -> DatasetReference:
"""Returns the `DatasetReference` for the given TFDS dataset."""
parsed_name, builder_kwargs = parse_builder_name_kwargs(tfds_name)
version, config = None, None
version = builder_kwargs.get('version')
config = builder_kwargs.get('config')
return cls(
Expand All @@ -304,6 +303,46 @@ def from_tfds_name(
data_dir=data_dir,
)

@classmethod
def from_path(
cls,
dataset_dir: epath.PathLike,
root_data_dir: epath.PathLike,
) -> DatasetReference:
"""Returns the `DatasetReference` for the given dataset directory.
Args:
dataset_dir: The path to the dataset directory, e.g.,
`/data/my_dataset/my_config/1.2.3`.
root_data_dir: The root data directory, e.g., `/data`.
"""
dataset_dir = os.fspath(dataset_dir)
root_data_dir = os.fspath(root_data_dir)

if not dataset_dir.startswith(root_data_dir):
raise ValueError(f'{dataset_dir=} does not start with {root_data_dir=}!')

relative_path = dataset_dir.removeprefix(root_data_dir)
relative_path = relative_path.removeprefix('/').removesuffix('/')
parts = relative_path.split('/')
dataset_name = parts[0]
if len(parts) == 2:
config_name = None
version = parts[1]
elif len(parts) == 3:
config_name = parts[1]
version = parts[2]
else:
raise ValueError(
f'Invalid {relative_path=} for {root_data_dir=} and {dataset_dir=}'
)
return cls(
dataset_name=dataset_name,
config=config_name,
version=version,
data_dir=root_data_dir.removesuffix('/'),
)


def references_for(
name_to_tfds_name: Mapping[str, str],
Expand Down
94 changes: 94 additions & 0 deletions tensorflow_datasets/core/naming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,100 @@ def test_dataset_reference_from_tfds_name(
)


@pytest.mark.parametrize(
('dataset_dir', 'root_data_dir', 'expected'),
[
# Dataset with a config and a version.
(
'/data/ds/config/1.2.3',
'/data',
naming.DatasetReference(
dataset_name='ds',
version='1.2.3',
config='config',
data_dir='/data',
),
),
# Dataset with no config and a version.
(
'/data/ds/1.2.3',
'/data',
naming.DatasetReference(
dataset_name='ds',
version='1.2.3',
config=None,
data_dir='/data',
),
),
# Dataset dir with trailing slash.
(
'/data/ds/config/1.2.3/',
'/data',
naming.DatasetReference(
dataset_name='ds',
version='1.2.3',
config='config',
data_dir='/data',
),
),
# Root data dir with trailing slash.
(
'/data/ds/config/1.2.3',
'/data/',
naming.DatasetReference(
dataset_name='ds',
version='1.2.3',
config='config',
data_dir='/data',
),
),
# Dataset dir and root data dir with trailing slash.
(
'/data/ds/config/1.2.3/',
'/data/',
naming.DatasetReference(
dataset_name='ds',
version='1.2.3',
config='config',
data_dir='/data',
),
),
],
)
def test_dataset_reference_from_path(dataset_dir, root_data_dir, expected):
actual = naming.DatasetReference.from_path(
dataset_dir=dataset_dir, root_data_dir=root_data_dir
)
assert actual == expected


@pytest.mark.parametrize(
('dataset_dir', 'root_data_dir'),
[
# Root data dir is not a prefix of the dataset dir.
(
'/data/ds/config/1.2.3',
'/somewhere_else',
),
# Too many nested folders.
(
'/data/ds/config/another_folder/1.2.3',
'/data',
),
# Too few nested folders.
(
'/data/ds/',
'/data',
),
],
)
def test_dataset_reference_from_path_invalid(dataset_dir, root_data_dir):
with pytest.raises(ValueError):
naming.DatasetReference.from_path(
dataset_dir=dataset_dir, root_data_dir=root_data_dir
)


@pytest.mark.parametrize(
('ds_name', 'namespace', 'version', 'config', 'tfds_name'),
[
Expand Down

0 comments on commit c388d5f

Please sign in to comment.