Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Jul 11, 2024
1 parent 5e3982f commit ab92fb1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
30 changes: 22 additions & 8 deletions audformat/core/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,8 @@ def _load_csv(self, path: str):
than the method applied here.
We first load the CSV file as a :class:`pyarrow.Table`
and convert it to a dataframe afterwards.
If this fails,
we fall back to :func:`pandas.read_csv()`.
Args:
path: path to table, including file extension
Expand All @@ -905,20 +907,32 @@ def _load_csv(self, path: str):
# If pyarrow fails to parse the CSV file
# https://github.com/audeering/audformat/issues/449

# Replace dtype with converter for dates or timestamps
# Collect csv file columns and data types.
# index
columns_and_dtypes = self._levels_and_dtypes
# columns
for column_id, column in self.columns.items():
if column.scheme_id is not None:
columns_and_dtypes[column_id] = self.db.schemes[
column.scheme_id
].dtype
else:
columns_and_dtypes[column_id] = define.DataType.OBJECT

# Replace data type with converter for dates or timestamps
converters = {}
dtypes_wo_converters = {}
for level, dtype in self._levels_and_dtypes.items():
if dtype == "date":
converters[level] = lambda x: pd.to_datetime(x)
elif dtype == "time":
converters[level] = lambda x: pd.to_timedelta(x)
for column, dtype in columns_and_dtypes.items():
if dtype == define.DataType.DATE:
converters[column] = lambda x: pd.to_datetime(x)
elif dtype == define.DataType.TIME:
converters[column] = lambda x: pd.to_timedelta(x)
else:
dtypes_wo_converters[level] = to_pandas_dtype(dtype)
dtypes_wo_converters[column] = to_pandas_dtype(dtype)

df = pd.read_csv(
path,
usecols=levels + columns,
usecols=list(columns_and_dtypes.keys()),
dtype=dtypes_wo_converters,
index_col=levels,
converters=converters,
Expand Down
39 changes: 30 additions & 9 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,9 +1144,9 @@ def test_load(tmpdir):


def test_load_broken_csv(tmpdir):
r"""Test loading of malformed CSV files.
r"""Test loading of malformed csv files.
If csv files contain a lot of special character,
If csv files contain a lot of special characters,
or a different number of columns,
than specified in the database header,
loading of them should not fail.
Expand All @@ -1160,18 +1160,39 @@ def test_load_broken_csv(tmpdir):
build_dir = audeer.mkdir(tmpdir, "build")

# Create database with single table and column
#
# Ensure:
#
# * it contains an empty table
# * the columns use schemes with time and date data types
# * at least one column has no scheme
#
# as those cases needed special care with csv files
#
db = audformat.Database("mydb")
db.schemes["date"] = audformat.Scheme("date")
db.schemes["time"] = audformat.Scheme("time")
db["table"] = audformat.Table(audformat.filewise_index("file.wav"))
db["table"]["column"] = audformat.Column()
db["table"]["column"].set(["label"])

# Add another column to dataframe,
# without adding a column to the header
db["table"].df["hidden-column"] = ["hidden-label"]
db["table"]["date"] = audformat.Column(scheme_id="date")
db["table"]["date"].set([pd.to_datetime("2018-10-26")])
db["table"]["time"] = audformat.Column(scheme_id="time")
db["table"]["time"].set([pd.Timedelta(1)])
db["table"]["no-scheme"] = audformat.Column()
db["table"]["no-scheme"].set(["label"])
db["empty-table"] = audformat.Table(audformat.filewise_index())
db["empty-table"]["column"] = audformat.Column()

# Add a hidden column to the table dataframes,
# without adding it to the table header
db["table"].df["hidden"] = ["hidden"]
db["empty-table"].df["hidden"] = []

db.save(build_dir, storage_format="csv")
db_loaded = audformat.Database.load(build_dir, load_data=True)
assert "hidden-column" not in db_loaded["table"].df
assert "table" in db_loaded
assert "empty-table" in db_loaded
assert "hidden" not in db_loaded["table"].df
assert "hidden-column" not in db_loaded["empty-table"].df


def test_load_old_pickle(tmpdir):
Expand Down

0 comments on commit ab92fb1

Please sign in to comment.