Skip to content

Commit

Permalink
bug fix and clean up of temp_column_name (capitalone#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
fdosani authored Jun 3, 2024
1 parent a8f43cc commit be17763
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 138 deletions.
30 changes: 30 additions & 0 deletions datacompy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,33 @@ def report(
html_file: Optional[str] = None,
) -> str:
pass


def temp_column_name(*dataframes) -> str:
"""Gets a temp column name that isn't included in columns of any dataframes
Parameters
----------
dataframes : list of DataFrames
The DataFrames to create a temporary column name for
Returns
-------
str
String column name that looks like '_temp_x' for some integer x
"""
i = 0
columns = []
for dataframe in dataframes:
columns = columns + list(dataframe.columns)
columns = set(columns)

while True:
temp_column = f"_temp_{i}"
unique = True

if temp_column in columns:
i += 1
unique = False
if unique:
return temp_column
27 changes: 1 addition & 26 deletions datacompy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import pandas as pd
from ordered_set import OrderedSet

from datacompy.base import BaseCompare
from datacompy.base import BaseCompare, temp_column_name

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -890,31 +890,6 @@ def get_merged_columns(
return columns


def temp_column_name(*dataframes: pd.DataFrame) -> str:
"""Gets a temp column name that isn't included in columns of any dataframes
Parameters
----------
dataframes : list of Pandas.DataFrame
The DataFrames to create a temporary column name for
Returns
-------
str
String column name that looks like '_temp_x' for some integer x
"""
i = 0
while True:
temp_column = f"_temp_{i}"
unique = True
for dataframe in dataframes:
if temp_column in dataframe.columns:
i += 1
unique = False
if unique:
return temp_column


def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> float:
"""Get a maximum difference between two columns
Expand Down
45 changes: 14 additions & 31 deletions datacompy/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import numpy as np
from ordered_set import OrderedSet

from datacompy.base import BaseCompare
from datacompy.base import BaseCompare, temp_column_name

try:
import polars as pl
Expand Down Expand Up @@ -278,11 +278,17 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:

# process merge indicator
outer_join = outer_join.with_columns(
pl.when((pl.col("_merge_left") == True) & (pl.col("_merge_right") == True))
pl.when(
(pl.col("_merge_left") == True) & (pl.col("_merge_right") == True)
) # noqa: E712
.then(pl.lit("both"))
.when((pl.col("_merge_left") == True) & (pl.col("_merge_right").is_null()))
.when(
(pl.col("_merge_left") == True) & (pl.col("_merge_right").is_null())
) # noqa: E712
.then(pl.lit("left_only"))
.when((pl.col("_merge_left").is_null()) & (pl.col("_merge_right") == True))
.when(
(pl.col("_merge_left").is_null()) & (pl.col("_merge_right") == True)
) # noqa: E712
.then(pl.lit("right_only"))
.alias("_merge")
)
Expand Down Expand Up @@ -497,7 +503,9 @@ def sample_mismatch(
col_match = self.intersect_rows[column + "_match"]
match_cnt = col_match.sum()
sample_count = min(sample_count, row_cnt - match_cnt) # type: ignore
sample = self.intersect_rows.filter(pl.col(column + "_match") != True).sample(
sample = self.intersect_rows.filter(
pl.col(column + "_match") != True
).sample( # noqa: E712
sample_count
)
return_cols = self.join_columns + [
Expand Down Expand Up @@ -558,7 +566,7 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
)
return (
self.intersect_rows.with_columns(__all=pl.all_horizontal(match_list))
.filter(pl.col("__all") != True)
.filter(pl.col("__all") != True) # noqa: E712
.select(self.join_columns + return_list)
)

Expand Down Expand Up @@ -899,31 +907,6 @@ def get_merged_columns(
return columns


def temp_column_name(*dataframes: "pl.DataFrame") -> str:
"""Gets a temp column name that isn't included in columns of any dataframes
Parameters
----------
dataframes : list of Polars.DataFrame
The DataFrames to create a temporary column name for
Returns
-------
str
String column name that looks like '_temp_x' for some integer x
"""
i = 0
while True:
temp_column = f"_temp_{i}"
unique = True
for dataframe in dataframes:
if temp_column in dataframe.columns:
i += 1
unique = False
if unique:
return temp_column


def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float:
"""Get a maximum difference between two columns
Expand Down
36 changes: 7 additions & 29 deletions datacompy/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pandas as pd
from ordered_set import OrderedSet

from datacompy.base import BaseCompare
from datacompy.base import BaseCompare, temp_column_name

try:
import pyspark.pandas as ps
Expand Down Expand Up @@ -301,15 +301,18 @@ def _dataframe_merge(self, ignore_spaces):

# process merge indicator
outer_join["_merge"] = outer_join._merge.mask(
(outer_join["_merge_left"] == True) & (outer_join["_merge_right"] == True),
(outer_join["_merge_left"] == True)
& (outer_join["_merge_right"] == True), # noqa: E712
"both",
)
outer_join["_merge"] = outer_join._merge.mask(
(outer_join["_merge_left"] == True) & (outer_join["_merge_right"] != True),
(outer_join["_merge_left"] == True)
& (outer_join["_merge_right"] != True), # noqa: E712
"left_only",
)
outer_join["_merge"] = outer_join._merge.mask(
(outer_join["_merge_left"] != True) & (outer_join["_merge_right"] == True),
(outer_join["_merge_left"] != True)
& (outer_join["_merge_right"] == True), # noqa: E712
"right_only",
)

Expand Down Expand Up @@ -913,31 +916,6 @@ def get_merged_columns(original_df, merged_df, suffix):
return columns


def temp_column_name(*dataframes):
"""Gets a temp column name that isn't included in columns of any dataframes
Parameters
----------
dataframes : list of pyspark.pandas.frame.DataFrame
The DataFrames to create a temporary column name for
Returns
-------
str
String column name that looks like '_temp_x' for some integer x
"""
i = 0
while True:
temp_column = f"_temp_{i}"
unique = True
for dataframe in dataframes:
if temp_column in dataframe.columns:
i += 1
unique = False
if unique:
return temp_column


def calculate_max_diff(col_1, col_2):
"""Get a maximum difference between two columns
Expand Down
36 changes: 18 additions & 18 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,11 @@ def test_mixed_column_with_ignore_spaces_and_case():
def test_compare_df_setter_bad():
df = pd.DataFrame([{"a": 1, "A": 2}, {"a": 2, "A": 2}])
with raises(TypeError, match="df1 must be a pandas DataFrame"):
compare = datacompy.Compare("a", "a", ["a"])
datacompy.Compare("a", "a", ["a"])
with raises(ValueError, match="df1 must have all columns from join_columns"):
compare = datacompy.Compare(df, df.copy(), ["b"])
datacompy.Compare(df, df.copy(), ["b"])
with raises(ValueError, match="df1 must have unique column names"):
compare = datacompy.Compare(df, df.copy(), ["a"])
datacompy.Compare(df, df.copy(), ["a"])
df_dupe = pd.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 3}])
assert datacompy.Compare(df_dupe, df_dupe.copy(), ["a", "b"]).df1.equals(df_dupe)

Expand Down Expand Up @@ -450,15 +450,15 @@ def test_compare_df_setter_different_cases():
def test_compare_df_setter_bad_index():
df = pd.DataFrame([{"a": 1, "A": 2}, {"a": 2, "A": 2}])
with raises(TypeError, match="df1 must be a pandas DataFrame"):
compare = datacompy.Compare("a", "a", on_index=True)
datacompy.Compare("a", "a", on_index=True)
with raises(ValueError, match="df1 must have unique column names"):
compare = datacompy.Compare(df, df.copy(), on_index=True)
datacompy.Compare(df, df.copy(), on_index=True)


def test_compare_on_index_and_join_columns():
df = pd.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}])
with raises(Exception, match="Only provide on_index or join_columns"):
compare = datacompy.Compare(df, df.copy(), on_index=True, join_columns=["a"])
datacompy.Compare(df, df.copy(), on_index=True, join_columns=["a"])


def test_compare_df_setter_good_index():
Expand Down Expand Up @@ -647,7 +647,7 @@ def test_temp_column_name_one_has():
assert actual == "_temp_1"


def test_temp_column_name_both_have():
def test_temp_column_name_both_have_temp_1():
df1 = pd.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}])
df2 = pd.DataFrame(
[
Expand All @@ -660,7 +660,7 @@ def test_temp_column_name_both_have():
assert actual == "_temp_1"


def test_temp_column_name_both_have():
def test_temp_column_name_both_have_temp_2():
df1 = pd.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}])
df2 = pd.DataFrame(
[
Expand Down Expand Up @@ -693,7 +693,7 @@ def test_simple_dupes_one_field():
compare = datacompy.Compare(df1, df2, join_columns=["a"])
assert compare.matches()
# Just render the report to make sure it renders.
t = compare.report()
compare.report()


def test_simple_dupes_two_fields():
Expand All @@ -702,7 +702,7 @@ def test_simple_dupes_two_fields():
compare = datacompy.Compare(df1, df2, join_columns=["a", "b"])
assert compare.matches()
# Just render the report to make sure it renders.
t = compare.report()
compare.report()


def test_simple_dupes_index():
Expand All @@ -714,19 +714,19 @@ def test_simple_dupes_index():
compare = datacompy.Compare(df1, df2, on_index=True)
assert compare.matches()
# Just render the report to make sure it renders.
t = compare.report()
compare.report()


def test_simple_dupes_one_field_two_vals():
def test_simple_dupes_one_field_two_vals_1():
df1 = pd.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}])
df2 = pd.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}])
compare = datacompy.Compare(df1, df2, join_columns=["a"])
assert compare.matches()
# Just render the report to make sure it renders.
t = compare.report()
compare.report()


def test_simple_dupes_one_field_two_vals():
def test_simple_dupes_one_field_two_vals_2():
df1 = pd.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}])
df2 = pd.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 0}])
compare = datacompy.Compare(df1, df2, join_columns=["a"])
Expand All @@ -735,7 +735,7 @@ def test_simple_dupes_one_field_two_vals():
assert len(compare.df2_unq_rows) == 1
assert len(compare.intersect_rows) == 1
# Just render the report to make sure it renders.
t = compare.report()
compare.report()


def test_simple_dupes_one_field_three_to_two_vals():
Expand All @@ -747,7 +747,7 @@ def test_simple_dupes_one_field_three_to_two_vals():
assert len(compare.df2_unq_rows) == 0
assert len(compare.intersect_rows) == 2
# Just render the report to make sure it renders.
t = compare.report()
compare.report()

assert "(First 1 Columns)" in compare.report(column_count=1)
assert "(First 2 Columns)" in compare.report(column_count=2)
Expand Down Expand Up @@ -786,8 +786,8 @@ def test_dupes_from_real_data():
)
assert compare_unq.matches()
# Just render the report to make sure it renders.
t = compare_acct.report()
r = compare_unq.report()
compare_acct.report()
compare_unq.report()


def test_strings_with_joins_with_ignore_spaces():
Expand Down
Loading

0 comments on commit be17763

Please sign in to comment.