Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast array extraction #7227

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,7 +2510,7 @@ def set_format(

# Check that the format_type and format_kwargs are valid and make it possible to have a Formatter
type = get_format_type_from_alias(type)
get_formatter(type, features=self._info.features, **format_kwargs)
get_formatter(type, features=self._info.features, **format_kwargs) if type is not None else None

# Check filter column
if isinstance(columns, str):
Expand Down
4 changes: 4 additions & 0 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
cast_to_python_objects,
generate_from_arrow_type,
get_nested_type,
list_of_dicts_to_pyarrow_structarray,
list_of_np_array_to_pyarrow_listarray,
numpy_to_pyarrow_listarray,
to_pyarrow_listarray,
Expand Down Expand Up @@ -183,6 +184,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
out = numpy_to_pyarrow_listarray(data)
elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], np.ndarray):
out = list_of_np_array_to_pyarrow_listarray(data)
elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], dict):
# pa_type should be a struct type
out = list_of_dicts_to_pyarrow_structarray(data, pa_type)
else:
trying_cast_to_python_objects = True
out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__all__ = [
"Array1D",
"Audio",
"Array2D",
"Array3D",
Expand All @@ -14,6 +15,6 @@
"TranslationVariableLanguages",
]
from .audio import Audio
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value
from .features import Array1D, Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value
from .image import Image
from .translation import Translation, TranslationVariableLanguages
91 changes: 82 additions & 9 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,33 @@ def encode_example(self, value):
return value


@dataclass
class Array1D(_ArrayXD):
"""Create a one-dimensional array.

Unlike Sequence, will be extracted as a numpy array irrespective of formatting.

Args:
shape (`tuple`):
Size of each dimension.
dtype (`str`):
Name of the data type.

Example:

```py
>>> from datasets import Features
>>> features = Features({'x': Array1D(shape=(3,), dtype='int32')})
```
"""

shape: tuple
dtype: str
id: Optional[str] = None
# Automatically constructed
_type: str = field(default="Array1D", init=False, repr=False)


@dataclass
class Array2D(_ArrayXD):
"""Create a two-dimensional array.
Expand Down Expand Up @@ -649,8 +676,8 @@ class _ArrayXDExtensionType(pa.ExtensionType):
ndims: Optional[int] = None

def __init__(self, shape: tuple, dtype: str):
if self.ndims is None or self.ndims <= 1:
raise ValueError("You must instantiate an array type with a value for dim that is > 1")
if self.ndims is None:
raise ValueError("You must instantiate an array type with a value for dim that is >= 1")
if len(shape) != self.ndims:
raise ValueError(f"shape={shape} and ndims={self.ndims} don't match")
for dim in range(1, self.ndims):
Expand Down Expand Up @@ -691,6 +718,10 @@ def to_pandas_dtype(self):
return PandasArrayExtensionDtype(self.value_type)


class Array1DExtensionType(_ArrayXDExtensionType):
ndims = 1


class Array2DExtensionType(_ArrayXDExtensionType):
ndims = 2

Expand All @@ -708,6 +739,7 @@ class Array5DExtensionType(_ArrayXDExtensionType):


# Register the extension types for deserialization
pa.register_extension_type(Array1DExtensionType((1,), "int64"))
pa.register_extension_type(Array2DExtensionType((1, 2), "int64"))
pa.register_extension_type(Array3DExtensionType((1, 2, 3), "int64"))
pa.register_extension_type(Array4DExtensionType((1, 2, 3, 4), "int64"))
Expand Down Expand Up @@ -791,10 +823,7 @@ def to_numpy(self, zero_copy_only=True):
def to_pylist(self):
zero_copy_only = _is_zero_copy_only(self.storage.type, unnest=True)
numpy_arr = self.to_numpy(zero_copy_only=zero_copy_only)
if self.type.shape[0] is None and numpy_arr.dtype == object:
return [arr.tolist() for arr in numpy_arr.tolist()]
else:
return numpy_arr.tolist()
return list(numpy_arr)


class PandasArrayExtensionDtype(PandasExtensionDtype):
Expand Down Expand Up @@ -1196,6 +1225,7 @@ class LargeList:
TranslationVariableLanguages,
LargeList,
Sequence,
Array1D,
Array2D,
Array3D,
Array4D,
Expand Down Expand Up @@ -1411,6 +1441,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
TranslationVariableLanguages.__name__: TranslationVariableLanguages,
LargeList.__name__: LargeList,
Sequence.__name__: Sequence,
Array1D.__name__: Array1D,
Array2D.__name__: Array2D,
Array3D.__name__: Array3D,
Array4D.__name__: Array4D,
Expand Down Expand Up @@ -1494,8 +1525,11 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
feature = generate_from_arrow_type(pa_type.value_type)
return LargeList(feature=feature)
elif isinstance(pa_type, _ArrayXDExtensionType):
array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims]
return array_feature(shape=pa_type.shape, dtype=pa_type.value_type)
if pa_type.ndims >= 1:
array_feature = [Array1D, Array2D, Array3D, Array4D, Array5D][pa_type.ndims - 1]
return array_feature(shape=pa_type.shape, dtype=pa_type.value_type)
else:
raise ValueError("Cannot convert 0-dimensional array to Array Feature type.")
elif isinstance(pa_type, pa.DataType):
return Value(dtype=_arrow_to_datasets_dtype(pa_type))
else:
Expand Down Expand Up @@ -1586,6 +1620,45 @@ def to_pyarrow_listarray(data: Any, pa_type: _ArrayXDExtensionType) -> pa.Array:
return pa.array(data, pa_type.storage_dtype)


def list_of_dicts_to_pyarrow_structarray(
data: List[Dict[str, Any]], struct_type: Optional[pa.StructType] = None
) -> pa.StructArray:
"""Convert a list of dictionaries to a pyarrow StructArray.

