Skip to content

Commit

Permalink
refreshes the session on expiry
Browse files Browse the repository at this point in the history
  • Loading branch information
“rdeshmukh15” committed Oct 29, 2024
1 parent a36b253 commit a88f034
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
1 change: 0 additions & 1 deletion tap_s3_csv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import sys
import singer
import boto3

from singer import metadata
from tap_s3_csv.discover import discover_streams
Expand Down
85 changes: 65 additions & 20 deletions tap_s3_csv/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import backoff
import boto3
import singer
import time

from botocore.credentials import (
AssumeRoleCredentialFetcher,
Expand Down Expand Up @@ -38,6 +39,7 @@

# timeout request after 300 seconds
REQUEST_TIMEOUT = 300
SESSION_DURATION = 900

def is_access_denied_error(error):
"""
Expand Down Expand Up @@ -89,16 +91,17 @@ def fetch_credentials(self):
RoleArn=self.role_arn,
RoleSessionName=self.extra_args['RoleSessionName'],
ExternalId=self.extra_args.get('ExternalId'),
DurationSeconds=self.extra_args.get('DurationSeconds', 3600)
DurationSeconds=self.extra_args.get('DurationSeconds', SESSION_DURATION)
)
LOGGER.info("fetch_credentials:%s",response)
return {
'access_key': response['Credentials']['AccessKeyId'],
'secret_key': response['Credentials']['SecretAccessKey'],
'token': response['Credentials']['SessionToken'],
}

@retry_pattern
def setup_aws_client(config):
def setup_aws_client(config, flag=False):
proxy_role_arn = f"arn:aws:iam::{config['proxy_account_id']}:role/{config['proxy_role_name']}"
session = boto3.Session()

Expand All @@ -111,7 +114,7 @@ def setup_aws_client(config):
session.get_credentials(),
proxy_role_arn,
extra_args={
'DurationSeconds': 3600,
'DurationSeconds': SESSION_DURATION,
'RoleSessionName': 'TapProxySession',
'ExternalId': config['proxy_external_id']
},
Expand All @@ -137,7 +140,7 @@ def setup_aws_client(config):
proxy_session.get_credentials(),
cust_role_arn,
extra_args={
'DurationSeconds': 3600,
'DurationSeconds': SESSION_DURATION,
'RoleSessionName': 'TapS3CSV',
'ExternalId': config['cust_external_id']
},
Expand All @@ -153,6 +156,10 @@ def setup_aws_client(config):
aws_session_token=cust_credentials['token']
)

if flag==False:
LOGGER.info("sleeping after default session for 15 mins")
time.sleep(950)


def get_sampled_schema_for_table(config, table_spec):
LOGGER.info('Sampling records to determine table schema.')
Expand Down Expand Up @@ -523,13 +530,20 @@ def get_request_timeout(config):
request_timeout = REQUEST_TIMEOUT
return request_timeout

def refresh_session(config):
# This function calls setup_aws_client to refresh the credentials
LOGGER.info("Refreshing AWS session...")
setup_aws_client(config, True)

@retry_pattern
def list_files_in_bucket(config, search_prefix=None):
# Set connect and read timeout for resource
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
s3_client = boto3.client('s3', config=client_config)
def create_s3_client():
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
return boto3.client('s3', config=client_config)

s3_client = create_s3_client()
LOGGER.info("in list_files_in_bucket.......1")
s3_object_count = 0

max_results = 1000
Expand All @@ -543,12 +557,29 @@ def list_files_in_bucket(config, search_prefix=None):
args['Prefix'] = search_prefix

paginator = s3_client.get_paginator('list_objects_v2')
LOGGER.info("in list_files_in_bucket.......2")
pages = 0
for page in paginator.paginate(**args):
pages += 1
LOGGER.debug("On page %s", pages)
s3_object_count += len(page['Contents'])
yield from page['Contents']

while True:
try:
for page in paginator.paginate(**args):
LOGGER.info("in list_files_in_bucket.......3")
pages += 1
LOGGER.debug("On page %s", pages)
s3_object_count += len(page.get('Contents', []))
yield from page.get('Contents', [])
break # Break if pagination is successful
except ClientError as e:
# Check if the error is due to an expired token
if e.response['Error']['Code'] == 'ExpiredToken':
LOGGER.warning("Token expired, refreshing credentials...")
refresh_session(config)
# Re-create the S3 client with new credentials
s3_client = create_s3_client()
paginator = s3_client.get_paginator('list_objects_v2') # Recreate paginator with new client
else:
LOGGER.error("Failed to list files: %s", e)
raise

if s3_object_count > 0:
LOGGER.info("Found %s files.", s3_object_count)
Expand All @@ -558,12 +589,26 @@ def list_files_in_bucket(config, search_prefix=None):

@retry_pattern
def get_file_handle(config, s3_path):
bucket = config['bucket']
# Set connect and read timeout for resource
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
s3_client = boto3.resource('s3', config=client_config)
def create_s3_resource():
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
return boto3.resource('s3', config=client_config)

s3_bucket = s3_client.Bucket(bucket)
s3_resource = create_s3_resource()
bucket = config['bucket']
s3_bucket = s3_resource.Bucket(bucket)
s3_object = s3_bucket.Object(s3_path)
return s3_object.get()['Body']

while True:
try:
return s3_object.get()['Body']
except ClientError as e:
if e.response['Error']['Code'] == 'ExpiredToken':
LOGGER.warning("Token expired, refreshing credentials...")
refresh_session(config)
s3_resource = create_s3_resource()
s3_bucket = s3_resource.Bucket(bucket)
s3_object = s3_bucket.Object(s3_path)
else:
LOGGER.error("Failed to get file handle: %s", e)
raise

0 comments on commit a88f034

Please sign in to comment.