diff --git a/datacompy/__init__.py b/datacompy/__init__.py index 3e5cb57b..6bc1e6ea 100644 --- a/datacompy/__init__.py +++ b/datacompy/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.11.1" +__version__ = "0.11.2" from datacompy.core import * from datacompy.fugue import ( diff --git a/datacompy/spark.py b/datacompy/spark.py index f27d2e58..cf32ff17 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -17,6 +17,7 @@ from enum import Enum from itertools import chain from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union +from warnings import warn try: import pyspark @@ -25,6 +26,13 @@ pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality +warn( + f"The module {__name__} is deprecated. In future versions (0.12.0 and above) SparkCompare will be refactored and the legacy logic will move to LegacySparkCompare ", + DeprecationWarning, + stacklevel=2, +) + + class MatchType(Enum): MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3) @@ -383,7 +391,7 @@ def rows_only_base(self) -> "pyspark.sql.DataFrame": base_rows.createOrReplaceTempView("baseRows") self.base_df.createOrReplaceTempView("baseTable") join_condition = " AND ".join( - ["A." + name + "<=>B." + name for name in self._join_column_names] + ["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names] ) sql_query = "select A.* from baseTable as A, baseRows as B where {}".format( join_condition @@ -403,7 +411,7 @@ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: compare_rows.createOrReplaceTempView("compareRows") self.compare_df.createOrReplaceTempView("compareTable") where_condition = " AND ".join( - ["A." + name + "<=>B." + name for name in self._join_column_names] + ["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names] ) sql_query = ( "select A.* from compareTable as A, compareRows as B where {}".format( @@ -435,15 +443,15 @@ def _generate_select_statement(self, match_data: bool = True) -> str: [self._create_select_statement(name=column_name)] ) elif column_name in base_only: - select_statement = select_statement + ",".join(["A." + column_name]) + select_statement = select_statement + ",".join(["A.`" + column_name + "`"]) elif column_name in compare_only: if match_data: - select_statement = select_statement + ",".join(["B." + column_name]) + select_statement = select_statement + ",".join(["B.`" + column_name + "`"]) else: - select_statement = select_statement + ",".join(["A." + column_name]) + select_statement = select_statement + ",".join(["A.`" + column_name + "`"]) elif column_name in self._join_column_names: - select_statement = select_statement + ",".join(["A." + column_name]) + select_statement = select_statement + ",".join(["A.`" + column_name + "`"]) if column_name != sorted_list[-1]: select_statement = select_statement + " , " @@ -465,7 +473,7 @@ def _merge_dataframes(self) -> None: self._all_matched_rows.createOrReplaceTempView("matched_table") where_cond = " OR ".join( - ["A." + name + "_match= False" for name in self.columns_compared] + ["A.`" + name + "_match`= False" for name in self.columns_compared] ) mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond) self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy( @@ -475,7 +483,7 @@ def _merge_dataframes(self) -> None: def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame": if self._joined_dataframe is None: join_condition = " AND ".join( - ["A." + name + "<=>B." + name for name in self._join_column_names] + ["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names] ) select_statement = self._generate_select_statement(match_data=True) @@ -566,22 +574,22 @@ def _create_select_statement(self, name: str) -> str: match_type_comparison = "" for k in MatchType: match_type_comparison += ( - " WHEN (A.{name}={match_value}) THEN '{match_name}'".format( + " WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format( name=name, match_value=str(k.value), match_name=k.name ) ) - return "A.{name}_base, A.{name}_compare, (CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END) AS {name}_match, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS {name}_match_type ".format( + return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format( name=name, match_failure=MatchType.MISMATCH.value, match_type_comparison=match_type_comparison, ) else: - return "A.{name}_base, A.{name}_compare, CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END AS {name}_match ".format( + return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format( name=name, match_failure=MatchType.MISMATCH.value ) def _create_case_statement(self, name: str) -> str: - equal_comparisons = ["(A.{name} IS NULL AND B.{name} IS NULL)"] + equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"] known_diff_comparisons = ["(FALSE)"] base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0] @@ -592,30 +600,30 @@ def _create_case_statement(self, name: str) -> str: compare_dtype in NUMERIC_SPARK_TYPES ): # numeric tolerance comparison equal_comparisons.append( - "((A.{name}=B.{name}) OR ((abs(A.{name}-B.{name}))<=(" + "((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=(" + str(self.abs_tol) + "+(" + str(self.rel_tol) - + "*abs(A.{name})))))" + + "*abs(A.`{name}`)))))" ) else: # non-numeric comparison - equal_comparisons.append("((A.{name}=B.{name}))") + equal_comparisons.append("((A.`{name}`=B.`{name}`))") if self._known_differences: - new_input = "B.{name}" + new_input = "B.`{name}`" for kd in self._known_differences: if compare_dtype in kd["types"]: if "flags" in kd and "nullcheck" in kd["flags"]: known_diff_comparisons.append( "((" + kd["transformation"].format(new_input, input=new_input) - + ") is null AND A.{name} is null)" + + ") is null AND A.`{name}` is null)" ) else: known_diff_comparisons.append( "((" + kd["transformation"].format(new_input, input=new_input) - + ") = A.{name})" + + ") = A.`{name}`)" ) case_string = ( @@ -624,7 +632,7 @@ def _create_case_statement(self, name: str) -> str: + ") THEN {match_success} WHEN (" + " OR ".join(known_diff_comparisons) + ") THEN {match_known_difference} ELSE {match_failure} END) " - + "AS {name}, A.{name} AS {name}_base, B.{name} AS {name}_compare" + + "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`" ) return case_string.format( diff --git a/tests/test_spark.py b/tests/test_spark.py index 488e1008..92e9d877 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -2093,3 +2093,16 @@ def text_alignment_validator( if not at_column_section and section_start in line: at_column_section = True + +def test_unicode_columns(spark_session): + df1 = spark_session.createDataFrame( + [{"a": 1, "例": 2}, {"a": 1, "例": 3}] + ) + df2 = spark_session.createDataFrame( + [{"a": 1, "例": 2}, {"a": 1, "例": 3}] + ) + compare = SparkCompare( + spark_session, df1, df2, join_columns=["例"] + ) + # Just render the report to make sure it renders. + compare.report()