Skip to content

Commit

Permalink
make code more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
sgandhi1311 committed Dec 13, 2023
1 parent 2e3b150 commit 146d713
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 26 deletions.
3 changes: 1 addition & 2 deletions tests/test_sftp_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def test_run(self):
headers = self.get_headers_for_table(tap_stream_id)

for i in range(0, len(initial_records)):
initial_record = initial_records[i]
extracted_record = [extracted_messages[i]["data"][key] for key in headers]
self.assertEqual(initial_record, extracted_record)
self.assertEqual(initial_records[i], extracted_record)

45 changes: 26 additions & 19 deletions tests/unittests/test_encoding_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,14 @@ def tearDown(self):
# Reset sys.stdout to its original state
sys.stdout = sys.__stdout__

@parameterized.expand(
[
[
"utf-8",
],
[
"latin_1",
],
]
)

@patch("tap_sftp.is_valid_encoding", return_value=True)
@patch("tap_sftp.discover_streams", return_value=["stream1", "stream2"])
def test_do_discover_valid_encoding(
self, encoding_format, mock_discover_streams, mock_is_valid_encoding
def test_do_discover_valid_encoding_utf_8(
self, mock_discover_streams, mock_is_valid_encoding
):
"""Test do_discover with a valid encoding format"""
config = {"encoding_format": encoding_format}
config = {"encoding_format": "utf-8"}
captured_output = sys.stdout

with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
Expand All @@ -44,15 +35,31 @@ def test_do_discover_valid_encoding(
{"streams": ["stream1", "stream2"]}, indent=2
)
self.assertEqual(output, expected_output)

if encoding_format == "utf-8":
# Bypassing encoding check for `utf-8` as it is widely used
mock_is_valid_encoding.assert_not_called()
else:
mock_is_valid_encoding.assert_called_with("latin_1")
mock_discover_streams.assert_called_with(config, encoding_format)
mock_discover_streams.assert_called_with(config, "utf-8")
self.assertEqual(captured_output, sys.stdout) # Ensure sys.stdout is restored


@patch("tap_sftp.is_valid_encoding", return_value=True)
@patch("tap_sftp.discover_streams", return_value=["stream1", "stream2"])
def test_do_discover_encoding_latin_1(
self, mock_discover_streams, mock_is_valid_encoding
):
"""Test do_discover with a valid encoding format"""
config = {"encoding_format": "latin_1"}
captured_output = sys.stdout

with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
do_discover(config)
output = mock_stdout.getvalue().strip()
expected_output = json.dumps(
{"streams": ["stream1", "stream2"]}, indent=2
)
self.assertEqual(output, expected_output)

mock_is_valid_encoding.assert_called_with("latin_1")
self.assertEqual(captured_output, sys.stdout) # Ensure sys.stdout is restored

@patch("tap_sftp.is_valid_encoding", return_value=False)
def test_do_discover_invalid_encoding(self, mock_is_valid_encoding):
"""Test do_discover with an invalid encoding format."""
Expand Down
19 changes: 15 additions & 4 deletions tests/unittests/test_permission_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def test_no_error_during_sync(self, mocked_get_row_iterators, mocked_stats, mock

conn = client.SFTPConnection("10.0.0.1", "username", port="22")

rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","}, encoding_format=DEFAULT_ENCODING_FORMAT)
rows_synced = sync.sync_file(conn,
{"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"},
None,
{"key_properties": ["id"], "delimiter": ","},
encoding_format=DEFAULT_ENCODING_FORMAT)
# check if "csv.get_row_iterators" is called if it is called then error has not occurred
# if it is not called then error has occured and function returned from the except block
self.assertEquals(1, mocked_get_row_iterators.call_count)
Expand All @@ -69,8 +73,11 @@ def test_permisison_error_during_sync(self, mocked_get_row_iterators, mocked_log

conn = client.SFTPConnection("10.0.0.1", "username", port="22")

rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, \
{"key_properties": ["id"], "delimiter": ","}, encoding_format=DEFAULT_ENCODING_FORMAT)
rows_synced = sync.sync_file(conn,
{"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"},
None,
{"key_properties": ["id"], "delimiter": ","},
encoding_format=DEFAULT_ENCODING_FORMAT)
# check if "csv.get_row_iterators" is called if it is called then error has not occurred
# if it is not called then error has occured and function returned from the except block
self.assertEquals(0, mocked_get_row_iterators.call_count)
Expand All @@ -84,7 +91,11 @@ def test_oserror_during_sync(self, mocked_get_row_iterators, mocked_logger, mock

conn = client.SFTPConnection("10.0.0.1", "username", port="22")

rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","}, encoding_format=DEFAULT_ENCODING_FORMAT)
rows_synced = sync.sync_file(conn,
{"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"},
None,
{"key_properties": ["id"], "delimiter": ","},
encoding_format=DEFAULT_ENCODING_FORMAT)
# check if "csv.get_row_iterators" is called if it is called then error has not occurred
# if it is not called then error has occured and function returned from the except block
self.assertEquals(0, mocked_get_row_iterators.call_count)
Expand Down
5 changes: 4 additions & 1 deletion tests/unittests/test_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ def test_timeout_backoff__sync_file(self, mocked_get_row_iterators, mocked_get_f
conn = client.connection(config=config)
with self.assertRaises(socket.timeout):
# function call
sync.sync_file(conn=conn, f=file, stream="test_stream", table_spec=table_spec, encoding_format=encoding_format)
sync.sync_file(conn=conn,
f=file, stream="test_stream",
table_spec=table_spec,
encoding_format=encoding_format)

# verify that the tap backoff for 5 times
self.assertEquals(mocked_get_row_iterators.call_count, 5)
Expand Down

0 comments on commit 146d713

Please sign in to comment.