Skip to content

Commit

Permalink
Enable feather and parquet in S3 (#361)
Browse files Browse the repository at this point in the history
* fix: adapt with feather and parquet in S3
* docs: add fixme comment

---------

Co-authored-by: Mamoru Miura <[email protected]>
  • Loading branch information
mamo3gr and Mamoru Miura authored Apr 3, 2024
1 parent 837742d commit 624200b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
24 changes: 19 additions & 5 deletions gokart/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import xml.etree.ElementTree as ET
from abc import abstractmethod
from io import BytesIO
from logging import getLogger

import luigi
Expand Down Expand Up @@ -203,11 +204,17 @@ def __init__(self, engine='pyarrow', compression=None):
super(ParquetFileProcessor, self).__init__()

def format(self):
return None
return luigi.format.Nop

def load(self, file):
# MEMO: read_parquet only supports a filepath as string (not a file handle)
return pd.read_parquet(file.name)
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_parquet accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
return pd.read_parquet(file.name)
else:
return pd.read_parquet(BytesIO(file.read()))

def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
Expand All @@ -222,10 +229,17 @@ def __init__(self, store_index_in_feather: bool):
self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__'

def format(self):
return None
return luigi.format.Nop

def load(self, file):
loaded_df = pd.read_feather(file.name)
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_feather accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
loaded_df = pd.read_feather(file.name)
else:
loaded_df = pd.read_feather(BytesIO(file.read()))

if self._store_index_in_feather:
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
Expand Down
28 changes: 28 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,34 @@ def test_last_modified_time_without_file(self):
with self.assertRaises(FileNotFoundError):
target.last_modification_time()

@mock_s3
def test_save_on_s3_feather(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')

obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join('s3://test/', 'test.feather')

target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()

pd.testing.assert_frame_equal(loaded, obj)

@mock_s3
def test_save_on_s3_parquet(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')

obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join('s3://test/', 'test.parquet')

target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()

pd.testing.assert_frame_equal(loaded, obj)


class ModelTargetTest(unittest.TestCase):
def tearDown(self):
Expand Down

0 comments on commit 624200b

Please sign in to comment.