diff --git a/setup.py b/setup.py index 161c670..90ec878 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,9 @@ 'boto3==1.9.57', 'singer-encodings==0.1.2', 'singer-python==5.12.1', - 'voluptuous==0.10.5' + 'voluptuous==0.10.5', + 'pyarrow==6.0.1', + 'python-dateutil', ], extras_require={ 'dev': [ diff --git a/tap_s3_csv/__init__.py b/tap_s3_csv/__init__.py index 66335c8..1c6a7ea 100644 --- a/tap_s3_csv/__init__.py +++ b/tap_s3_csv/__init__.py @@ -10,7 +10,7 @@ LOGGER = singer.get_logger() -REQUIRED_CONFIG_KEYS = ["start_date", "bucket", "account_id", "external_id", "role_name"] +REQUIRED_CONFIG_KEYS = ["start_date", "bucket"] def do_discover(config): @@ -75,7 +75,7 @@ def main(): config['tables'] = validate_table_config(config) try: - for page in s3.list_files_in_bucket(config['bucket']): + for page in s3.list_files_in_bucket(config['bucket'], config=config): break LOGGER.warning("I have direct access to the bucket without assuming the configured role.") except: diff --git a/tap_s3_csv/s3.py b/tap_s3_csv/s3.py index 5cdd919..8bd9d55 100644 --- a/tap_s3_csv/s3.py +++ b/tap_s3_csv/s3.py @@ -6,7 +6,13 @@ import backoff import boto3 import singer +import tempfile +import pathlib +import os +import pyarrow.parquet as pq +import pytz +from dateutil.parser import parse from botocore.credentials import ( AssumeRoleCredentialFetcher, CredentialResolver, @@ -60,35 +66,20 @@ def load(self): @retry_pattern() def setup_aws_client(config): - role_arn = "arn:aws:iam::{}:role/{}".format(config['account_id'].replace('-', ''), - config['role_name']) - session = Session() - fetcher = AssumeRoleCredentialFetcher( - session.create_client, - session.get_credentials(), - role_arn, - extra_args={ - 'DurationSeconds': 3600, - 'RoleSessionName': 'TapS3CSV', - 'ExternalId': config['external_id'] - }, - cache=JSONFileCache() - ) - - refreshable_session = Session() - refreshable_session.register_component( - 'credential_provider', - CredentialResolver([AssumeRoleProvider(fetcher)]) - ) - - LOGGER.info("Attempting to assume_role on RoleArn: %s", role_arn) - boto3.setup_default_session(botocore_session=refreshable_session) + key = config.get('aws_access_key_id', os.environ.get("aws_access_key_id")) + secret = config.get('aws_secret_access_key', os.environ.get("aws_secret_access_key")) + + return boto3.session.Session(aws_access_key_id=key, aws_secret_access_key=secret) def get_sampled_schema_for_table(config, table_spec): LOGGER.info('Sampling records to determine table schema.') - s3_files_gen = get_input_files_for_table(config, table_spec) + s3_files_gen = get_input_files_for_table( + config, + table_spec, + modified_since=parse(config.get('start_date')).replace(tzinfo=pytz.UTC), + modified_until=config.get('end_date')) samples = [sample for sample in sample_files(config, table_spec, s3_files_gen)] @@ -188,27 +179,63 @@ def get_records_for_jsonl(s3_path, sample_rate, iterator): LOGGER.info("Sampled %s rows from %s", sampled_row_count, s3_path) +def get_records_for_parquet(s3_bucket, s3_path, sample_rate, config): + + local_path = os.path.join(tempfile.gettempdir(), s3_path) + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + if os.path.isfile(local_path): + LOGGER.info(f"Skipping download, file exists: {local_path}") + else: + LOGGER.info(f"Downloading {s3_path} to {local_path}") + session = setup_aws_client(config) + session.resource("s3").Bucket(s3_bucket).download_file(s3_path, local_path) + + parquet_file = pq.ParquetFile(local_path) + + current_row = 0 + sampled_row_count = 0 + + for i in range(parquet_file.num_row_groups): + table = parquet_file.read_row_group(i) + for batch in table.to_batches(): + for row in zip(*batch.columns): + if (current_row % sample_rate) == 0: + current_row += 1 + sampled_row_count += 1 + yield { + table.column_names[i]: val.as_py() + for i, val in enumerate(row, start=0) + } + + LOGGER.info("Sampled %s rows from %s", sampled_row_count, s3_path) + + def check_key_properties_and_date_overrides_for_jsonl_file(table_spec, jsonl_sample_records, s3_path): + rows = 0 all_keys = set() for record in jsonl_sample_records: + rows += 1 keys = record.keys() all_keys.update(keys) + if rows > 5000: + break if table_spec.get('key_properties'): key_properties = set(table_spec['key_properties']) if not key_properties.issubset(all_keys): - raise Exception('JSONL file "{}" is missing required key_properties key: {}' + raise Exception('JSONL/parquet file "{}" is missing required key_properties key: {}' .format(s3_path, key_properties - all_keys)) if table_spec.get('date_overrides'): date_overrides = set(table_spec['date_overrides']) if not date_overrides.issubset(all_keys): - raise Exception('JSONL file "{}" is missing date_overrides key: {}' + raise Exception('JSONL/parquet file "{}" is missing date_overrides key: {}' .format(s3_path, date_overrides - all_keys)) #pylint: disable=global-statement -def sampling_gz_file(table_spec, s3_path, file_handle, sample_rate): +def sampling_gz_file(table_spec, s3_path, file_handle, sample_rate, config): global skipped_files_count if s3_path.endswith(".tar.gz"): LOGGER.warning('Skipping "%s" file as .tar.gz extension is not supported',s3_path) @@ -235,12 +262,19 @@ def sampling_gz_file(table_spec, s3_path, file_handle, sample_rate): return [] gz_file_extension = gz_file_name.split(".")[-1].lower() - return sample_file(table_spec, s3_path + "/" + gz_file_name, io.BytesIO(gz_file_obj.read()), sample_rate, gz_file_extension) + return sample_file(table_spec, s3_path + "/" + gz_file_name, io.BytesIO(gz_file_obj.read()), sample_rate, gz_file_extension, config) raise Exception('"{}" file has some error(s)'.format(s3_path)) +def peek(iterable): + try: + first = next(iterable) + except StopIteration: + return None + return first, itertools.chain([first], iterable) + #pylint: disable=global-statement -def sample_file(table_spec, s3_path, file_handle, sample_rate, extension): +def sample_file(table_spec, s3_bucket, s3_path, file_handle, sample_rate, extension, config): global skipped_files_count # Check whether file is without extension or not @@ -260,21 +294,36 @@ def sample_file(table_spec, s3_path, file_handle, sample_rate, extension): skipped_files_count = skipped_files_count + 1 return csv_records if extension == "gz": - return sampling_gz_file(table_spec, s3_path, file_handle, sample_rate) + return sampling_gz_file(table_spec, s3_path, file_handle, sample_rate, config) if extension == "jsonl": # If file object read from s3 bucket file else use extracted file object from zip or gz + file_handle = file_handle._raw_stream if hasattr(file_handle, "_raw_stream") else file_handle - records = get_records_for_jsonl( - s3_path, sample_rate, file_handle) - check_jsonl_sample_records, records = itertools.tee( - records) - jsonl_sample_records = list(check_jsonl_sample_records) - if len(jsonl_sample_records) == 0: + records = get_records_for_jsonl(s3_path, sample_rate, file_handle) + check_jsonl_sample_records, records = itertools.tee(records) + + result = peek(check_jsonl_sample_records) + if result is None: LOGGER.warning('Skipping "%s" file as it is empty', s3_path) skipped_files_count = skipped_files_count + 1 - check_key_properties_and_date_overrides_for_jsonl_file( - table_spec, jsonl_sample_records, s3_path) - + return [] + else: + check_jsonl_sample_records = result[1] + check_key_properties_and_date_overrides_for_jsonl_file(table_spec, check_jsonl_sample_records, s3_path) + return records + if extension == "parquet": + records = get_records_for_parquet(s3_bucket, s3_path, sample_rate, config) + check_jsonl_sample_records, records = itertools.tee(records) + + result = peek(check_jsonl_sample_records) + if result is None: + LOGGER.warning('Skipping "%s" file as it is empty', s3_path) + skipped_files_count = skipped_files_count + 1 + return [] + else: + check_jsonl_sample_records = result[1] + check_key_properties_and_date_overrides_for_jsonl_file(table_spec, check_jsonl_sample_records, s3_path) + return records if extension == "zip": LOGGER.warning('Skipping "%s" file as it contains nested compression.',s3_path) @@ -303,7 +352,7 @@ def get_files_to_sample(config, s3_files, max_files): global skipped_files_count sampled_files = [] - OTHER_FILES = ["csv","gz","jsonl","txt"] + OTHER_FILES = ["csv","gz","jsonl","txt","parquet"] for s3_file in s3_files: file_key = s3_file.get('key') @@ -343,11 +392,12 @@ def get_files_to_sample(config, s3_files, max_files): def sample_files(config, table_spec, s3_files, sample_rate=5, max_records=1000, max_files=5): global skipped_files_count + max_files = config.get("max_sample_files", max_files) LOGGER.info("Sampling files (max files: %s)", max_files) for s3_file in itertools.islice(get_files_to_sample(config, s3_files, max_files), max_files): - + s3_bucket = config['bucket'] s3_path = s3_file.get("s3_path","") file_handle = s3_file.get("file_handle") file_type = s3_file.get("type") @@ -364,7 +414,7 @@ def sample_files(config, table_spec, s3_files, max_records, sample_rate) try: - yield from itertools.islice(sample_file(table_spec, s3_path, file_handle, sample_rate, extension), max_records) + yield from itertools.islice(sample_file(table_spec, s3_bucket, s3_path, file_handle, sample_rate, extension, config), max_records) except (UnicodeDecodeError,json.decoder.JSONDecodeError): # UnicodeDecodeError will be raised if non csv file parsed to csv parser # JSONDecodeError will be reaised if non JSONL file parsed to JSON parser @@ -373,7 +423,7 @@ def sample_files(config, table_spec, s3_files, skipped_files_count = skipped_files_count + 1 #pylint: disable=global-statement -def get_input_files_for_table(config, table_spec, modified_since=None): +def get_input_files_for_table(config, table_spec, modified_since=None, modified_until=None): global skipped_files_count bucket = config['bucket'] @@ -389,13 +439,13 @@ def get_input_files_for_table(config, table_spec, modified_since=None): "https://docs.python.org/3.5/library/re.html#regular-expression-syntax").format(table_spec['table_name']), pattern) from e - LOGGER.info( - 'Checking bucket "%s" for keys matching "%s"', bucket, pattern) + LOGGER.info('Checking bucket "%s" for keys matching "%s"', bucket, pattern) + LOGGER.info('Window period: since %s until %s',modified_since,modified_until) matched_files_count = 0 unmatched_files_count = 0 max_files_before_log = 30000 - for s3_object in list_files_in_bucket(bucket, table_spec.get('search_prefix')): + for s3_object in list_files_in_bucket(bucket, table_spec.get('search_prefix'), config=config): key = s3_object['Key'] last_modified = s3_object['LastModified'] @@ -408,10 +458,9 @@ def get_input_files_for_table(config, table_spec, modified_since=None): if matcher.search(key): matched_files_count += 1 if modified_since is None or modified_since < last_modified: - LOGGER.info('Will download key "%s" as it was last modified %s', - key, - last_modified) - yield {'key': key, 'last_modified': last_modified} + if modified_until is None or last_modified.replace(tzinfo=pytz.UTC) < parse(modified_until).replace(tzinfo=pytz.UTC): + LOGGER.info('Will download key "%s" as it was last modified %s',key,last_modified) + yield {'key': key, 'last_modified': last_modified} else: unmatched_files_count += 1 @@ -431,8 +480,8 @@ def get_input_files_for_table(config, table_spec, modified_since=None): @retry_pattern() -def list_files_in_bucket(bucket, search_prefix=None): - s3_client = boto3.client('s3') +def list_files_in_bucket(bucket, search_prefix=None, config=None): + s3_client = setup_aws_client(config).client('s3') s3_object_count = 0 @@ -462,7 +511,7 @@ def list_files_in_bucket(bucket, search_prefix=None): @retry_pattern() def get_file_handle(config, s3_path): bucket = config['bucket'] - s3_client = boto3.resource('s3') + s3_client = setup_aws_client(config).resource('s3') s3_bucket = s3_client.Bucket(bucket) s3_object = s3_bucket.Object(s3_path) diff --git a/tap_s3_csv/sync.py b/tap_s3_csv/sync.py index 2b9a6f8..7d51ac9 100644 --- a/tap_s3_csv/sync.py +++ b/tap_s3_csv/sync.py @@ -3,6 +3,11 @@ import io import json import gzip +import os +import tempfile +import pathlib +import boto3 +import pyarrow.parquet as pq from singer import metadata from singer import Transformer @@ -29,7 +34,7 @@ def sync_stream(config, state, table_spec, stream): LOGGER.info('Getting files modified since %s.', modified_since) s3_files = s3.get_input_files_for_table( - config, table_spec, modified_since) + config, table_spec, modified_since, modified_until=config.get('end_date')) records_streamed = 0 @@ -64,9 +69,10 @@ def sync_table_file(config, s3_path, table_spec, stream): try: if extension == "zip": return sync_compressed_file(config, s3_path, table_spec, stream) - if extension in ["csv", "gz", "jsonl", "txt"]: + if extension in ["csv", "gz", "jsonl", "txt", "parquet"]: return handle_file(config, s3_path, table_spec, stream, extension) LOGGER.warning('"%s" having the ".%s" extension will not be synced.',s3_path,extension) + raise Exception(f"Extension {extension} not supported.") except (UnicodeDecodeError,json.decoder.JSONDecodeError): # UnicodeDecodeError will be raised if non csv file passed to csv parser # JSONDecodeError will be raised if non JSONL file passed to JSON parser @@ -107,6 +113,16 @@ def handle_file(config, s3_path, table_spec, stream, extension, file_handler = N LOGGER.warning('Skipping "%s" file as it is empty', s3_path) return records + if extension == "parquet": + + # If file is extracted from zip or gz use file object else get file object from s3 bucket + records = sync_parquet_file(config, None, s3_path, table_spec, stream) + if records == 0: + # Only space isn't the valid JSON but it is a valid CSV header hence skipping the jsonl file with only space. + s3.skipped_files_count = s3.skipped_files_count + 1 + LOGGER.warning('Skipping "%s" file as it is empty', s3_path) + return records + if extension == "zip": LOGGER.warning('Skipping "%s" file as it contains nested compression.',s3_path) s3.skipped_files_count = s3.skipped_files_count + 1 @@ -164,7 +180,7 @@ def sync_compressed_file(config, s3_path, table_spec, stream): for decompressed_file in decompressed_files: extension = decompressed_file.name.split(".")[-1].lower() - if extension in ["csv", "jsonl", "gz", "txt"]: + if extension in ["csv", "jsonl", "gz", "txt", "parquet"]: # Append the extracted file name with zip file. s3_file_path = s3_path + "/" + decompressed_file.name @@ -275,3 +291,67 @@ def sync_jsonl_file(config, iterator, s3_path, table_spec, stream): records_synced += 1 return records_synced + + +def sync_parquet_file(config, iterator, s3_path, table_spec, stream): + LOGGER.info('Syncing file "%s".', s3_path) + + bucket = config['bucket'] + table_name = table_spec['table_name'] + + local_path = os.path.join(tempfile.gettempdir(), s3_path) + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + if os.path.isfile(local_path): + LOGGER.info(f"Skipping download, file exists: {local_path}") + else: + LOGGER.info(f"Downloading {s3_path} to {local_path}") + session = s3.setup_aws_client(config) + session.resource("s3").Bucket(config["bucket"]).download_file(s3_path, local_path) + + parquet_file = pq.ParquetFile(local_path) + records_synced = 0 + + for i in range(parquet_file.num_row_groups): + table = parquet_file.read_row_group(i) + for batch in table.to_batches(): + for row in zip(*batch.columns): + custom_columns = { + s3.SDC_SOURCE_BUCKET_COLUMN: bucket, + s3.SDC_SOURCE_FILE_COLUMN: s3_path, + + # index zero and then starting from 1 + s3.SDC_SOURCE_LINENO_COLUMN: records_synced + 1 + } + raw_rec = { + table.column_names[i]: val.as_py() + for i, val in enumerate(row, start=0) + } + rec = {**raw_rec, **custom_columns} + + with Transformer() as transformer: + to_write = transformer.transform(rec, stream['schema'], metadata.to_map(stream['metadata'])) + # collecting the value which was removed in transform to add those in _sdc_extra + value = [ {field:rec[field]} for field in set(rec) - set(to_write) ] + + if value: + LOGGER.warning( + "\"%s\" is not found in catalog and its value will be stored in the \"_sdc_extra\" field.", value) + extra_data = { + s3.SDC_EXTRA_COLUMN: value + } + update_to_write = {**to_write,**extra_data} + else: + update_to_write = to_write + + # Transform again to validate _sdc_extra value. + with Transformer() as transformer: + update_to_write = transformer.transform(update_to_write, stream['schema'], metadata.to_map(stream['metadata'])) + + singer.write_record(table_name, update_to_write) + records_synced += 1 + + LOGGER.info(f"Cleaning file: {local_path}") + os.remove(local_path) + + return records_synced