Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load csv tables with pandas if pyarrow fails #450

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions audformat/core/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from audformat.core.column import Column
from audformat.core.common import HeaderBase
from audformat.core.common import HeaderDict
from audformat.core.common import to_pandas_dtype
from audformat.core.errors import BadIdError
from audformat.core.index import filewise_index
from audformat.core.index import index_type
Expand Down Expand Up @@ -880,25 +881,63 @@ def _load_csv(self, path: str):
than the method applied here.
We first load the CSV file as a :class:`pyarrow.Table`
and convert it to a dataframe afterwards.
If this fails,
we fall back to :func:`pandas.read_csv()`.

Args:
path: path to table, including file extension

"""
levels = list(self._levels_and_dtypes.keys())
columns = list(self.columns.keys())
table = csv.read_csv(
path,
read_options=csv.ReadOptions(
column_names=levels + columns,
skip_rows=1,
),
convert_options=csv.ConvertOptions(
column_types=self._pyarrow_csv_schema(),
strings_can_be_null=True,
),
)
df = self._pyarrow_table_to_dataframe(table, from_csv=True)
try:
table = csv.read_csv(
path,
read_options=csv.ReadOptions(
column_names=levels + columns,
skip_rows=1,
),
convert_options=csv.ConvertOptions(
column_types=self._pyarrow_csv_schema(),
strings_can_be_null=True,
),
)
df = self._pyarrow_table_to_dataframe(table, from_csv=True)
except pa.lib.ArrowInvalid:
# If pyarrow fails to parse the CSV file
# https://github.com/audeering/audformat/issues/449

# Collect csv file columns and data types.
# index
columns_and_dtypes = self._levels_and_dtypes
# columns
for column_id, column in self.columns.items():
if column.scheme_id is not None:
columns_and_dtypes[column_id] = self.db.schemes[
column.scheme_id
].dtype
else:
columns_and_dtypes[column_id] = define.DataType.OBJECT

# Replace data type with converter for dates or timestamps
converters = {}
dtypes_wo_converters = {}
for column, dtype in columns_and_dtypes.items():
if dtype == define.DataType.DATE:
converters[column] = lambda x: pd.to_datetime(x)
elif dtype == define.DataType.TIME:
converters[column] = lambda x: pd.to_timedelta(x)
else:
dtypes_wo_converters[column] = to_pandas_dtype(dtype)

df = pd.read_csv(
path,
usecols=list(columns_and_dtypes.keys()),
dtype=dtypes_wo_converters,
index_col=levels,
converters=converters,
float_precision="round_trip",
)

self._df = df

Expand Down
73 changes: 73 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,79 @@ def test_load(tmpdir):
os.remove(f"{path_no_ext}.{ext}")


class TestLoadBrokenCsv:
r"""Test loading of malformed csv files.

If csv files contain a lot of special characters,
or a different number of columns,
than specified in the database header,
loading of them should not fail.

See https://github.com/audeering/audformat/issues/449

"""

def database_with_hidden_columns(self) -> audformat.Database:
r"""Database with hidden columns.

Create database with hidden columns
that are stored in csv,
but not in the header of the table.

Ensure:

* it contains an empty table
* the columns use schemes with time and date data types
* at least one column has no scheme

as those cases needed special care with csv files,
before switching to use pyarrow.csv.read_csv()
in https://github.com/audeering/audformat/pull/419.

Returns:
database

"""
db = audformat.Database("mydb")
db.schemes["date"] = audformat.Scheme("date")
db.schemes["time"] = audformat.Scheme("time")
db["table"] = audformat.Table(audformat.filewise_index("file.wav"))
db["table"]["date"] = audformat.Column(scheme_id="date")
db["table"]["date"].set([pd.to_datetime("2018-10-26")])
db["table"]["time"] = audformat.Column(scheme_id="time")
db["table"]["time"].set([pd.Timedelta(1)])
db["table"]["no-scheme"] = audformat.Column()
db["table"]["no-scheme"].set(["label"])
db["empty-table"] = audformat.Table(audformat.filewise_index())
db["empty-table"]["column"] = audformat.Column()
# Add a hidden column to the table dataframes,
# without adding it to the table header
db["table"].df["hidden"] = ["hidden"]
db["empty-table"].df["hidden"] = []
return db

def test_load_broken_csv(self, tmpdir):
r"""Test loading a database table from broken csv files.

Broken csv files
refer to csv tables,
that raise an error
when loading with ``pyarrow.csv.read_csv()``.

Args:
tmpdir: tmpdir fixture

"""
db = self.database_with_hidden_columns()
build_dir = audeer.mkdir(tmpdir, "build")
db.save(build_dir, storage_format="csv")
db_loaded = audformat.Database.load(build_dir, load_data=True)
assert "table" in db_loaded
assert "empty-table" in db_loaded
assert "hidden" not in db_loaded["table"].df
assert "hidden-column" not in db_loaded["empty-table"].df


def test_load_old_pickle(tmpdir):
# We have stored string dtype as object dtype before
# and have to fix this when loading old PKL files from cache.
Expand Down