Skip to content

Commit

Permalink
extract lists of structs of arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hh committed Oct 15, 2024
1 parent a89ef52 commit 550d2f0
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ def extract_struct_array(pa_array: pa.StructArray) -> list:
for field in pa_array.type:
if pa.types.is_struct(pa_array.field(field.name).type):
batch[field.name] = extract_struct_array(pa_array.field(field.name))
elif pa.types.is_list(pa_array.field(field.name).type):
if pa.types.is_struct(pa_array.field(field.name).type.value_type):
list_array = pa_array.field(field.name)
batch[field.name] = [extract_struct_array(list_array[i]) for i in range(list_array.length)]
else:
batch[field.name] = pa_array.field(field.name).to_pylist()
else:
batch[field.name] = pa_array.field(field.name).to_pylist()
return dict_of_lists_to_list_of_dicts(batch)
Expand All @@ -172,6 +178,9 @@ def extract_row(self, pa_table: pa.Table) -> dict:
def extract_column(self, pa_table: pa.Table) -> list:
if pa.types.is_struct(pa_table[pa_table.column_names[0]].type):
return extract_struct_array(pa_table[pa_table.column_names[0]])
elif pa.types.is_list(pa_table[pa_table.column_names[0]].type):
list_array = pa_table[pa_table.column_names[0]]
return [extract_struct_array(list_array[i]) for i in range(list_array.length)]
else:
return pa_table.column(0).to_pylist()

Expand Down Expand Up @@ -203,12 +212,16 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray:
if isinstance(pa_array.type, _ArrayXDExtensionType):
# don't call to_pylist() to preserve dtype of the fixed-size array
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
array: List = [
row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)
]
else:
zero_copy_only = _is_zero_copy_only(pa_array.type) and all(
not _is_array_with_nulls(chunk) for chunk in pa_array.chunks
)
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
array: List = [
row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)
]
else:
if isinstance(pa_array.type, _ArrayXDExtensionType):
# don't call to_pylist() to preserve dtype of the fixed-size array
Expand Down

0 comments on commit 550d2f0

Please sign in to comment.