Skip to content

Commit

Permalink
Merge pull request #282 from capitalone/develop
Browse files Browse the repository at this point in the history
Release v0.11.2
  • Loading branch information
fdosani authored Mar 19, 2024
2 parents 4119130 + 2a35632 commit bfb6ddf
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 71 deletions.
2 changes: 1 addition & 1 deletion datacompy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.11.1"
__version__ = "0.11.2"

from datacompy.core import *
from datacompy.fugue import (
Expand Down
46 changes: 27 additions & 19 deletions datacompy/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union
from warnings import warn

try:
import pyspark
Expand All @@ -25,6 +26,13 @@
pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality


warn(
f"The module {__name__} is deprecated. In future versions (0.12.0 and above) SparkCompare will be refactored and the legacy logic will move to LegacySparkCompare ",
DeprecationWarning,
stacklevel=2,
)


class MatchType(Enum):
MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3)

Expand Down Expand Up @@ -383,7 +391,7 @@ def rows_only_base(self) -> "pyspark.sql.DataFrame":
base_rows.createOrReplaceTempView("baseRows")
self.base_df.createOrReplaceTempView("baseTable")
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
sql_query = "select A.* from baseTable as A, baseRows as B where {}".format(
join_condition
Expand All @@ -403,7 +411,7 @@ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]:
compare_rows.createOrReplaceTempView("compareRows")
self.compare_df.createOrReplaceTempView("compareTable")
where_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
sql_query = (
"select A.* from compareTable as A, compareRows as B where {}".format(
Expand Down Expand Up @@ -435,15 +443,15 @@ def _generate_select_statement(self, match_data: bool = True) -> str:
[self._create_select_statement(name=column_name)]
)
elif column_name in base_only:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])

elif column_name in compare_only:
if match_data:
select_statement = select_statement + ",".join(["B." + column_name])
select_statement = select_statement + ",".join(["B.`" + column_name + "`"])
else:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])
elif column_name in self._join_column_names:
select_statement = select_statement + ",".join(["A." + column_name])
select_statement = select_statement + ",".join(["A.`" + column_name + "`"])

if column_name != sorted_list[-1]:
select_statement = select_statement + " , "
Expand All @@ -465,7 +473,7 @@ def _merge_dataframes(self) -> None:
self._all_matched_rows.createOrReplaceTempView("matched_table")

where_cond = " OR ".join(
["A." + name + "_match= False" for name in self.columns_compared]
["A.`" + name + "_match`= False" for name in self.columns_compared]
)
mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond)
self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy(
Expand All @@ -475,7 +483,7 @@ def _merge_dataframes(self) -> None:
def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame":
if self._joined_dataframe is None:
join_condition = " AND ".join(
["A." + name + "<=>B." + name for name in self._join_column_names]
["A.`" + name + "`<=>B.`" + name + "`" for name in self._join_column_names]
)
select_statement = self._generate_select_statement(match_data=True)

Expand Down Expand Up @@ -566,22 +574,22 @@ def _create_select_statement(self, name: str) -> str:
match_type_comparison = ""
for k in MatchType:
match_type_comparison += (
" WHEN (A.{name}={match_value}) THEN '{match_name}'".format(
" WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format(
name=name, match_value=str(k.value), match_name=k.name
)
)
return "A.{name}_base, A.{name}_compare, (CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END) AS {name}_match, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS {name}_match_type ".format(
return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format(
name=name,
match_failure=MatchType.MISMATCH.value,
match_type_comparison=match_type_comparison,
)
else:
return "A.{name}_base, A.{name}_compare, CASE WHEN (A.{name}={match_failure}) THEN False ELSE True END AS {name}_match ".format(
return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format(
name=name, match_failure=MatchType.MISMATCH.value
)

def _create_case_statement(self, name: str) -> str:
equal_comparisons = ["(A.{name} IS NULL AND B.{name} IS NULL)"]
equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"]
known_diff_comparisons = ["(FALSE)"]

base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0]
Expand All @@ -592,30 +600,30 @@ def _create_case_statement(self, name: str) -> str:
compare_dtype in NUMERIC_SPARK_TYPES
): # numeric tolerance comparison
equal_comparisons.append(
"((A.{name}=B.{name}) OR ((abs(A.{name}-B.{name}))<=("
"((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=("
+ str(self.abs_tol)
+ "+("
+ str(self.rel_tol)
+ "*abs(A.{name})))))"
+ "*abs(A.`{name}`)))))"
)
else: # non-numeric comparison
equal_comparisons.append("((A.{name}=B.{name}))")
equal_comparisons.append("((A.`{name}`=B.`{name}`))")

if self._known_differences:
new_input = "B.{name}"
new_input = "B.`{name}`"
for kd in self._known_differences:
if compare_dtype in kd["types"]:
if "flags" in kd and "nullcheck" in kd["flags"]:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") is null AND A.{name} is null)"
+ ") is null AND A.`{name}` is null)"
)
else:
known_diff_comparisons.append(
"(("
+ kd["transformation"].format(new_input, input=new_input)
+ ") = A.{name})"
+ ") = A.`{name}`)"
)

