Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Athena catalog column name to default #1888

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
6 changes: 2 additions & 4 deletions soda/athena/soda/data_sources/athena_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
session_token=data_source_properties.get("session_token"),
region_name=data_source_properties.get("region_name"),
profile_name=data_source_properties.get("profile_name"),
external_id=data_source_properties.get("external_id")
)

def connect(self):
Expand All @@ -45,6 +46,7 @@ def connect(self):
s3_staging_dir=self.athena_staging_dir,
region_name=self.aws_credentials.region_name,
role_arn=self.aws_credentials.role_arn,
external_id=self.aws_credentials.external_id,
catalog_name=self.catalog,
work_group=self.work_group,
schema_name=self.schema,
Expand Down Expand Up @@ -100,10 +102,6 @@ def quote_column(self, column_name: str) -> str:
def regex_replace_flags(self) -> str:
return ""

@staticmethod
def column_metadata_catalog_column() -> str:
return "table_schema"

def default_casify_table_name(self, identifier: str) -> str:
return identifier.lower()

Expand Down
5 changes: 4 additions & 1 deletion soda/core/soda/common/aws_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ def __init__(
session_token: Optional[str] = None,
profile_name: Optional[str] = None,
region_name: Optional[str] = "eu-west-1",
external_id: Optional[str] = None,
):
self.access_key_id = access_key_id
self.secret_access_key = secret_access_key
self.role_arn = role_arn
self.external_id = external_id
self.session_token = session_token
self.profile_name = profile_name
self.region_name = region_name
Expand All @@ -32,6 +34,7 @@ def from_configuration(cls, configuration: dict):
access_key_id=access_key_id,
secret_access_key=configuration.get("secret_access_key"),
role_arn=configuration.get("role_arn"),
external_id=configuration.get("external_id"),
session_token=configuration.get("session_token"),
profile_name=configuration.get("profile_name"),
region_name=configuration.get("region", "eu-west-1"),
Expand All @@ -55,7 +58,7 @@ def assume_role(self, role_session_name: str):
aws_session_token=self.session_token,
)

assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, RoleSessionName=role_session_name)
assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, ExternalId=self.external_id, RoleSessionName=role_session_name)
credentials_dict = assumed_role_object["Credentials"]
return AwsCredentials(
region_name=self.region_name,
Expand Down
10 changes: 7 additions & 3 deletions soda/redshift/soda/data_sources/redshift_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
self.connect_timeout = data_source_properties.get("connection_timeout_sec")
self.username = data_source_properties.get("username")
self.password = data_source_properties.get("password")
self.dbuser = data_source_properties.get("dbuser")
self.dbname = data_source_properties.get("dbname")
self.cluster_id = data_source_properties.get("cluster_id")

if not self.username or not self.password:
aws_credentials = AwsCredentials(
Expand All @@ -31,6 +34,7 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
session_token=data_source_properties.get("session_token"),
region_name=data_source_properties.get("region", "eu-west-1"),
profile_name=data_source_properties.get("profile_name"),
external_id=data_source_properties.get("external_id"),
)
self.username, self.password = self.__get_cluster_credentials(aws_credentials)

Expand Down Expand Up @@ -60,9 +64,9 @@ def __get_cluster_credentials(self, aws_credentials: AwsCredentials):
aws_session_token=resolved_aws_credentials.session_token,
)

cluster_name = self.host.split(".")[0]
username = self.username
db_name = self.database
cluster_name = self.cluster_id if self.cluster_id else self.host.split(".")[0]
username = self.dbuser if self.dbuser else self.username
db_name = self.dbname if self.dbname else self.database
cluster_creds = client.get_cluster_credentials(
DbUser=username, DbName=db_name, ClusterIdentifier=cluster_name, AutoCreate=False, DurationSeconds=3600
)
Expand Down
4 changes: 4 additions & 0 deletions soda/spark/soda/data_sources/spark_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class SparkDataSource(SparkSQLBase):

def __init__(self, logs: Logs, data_source_name: str, data_source_properties: dict):
super().__init__(logs, data_source_name, data_source_properties)
self.NUMERIC_TYPES_FOR_PROFILING = ["integer", "int", "double", "float", "decimal", "bigint"]

self.method = data_source_properties.get("method", "hive")
self.host = data_source_properties.get("host", "localhost")
Expand Down Expand Up @@ -476,3 +477,6 @@ def connect(self):
self.connection = connection
except Exception as e:
raise DataSourceConnectionError(self.type, e)

def cast_to_text(self, expr: str) -> str:
return f"CAST({expr} AS VARCHAR(100))"