Skip to content

Commit

Permalink
Unicode column names fix (#281)
Browse files Browse the repository at this point in the history
* fixes #280

* Update test_spark.py

* Update test_spark.py

* Update test_spark.py

* bumping version
  • Loading branch information
fdosani authored Mar 19, 2024
1 parent 132ced1 commit 191c3a3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
2 changes: 1 addition & 1 deletion datacompy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.11.1"
__version__ = "0.11.2"

from datacompy.core import *
from datacompy.fugue import (
Expand Down
46 changes: 27 additions & 19 deletions datacompy/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union
from warnings import warn

try:
import pyspark
Expand All @@ -25,6 +26,13 @@
pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality


warn(
f"The module {__name__} is deprecated. In future versions (0.12.0 and above) SparkCompare will be refactored and the legacy logic will move to LegacySparkCompare ",
DeprecationWarning,
stacklevel=2,
)


class MatchType(Enum):
MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3)

Expand Down Expand Up @@ -383,7 +391,7 @@ def rows_only_base(self) -> "pyspark.sql.DataFrame":
base_rows.createOrReplaceTempView("baseRows")
self.base_df.createOrReplaceTempView("baseTable")
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
sql_query = "select A.* from baseTable as A, baseRows as B where {}".format(
join_condition
Expand All @@ -403,7 +411,7 @@ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]:
compare_rows.createOrReplaceTempView("compareRows")
self.compare_df.createOrReplaceTempView("compareTable")
where_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
sql_query = (
"select A.* from compareTable as A, compareRows as B where {}".format(
Expand Down Expand Up @@ -435,15 +443,15 @@ def _generate_select_statement(self, match_data: bool = True) -> str:
[self._create_select_statement(name=column_name)]
)
elif column_name in base_only:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])

elif column_name in compare_only:
if match_data:
select_statement = select_statement + ",".join(["B." + column_name])
select_statement = select_statement + ",".join(["B.`" + column_name + "`"])
else:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])
elif column_name in self._join_column_names:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])

if column_name != sorted_list[-1]:
select_statement = select_statement + " , "
Expand All @@ -465,7 +473,7 @@ def _merge_dataframes(self) -> None:
self._all_matched_rows.createOrReplaceTempView("matched_table")

where_cond = " OR ".join(
["A." + name + "_match= False" for name in self.columns_compared]
["A.`" + name + "_match`= False" for name in self.columns_compared]
)
mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond)
self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy(
Expand All @@ -475,7 +483,7 @@ def _merge_dataframes(self) -> None:
def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame":
if self._joined_dataframe is None:
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
select_statement = self._generate_select_statement(match_data=True)

Expand Down Expand Up @@ -566,22 +574,22 @@ def _create_select_statement(self, name: str) -> str:
match_type_comparison = ""
for k in MatchType:
match_type_comparison += (
" WHEN (A.{name}={match_value}) THEN '{match_name}'".format(
" WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format(
name=name, match_value=str(k.value), match_name=k.name
)
)
return "A.{name}_base, A.{name}_compare, (CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END) AS {name}_match, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS {name}_match_type ".format(
return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format(
name=name,
match_failure=MatchType.MISMATCH.value,
match_type_comparison=match_type_comparison,
)
else:
return "A.{name}_base, A.{name}_compare, CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END AS {name}_match ".format(
return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format(
name=name, match_failure=MatchType.MISMATCH.value
)

def _create_case_statement(self, name: str) -> str:
equal_comparisons = ["(A.{name} IS NULL AND B.{name} IS NULL)"]
equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"]
known_diff_comparisons = ["(FALSE)"]

base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0]
Expand All @@ -592,30 +600,30 @@ def _create_case_statement(self, name: str) -> str:
compare_dtype in NUMERIC_SPARK_TYPES
): # numeric tolerance comparison
equal_comparisons.append(
"((A.{name}=B.{name}) OR ((abs(A.{name}-B.{name}))<=("
"((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=("
+ str(self.abs_tol)
+ "+("
+ str(self.rel_tol)
+ "*abs(A.{name})))))"
+ "*abs(A.`{name}`)))))"
)
else: # non-numeric comparison
equal_comparisons.append("((A.{name}=B.{name}))")
equal_comparisons.append("((A.`{name}`=B.`{name}`))")

if self._known_differences:
new_input = "B.{name}"
new_input = "B.`{name}`"
for kd in self._known_differences:
if compare_dtype in kd["types"]:
if "flags" in kd and "nullcheck" in kd["flags"]:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") is null AND A.{name} is null)"
+ ") is null AND A.`{name}` is null)"
)
else:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") = A.{name})"
+ ") = A.`{name}`)"
)

case_string = (
Expand All @@ -624,7 +632,7 @@ def _create_case_statement(self, name: str) -> str:
+ ") THEN {match_success} WHEN ("
+ " OR ".join(known_diff_comparisons)
+ ") THEN {match_known_difference} ELSE {match_failure} END) "
+ "AS {name}, A.{name} AS {name}_base, B.{name} AS {name}_compare"
+ "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`"
)

return case_string.format(
Expand Down
13 changes: 13 additions & 0 deletions tests/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,3 +2093,16 @@ def text_alignment_validator(

if not at_column_section and section_start in line:
at_column_section = True

def test_unicode_columns(spark_session):
df1 = spark_session.createDataFrame(
[{"a": 1, "例": 2}, {"a": 1, "例": 3}]
)
df2 = spark_session.createDataFrame(
[{"a": 1, "例": 2}, {"a": 1, "例": 3}]
)
compare = SparkCompare(
spark_session, df1, df2, join_columns=["例"]
)
# Just render the report to make sure it renders.
compare.report()

0 comments on commit 191c3a3

Please sign in to comment.