Skip to content

Commit

Permalink
updated test and col_data var. added tests to generators for edgecase
Browse files Browse the repository at this point in the history
  • Loading branch information
drahc1R committed Aug 3, 2023
1 parent 853f8eb commit 14bee78
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 15 deletions.
17 changes: 9 additions & 8 deletions synthetic_data/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
else:
self.tabular_generator_seed = seed
self.rng = np.random.default_rng(seed=self.tabular_generator_seed)
self.col_data = []

@classmethod
def post_profile_processing_w_data(cls, data, profile):
Expand Down Expand Up @@ -76,17 +75,18 @@ def synthesize(
seed=self.tabular_generator_seed,
)
else:
self.generate_uncorrelated_column_data()
col_data = self.generate_uncorrelated_column_data()

return generate_dataset(
rng=self.rng,
columns_to_generate=self.col_data,
columns_to_generate=col_data,
dataset_length=num_samples,
)

def generate_uncorrelated_column_data(self):
"""Generate column data."""
columns = self.profile.report()["data_stats"]
col_data = []

for col in columns:
generator = col.get("data_type", None)
Expand All @@ -103,7 +103,7 @@ def generate_uncorrelated_column_data(self):
end_date = pd.to_datetime(
col_stats.get("max", None), format=date_format[0]
)
self.col_data.append(
col_data.append(
{
"generator": generator,
"name": "dat",
Expand All @@ -114,7 +114,7 @@ def generate_uncorrelated_column_data(self):
}
)
elif generator == "int":
self.col_data.append(
col_data.append(
{
"generator": "integer",
"name": generator,
Expand All @@ -125,7 +125,7 @@ def generate_uncorrelated_column_data(self):
)

elif generator == "float":
self.col_data.append(
col_data.append(
{
"generator": generator,
"name": "flo",
Expand All @@ -148,7 +148,7 @@ def generate_uncorrelated_column_data(self):
for count in col_stats["categorical_count"].values():
probabilities.append(count / total)

self.col_data.append(
col_data.append(
{
"generator": "categorical",
"name": "cat",
Expand All @@ -158,7 +158,7 @@ def generate_uncorrelated_column_data(self):
}
)
else:
self.col_data.append(
col_data.append(
{
"generator": "text",
"name": "txt",
Expand All @@ -168,6 +168,7 @@ def generate_uncorrelated_column_data(self):
"order": order,
},
)
return col_data


class UnstructuredGenerator(BaseGenerator):
Expand Down
2 changes: 1 addition & 1 deletion tests/distinct_generators/test_int_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_size(self):
self.assertEqual(result.shape[0], 1)

def test_values_range(self):
ranges = [(-1, 1), (-10, 10), (-100, 100)]
ranges = [(-1, 1), (-10, 10), (-100, 100), (0, 0), (50, 50)]
for range in ranges:
result = random_integers(self.rng, range[0], range[1])
for x in result:
Expand Down
10 changes: 9 additions & 1 deletion tests/distinct_generators/test_text_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@ def test_return_type(self):
for x in text_arr:
self.assertIsInstance(x, np.str_)

def test_text_length(self):
def test_text_length_range(self):
text_arr = random_text(self.rng, str_len_min=4, str_len_max=5)
self.assertLessEqual(len(text_arr[0]), 5)
self.assertGreaterEqual(len(text_arr[0]), 4)

def test_text_equal_length_range(self):
try:
text_arr = random_text(self.rng, str_len_min=5, str_len_max=5)
self.assertLessEqual(len(text_arr[0]), 5)
self.assertGreaterEqual(len(text_arr[0]), 5)
except Exception as e:
print("test_text_equal_length_range failed unexpectedly: ", e)

def test_num_rows(self):
num_rows = [1, 5, 10]
for nr in num_rows:
Expand Down
15 changes: 10 additions & 5 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ def test_synthesize_correlated_method(self, mock_make_data):
correlated_tabular_generator.synthesize(num_samples=10)
mock_make_data.assert_called_once()

@mock.patch(
"synthetic_data.generators.TabularGenerator.generate_uncorrelated_column_data"
)
@mock.patch("synthetic_data.generators.generate_dataset")
def test_uncorrelated_synthesize_columns_to_generate(self, mock_generate_dataset):
def test_uncorrelated_synthesize_columns_to_generate(
self, mock_generate_dataset, mock_col_data
):

generator = TabularGenerator(profile=self.profile, is_correlated=False, seed=42)
self.assertFalse(generator.is_correlated)
Expand Down Expand Up @@ -202,16 +207,16 @@ def test_uncorrelated_synthesize_columns_to_generate(self, mock_generate_dataset

mock_generate_dataset.assert_called_once()
for i in range(len(expected_columns_to_generate)):
for key in generator.col_data[i].keys():
if isinstance(generator.col_data[i][key], list):
for key in mock_col_data[i].keys():
if isinstance(mock_col_data[i][key], list):
self.assertTrue(
set(generator.col_data[i][key]).issubset(
set(mock_col_data[i][key]).issubset(
expected_columns_to_generate[i][key]
)
)
else:
self.assertEqual(
generator.col_data[i][key], expected_columns_to_generate[i][key]
mock_col_data[i][key], expected_columns_to_generate[i][key]
)

def test_uncorrelated_synthesize_output(self):
Expand Down

0 comments on commit 14bee78

Please sign in to comment.