Skip to content

Commit

Permalink
Cast arrays using Python instead of Postgres db
Browse files Browse the repository at this point in the history
  • Loading branch information
drdee committed Nov 20, 2019
1 parent 4fdaac8 commit 29c2e87
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 78 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import setup

setup(name='tap-postgres',
version='0.0.65',
version='0.0.66',
description='Singer.io tap for extracting data from PostgreSQL',
author='Stitch',
url='https://singer.io',
Expand Down
107 changes: 30 additions & 77 deletions tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
#!/usr/bin/env python3
# pylint: disable=missing-docstring,not-an-iterable,too-many-locals,too-many-arguments,invalid-name,too-many-return-statements,too-many-branches,len-as-condition,too-many-nested-blocks,wrong-import-order,duplicate-code, anomalous-backslash-in-string, too-many-statements, singleton-comparison, consider-using-in

import singer
from functools import reduce
from select import select
import copy
import csv
import datetime
import decimal
import json
import re

from dateutil.parser import parse
import psycopg2
import singer
from singer import utils, get_bookmark
import singer.metadata as metadata
import tap_postgres.db as post_db
import tap_postgres.sync_strategies.common as sync_common
from dateutil.parser import parse
import psycopg2
from psycopg2 import sql
import copy
from select import select
from functools import reduce
import json
import re


LOGGER = singer.get_logger()

Expand Down Expand Up @@ -69,9 +71,6 @@ def tuples_to_map(accum, t):
accum[t[0]] = t[1]
return accum

def create_hstore_elem_query(elem):
return sql.SQL("SELECT hstore_to_array({})").format(sql.Literal(elem))

def create_hstore_elem(conn_info, elem):
with post_db.open_connection(conn_info) as conn:
with conn.cursor() as cur:
Expand All @@ -81,65 +80,15 @@ def create_hstore_elem(conn_info, elem):
hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {})
return hstore_elem

def create_array_elem(elem, sql_datatype, conn_info):
def create_array_elem(elem):
if elem is None:
return None

with post_db.open_connection(conn_info) as conn:
with conn.cursor() as cur:
if sql_datatype == 'bit[]':
cast_datatype = 'boolean[]'
elif sql_datatype == 'boolean[]':
cast_datatype = 'boolean[]'
elif sql_datatype == 'character varying[]':
cast_datatype = 'character varying[]'
elif sql_datatype == 'cidr[]':
cast_datatype = 'cidr[]'
elif sql_datatype == 'citext[]':
cast_datatype = 'text[]'
elif sql_datatype == 'date[]':
cast_datatype = 'text[]'
elif sql_datatype == 'double precision[]':
cast_datatype = 'double precision[]'
elif sql_datatype == 'hstore[]':
cast_datatype = 'text[]'
elif sql_datatype == 'integer[]':
cast_datatype = 'integer[]'
elif sql_datatype == 'bigint[]':
cast_datatype = 'bigint[]'
elif sql_datatype == 'inet[]':
cast_datatype = 'inet[]'
elif sql_datatype == 'json[]':
cast_datatype = 'text[]'
elif sql_datatype == 'jsonb[]':
cast_datatype = 'text[]'
elif sql_datatype == 'macaddr[]':
cast_datatype = 'macaddr[]'
elif sql_datatype == 'money[]':
cast_datatype = 'text[]'
elif sql_datatype == 'numeric[]':
cast_datatype = 'text[]'
elif sql_datatype == 'real[]':
cast_datatype = 'real[]'
elif sql_datatype == 'smallint[]':
cast_datatype = 'smallint[]'
elif sql_datatype == 'text[]':
cast_datatype = 'text[]'
elif sql_datatype in ('time without time zone[]', 'time with time zone[]'):
cast_datatype = 'text[]'
elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'):
cast_datatype = 'text[]'
elif sql_datatype == 'uuid[]':
cast_datatype = 'text[]'

else:
#custom datatypes like enums
cast_datatype = 'text[]'

sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format(elem, cast_datatype)
cur.execute(sql_stmt)
res = cur.fetchone()[0]
return res
elem = [elem[1:-1]]
reader = csv.reader(elem, delimiter=',', escapechar='\\' , quotechar='"')
array = next(reader)
array = [None if element.lower() == 'null' else element for element in array]
return array

#pylint: disable=too-many-branches,too-many-nested-blocks
def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
Expand All @@ -166,17 +115,21 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
#for ordinary bits, elem will == '1'
return elem == '1' or elem == True
if sql_datatype == 'boolean':
return elem
return bool(elem)
if sql_datatype == 'hstore':
return create_hstore_elem(conn_info, elem)
if 'numeric' in sql_datatype:
return decimal.Decimal(str(elem))
if isinstance(elem, int):
return elem
if isinstance(elem, float):
return elem
if isinstance(elem, str):
return elem
return decimal.Decimal(elem)
if sql_datatype == 'money':
return decimal.Decimal(elem[1:])
if sql_datatype in ('integer', 'smallint', 'bigint'):
return int(elem)
if sql_datatype in ('double precision', 'real', 'float'):
return float(elem)
if sql_datatype in ('text', 'character varying'):
return elem # return as string
if sql_datatype in ('cidr', 'citext', 'json', 'jsonb', 'inet', 'macaddr', 'uuid'):
return elem # return as string

