Skip to content

Commit

Permalink
staticmethods to _pyarrow_write_attrs/_pyarrow_read_attrs|add docstri…
Browse files Browse the repository at this point in the history
…ng/types
  • Loading branch information
snowman2 committed Jun 2, 2021
1 parent 1e2fa66 commit 0221b26
Showing 1 changed file with 40 additions and 28 deletions.
68 changes: 40 additions & 28 deletions pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,44 @@ def read(self, path, columns=None, **kwargs):
raise AbstractMethodError(self)


def _pyarrow_write_attrs(table: Any, df: DataFrame) -> Any:
"""
.. versionadded:: 1.3
Copy attts from pandas.DataFrame and pandas.Series to
schema metadata in pyarrow.Table.
"""
schema_metadata = table.schema.metadata or {}
pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}"))
column_attrs = {}
for col in df.columns:
attrs = df[col].attrs
if not attrs or not isinstance(col, str):
continue
column_attrs[col] = attrs
pandas_metadata.update(
attrs=df.attrs,
column_attrs=column_attrs,
)
schema_metadata[b"pandas"] = json.dumps(pandas_metadata)
return table.replace_schema_metadata(schema_metadata)


def _pyarrow_read_attrs(table: Any, df: DataFrame) -> None:
"""
.. versionadded:: 1.3
Copy schema metadata from pyarrow.Table
to attrs in pandas.DataFrame and pandas.Series.
"""
schema_metadata = table.schema.metadata or {}
pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}"))
df.attrs = pandas_metadata.get("attrs", {})
col_attrs = pandas_metadata.get("column_attrs", {})
for col in df.columns:
df[col].attrs = col_attrs.get(col, {})


class PyArrowImpl(BaseImpl):
def __init__(self):
import_optional_dependency(
Expand All @@ -155,32 +193,6 @@ def __init__(self):

self.api = pyarrow

@staticmethod
def _write_attrs(table, df: DataFrame):
schema_metadata = table.schema.metadata or {}
pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}"))
column_attrs = {}
for col in df.columns:
attrs = df[col].attrs
if not attrs or not isinstance(col, str):
continue
column_attrs[col] = attrs
pandas_metadata.update(
attrs=df.attrs,
column_attrs=column_attrs,
)
schema_metadata[b"pandas"] = json.dumps(pandas_metadata)
return table.replace_schema_metadata(schema_metadata)

@staticmethod
def _read_attrs(table, df: DataFrame):
schema_metadata = table.schema.metadata or {}
pandas_metadata = json.loads(schema_metadata.get(b"pandas", "{}"))
df.attrs = pandas_metadata.get("attrs", {})
col_attrs = pandas_metadata.get("column_attrs", {})
for col in df.columns:
df[col].attrs = col_attrs.get(col, {})

def write(
self,
df: DataFrame,
Expand All @@ -198,7 +210,7 @@ def write(
from_pandas_kwargs["preserve_index"] = index

table = self.api.Table.from_pandas(df, **from_pandas_kwargs)
table = self._write_attrs(table, df)
table = _pyarrow_write_attrs(table, df)

path_or_handle, handles, kwargs["filesystem"] = _get_path_or_handle(
path,
Expand Down Expand Up @@ -268,7 +280,7 @@ def read(
path_or_handle, columns=columns, **kwargs
)
result = table.to_pandas(**to_pandas_kwargs)
self._read_attrs(table, result)
_pyarrow_read_attrs(table, result)
if manager == "array":
result = result._as_manager("array", copy=False)
return result
Expand Down

0 comments on commit 0221b26

Please sign in to comment.