First builds a dict of lists, then converts each list to a pyarrow array,
then creates a StructArray from the arrays.
"""
if not data:
raise ValueError("Input data must be a non-empty list of dictionaries.")

field_arrays = {key: [] for key in data[0].keys()}

for row in data:
for key in field_arrays.keys():
value = row.get(key, None)
field_arrays[key].append(value)

# TODO: do these need to be ordered?
pa_fields = []
for key, values in field_arrays.items():
if struct_type is not None:
index = struct_type.get_field_index(key)
field_type = struct_type[index].type
else:
field_type = None
# TODO: should field_type None be handled better?
pa_field = (
to_pyarrow_listarray(values, field_type)
if contains_any_np_array(values) and field_type is not None
else pa.array(values)
)
pa_fields.append((key, pa_field))

field_names, field_arrays = zip(*pa_fields)

return pa.StructArray.from_arrays(field_arrays, field_names)


def _visit(feature: FeatureType, func: Callable[[FeatureType], Optional[FeatureType]]) -> FeatureType:
"""Visit a (possibly nested) feature.

Expand Down Expand Up @@ -1715,7 +1788,7 @@ class Features(dict):

</Tip>

- [`Array2D`], [`Array3D`], [`Array4D`] or [`Array5D`] feature for multidimensional arrays.
- [`Array1D`], [`Array2D`], [`Array3D`], [`Array4D`] or [`Array5D`] feature for multidimensional arrays.
- [`Audio`] feature to store the absolute path to an audio file or a dictionary with the relative path
to an audio file ("path" key) and its bytes content ("bytes" key). This feature extracts the audio data.
- [`Image`] feature to store the absolute path to an image file, an `np.ndarray` object, a `PIL.Image.Image` object
Expand Down
61 changes: 58 additions & 3 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ def _is_array_with_nulls(pa_array: pa.Array) -> bool:
return pa_array.null_count > 0


def dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[str, List[T]]) -> List[Dict[str, T]]:
# convert to list of dicts
list_of_dicts = []
keys = dict_of_lists.keys()
value_arrays = [dict_of_lists[key] for key in keys]
for vals in zip(*value_arrays):
list_of_dicts.append(dict(zip(keys, vals)))
return list_of_dicts