raise Exception("do not know how to marshall value of type {}".format(elem.__class__))

Expand All @@ -189,7 +142,7 @@ def selected_array_to_singer_value(elem, sql_datatype, conn_info):
def selected_value_to_singer_value(elem, sql_datatype, conn_info):
#are we dealing with an array?
if sql_datatype.find('[]') > 0:
cleaned_elem = create_array_elem(elem, sql_datatype, conn_info)
cleaned_elem = create_array_elem(elem)
return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), (cleaned_elem or [])))

return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info)
Expand Down
108 changes: 108 additions & 0 deletions tests/test_logical_replication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from decimal import Decimal
import unittest
from unittest.mock import patch

from utils import get_test_connection_config
from tap_postgres.sync_strategies import logical_replication


class TestHandlingArrays(unittest.TestCase):
def setUp(self):
self.env = patch.dict('os.environ',
{'TAP_POSTGRES_HOST':'test'},
{'TAP_POSTGRES_USER':'test'}
{'TAP_POSTGRES_PASSWORD':'test'}
{'TAP_POSTGRES_PORT':'5432'}
)

self.arrays = [
'{10,01,NULL}',
'{t,f,NULL}',
'{127.0.0.1/32,10.0.0.0/32,NULL}',
'{CASE_INSENSITIVE,case_insensitive,NULL,"CASE,,INSENSITIVE"}',
'{2000-12-31,2001-01-01,NULL}',
'{3.14159265359,3.1415926,NULL}',
'{"\\"foo\\"=>\\"bar\\"","\\"foo\\"=>NULL",NULL,"\\"foo\\"=>\\"bar\\""}',
'{1,2,NULL}',
'{9223372036854775807,NULL}',
'{198.24.10.0/24,NULL}',
'{"{\\"foo\\":\\"bar\\"}",NULL}',
'{"{\\"foo\\": \\"bar\\"}",NULL}',
'{08:00:2b:01:02:03,NULL}',
'{$19.99,NULL}',
'{19.9999999,NULL}',
'{3.14159,NULL}',
'{0,1,NULL}',
'{foo,bar,NULL,"foo,bar","diederik\'s motel "}',
'{16:38:47,NULL}',
'{"2019-11-19 11:38:47-05",NULL}',
'{123e4567-e89b-12d3-a456-426655440000,NULL}'
]

self.sql_datatypes = {
'bit[]': bool,
'boolean[]': bool,
'cidr[]': str,
'citext[]': str,
'date[]': str,
'double precision[]': float,
'hstore[]': dict,
'integer[]': int,
'bigint[]': int,
'inet[]': str,
'json[]': str,
'jsonb[]': str,
'macaddr[]': str,
'money[]': Decimal,
'numeric[]': Decimal,
'real[]': float,
'smallint[]': int,
'text[]': str,
'time with time zone[]': str,
'timestamp with time zone[]': str,
'uuid[]': str,
}

def test_create_array_elem(self):
expected_arrays = [
['10', '01' ,None],
['t', 'f', None],
['127.0.0.1/32', '10.0.0.0/32', None],
['CASE_INSENSITIVE', 'case_insensitive', None,"CASE,,INSENSITIVE"],
['2000-12-31', '2001-01-01', None],
['3.14159265359','3.1415926', None],
None,
['1','2',None],
['9223372036854775807', None],
['198.24.10.0/24', None],
["{\"foo\":\"bar\"}", None],
["{\"foo\": \"bar\"}", None],
['08:00:2b:01:02:03', None],
['$19.99', None],
['19.9999999', None],
['3.14159', None],
['0','1', None],
['foo','bar',None,"foo,bar","diederik\'s motel "],
['16:38:47',None],
["2019-11-19 11:38:47-05",None],
['123e4567-e89b-12d3-a456-426655440000', None],
]
for elem, expected_array in zip(self.arrays, expected_arrays):
array = logical_replication.create_array_elem(elem)
self.assertEqual(array, expected_array)

def test_selected_value_to_singer_value_impl(self):
with self.env:
conn_info = get_test_connection_config()
for elem, sql_datatype in zip(self.arrays, self.sql_datatypes.keys()):
array = logical_replication.selected_value_to_singer_value(elem, sql_datatype, conn_info)

for element in array:
python_datatype = self.sql_datatypes[sql_datatype]
if element:
self.assertIsInstance(element, python_datatype)

if __name__== "__main__":
test1 = TestHandlingArrays()
test1.setUp()
test1.test_selected_value_to_singer_value_impl()

0 comments on commit 29c2e87

Please sign in to comment.