diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 6400d7e901c..2f46a30b608 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -160,6 +160,12 @@ def extract_struct_array(pa_array: pa.StructArray) -> list: for field in pa_array.type: if pa.types.is_struct(pa_array.field(field.name).type): batch[field.name] = extract_struct_array(pa_array.field(field.name)) + elif pa.types.is_list(pa_array.field(field.name).type): + if pa.types.is_struct(pa_array.field(field.name).type.value_type): + list_array = pa_array.field(field.name) + batch[field.name] = [extract_struct_array(list_array[i]) for i in range(list_array.length)] + else: + batch[field.name] = pa_array.field(field.name).to_pylist() else: batch[field.name] = pa_array.field(field.name).to_pylist() return dict_of_lists_to_list_of_dicts(batch) @@ -172,6 +178,9 @@ def extract_row(self, pa_table: pa.Table) -> dict: def extract_column(self, pa_table: pa.Table) -> list: if pa.types.is_struct(pa_table[pa_table.column_names[0]].type): return extract_struct_array(pa_table[pa_table.column_names[0]]) + elif pa.types.is_list(pa_table[pa_table.column_names[0]].type): + list_array = pa_table[pa_table.column_names[0]] + return [extract_struct_array(list_array[i]) for i in range(list_array.length)] else: return pa_table.column(0).to_pylist() @@ -203,12 +212,16 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray: if isinstance(pa_array.type, _ArrayXDExtensionType): # don't call to_pylist() to preserve dtype of the fixed-size array zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) - array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + array: List = [ + row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) + ] else: zero_copy_only = _is_zero_copy_only(pa_array.type) and all( not _is_array_with_nulls(chunk) for chunk in pa_array.chunks ) - array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)] + array: List = [ + row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only) + ] else: if isinstance(pa_array.type, _ArrayXDExtensionType): # don't call to_pylist() to preserve dtype of the fixed-size array