diff --git a/tensorflow_datasets/core/naming.py b/tensorflow_datasets/core/naming.py index d1416e3002e..3187dc55609 100644 --- a/tensorflow_datasets/core/naming.py +++ b/tensorflow_datasets/core/naming.py @@ -292,7 +292,6 @@ def from_tfds_name( ) -> DatasetReference: """Returns the `DatasetReference` for the given TFDS dataset.""" parsed_name, builder_kwargs = parse_builder_name_kwargs(tfds_name) - version, config = None, None version = builder_kwargs.get('version') config = builder_kwargs.get('config') return cls( @@ -304,6 +303,46 @@ def from_tfds_name( data_dir=data_dir, ) + @classmethod + def from_path( + cls, + dataset_dir: epath.PathLike, + root_data_dir: epath.PathLike, + ) -> DatasetReference: + """Returns the `DatasetReference` for the given dataset directory. + + Args: + dataset_dir: The path to the dataset directory, e.g., + `/data/my_dataset/my_config/1.2.3`. + root_data_dir: The root data directory, e.g., `/data`. + """ + dataset_dir = os.fspath(dataset_dir) + root_data_dir = os.fspath(root_data_dir) + + if not dataset_dir.startswith(root_data_dir): + raise ValueError(f'{dataset_dir=} does not start with {root_data_dir=}!') + + relative_path = dataset_dir.removeprefix(root_data_dir) + relative_path = relative_path.removeprefix('/').removesuffix('/') + parts = relative_path.split('/') + dataset_name = parts[0] + if len(parts) == 2: + config_name = None + version = parts[1] + elif len(parts) == 3: + config_name = parts[1] + version = parts[2] + else: + raise ValueError( + f'Invalid {relative_path=} for {root_data_dir=} and {dataset_dir=}' + ) + return cls( + dataset_name=dataset_name, + config=config_name, + version=version, + data_dir=root_data_dir.removesuffix('/'), + ) + def references_for( name_to_tfds_name: Mapping[str, str], diff --git a/tensorflow_datasets/core/naming_test.py b/tensorflow_datasets/core/naming_test.py index e253b2ef48b..776177ea950 100644 --- a/tensorflow_datasets/core/naming_test.py +++ b/tensorflow_datasets/core/naming_test.py @@ -862,6 +862,100 @@ def test_dataset_reference_from_tfds_name( ) +@pytest.mark.parametrize( + ('dataset_dir', 'root_data_dir', 'expected'), + [ + # Dataset with a config and a version. + ( + '/data/ds/config/1.2.3', + '/data', + naming.DatasetReference( + dataset_name='ds', + version='1.2.3', + config='config', + data_dir='/data', + ), + ), + # Dataset with no config and a version. + ( + '/data/ds/1.2.3', + '/data', + naming.DatasetReference( + dataset_name='ds', + version='1.2.3', + config=None, + data_dir='/data', + ), + ), + # Dataset dir with trailing slash. + ( + '/data/ds/config/1.2.3/', + '/data', + naming.DatasetReference( + dataset_name='ds', + version='1.2.3', + config='config', + data_dir='/data', + ), + ), + # Root data dir with trailing slash. + ( + '/data/ds/config/1.2.3', + '/data/', + naming.DatasetReference( + dataset_name='ds', + version='1.2.3', + config='config', + data_dir='/data', + ), + ), + # Dataset dir and root data dir with trailing slash. + ( + '/data/ds/config/1.2.3/', + '/data/', + naming.DatasetReference( + dataset_name='ds', + version='1.2.3', + config='config', + data_dir='/data', + ), + ), + ], +) +def test_dataset_reference_from_path(dataset_dir, root_data_dir, expected): + actual = naming.DatasetReference.from_path( + dataset_dir=dataset_dir, root_data_dir=root_data_dir + ) + assert actual == expected + + +@pytest.mark.parametrize( + ('dataset_dir', 'root_data_dir'), + [ + # Root data dir is not a prefix of the dataset dir. + ( + '/data/ds/config/1.2.3', + '/somewhere_else', + ), + # Too many nested folders. + ( + '/data/ds/config/another_folder/1.2.3', + '/data', + ), + # Too few nested folders. + ( + '/data/ds/', + '/data', + ), + ], +) +def test_dataset_reference_from_path_invalid(dataset_dir, root_data_dir): + with pytest.raises(ValueError): + naming.DatasetReference.from_path( + dataset_dir=dataset_dir, root_data_dir=root_data_dir + ) + + @pytest.mark.parametrize( ('ds_name', 'namespace', 'version', 'config', 'tfds_name'), [