Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nc/typo fix #59

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
'boto3==1.9.57',
'singer-encodings==0.1.2',
'singer-python==5.12.1',
'voluptuous==0.10.5'
'voluptuous==0.10.5',
'pyarrow==6.0.1',
'python-dateutil',
],
extras_require={
'dev': [
Expand Down
4 changes: 2 additions & 2 deletions tap_s3_csv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

LOGGER = singer.get_logger()

REQUIRED_CONFIG_KEYS = ["start_date", "bucket", "account_id", "external_id", "role_name"]
REQUIRED_CONFIG_KEYS = ["start_date", "bucket"]


def do_discover(config):
Expand Down Expand Up @@ -75,7 +75,7 @@ def main():
config['tables'] = validate_table_config(config)

try:
for page in s3.list_files_in_bucket(config['bucket']):
for page in s3.list_files_in_bucket(config['bucket'], config=config):
break
LOGGER.warning("I have direct access to the bucket without assuming the configured role.")
except:
Expand Down
155 changes: 102 additions & 53 deletions tap_s3_csv/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import backoff
import boto3
import singer
import tempfile
import pathlib
import os
import pyarrow.parquet as pq
import pytz

from dateutil.parser import parse
from botocore.credentials import (
AssumeRoleCredentialFetcher,
CredentialResolver,
Expand Down Expand Up @@ -60,35 +66,20 @@ def load(self):

@retry_pattern()
def setup_aws_client(config):
role_arn = "arn:aws:iam::{}:role/{}".format(config['account_id'].replace('-', ''),
config['role_name'])
session = Session()
fetcher = AssumeRoleCredentialFetcher(
session.create_client,
session.get_credentials(),
role_arn,
extra_args={
'DurationSeconds': 3600,
'RoleSessionName': 'TapS3CSV',
'ExternalId': config['external_id']
},
cache=JSONFileCache()
)

refreshable_session = Session()
refreshable_session.register_component(
'credential_provider',
CredentialResolver([AssumeRoleProvider(fetcher)])
)

LOGGER.info("Attempting to assume_role on RoleArn: %s", role_arn)
boto3.setup_default_session(botocore_session=refreshable_session)
key = config.get('aws_access_key_id', os.environ.get("aws_access_key_id"))
secret = config.get('aws_secret_access_key', os.environ.get("aws_secret_access_key"))

return boto3.session.Session(aws_access_key_id=key, aws_secret_access_key=secret)


def get_sampled_schema_for_table(config, table_spec):
LOGGER.info('Sampling records to determine table schema.')

s3_files_gen = get_input_files_for_table(config, table_spec)
s3_files_gen = get_input_files_for_table(
config,
table_spec,
modified_since=parse(config.get('start_date')).replace(tzinfo=pytz.UTC),
modified_until=config.get('end_date'))

samples = [sample for sample in sample_files(config, table_spec, s3_files_gen)]

Expand Down Expand Up @@ -188,27 +179,63 @@ def get_records_for_jsonl(s3_path, sample_rate, iterator):
LOGGER.info("Sampled %s rows from %s", sampled_row_count, s3_path)


def get_records_for_parquet(s3_bucket, s3_path, sample_rate, config):

local_path = os.path.join(tempfile.gettempdir(), s3_path)
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

if os.path.isfile(local_path):
LOGGER.info(f"Skipping download, file exists: {local_path}")
else:
LOGGER.info(f"Downloading {s3_path} to {local_path}")
session = setup_aws_client(config)
session.resource("s3").Bucket(s3_bucket).download_file(s3_path, local_path)

parquet_file = pq.ParquetFile(local_path)

current_row = 0
sampled_row_count = 0

for i in range(parquet_file.num_row_groups):
table = parquet_file.read_row_group(i)
for batch in table.to_batches():
for row in zip(*batch.columns):
if (current_row % sample_rate) == 0:
current_row += 1
sampled_row_count += 1
yield {
table.column_names[i]: val.as_py()
for i, val in enumerate(row, start=0)
}

LOGGER.info("Sampled %s rows from %s", sampled_row_count, s3_path)


def check_key_properties_and_date_overrides_for_jsonl_file(table_spec, jsonl_sample_records, s3_path):

rows = 0
all_keys = set()
for record in jsonl_sample_records:
rows += 1
keys = record.keys()
all_keys.update(keys)
if rows > 5000:
break

if table_spec.get('key_properties'):
key_properties = set(table_spec['key_properties'])
if not key_properties.issubset(all_keys):
raise Exception('JSONL file "{}" is missing required key_properties key: {}'
raise Exception('JSONL/parquet file "{}" is missing required key_properties key: {}'
.format(s3_path, key_properties - all_keys))

if table_spec.get('date_overrides'):
date_overrides = set(table_spec['date_overrides'])
if not date_overrides.issubset(all_keys):
raise Exception('JSONL file "{}" is missing date_overrides key: {}'
raise Exception('JSONL/parquet file "{}" is missing date_overrides key: {}'
.format(s3_path, date_overrides - all_keys))

#pylint: disable=global-statement
def sampling_gz_file(table_spec, s3_path, file_handle, sample_rate):
def sampling_gz_file(table_spec, s3_path, file_handle, sample_rate, config):
global skipped_files_count
if s3_path.endswith(".tar.gz"):
LOGGER.warning('Skipping "%s" file as .tar.gz extension is not supported',s3_path)
Expand All @@ -235,12 +262,19 @@ def sampling_gz_file(table_spec, s3_path, file_handle, sample_rate):
return []

gz_file_extension = gz_file_name.split(".")[-1].lower()
return sample_file(table_spec, s3_path + "/" + gz_file_name, io.BytesIO(gz_file_obj.read()), sample_rate, gz_file_extension)
return sample_file(table_spec, s3_path + "/" + gz_file_name, io.BytesIO(gz_file_obj.read()), sample_rate, gz_file_extension, config)

raise Exception('"{}" file has some error(s)'.format(s3_path))

def peek(iterable):
try:
first = next(iterable)
except StopIteration:
return None
return first, itertools.chain([first], iterable)

#pylint: disable=global-statement
def sample_file(table_spec, s3_path, file_handle, sample_rate, extension):
def sample_file(table_spec, s3_bucket, s3_path, file_handle, sample_rate, extension, config):
global skipped_files_count

# Check whether file is without extension or not
Expand All @@ -260,21 +294,36 @@ def sample_file(table_spec, s3_path, file_handle, sample_rate, extension):
skipped_files_count = skipped_files_count + 1
return csv_records
if extension == "gz":
return sampling_gz_file(table_spec, s3_path, file_handle, sample_rate)
return sampling_gz_file(table_spec, s3_path, file_handle, sample_rate, config)
if extension == "jsonl":
# If file object read from s3 bucket file else use extracted file object from zip or gz

file_handle = file_handle._raw_stream if hasattr(file_handle, "_raw_stream") else file_handle
records = get_records_for_jsonl(
s3_path, sample_rate, file_handle)
check_jsonl_sample_records, records = itertools.tee(
records)
jsonl_sample_records = list(check_jsonl_sample_records)
if len(jsonl_sample_records) == 0:
records = get_records_for_jsonl(s3_path, sample_rate, file_handle)
check_jsonl_sample_records, records = itertools.tee(records)

result = peek(check_jsonl_sample_records)
if result is None:
LOGGER.warning('Skipping "%s" file as it is empty', s3_path)
skipped_files_count = skipped_files_count + 1
check_key_properties_and_date_overrides_for_jsonl_file(
table_spec, jsonl_sample_records, s3_path)

return []
else:
check_jsonl_sample_records = result[1]
check_key_properties_and_date_overrides_for_jsonl_file(table_spec, check_jsonl_sample_records, s3_path)
return records
if extension == "parquet":
records = get_records_for_parquet(s3_bucket, s3_path, sample_rate, config)
check_jsonl_sample_records, records = itertools.tee(records)

result = peek(check_jsonl_sample_records)
if result is None:
LOGGER.warning('Skipping "%s" file as it is empty', s3_path)
skipped_files_count = skipped_files_count + 1
return []
else:
check_jsonl_sample_records = result[1]
check_key_properties_and_date_overrides_for_jsonl_file(table_spec, check_jsonl_sample_records, s3_path)

return records
if extension == "zip":
LOGGER.warning('Skipping "%s" file as it contains nested compression.',s3_path)
Expand Down Expand Up @@ -303,7 +352,7 @@ def get_files_to_sample(config, s3_files, max_files):
global skipped_files_count
sampled_files = []

OTHER_FILES = ["csv","gz","jsonl","txt"]
OTHER_FILES = ["csv","gz","jsonl","txt","parquet"]

for s3_file in s3_files:
file_key = s3_file.get('key')
Expand Down Expand Up @@ -343,11 +392,12 @@ def get_files_to_sample(config, s3_files, max_files):
def sample_files(config, table_spec, s3_files,
sample_rate=5, max_records=1000, max_files=5):
global skipped_files_count
max_files = config.get("max_sample_files", max_files)
LOGGER.info("Sampling files (max files: %s)", max_files)

for s3_file in itertools.islice(get_files_to_sample(config, s3_files, max_files), max_files):


s3_bucket = config['bucket']
s3_path = s3_file.get("s3_path","")
file_handle = s3_file.get("file_handle")
file_type = s3_file.get("type")
Expand All @@ -364,7 +414,7 @@ def sample_files(config, table_spec, s3_files,
max_records,
sample_rate)
try:
yield from itertools.islice(sample_file(table_spec, s3_path, file_handle, sample_rate, extension), max_records)
yield from itertools.islice(sample_file(table_spec, s3_bucket, s3_path, file_handle, sample_rate, extension, config), max_records)
except (UnicodeDecodeError,json.decoder.JSONDecodeError):
# UnicodeDecodeError will be raised if non csv file parsed to csv parser
# JSONDecodeError will be reaised if non JSONL file parsed to JSON parser
Expand All @@ -373,7 +423,7 @@ def sample_files(config, table_spec, s3_files,
skipped_files_count = skipped_files_count + 1

#pylint: disable=global-statement
def get_input_files_for_table(config, table_spec, modified_since=None):
def get_input_files_for_table(config, table_spec, modified_since=None, modified_until=None):
global skipped_files_count
bucket = config['bucket']

Expand All @@ -389,13 +439,13 @@ def get_input_files_for_table(config, table_spec, modified_since=None):
"https://docs.python.org/3.5/library/re.html#regular-expression-syntax").format(table_spec['table_name']),
pattern) from e

LOGGER.info(
'Checking bucket "%s" for keys matching "%s"', bucket, pattern)
LOGGER.info('Checking bucket "%s" for keys matching "%s"', bucket, pattern)
LOGGER.info('Window period: since %s until %s',modified_since,modified_until)

matched_files_count = 0
unmatched_files_count = 0
max_files_before_log = 30000
for s3_object in list_files_in_bucket(bucket, table_spec.get('search_prefix')):
for s3_object in list_files_in_bucket(bucket, table_spec.get('search_prefix'), config=config):
key = s3_object['Key']
last_modified = s3_object['LastModified']

Expand All @@ -408,10 +458,9 @@ def get_input_files_for_table(config, table_spec, modified_since=None):
if matcher.search(key):
matched_files_count += 1
if modified_since is None or modified_since < last_modified:
LOGGER.info('Will download key "%s" as it was last modified %s',
key,
last_modified)
yield {'key': key, 'last_modified': last_modified}
if modified_until is None or last_modified.replace(tzinfo=pytz.UTC) < parse(modified_until).replace(tzinfo=pytz.UTC):
LOGGER.info('Will download key "%s" as it was last modified %s',key,last_modified)
yield {'key': key, 'last_modified': last_modified}
else:
unmatched_files_count += 1

Expand All @@ -431,8 +480,8 @@ def get_input_files_for_table(config, table_spec, modified_since=None):


@retry_pattern()
def list_files_in_bucket(bucket, search_prefix=None):
s3_client = boto3.client('s3')
def list_files_in_bucket(bucket, search_prefix=None, config=None):
s3_client = setup_aws_client(config).client('s3')

s3_object_count = 0

Expand Down Expand Up @@ -462,7 +511,7 @@ def list_files_in_bucket(bucket, search_prefix=None):
@retry_pattern()
def get_file_handle(config, s3_path):
bucket = config['bucket']
s3_client = boto3.resource('s3')
s3_client = setup_aws_client(config).resource('s3')

s3_bucket = s3_client.Bucket(bucket)
s3_object = s3_bucket.Object(s3_path)
Expand Down
Loading