case_string = (
Expand All @@ -624,7 +632,7 @@ def _create_case_statement(self, name: str) -> str:
+ ") THEN {match_success} WHEN ("
+ " OR ".join(known_diff_comparisons)
+ ") THEN {match_known_difference} ELSE {match_failure} END) "
+ "AS {name}, A.{name} AS {name}_base, B.{name} AS {name}_compare"
+ "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`"
)

return case_string.format(
Expand Down
60 changes: 9 additions & 51 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@ maintainers = [
{ name="Faisal Dosani", email="[email protected]" }
]
license = {text = "Apache Software License"}
dependencies = [
"pandas<=2.2.0,>=0.25.0",
"numpy<=1.26.4,>=1.22.0",
"ordered-set<=4.1.0,>=4.0.2",
"fugue<=0.8.7,>=0.8.7",
]
dependencies = ["pandas<=2.2.1,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.8.7,>=0.8.7"]
requires-python = ">=3.8.0"
classifiers = [
"Intended Audience :: Developers",
Expand Down Expand Up @@ -61,54 +56,17 @@ python-tag = "py3"
[project.optional-dependencies]
duckdb = ["fugue[duckdb]"]
polars = ["polars"]
spark = [
"pyspark>=3.1.1; python_version < '3.11'",
"pyspark>=3.4; python_version >= '3.11'",
]
spark = ["pyspark>=3.1.1; python_version < \"3.11\"", "pyspark>=3.4; python_version >= \"3.11\""]
dask = ["fugue[dask]"]
ray = ["fugue[ray]"]
docs = [
"sphinx",
"furo",
"myst-parser",
]
tests = [
"pytest",
"pytest-cov",
]
docs = ["sphinx", "furo", "myst-parser"]
tests = ["pytest", "pytest-cov"]

tests-spark = [
"pytest",
"pytest-cov",
"pytest-spark",
"spark",
]
qa = [
"pre-commit",
"black",
"isort",
"mypy",
"pandas-stubs",
]
build = [
"build",
"twine",
"wheel",
]
edgetest = [
"edgetest",
"edgetest-conda",
]
dev = [
"datacompy[duckdb]",
"datacompy[polars]",
"datacompy[spark]",
"datacompy[docs]",
"datacompy[tests]",
"datacompy[tests-spark]",
"datacompy[qa]",
"datacompy[build]",
]
tests-spark = ["pytest", "pytest-cov", "pytest-spark", "spark"]
qa = ["pre-commit", "black", "isort", "mypy", "pandas-stubs"]
build = ["build", "twine", "wheel"]
edgetest = ["edgetest", "edgetest-conda"]
dev = ["datacompy[duckdb]", "datacompy[polars]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"]

[tool.isort]
multi_line_output = 3
Expand Down
13 changes: 13 additions & 0 deletions tests/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,3 +2093,16 @@ def text_alignment_validator(

if not at_column_section and section_start in line:
at_column_section = True

def test_unicode_columns(spark_session):
df1 = spark_session.createDataFrame(
[{"a": 1, "例": 2}, {"a": 1, "例": 3}]
)
df2 = spark_session.createDataFrame(
[{"a": 1, "例": 2}, {"a": 1, "例": 3}]
)
compare = SparkCompare(
spark_session, df1, df2, join_columns=["例"]
)
# Just render the report to make sure it renders.
compare.report()

0 comments on commit bfb6ddf

Please sign in to comment.