Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unicode column names fix #281

Merged
merged 5 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datacompy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.11.1"
__version__ = "0.11.2"

from datacompy.core import *
from datacompy.fugue import (
Expand Down
46 changes: 27 additions & 19 deletions datacompy/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union
from warnings import warn

try:
import pyspark
Expand All @@ -25,6 +26,13 @@
pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality


warn(
f"The module {__name__} is deprecated. In future versions (0.12.0 and above) SparkCompare will be refactored and the legacy logic will move to LegacySparkCompare ",
DeprecationWarning,
stacklevel=2,
)


class MatchType(Enum):
MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3)

Expand Down Expand Up @@ -383,7 +391,7 @@ def rows_only_base(self) -> "pyspark.sql.DataFrame":
base_rows.createOrReplaceTempView("baseRows")
self.base_df.createOrReplaceTempView("baseTable")
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
sql_query = "select A.* from baseTable as A, baseRows as B where {}".format(
join_condition
Expand All @@ -403,7 +411,7 @@ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]:
compare_rows.createOrReplaceTempView("compareRows")
self.compare_df.createOrReplaceTempView("compareTable")
where_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
sql_query = (
"select A.* from compareTable as A, compareRows as B where {}".format(
Expand Down Expand Up @@ -435,15 +443,15 @@ def _generate_select_statement(self, match_data: bool = True) -> str:
[self._create_select_statement(name=column_name)]
)
elif column_name in base_only:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])

elif column_name in compare_only:
if match_data:
select_statement = select_statement + ",".join(["B." + column_name])
select_statement = select_statement + ",".join(["B.`" + column_name + "`"])
else:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])
elif column_name in self._join_column_names:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])

if column_name != sorted_list[-1]:
select_statement = select_statement + " , "
Expand All @@ -465,7 +473,7 @@ def _merge_dataframes(self) -> None:
self._all_matched_rows.createOrReplaceTempView("matched_table")

where_cond = " OR ".join(
["A." + name + "_match= False" for name in self.columns_compared]
["A.`" + name + "_match`= False" for name in self.columns_compared]
)
mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond)
self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy(
Expand All @@ -475,7 +483,7 @@ def _merge_dataframes(self) -> None:
def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame":
if self._joined_dataframe is None:
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
select_statement = self._generate_select_statement(match_data=True)

Expand Down Expand Up @@ -566,22 +574,22 @@ def _create_select_statement(self, name: str) -> str:
match_type_comparison = ""
for k in MatchType:
match_type_comparison += (
" WHEN (A.{name}={match_value}) THEN '{match_name}'".format(
" WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format(
name=name, match_value=str(k.value), match_name=k.name
)
)
return "A.{name}_base, A.{name}_compare, (CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END) AS {name}_match, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS {name}_match_type ".format(
return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format(
name=name,
match_failure=MatchType.MISMATCH.value,
match_type_comparison=match_type_comparison,
)
else:
return "A.{name}_base, A.{name}_compare, CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END AS {name}_match ".format(
return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format(
name=name, match_failure=MatchType.MISMATCH.value
)

def _create_case_statement(self, name: str) -> str:
equal_comparisons = ["(A.{name} IS NULL AND B.{name} IS NULL)"]
equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"]
known_diff_comparisons = ["(FALSE)"]

base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0]
Expand All @@ -592,30 +600,30 @@ def _create_case_statement(self, name: str) -> str:
compare_dtype in NUMERIC_SPARK_TYPES
): # numeric tolerance comparison
equal_comparisons.append(
"((A.{name}=B.{name}) OR ((abs(A.{name}-B.{name}))<=("
"((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=("
+ str(self.abs_tol)
+ "+("
+ str(self.rel_tol)
+ "*abs(A.{name})))))"
+ "*abs(A.`{name}`)))))"
)
else: # non-numeric comparison
equal_comparisons.append("((A.{name}=B.{name}))")
equal_comparisons.append("((A.`{name}`=B.`{name}`))")

if self._known_differences:
new_input = "B.{name}"
new_input = "B.`{name}`"
for kd in self._known_differences:
if compare_dtype in kd["types"]:
if "flags" in kd and "nullcheck" in kd["flags"]:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") is null AND A.{name} is null)"
+ ") is null AND A.`{name}` is null)"
)
else:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") = A.{name})"
+ ") = A.`{name}`)"
)

case_string = (
Expand All @@ -624,7 +632,7 @@ def _create_case_statement(self, name: str) -> str:
+ ") THEN {match_success} WHEN ("
+ " OR ".join(known_diff_comparisons)
+ ") THEN {match_known_difference} ELSE {match_failure} END) "
+ "AS {name}, A.{name} AS {name}_base, B.{name} AS {name}_compare"
+ "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`"
)

return case_string.format(
Expand Down
13 changes: 13 additions & 0 deletions tests/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,3 +2093,16 @@ def text_alignment_validator(

if not at_column_section and section_start in line:
at_column_section = True

def test_unicode_columns(spark_session):
df1 = spark_session.createDataFrame(
[{"a": 1, "例": 2}, {"a": 1, "例": 3}]
)
df2 = spark_session.createDataFrame(
[{"a": 1, "例": 2}, {"a": 1, "例": 3}]
)
compare = SparkCompare(
spark_session, df1, df2, join_columns=["例"]
)
# Just render the report to make sure it renders.
compare.report()
Loading