From a181015979225e1743c6ca8c49fe5cfab482ce21 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Mon, 14 Oct 2024 21:43:31 +0100 Subject: [PATCH 01/13] fast array extraction --- src/datasets/features/features.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 1d241e0b7b7..19340ab4afe 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -791,10 +791,7 @@ def to_numpy(self, zero_copy_only=True): def to_pylist(self): zero_copy_only = _is_zero_copy_only(self.storage.type, unnest=True) numpy_arr = self.to_numpy(zero_copy_only=zero_copy_only) - if self.type.shape[0] is None and numpy_arr.dtype == object: - return [arr.tolist() for arr in numpy_arr.tolist()] - else: - return numpy_arr.tolist() + return list(numpy_arr) class PandasArrayExtensionDtype(PandasExtensionDtype): From 426178e42ec1c568368e800894fa8c1760841d4d Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 10:06:29 +0100 Subject: [PATCH 02/13] add array 1d feature --- src/datasets/features/__init__.py | 3 ++- src/datasets/features/features.py | 45 +++++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 35ebfb4ac0c..d45eb1d06ce 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -1,4 +1,5 @@ __all__ = [ + "Array1D", "Audio", "Array2D", "Array3D", @@ -14,6 +15,6 @@ "TranslationVariableLanguages", ] from .audio import Audio -from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value +from .features import Array1D, Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value from .image import Image from .translation import Translation, TranslationVariableLanguages diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 19340ab4afe..2b92a577f5d 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -545,6 +545,33 @@ def encode_example(self, value): return value +@dataclass +class Array1D(_ArrayXD): + """Create a one-dimensional array. + + Unlike Sequence, will be extracted as a numpy array irrespective of formatting. + + Args: + shape (`tuple`): + Size of each dimension. + dtype (`str`): + Name of the data type. + + Example: + + ```py + >>> from datasets import Features + >>> features = Features({'x': Array1D(shape=(3,), dtype='int32')}) + ``` + """ + + shape: tuple + dtype: str + id: Optional[str] = None + # Automatically constructed + _type: str = field(default="Array1D", init=False, repr=False) + + @dataclass class Array2D(_ArrayXD): """Create a two-dimensional array. @@ -649,8 +676,8 @@ class _ArrayXDExtensionType(pa.ExtensionType): ndims: Optional[int] = None def __init__(self, shape: tuple, dtype: str): - if self.ndims is None or self.ndims <= 1: - raise ValueError("You must instantiate an array type with a value for dim that is > 1") + if self.ndims is None: + raise ValueError("You must instantiate an array type with a value for dim that is >= 1") if len(shape) != self.ndims: raise ValueError(f"shape={shape} and ndims={self.ndims} don't match") for dim in range(1, self.ndims): @@ -691,6 +718,10 @@ def to_pandas_dtype(self): return PandasArrayExtensionDtype(self.value_type) +class Array1DExtensionType(_ArrayXDExtensionType): + ndims = 1 + + class Array2DExtensionType(_ArrayXDExtensionType): ndims = 2 @@ -708,6 +739,7 @@ class Array5DExtensionType(_ArrayXDExtensionType): # Register the extension types for deserialization +pa.register_extension_type(Array1DExtensionType((1,), "int64")) pa.register_extension_type(Array2DExtensionType((1, 2), "int64")) pa.register_extension_type(Array3DExtensionType((1, 2, 3), "int64")) pa.register_extension_type(Array4DExtensionType((1, 2, 3, 4), "int64")) @@ -1491,8 +1523,11 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType: feature = generate_from_arrow_type(pa_type.value_type) return LargeList(feature=feature) elif isinstance(pa_type, _ArrayXDExtensionType): - array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims] - return array_feature(shape=pa_type.shape, dtype=pa_type.value_type) + if pa_type.ndims >= 1: + array_feature = [Array1D, Array2D, Array3D, Array4D, Array5D][pa_type.ndims - 1] + return array_feature(shape=pa_type.shape, dtype=pa_type.value_type) + else: + raise ValueError("Cannot convert 0-dimensional array to Array Feature type.") elif isinstance(pa_type, pa.DataType): return Value(dtype=_arrow_to_datasets_dtype(pa_type)) else: @@ -1712,7 +1747,7 @@ class Features(dict): - - [`Array2D`], [`Array3D`], [`Array4D`] or [`Array5D`] feature for multidimensional arrays. + - [`Array1D`], [`Array2D`], [`Array3D`], [`Array4D`] or [`Array5D`] feature for multidimensional arrays. - [`Audio`] feature to store the absolute path to an audio file or a dictionary with the relative path to an audio file ("path" key) and its bytes content ("bytes" key). This feature extracts the audio data. - [`Image`] feature to store the absolute path to an image file, an `np.ndarray` object, a `PIL.Image.Image` object From 303c4e20a090698bd5647b0d07299d808cff3825 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 10:31:46 +0100 Subject: [PATCH 03/13] fast struct extraction by invoking extension type to_pylist --- src/datasets/formatting/formatting.py | 109 ++++++++++++++++---------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 2dae3a52fd3..bd513e19ba2 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -107,6 +107,51 @@ def _is_array_with_nulls(pa_array: pa.Array) -> bool: return pa_array.null_count > 0 +def _arrow_array_to_numpy(pa_array: pa.Array) -> np.ndarray: + if isinstance(pa_array, pa.ChunkedArray): + if isinstance(pa_array.type, _ArrayXDExtensionType): + # don't call to_pylist() to preserve dtype of the fixed-size array + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + else: + zero_copy_only = _is_zero_copy_only(pa_array.type) and all( + not _is_array_with_nulls(chunk) for chunk in pa_array.chunks + ) + array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + else: + if isinstance(pa_array.type, _ArrayXDExtensionType): + # don't call to_pylist() to preserve dtype of the fixed-size array + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only) + else: + zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array) + array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist() + + if len(array) > 0: + if any( + (isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape)) + or (isinstance(x, float) and np.isnan(x)) + for x in array + ): + if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": + return np.asarray(array, dtype=object) + return np.array(array, copy=False, dtype=object) + if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": + return np.asarray(array) + else: + return np.array(array, copy=False) + + +def dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[str, List[T]]) -> List[Dict[str, T]]: + # convert to list of dicts + list_of_dicts = [] + keys = dict_of_lists.keys() + value_arrays = [dict_of_lists[key] for key in keys] + for vals in zip(*value_arrays): + list_of_dicts.append(dict(zip(keys, vals))) + return list_of_dicts + + class BaseArrowExtractor(Generic[RowFormat, ColumnFormat, BatchFormat]): """ Arrow extractor are used to extract data from pyarrow tables. @@ -140,6 +185,20 @@ def extract_batch(self, pa_table: pa.Table) -> pa.Table: return pa_table +def extract_struct_array(pa_array: pa.StructArray) -> list: + if isinstance(pa_array, pa.ChunkedArray): + batch_chunks = [extract_struct_array(chunk) for chunk in pa_array.chunks] + return [item for chunk in batch_chunks for item in chunk] + + batch = {} + for field in pa_array.type: + if pa.types.is_struct(pa_array.field(field.name).type): + batch[field.name] = extract_struct_array(pa_array.field(field.name)) + else: + batch[field.name] = pa_array.field(field.name).to_pylist() + return dict_of_lists_to_list_of_dicts(batch) + + class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]): def extract_row(self, pa_table: pa.Table) -> dict: return _unnest(pa_table.to_pydict()) @@ -148,7 +207,15 @@ def extract_column(self, pa_table: pa.Table) -> list: return pa_table.column(0).to_pylist() def extract_batch(self, pa_table: pa.Table) -> dict: - return pa_table.to_pydict() + batch = {} + for col in pa_table.column_names: + if pa.types.is_list(pa_table[col].type): + batch[col] = list(pa_table[col].to_numpy()) + elif pa.types.is_struct(pa_table[col].type): + batch[col] = extract_struct_array(pa_table[col]) + else: + batch[col] = pa_table[col].to_pylist() + return batch class NumpyArrowExtractor(BaseArrowExtractor[dict, np.ndarray, dict]): @@ -162,45 +229,7 @@ def extract_column(self, pa_table: pa.Table) -> np.ndarray: return self._arrow_array_to_numpy(pa_table[pa_table.column_names[0]]) def extract_batch(self, pa_table: pa.Table) -> dict: - return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names} - - def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray: - if isinstance(pa_array, pa.ChunkedArray): - if isinstance(pa_array.type, _ArrayXDExtensionType): - # don't call to_pylist() to preserve dtype of the fixed-size array - zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) - array: List = [ - row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) - ] - else: - zero_copy_only = _is_zero_copy_only(pa_array.type) and all( - not _is_array_with_nulls(chunk) for chunk in pa_array.chunks - ) - array: List = [ - row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) - ] - else: - if isinstance(pa_array.type, _ArrayXDExtensionType): - # don't call to_pylist() to preserve dtype of the fixed-size array - zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) - array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only) - else: - zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array) - array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist() - - if len(array) > 0: - if any( - (isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape)) - or (isinstance(x, float) and np.isnan(x)) - for x in array - ): - if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": - return np.asarray(array, dtype=object) - return np.array(array, copy=False, dtype=object) - if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": - return np.asarray(array) - else: - return np.array(array, copy=False) + return {col: _arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names} class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]): From 0be08953494f10e05d354bdfc9ec3c489318f24a Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 10:32:33 +0100 Subject: [PATCH 04/13] also use to_pylist for list array --- src/datasets/formatting/formatting.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index bd513e19ba2..9157283961e 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -209,9 +209,7 @@ def extract_column(self, pa_table: pa.Table) -> list: def extract_batch(self, pa_table: pa.Table) -> dict: batch = {} for col in pa_table.column_names: - if pa.types.is_list(pa_table[col].type): - batch[col] = list(pa_table[col].to_numpy()) - elif pa.types.is_struct(pa_table[col].type): + if pa.types.is_struct(pa_table[col].type): batch[col] = extract_struct_array(pa_table[col]) else: batch[col] = pa_table[col].to_pylist() From deee87ed40e6641be774129e939e1a2ff1a7a4b4 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 10:51:39 +0100 Subject: [PATCH 05/13] improve struct extraction --- src/datasets/features/features.py | 2 ++ src/datasets/formatting/formatting.py | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 2b92a577f5d..873f12b6bee 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1225,6 +1225,7 @@ class LargeList: TranslationVariableLanguages, LargeList, Sequence, + Array1D, Array2D, Array3D, Array4D, @@ -1440,6 +1441,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni TranslationVariableLanguages.__name__: TranslationVariableLanguages, LargeList.__name__: LargeList, Sequence.__name__: Sequence, + Array1D.__name__: Array1D, Array2D.__name__: Array2D, Array3D.__name__: Array3D, Array4D.__name__: Array4D, diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 9157283961e..3a0d1161e5c 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -201,10 +201,13 @@ def extract_struct_array(pa_array: pa.StructArray) -> list: class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]): def extract_row(self, pa_table: pa.Table) -> dict: - return _unnest(pa_table.to_pydict()) + return _unnest(self.extract_batch(pa_table)) def extract_column(self, pa_table: pa.Table) -> list: - return pa_table.column(0).to_pylist() + if pa.types.is_struct(pa_table[pa_table.column_names[0]].type): + return extract_struct_array(pa_table[pa_table.column_names[0]]) + else: + return pa_table.column(0).to_pylist() def extract_batch(self, pa_table: pa.Table) -> dict: batch = {} From ac5a46d2685c12e52ba5c7a8374533378d04e6db Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 11:55:40 +0100 Subject: [PATCH 06/13] handle structs and lists of arrays --- src/datasets/formatting/formatting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 3a0d1161e5c..84fd7f00fef 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -186,6 +186,7 @@ def extract_batch(self, pa_table: pa.Table) -> pa.Table: def extract_struct_array(pa_array: pa.StructArray) -> list: + """StructArray.to_pylist / to_pydict does not call sub-arrays to_pylist / to_pydict methods so handle them manually.""" if isinstance(pa_array, pa.ChunkedArray): batch_chunks = [extract_struct_array(chunk) for chunk in pa_array.chunks] return [item for chunk in batch_chunks for item in chunk] From a89ef529fdddf7037d5491b15df2b3e6ea1dbb69 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 12:08:19 +0100 Subject: [PATCH 07/13] restore arrow array to numpy to numpy extractor --- src/datasets/formatting/formatting.py | 71 +++++++++++++-------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 84fd7f00fef..6400d7e901c 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -107,41 +107,6 @@ def _is_array_with_nulls(pa_array: pa.Array) -> bool: return pa_array.null_count > 0 -def _arrow_array_to_numpy(pa_array: pa.Array) -> np.ndarray: - if isinstance(pa_array, pa.ChunkedArray): - if isinstance(pa_array.type, _ArrayXDExtensionType): - # don't call to_pylist() to preserve dtype of the fixed-size array - zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) - array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] - else: - zero_copy_only = _is_zero_copy_only(pa_array.type) and all( - not _is_array_with_nulls(chunk) for chunk in pa_array.chunks - ) - array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] - else: - if isinstance(pa_array.type, _ArrayXDExtensionType): - # don't call to_pylist() to preserve dtype of the fixed-size array - zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) - array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only) - else: - zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array) - array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist() - - if len(array) > 0: - if any( - (isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape)) - or (isinstance(x, float) and np.isnan(x)) - for x in array - ): - if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": - return np.asarray(array, dtype=object) - return np.array(array, copy=False, dtype=object) - if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": - return np.asarray(array) - else: - return np.array(array, copy=False) - - def dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[str, List[T]]) -> List[Dict[str, T]]: # convert to list of dicts list_of_dicts = [] @@ -231,7 +196,41 @@ def extract_column(self, pa_table: pa.Table) -> np.ndarray: return self._arrow_array_to_numpy(pa_table[pa_table.column_names[0]]) def extract_batch(self, pa_table: pa.Table) -> dict: - return {col: _arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names} + return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names} + + def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray: + if isinstance(pa_array, pa.ChunkedArray): + if isinstance(pa_array.type, _ArrayXDExtensionType): + # don't call to_pylist() to preserve dtype of the fixed-size array + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + else: + zero_copy_only = _is_zero_copy_only(pa_array.type) and all( + not _is_array_with_nulls(chunk) for chunk in pa_array.chunks + ) + array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + else: + if isinstance(pa_array.type, _ArrayXDExtensionType): + # don't call to_pylist() to preserve dtype of the fixed-size array + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only) + else: + zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array) + array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist() + + if len(array) > 0: + if any( + (isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape)) + or (isinstance(x, float) and np.isnan(x)) + for x in array + ): + if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": + return np.asarray(array, dtype=object) + return np.array(array, copy=False, dtype=object) + if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": + return np.asarray(array) + else: + return np.array(array, copy=False) class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]): From 7f1e2173ec9b2df14210eb88bed55ede95435a85 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 17:43:43 +0100 Subject: [PATCH 08/13] fix failing array tests --- src/datasets/formatting/formatting.py | 2 ++ tests/features/test_array_xd.py | 9 +++++++-- tests/test_table.py | 4 ++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 6400d7e901c..aafce2f68b8 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -172,7 +172,9 @@ def extract_row(self, pa_table: pa.Table) -> dict: def extract_column(self, pa_table: pa.Table) -> list: if pa.types.is_struct(pa_table[pa_table.column_names[0]].type): return extract_struct_array(pa_table[pa_table.column_names[0]]) + # TODO: handle list of struct else: + # should work for list of ArrayXD return pa_table.column(0).to_pylist() def extract_batch(self, pa_table: pa.Table) -> dict: diff --git a/tests/features/test_array_xd.py b/tests/features/test_array_xd.py index 8a50823b996..b57ababdeda 100644 --- a/tests/features/test_array_xd.py +++ b/tests/features/test_array_xd.py @@ -419,7 +419,7 @@ def test_array_xd_with_none(): def test_array_xd_with_np(seq_type, dtype, shape, feature_class): feature = feature_class(dtype=dtype, shape=shape) data = np.zeros(shape, dtype=dtype) - expected = data.tolist() + expected = data if seq_type == "sequence": feature = datasets.Sequence(feature) data = [data] @@ -429,7 +429,12 @@ def test_array_xd_with_np(seq_type, dtype, shape, feature_class): data = [[data]] expected = [[expected]] ds = datasets.Dataset.from_dict({"col": [data]}, features=datasets.Features({"col": feature})) - assert ds[0]["col"] == expected + if seq_type == "sequence": + assert (ds[0]["col"][0] == expected[0]).all() + elif seq_type == "sequence_of_sequence": + assert (ds[0]["col"][0][0] == expected[0][0]).all() + else: + assert (ds[0]["col"] == expected).all() @pytest.mark.parametrize("with_none", [False, True]) diff --git a/tests/test_table.py b/tests/test_table.py index 3d3db09e5d6..9a3f92564d0 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1339,11 +1339,11 @@ def test_cast_array_xd_to_features_sequence(): # Variable size list casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"))) assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"))) - assert casted_array.to_pylist() == arr.to_pylist() + assert (casted_array.to_pylist() == arr.to_pylist()).all() # Fixed size list casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4)) assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4)) - assert casted_array.to_pylist() == arr.to_pylist() + assert (casted_array.to_pylist() == arr.to_pylist()).all() def test_embed_array_storage(image_file): From abbb59a14c23899c577bae94af45b5d706a74e40 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 17:50:37 +0100 Subject: [PATCH 09/13] test cast array xd to features fix --- tests/test_table.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/test_table.py b/tests/test_table.py index 9a3f92564d0..a921527eae7 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1332,6 +1332,21 @@ def test_cast_array_to_feature_with_list_array_and_large_list_feature(from_list_ assert cast_array.type == expected_array_type +def all_arrays_equal(arr1, arr2): + if len(arr1) != len(arr2): + return False + for a1, a2 in zip(arr1, arr2): + if isinstance(a1, list) and isinstance(a2, list): + if not all_arrays_equal(a1, a2): + return False + elif isinstance(a1, np.ndarray) and isinstance(a2, np.ndarray): + if not (a1 == a2).all(): + return False + elif a1 != a2: + return False + return True + + def test_cast_array_xd_to_features_sequence(): arr = np.random.randint(0, 10, size=(8, 2, 3)).tolist() arr = Array2DExtensionType(shape=(2, 3), dtype="int64").wrap_array(pa.array(arr, pa.list_(pa.list_(pa.int64())))) @@ -1339,11 +1354,11 @@ def test_cast_array_xd_to_features_sequence(): # Variable size list casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"))) assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"))) - assert (casted_array.to_pylist() == arr.to_pylist()).all() + assert all_arrays_equal(casted_array.to_pylist(), arr.to_pylist()) # Fixed size list casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4)) assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4)) - assert (casted_array.to_pylist() == arr.to_pylist()).all() + assert all_arrays_equal(casted_array.to_pylist(), arr.to_pylist()) def test_embed_array_storage(image_file): From c39c4bc3b33619241b195d0ce456a400eac40b45 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 17:54:09 +0100 Subject: [PATCH 10/13] test array write --- tests/features/test_array_xd.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/features/test_array_xd.py b/tests/features/test_array_xd.py index b57ababdeda..2029b6b5230 100644 --- a/tests/features/test_array_xd.py +++ b/tests/features/test_array_xd.py @@ -175,20 +175,20 @@ def get_dict_examples(self, shape_1, shape_2): def _check_getitem_output_type(self, dataset, shape_1, shape_2, first_matrix): matrix_column = dataset["matrix"] self.assertIsInstance(matrix_column, list) - self.assertIsInstance(matrix_column[0], list) - self.assertIsInstance(matrix_column[0][0], list) + self.assertIsInstance(matrix_column[0], np.ndarray) + self.assertIsInstance(matrix_column[0][0], np.ndarray) self.assertTupleEqual(np.array(matrix_column).shape, (2, *shape_2)) matrix_field_of_first_example = dataset[0]["matrix"] - self.assertIsInstance(matrix_field_of_first_example, list) - self.assertIsInstance(matrix_field_of_first_example, list) + self.assertIsInstance(matrix_field_of_first_example, np.ndarray) + self.assertIsInstance(matrix_field_of_first_example[0], np.ndarray) self.assertEqual(np.array(matrix_field_of_first_example).shape, shape_2) np.testing.assert_array_equal(np.array(matrix_field_of_first_example), np.array(first_matrix)) matrix_field_of_first_two_examples = dataset[:2]["matrix"] self.assertIsInstance(matrix_field_of_first_two_examples, list) - self.assertIsInstance(matrix_field_of_first_two_examples[0], list) - self.assertIsInstance(matrix_field_of_first_two_examples[0][0], list) + self.assertIsInstance(matrix_field_of_first_two_examples[0], np.ndarray) + self.assertIsInstance(matrix_field_of_first_two_examples[0][0], np.ndarray) self.assertTupleEqual(np.array(matrix_field_of_first_two_examples).shape, (2, *shape_2)) with dataset.formatted_as("numpy"): @@ -268,7 +268,7 @@ def test_to_pylist(self): pylist = arr_xd.to_pylist() for first_dim, single_arr in zip(first_dim_list, pylist): - self.assertIsInstance(single_arr, list) + self.assertIsInstance(single_arr, np.ndarray) self.assertTupleEqual(np.array(single_arr).shape, (first_dim, *fixed_shape)) def test_to_numpy(self): @@ -311,8 +311,8 @@ def test_iter_dataset(self): for first_dim, ds_row in zip(first_dim_list, dataset): single_arr = ds_row["image"] - self.assertIsInstance(single_arr, list) - self.assertTupleEqual(np.array(single_arr).shape, (first_dim, *fixed_shape)) + self.assertIsInstance(single_arr, np.ndarray) + self.assertTupleEqual(single_arr.shape, (first_dim, *fixed_shape)) def test_to_pandas(self): fixed_shape = (2, 2) @@ -353,8 +353,8 @@ def test_map_dataset(self): # check also if above function resulted with 2x bigger first dim for first_dim, ds_row in zip(first_dim_list, dataset): single_arr = ds_row["image"] - self.assertIsInstance(single_arr, list) - self.assertTupleEqual(np.array(single_arr).shape, (first_dim * 2, *fixed_shape)) + self.assertIsInstance(single_arr, np.ndarray) + self.assertTupleEqual(single_arr.shape, (first_dim * 2, *fixed_shape)) @pytest.mark.parametrize("dtype, dummy_value", [("int32", 1), ("bool", True), ("float64", 1)]) From 67f65b5ca392a3f19c962d0749d65d326d637431 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Tue, 15 Oct 2024 18:07:26 +0100 Subject: [PATCH 11/13] formatting --- src/datasets/formatting/formatting.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index aafce2f68b8..ab76da3972e 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -205,12 +205,16 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray: if isinstance(pa_array.type, _ArrayXDExtensionType): # don't call to_pylist() to preserve dtype of the fixed-size array zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) - array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + array: List = [ + row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) + ] else: zero_copy_only = _is_zero_copy_only(pa_array.type) and all( not _is_array_with_nulls(chunk) for chunk in pa_array.chunks ) - array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + array: List = [ + row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) + ] else: if isinstance(pa_array.type, _ArrayXDExtensionType): # don't call to_pylist() to preserve dtype of the fixed-size array From 97f0f19e5a3aac9d80b7d90701d86b8379651cc2 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Thu, 17 Oct 2024 14:06:49 +0100 Subject: [PATCH 12/13] fix a couple more test cases --- src/datasets/arrow_dataset.py | 2 +- src/datasets/formatting/formatting.py | 23 +++++++++++++++++++++-- tests/test_arrow_dataset.py | 4 ++-- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index b289fba4106..fe0fe451692 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2510,7 +2510,7 @@ def set_format( # Check that the format_type and format_kwargs are valid and make it possible to have a Formatter type = get_format_type_from_alias(type) - get_formatter(type, features=self._info.features, **format_kwargs) + get_formatter(type, features=self._info.features, **format_kwargs) if type is not None else None # Check filter column if isinstance(columns, str): diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index ab76da3972e..b114c09bccc 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -161,10 +161,24 @@ def extract_struct_array(pa_array: pa.StructArray) -> list: if pa.types.is_struct(pa_array.field(field.name).type): batch[field.name] = extract_struct_array(pa_array.field(field.name)) else: - batch[field.name] = pa_array.field(field.name).to_pylist() + # use logic from _arrow_array_to_numpy to preserve dtype + if isinstance(pa_array.type, _ArrayXDExtensionType): + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + batch[field.name] = list(pa_array.to_numpy(zero_copy_only=zero_copy_only)) + else: + batch[field.name] = pa_array.field(field.name).to_pylist() return dict_of_lists_to_list_of_dicts(batch) +def extract_array_xdextension_array(pa_array: pa.Array) -> list: + print("Extracting array xdextension array") + if isinstance(pa_array, pa.ChunkedArray): + return [arr for chunk in pa_array.chunks for arr in extract_array_xdextension_array(chunk)] + else: + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + return list(pa_array.to_numpy(zero_copy_only=zero_copy_only)) + + class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]): def extract_row(self, pa_table: pa.Table) -> dict: return _unnest(self.extract_batch(pa_table)) @@ -183,7 +197,12 @@ def extract_batch(self, pa_table: pa.Table) -> dict: if pa.types.is_struct(pa_table[col].type): batch[col] = extract_struct_array(pa_table[col]) else: - batch[col] = pa_table[col].to_pylist() + pa_array = pa_table[col] + if isinstance(pa_array.type, _ArrayXDExtensionType): + # don't call to_pylist() to preserve dtype of the fixed-size array + batch[col] = extract_array_xdextension_array(pa_array) + else: + batch[col] = pa_table[col].to_pylist() return batch diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ffa048644e2..dff7067f4a2 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -195,8 +195,8 @@ def test_dummy_dataset(self, in_memory): } ), ) - self.assertEqual(dset[0]["col_2"], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]) - self.assertEqual(dset["col_2"][0], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]) + assert (dset[0]["col_2"] == np.array([[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]])).all() + assert (dset["col_2"][0] == np.array([[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]])).all() def test_dataset_getitem(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: From 0f37d058a971b7d9d21ccce02c65c7d7a1a617cd Mon Sep 17 00:00:00 2001 From: alex-hh Date: Fri, 18 Oct 2024 12:06:02 +0100 Subject: [PATCH 13/13] fix writing struct arrays --- src/datasets/arrow_writer.py | 4 ++++ src/datasets/features/features.py | 39 +++++++++++++++++++++++++++++++ tests/test_arrow_dataset.py | 15 +++++++----- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 3b9993736e4..1f25929ccd4 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -31,6 +31,7 @@ cast_to_python_objects, generate_from_arrow_type, get_nested_type, + list_of_dicts_to_pyarrow_structarray, list_of_np_array_to_pyarrow_listarray, numpy_to_pyarrow_listarray, to_pyarrow_listarray, @@ -183,6 +184,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): out = numpy_to_pyarrow_listarray(data) elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], np.ndarray): out = list_of_np_array_to_pyarrow_listarray(data) + elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], dict): + # pa_type should be a struct type + out = list_of_dicts_to_pyarrow_structarray(data, pa_type) else: trying_cast_to_python_objects = True out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True)) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 873f12b6bee..e41da2ff274 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1620,6 +1620,45 @@ def to_pyarrow_listarray(data: Any, pa_type: _ArrayXDExtensionType) -> pa.Array: return pa.array(data, pa_type.storage_dtype) +def list_of_dicts_to_pyarrow_structarray( + data: List[Dict[str, Any]], struct_type: Optional[pa.StructType] = None +) -> pa.StructArray: + """Convert a list of dictionaries to a pyarrow StructArray. + + First builds a dict of lists, then converts each list to a pyarrow array, + then creates a StructArray from the arrays. + """ + if not data: + raise ValueError("Input data must be a non-empty list of dictionaries.") + + field_arrays = {key: [] for key in data[0].keys()} + + for row in data: + for key in field_arrays.keys(): + value = row.get(key, None) + field_arrays[key].append(value) + + # TODO: do these need to be ordered? + pa_fields = [] + for key, values in field_arrays.items(): + if struct_type is not None: + index = struct_type.get_field_index(key) + field_type = struct_type[index].type + else: + field_type = None + # TODO: should field_type None be handled better? + pa_field = ( + to_pyarrow_listarray(values, field_type) + if contains_any_np_array(values) and field_type is not None + else pa.array(values) + ) + pa_fields.append((key, pa_field)) + + field_names, field_arrays = zip(*pa_fields) + + return pa.StructArray.from_arrays(field_arrays, field_names) + + def _visit(feature: FeatureType, func: Callable[[FeatureType], Optional[FeatureType]]) -> FeatureType: """Visit a (possibly nested) feature. diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index dff7067f4a2..37167ed3316 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -4386,30 +4386,33 @@ def f(x): def f(x): """May return a mix of LazyDict and regular Dict, but using an extension type""" if x["a"][0][0] < 2: - x["a"] = [[-1]] + x["a"] = np.array([[-1]], dtype="int32") return dict(x) if return_lazy_dict is False else x else: return x if return_lazy_dict is True else {} features = Features({"a": Array2D(shape=(1, 1), dtype="int32")}) - ds = Dataset.from_dict({"a": [[[i]] for i in [0, 1, 2, 3]]}, features=features) + # If not passing array we get exceptions that are not easy to understand - not sure if there could be some type-checking needed somewhere? + ds = Dataset.from_dict({"a": [np.array([[i]], dtype="int32") for i in [0, 1, 2, 3]]}, features=features) ds = ds.map(f) outputs = ds[:] - assert outputs == {"a": [[[i]] for i in [-1, -1, 2, 3]]} + assert outputs == {"a": [np.array([[i]], dtype="int32") for i in [-1, -1, 2, 3]]} def f(x): """May return a mix of LazyDict and regular Dict, but using a nested extension type""" if x["a"]["nested"][0][0] < 2: - x["a"] = {"nested": [[-1]]} + x["a"] = {"nested": np.array([[-1]], dtype="int64")} return dict(x) if return_lazy_dict is False else x else: return x if return_lazy_dict is True else {} features = Features({"a": {"nested": Array2D(shape=(1, 1), dtype="int64")}}) - ds = Dataset.from_dict({"a": [{"nested": [[i]]} for i in [0, 1, 2, 3]]}, features=features) + ds = Dataset.from_dict( + {"a": [{"nested": np.array([[i]], dtype="int64")} for i in [0, 1, 2, 3]]}, features=features + ) ds = ds.map(f) outputs = ds[:] - assert outputs == {"a": [{"nested": [[i]]} for i in [-1, -1, 2, 3]]} + assert outputs == {"a": [{"nested": np.array([[i]], dtype="int64")} for i in [-1, -1, 2, 3]]} def test_dataset_getitem_raises():