class BaseArrowExtractor(Generic[RowFormat, ColumnFormat, BatchFormat]):
"""
Arrow extractor are used to extract data from pyarrow tables.
Expand Down Expand Up @@ -140,15 +150,60 @@ def extract_batch(self, pa_table: pa.Table) -> pa.Table:
return pa_table


def extract_struct_array(pa_array: pa.StructArray) -> list:
"""StructArray.to_pylist / to_pydict does not call sub-arrays to_pylist / to_pydict methods so handle them manually."""
if isinstance(pa_array, pa.ChunkedArray):
batch_chunks = [extract_struct_array(chunk) for chunk in pa_array.chunks]
return [item for chunk in batch_chunks for item in chunk]

batch = {}
for field in pa_array.type:
if pa.types.is_struct(pa_array.field(field.name).type):
batch[field.name] = extract_struct_array(pa_array.field(field.name))
Comment on lines +161 to +162
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also check if it's a list or large_list type

Copy link
Contributor Author

@alex-hh alex-hh Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked that lists of ArrayExtensionType features will call ArrayExtensionArray.to_pylist(), which didn't seem to be the case for struct, and is the main performance issue there

Not sure about large list?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool ! maybe also check list of struct of ArrayExtensionType but no big deal, we can fix that rare case later (large list is also rare)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the list of struct case might require an ArrayExtensionScalar or something with an as_py method that returns a numpy object.

Seems like it could be useful but have no idea whether this is possible or how best to do it if so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless you know how to do this could we leave as issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just add a TODO comment about it for now ?

else:
# use logic from _arrow_array_to_numpy to preserve dtype
if isinstance(pa_array.type, _ArrayXDExtensionType):
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
batch[field.name] = list(pa_array.to_numpy(zero_copy_only=zero_copy_only))
else:
batch[field.name] = pa_array.field(field.name).to_pylist()
return dict_of_lists_to_list_of_dicts(batch)


def extract_array_xdextension_array(pa_array: pa.Array) -> list:
print("Extracting array xdextension array")
if isinstance(pa_array, pa.ChunkedArray):
return [arr for chunk in pa_array.chunks for arr in extract_array_xdextension_array(chunk)]
else:
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
return list(pa_array.to_numpy(zero_copy_only=zero_copy_only))


class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]):
def extract_row(self, pa_table: pa.Table) -> dict:
return _unnest(pa_table.to_pydict())
return _unnest(self.extract_batch(pa_table))

def extract_column(self, pa_table: pa.Table) -> list:
return pa_table.column(0).to_pylist()
if pa.types.is_struct(pa_table[pa_table.column_names[0]].type):
return extract_struct_array(pa_table[pa_table.column_names[0]])
# TODO: handle list of struct
else:
# should work for list of ArrayXD
return pa_table.column(0).to_pylist()

def extract_batch(self, pa_table: pa.Table) -> dict:
return pa_table.to_pydict()
batch = {}
for col in pa_table.column_names:
if pa.types.is_struct(pa_table[col].type):
batch[col] = extract_struct_array(pa_table[col])
else:
pa_array = pa_table[col]
if isinstance(pa_array.type, _ArrayXDExtensionType):
# don't call to_pylist() to preserve dtype of the fixed-size array
batch[col] = extract_array_xdextension_array(pa_array)
else:
batch[col] = pa_table[col].to_pylist()
return batch


class NumpyArrowExtractor(BaseArrowExtractor[dict, np.ndarray, dict]):
Expand Down
31 changes: 18 additions & 13 deletions tests/features/test_array_xd.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,20 @@ def get_dict_examples(self, shape_1, shape_2):
def _check_getitem_output_type(self, dataset, shape_1, shape_2, first_matrix):
matrix_column = dataset["matrix"]
self.assertIsInstance(matrix_column, list)
self.assertIsInstance(matrix_column[0], list)
self.assertIsInstance(matrix_column[0][0], list)
self.assertIsInstance(matrix_column[0], np.ndarray)
self.assertIsInstance(matrix_column[0][0], np.ndarray)
self.assertTupleEqual(np.array(matrix_column).shape, (2, *shape_2))

matrix_field_of_first_example = dataset[0]["matrix"]
self.assertIsInstance(matrix_field_of_first_example, list)
self.assertIsInstance(matrix_field_of_first_example, list)
self.assertIsInstance(matrix_field_of_first_example, np.ndarray)
self.assertIsInstance(matrix_field_of_first_example[0], np.ndarray)
self.assertEqual(np.array(matrix_field_of_first_example).shape, shape_2)
np.testing.assert_array_equal(np.array(matrix_field_of_first_example), np.array(first_matrix))

matrix_field_of_first_two_examples = dataset[:2]["matrix"]
self.assertIsInstance(matrix_field_of_first_two_examples, list)
self.assertIsInstance(matrix_field_of_first_two_examples[0], list)
self.assertIsInstance(matrix_field_of_first_two_examples[0][0], list)
self.assertIsInstance(matrix_field_of_first_two_examples[0], np.ndarray)
self.assertIsInstance(matrix_field_of_first_two_examples[0][0], np.ndarray)
self.assertTupleEqual(np.array(matrix_field_of_first_two_examples).shape, (2, *shape_2))

with dataset.formatted_as("numpy"):
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_to_pylist(self):
pylist = arr_xd.to_pylist()

for first_dim, single_arr in zip(first_dim_list, pylist):
self.assertIsInstance(single_arr, list)
self.assertIsInstance(single_arr, np.ndarray)
self.assertTupleEqual(np.array(single_arr).shape, (first_dim, *fixed_shape))

def test_to_numpy(self):
Expand Down Expand Up @@ -311,8 +311,8 @@ def test_iter_dataset(self):

for first_dim, ds_row in zip(first_dim_list, dataset):
single_arr = ds_row["image"]
self.assertIsInstance(single_arr, list)
self.assertTupleEqual(np.array(single_arr).shape, (first_dim, *fixed_shape))
self.assertIsInstance(single_arr, np.ndarray)
self.assertTupleEqual(single_arr.shape, (first_dim, *fixed_shape))

def test_to_pandas(self):
fixed_shape = (2, 2)
Expand Down Expand Up @@ -353,8 +353,8 @@ def test_map_dataset(self):
# check also if above function resulted with 2x bigger first dim
for first_dim, ds_row in zip(first_dim_list, dataset):
single_arr = ds_row["image"]
self.assertIsInstance(single_arr, list)
self.assertTupleEqual(np.array(single_arr).shape, (first_dim * 2, *fixed_shape))
self.assertIsInstance(single_arr, np.ndarray)
self.assertTupleEqual(single_arr.shape, (first_dim * 2, *fixed_shape))


@pytest.mark.parametrize("dtype, dummy_value", [("int32", 1), ("bool", True), ("float64", 1)])
Expand Down Expand Up @@ -419,7 +419,7 @@ def test_array_xd_with_none():
def test_array_xd_with_np(seq_type, dtype, shape, feature_class):
feature = feature_class(dtype=dtype, shape=shape)
data = np.zeros(shape, dtype=dtype)
expected = data.tolist()
expected = data
if seq_type == "sequence":
feature = datasets.Sequence(feature)
data = [data]
Expand All @@ -429,7 +429,12 @@ def test_array_xd_with_np(seq_type, dtype, shape, feature_class):
data = [[data]]
expected = [[expected]]
ds = datasets.Dataset.from_dict({"col": [data]}, features=datasets.Features({"col": feature}))
assert ds[0]["col"] == expected
if seq_type == "sequence":
assert (ds[0]["col"][0] == expected[0]).all()
elif seq_type == "sequence_of_sequence":
assert (ds[0]["col"][0][0] == expected[0][0]).all()
else:
assert (ds[0]["col"] == expected).all()


@pytest.mark.parametrize("with_none", [False, True])
Expand Down
Loading