diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 26b18def..dc50777d 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -64,17 +64,27 @@ jobs: java-version: '8' distribution: 'adopt' - - name: Install Spark and datacompy + - name: Install Spark, Pandas, and Numpy run: | python -m pip install --upgrade pip python -m pip install pytest pytest-spark pypandoc python -m pip install pyspark[connect]==${{ matrix.spark-version }} python -m pip install pandas==${{ matrix.pandas-version }} python -m pip install numpy==${{ matrix.numpy-version }} + + - name: Install Datacompy without Snowflake/Snowpark if Python 3.12 + if: ${{ matrix.python-version == '3.12' }} + run: | + python -m pip install .[dev_no_snowflake] + + - name: Install Datacompy with all dev dependencies if Python 3.9, 3.10, or 3.11 + if: ${{ matrix.python-version != '3.12' }} + run: | python -m pip install .[dev] + - name: Test with pytest run: | - python -m pytest tests/ + python -m pytest tests/ --ignore=tests/test_snowflake.py test-bare-install: @@ -101,7 +111,7 @@ jobs: python -m pip install .[tests] - name: Test with pytest run: | - python -m pytest tests/ + python -m pytest tests/ --ignore=tests/test_snowflake.py test-fugue-install-no-spark: @@ -127,4 +137,4 @@ jobs: python -m pip install .[tests,duckdb,polars,dask,ray] - name: Test with pytest run: | - python -m pytest tests/ + python -m pytest tests/ --ignore=tests/test_snowflake.py diff --git a/README.md b/README.md index b60457ef..cf6070df 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ pip install datacompy[spark] pip install datacompy[dask] pip install datacompy[duckdb] pip install datacompy[ray] +pip install datacompy[snowflake] ``` @@ -95,6 +96,7 @@ with the Pandas on Spark implementation. Spark plans to support Pandas 2 in [Spa - Pandas: ([See documentation](https://capitalone.github.io/datacompy/pandas_usage.html)) - Spark: ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html)) - Polars: ([See documentation](https://capitalone.github.io/datacompy/polars_usage.html)) +- Snowflake/Snowpark: ([See documentation](https://capitalone.github.io/datacompy/snowflake_usage.html)) - Fugue is a Python library that provides a unified interface for data processing on Pandas, DuckDB, Polars, Arrow, Spark, Dask, Ray, and many other backends. DataComPy integrates with Fugue to provide a simple way to compare data across these backends. Please note that Fugue will use the Pandas (Native) logic at its lowest level diff --git a/datacompy/__init__.py b/datacompy/__init__.py index a6d331e8..74154839 100644 --- a/datacompy/__init__.py +++ b/datacompy/__init__.py @@ -43,12 +43,14 @@ unq_columns, ) from datacompy.polars import PolarsCompare +from datacompy.snowflake import SnowflakeCompare from datacompy.spark.sql import SparkSQLCompare __all__ = [ "BaseCompare", "Compare", "PolarsCompare", + "SnowflakeCompare", "SparkSQLCompare", "all_columns_match", "all_rows_overlap", diff --git a/datacompy/snowflake.py b/datacompy/snowflake.py new file mode 100644 index 00000000..19f63978 --- /dev/null +++ b/datacompy/snowflake.py @@ -0,0 +1,1202 @@ +# +# Copyright 2024 Capital One Services, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compare two Snowpark SQL DataFrames and Snowflake tables. + +Originally this package was meant to provide similar functionality to +PROC COMPARE in SAS - i.e. human-readable reporting on the difference between +two dataframes. +""" + +import logging +import os +from copy import deepcopy +from typing import Any, Dict, List, Optional, Union, cast + +import pandas as pd +from ordered_set import OrderedSet + +try: + import snowflake.snowpark as sp + from snowflake.snowpark import Window + from snowflake.snowpark.exceptions import SnowparkSQLException + from snowflake.snowpark.functions import ( + abs, + col, + concat, + contains, + is_null, + lit, + monotonically_increasing_id, + row_number, + trim, + when, + ) +except ImportError: + pass # for non-snowflake users +from datacompy.base import BaseCompare +from datacompy.spark.sql import decimal_comparator + +LOG = logging.getLogger(__name__) + + +NUMERIC_SNOWPARK_TYPES = [ + "tinyint", + "smallint", + "int", + "bigint", + "float", + "double", + decimal_comparator(), +] + + +class SnowflakeCompare(BaseCompare): + """Comparison class to be used to compare whether two Snowpark dataframes are equal. + + df1 and df2 can refer to either a Snowpark dataframe or the name of a valid Snowflake table. + The data structures which df1 and df2 represent must contain all of the join_columns, + with unique column names. Differences between values are compared to + abs_tol + rel_tol * abs(df2['value']). + + Parameters + ---------- + session: snowflake.snowpark.session + Session with the required connection session info for user and targeted tables + df1 : Union[str, sp.Dataframe] + First table to check, provided either as the table's name or as a Snowpark DF. + df2 : Union[str, sp.Dataframe] + Second table to check, provided either as the table's name or as a Snowpark DF. + join_columns : list or str, optional + Column(s) to join dataframes on. If a string is passed in, that one + column will be used. + abs_tol : float, optional + Absolute tolerance between two values. + rel_tol : float, optional + Relative tolerance between two values. + df1_name : str, optional + A string name for the first dataframe. If used alongside a snowflake table, + overrides the default convention of naming the dataframe after the table. + df2_name : str, optional + A string name for the second dataframe. + ignore_spaces : bool, optional + Flag to strip whitespace (including newlines) from string columns (including any join + columns). + + Attributes + ---------- + df1_unq_rows : sp.DataFrame + All records that are only in df1 (based on a join on join_columns) + df2_unq_rows : sp.DataFrame + All records that are only in df2 (based on a join on join_columns) + """ + + def __init__( + self, + session: "sp.Session", + df1: Union[str, "sp.DataFrame"], + df2: Union[str, "sp.DataFrame"], + join_columns: Optional[Union[List[str], str]], + abs_tol: float = 0, + rel_tol: float = 0, + df1_name: Optional[str] = None, + df2_name: Optional[str] = None, + ignore_spaces: bool = False, + ) -> None: + if join_columns is None: + errmsg = "join_columns cannot be None" + raise ValueError(errmsg) + elif not join_columns: + errmsg = "join_columns is empty" + raise ValueError(errmsg) + elif isinstance(join_columns, (str, int, float)): + self.join_columns = [str(join_columns).replace('"', "").upper()] + else: + self.join_columns = [ + str(col).replace('"', "").upper() + for col in cast(List[str], join_columns) + ] + + self._any_dupes: bool = False + self.session = session + self.df1 = (df1, df1_name) + self.df2 = (df2, df2_name) + self.abs_tol = abs_tol + self.rel_tol = rel_tol + self.ignore_spaces = ignore_spaces + self.df1_unq_rows: sp.DataFrame + self.df2_unq_rows: sp.DataFrame + self.intersect_rows: sp.DataFrame + self.column_stats: List[Dict[str, Any]] = [] + self._compare(ignore_spaces=ignore_spaces) + + @property + def df1(self) -> "sp.DataFrame": + """Get the first dataframe.""" + return self._df1 + + @df1.setter + def df1(self, df1: tuple[Union[str, "sp.DataFrame"], Optional[str]]) -> None: + """Check that df1 is either a Snowpark DF or the name of a valid Snowflake table.""" + (df, df_name) = df1 + if isinstance(df, str): + table_name = [table_comp.upper() for table_comp in df.split(".")] + if len(table_name) != 3: + errmsg = f"{df} is not a valid table name. Be sure to include the target db and schema." + raise ValueError(errmsg) + self.df1_name = df_name.upper() if df_name else table_name[2] + self._df1 = self.session.table(df) + else: + self._df1 = df + self.df1_name = df_name.upper() if df_name else "DF1" + self._validate_dataframe(self.df1_name, "df1") + + @property + def df2(self) -> "sp.DataFrame": + """Get the second dataframe.""" + return self._df2 + + @df2.setter + def df2(self, df2: tuple[Union[str, "sp.DataFrame"], Optional[str]]) -> None: + """Check that df2 is either a Snowpark DF or the name of a valid Snowflake table.""" + (df, df_name) = df2 + if isinstance(df, str): + table_name = [table_comp.upper() for table_comp in df.split(".")] + if len(table_name) != 3: + errmsg = f"{df} is not a valid table name. Be sure to include the target db and schema." + raise ValueError(errmsg) + self.df2_name = df_name.upper() if df_name else table_name[2] + self._df2 = self.session.table(df) + else: + self._df2 = df + self.df2_name = df_name.upper() if df_name else "DF2" + self._validate_dataframe(self.df2_name, "df2") + + def _validate_dataframe(self, df_name: str, index: str) -> None: + """Validate the provided Snowpark dataframe. + + The dataframe can either be a standalone Snowpark dataframe or a representative + of a Snowflake table - in the latter case we check that the table it represents + is a valid table by forcing a collection. + + Parameters + ---------- + df_name : str + Name of the Snowflake table / Snowpark dataframe + index : str + The "index" of the dataframe - df1 or df2. + """ + df = getattr(self, index) + if not isinstance(df, "sp.DataFrame"): + raise TypeError(f"{df_name} must be a valid sp.Dataframe") + + # force all columns to be non-case-sensitive + if index == "df1": + col_map = dict( + zip( + self._df1.columns, + [str(c).replace('"', "").upper() for c in self._df1.columns], + ) + ) + self._df1 = self._df1.rename(col_map) + if index == "df2": + col_map = dict( + zip( + self._df2.columns, + [str(c).replace('"', "").upper() for c in self._df2.columns], + ) + ) + self._df2 = self._df2.rename(dict(col_map)) + + df = getattr(self, index) # refresh + if not set(self.join_columns).issubset(set(df.columns)): + raise ValueError(f"{df_name} must have all columns from join_columns") + if len(set(df.columns)) < len(df.columns): + raise ValueError(f"{df_name} must have unique column names") + + if df.drop_duplicates(self.join_columns).count() < df.count(): + self._any_dupes = True + + def _compare(self, ignore_spaces: bool) -> None: + """Actually run the comparison. + + This method will log out information about what is different between + the two dataframes. + """ + LOG.info(f"Number of columns in common: {len(self.intersect_columns())}") + LOG.debug("Checking column overlap") + for column in self.df1_unq_columns(): + LOG.info(f"Column in df1 and not in df2: {column}") + LOG.info( + f"Number of columns in df1 and not in df2: {len(self.df1_unq_columns())}" + ) + for column in self.df2_unq_columns(): + LOG.info(f"Column in df2 and not in df1: {column}") + LOG.info( + f"Number of columns in df2 and not in df1: {len(self.df2_unq_columns())}" + ) + LOG.debug("Merging dataframes") + self._dataframe_merge(ignore_spaces) + self._intersect_compare(ignore_spaces) + if self.matches(): + LOG.info("df1 matches df2") + else: + LOG.info("df1 does not match df2") + + def df1_unq_columns(self) -> OrderedSet[str]: + """Get columns that are unique to df1.""" + return cast( + OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) + ) + + def df2_unq_columns(self) -> OrderedSet[str]: + """Get columns that are unique to df2.""" + return cast( + OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) + ) + + def intersect_columns(self) -> OrderedSet[str]: + """Get columns that are shared between the two dataframes.""" + return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns) + + def _dataframe_merge(self, ignore_spaces: bool) -> None: + """Merge df1 to df2 on the join columns. + + Gets df1 - df2, df2 - df1, and df1 & df2 + joining on the ``join_columns``. + """ + LOG.debug("Outer joining") + + df1 = self.df1 + df2 = self.df2 + temp_join_columns = deepcopy(self.join_columns) + + if self._any_dupes: + LOG.debug("Duplicate rows found, deduping by order of remaining fields") + # setting internal index + LOG.info("Adding internal index to dataframes") + df1 = df1.withColumn("__index", monotonically_increasing_id()) + df2 = df2.withColumn("__index", monotonically_increasing_id()) + + # Create order column for uniqueness of match + order_column = temp_column_name(df1, df2) + df1 = df1.join( + _generate_id_within_group(df1, temp_join_columns, order_column), + on="__index", + how="inner", + ).drop("__index") + df2 = df2.join( + _generate_id_within_group(df2, temp_join_columns, order_column), + on="__index", + how="inner", + ).drop("__index") + temp_join_columns.append(order_column) + + # drop index + LOG.info("Dropping internal index") + df1 = df1.drop("__index") + df2 = df2.drop("__index") + + if ignore_spaces: + for column in self.join_columns: + if "string" in next( + dtype for name, dtype in df1.dtypes if name == column + ): + df1 = df1.withColumn(column, trim(col(column))) + if "string" in next( + dtype for name, dtype in df2.dtypes if name == column + ): + df2 = df2.withColumn(column, trim(col(column))) + + df1 = df1.withColumn("merge", lit(True)) + df2 = df2.withColumn("merge", lit(True)) + + for c in df1.columns: + df1 = df1.withColumnRenamed(c, c + "_" + self.df1_name) + for c in df2.columns: + df2 = df2.withColumnRenamed(c, c + "_" + self.df2_name) + + # NULL SAFE Outer join, not possible with Snowpark Dataframe join + df1.createOrReplaceTempView("df1") + df2.createOrReplaceTempView("df2") + on = " and ".join( + [ + f"EQUAL_NULL(df1.{c}_{self.df1_name}, df2.{c}_{self.df2_name})" + for c in temp_join_columns + ] + ) + outer_join = self.session.sql( + """ + SELECT * FROM + df1 FULL OUTER JOIN df2 + ON + """ + + on + ) + # Create join indicator + outer_join = outer_join.withColumn( + "_merge", + when( + outer_join[f"MERGE_{self.df1_name}"] + & outer_join[f"MERGE_{self.df2_name}"], + lit("BOTH"), + ) + .when( + outer_join[f"MERGE_{self.df1_name}"] + & outer_join[f"MERGE_{self.df2_name}"].is_null(), + lit("LEFT_ONLY"), + ) + .when( + outer_join[f"MERGE_{self.df1_name}"].is_null() + & outer_join[f"MERGE_{self.df2_name}"], + lit("RIGHT_ONLY"), + ), + ) + + df1 = df1.drop(f"MERGE_{self.df1_name}") + df2 = df2.drop(f"MERGE_{self.df2_name}") + + # Clean up temp columns for duplicate row matching + if self._any_dupes: + outer_join = outer_join.select_expr( + f"* EXCLUDE ({order_column}_{self.df1_name}, {order_column}_{self.df2_name})" + ) + df1 = df1.drop(f"{order_column}_{self.df1_name}") + df2 = df2.drop(f"{order_column}_{self.df2_name}") + + # Capitalization required - clean up + df1_cols = get_merged_columns(df1, outer_join, self.df1_name) + df2_cols = get_merged_columns(df2, outer_join, self.df2_name) + + LOG.debug("Selecting df1 unique rows") + self.df1_unq_rows = outer_join[outer_join["_merge"] == "LEFT_ONLY"][df1_cols] + + LOG.debug("Selecting df2 unique rows") + self.df2_unq_rows = outer_join[outer_join["_merge"] == "RIGHT_ONLY"][df2_cols] + LOG.info(f"Number of rows in df1 and not in df2: {self.df1_unq_rows.count()}") + LOG.info(f"Number of rows in df2 and not in df1: {self.df2_unq_rows.count()}") + + LOG.debug("Selecting intersecting rows") + self.intersect_rows = outer_join[outer_join["_merge"] == "BOTH"] + LOG.info( + f"Number of rows in df1 and df2 (not necessarily equal): {self.intersect_rows.count()}" + ) + self.intersect_rows = self.intersect_rows.cache_result() + + def _intersect_compare(self, ignore_spaces: bool) -> None: + """Run the comparison on the intersect dataframe. + + This loops through all columns that are shared between df1 and df2, and + creates a column column_match which is True for matches, False + otherwise. + """ + LOG.debug("Comparing intersection") + max_diff: float + null_diff: int + row_cnt = self.intersect_rows.count() + for column in self.intersect_columns(): + if column in self.join_columns: + match_cnt = row_cnt + col_match = "" + max_diff = 0 + null_diff = 0 + else: + col_1 = column + "_" + self.df1_name + col_2 = column + "_" + self.df2_name + col_match = column + "_MATCH" + self.intersect_rows = columns_equal( + self.intersect_rows, + col_1, + col_2, + col_match, + self.rel_tol, + self.abs_tol, + ignore_spaces, + ) + match_cnt = ( + self.intersect_rows.select(col_match) + .where(col(col_match) == True) # noqa: E712 + .count() + ) + max_diff = calculate_max_diff( + self.intersect_rows, + col_1, + col_2, + ) + null_diff = calculate_null_diff(self.intersect_rows, col_1, col_2) + + if row_cnt > 0: + match_rate = float(match_cnt) / row_cnt + else: + match_rate = 0 + LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match") + + col1_dtype, _ = _get_column_dtypes(self.df1, column, column) + col2_dtype, _ = _get_column_dtypes(self.df2, column, column) + + self.column_stats.append( + { + "column": column, + "match_column": col_match, + "match_cnt": match_cnt, + "unequal_cnt": row_cnt - match_cnt, + "dtype1": str(col1_dtype), + "dtype2": str(col2_dtype), + "all_match": all( + ( + col1_dtype == col2_dtype, + row_cnt == match_cnt, + ) + ), + "max_diff": max_diff, + "null_diff": null_diff, + } + ) + + def all_columns_match(self) -> bool: + """Whether the columns all match in the dataframes. + + Returns + ------- + bool + True if all columns in df1 are in df2 and vice versa + """ + return self.df1_unq_columns() == self.df2_unq_columns() == set() + + def all_rows_overlap(self) -> bool: + """Whether the rows are all present in both dataframes. + + Returns + ------- + bool + True if all rows in df1 are in df2 and vice versa (based on + existence for join option) + """ + return self.df1_unq_rows.count() == self.df2_unq_rows.count() == 0 + + def count_matching_rows(self) -> int: + """Count the number of rows match (on overlapping fields). + + Returns + ------- + int + Number of matching rows + """ + conditions = [] + match_columns = [] + for column in self.intersect_columns(): + if column not in self.join_columns: + match_columns.append(column + "_MATCH") + conditions.append(f"{column}_MATCH = True") + if len(conditions) > 0: + match_columns_count = self.intersect_rows.filter( + " and ".join(conditions) + ).count() + else: + match_columns_count = 0 + return match_columns_count + + def intersect_rows_match(self) -> bool: + """Check whether the intersect rows all match.""" + actual_length = self.intersect_rows.count() + return self.count_matching_rows() == actual_length + + def matches(self, ignore_extra_columns: bool = False) -> bool: + """Return True or False if the dataframes match. + + Parameters + ---------- + ignore_extra_columns : bool + Ignores any columns in one dataframe and not in the other. + + Returns + ------- + bool + True or False if the dataframes match. + """ + return not ( + (not ignore_extra_columns and not self.all_columns_match()) + or not self.all_rows_overlap() + or not self.intersect_rows_match() + ) + + def subset(self) -> bool: + """Return True if dataframe 2 is a subset of dataframe 1. + + Dataframe 2 is considered a subset if all of its columns are in + dataframe 1, and all of its rows match rows in dataframe 1 for the + shared columns. + + Returns + ------- + bool + True if dataframe 2 is a subset of dataframe 1. + """ + return not ( + self.df2_unq_columns() != set() + or self.df2_unq_rows.count() != 0 + or not self.intersect_rows_match() + ) + + def sample_mismatch( + self, column: str, sample_count: int = 10, for_display: bool = False + ) -> "sp.DataFrame": + """Return sample mismatches. + + Gets a sub-dataframe which contains the identifying + columns, and df1 and df2 versions of the column. + + Parameters + ---------- + column : str + The raw column name (i.e. without ``_df1`` appended) + sample_count : int, optional + The number of sample records to return. Defaults to 10. + for_display : bool, optional + Whether this is just going to be used for display (overwrite the + column names) + + Returns + ------- + sp.DataFrame + A sample of the intersection dataframe, containing only the + "pertinent" columns, for rows that don't match on the provided + column. + """ + row_cnt = self.intersect_rows.count() + col_match = self.intersect_rows.select(column + "_MATCH") + match_cnt = col_match.where( + col(column + "_MATCH") == True # noqa: E712 + ).count() + sample_count = min(sample_count, row_cnt - match_cnt) + sample = ( + self.intersect_rows.where(col(column + "_MATCH") == False) # noqa: E712 + .drop(column + "_MATCH") + .limit(sample_count) + ) + + for c in self.join_columns: + sample = sample.withColumnRenamed(c + "_" + self.df1_name, c) + + return_cols = [ + *self.join_columns, + column + "_" + self.df1_name, + column + "_" + self.df2_name, + ] + to_return = sample.select(return_cols) + + if for_display: + return to_return.toDF( + *[ + *self.join_columns, + column + " (" + self.df1_name + ")", + column + " (" + self.df2_name + ")", + ] + ) + return to_return + + def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame": + """Get all rows with any columns that have a mismatch. + + Returns all df1 and df2 versions of the columns and join + columns. + + Parameters + ---------- + ignore_matching_cols : bool, optional + Whether showing the matching columns in the output or not. The default is False. + + Returns + ------- + sp.DataFrame + All rows of the intersection dataframe, containing any columns, that don't match. + """ + match_list = [] + return_list = [] + for c in self.intersect_rows.columns: + if c.endswith("_MATCH"): + orig_col_name = c[:-6] + + col_comparison = columns_equal( + self.intersect_rows, + orig_col_name + "_" + self.df1_name, + orig_col_name + "_" + self.df2_name, + c, + self.rel_tol, + self.abs_tol, + self.ignore_spaces, + ) + + if not ignore_matching_cols or ( + ignore_matching_cols + and col_comparison.select(c) + .where(col(c) == False) # noqa: E712 + .count() + > 0 + ): + LOG.debug(f"Adding column {orig_col_name} to the result.") + match_list.append(c) + return_list.extend( + [ + orig_col_name + "_" + self.df1_name, + orig_col_name + "_" + self.df2_name, + ] + ) + elif ignore_matching_cols: + LOG.debug( + f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result." + ) + + mm_rows = self.intersect_rows.withColumn( + "match_array", concat(*match_list) + ).where(contains(col("match_array"), lit("false"))) + + for c in self.join_columns: + mm_rows = mm_rows.withColumnRenamed(c + "_" + self.df1_name, c) + + return mm_rows.select(self.join_columns + return_list) + + def report( + self, + sample_count: int = 10, + column_count: int = 10, + html_file: Optional[str] = None, + ) -> str: + """Return a string representation of a report. + + The representation can + then be printed or saved to a file. + + Parameters + ---------- + sample_count : int, optional + The number of sample records to return. Defaults to 10. + + column_count : int, optional + The number of columns to display in the sample records output. Defaults to 10. + + html_file : str, optional + HTML file name to save report output to. If ``None`` the file creation will be skipped. + + Returns + ------- + str + The report, formatted kinda nicely. + """ + # Header + report = render("header.txt") + df_header = pd.DataFrame( + { + "DataFrame": [self.df1_name, self.df2_name], + "Columns": [len(self.df1.columns), len(self.df2.columns)], + "Rows": [self.df1.count(), self.df2.count()], + } + ) + report += df_header[["DataFrame", "Columns", "Rows"]].to_string() + report += "\n\n" + + # Column Summary + report += render( + "column_summary.txt", + len(self.intersect_columns()), + len(self.df1_unq_columns()), + len(self.df2_unq_columns()), + self.df1_name, + self.df2_name, + ) + + # Row Summary + match_on = ", ".join(self.join_columns) + report += render( + "row_summary.txt", + match_on, + self.abs_tol, + self.rel_tol, + self.intersect_rows.count(), + self.df1_unq_rows.count(), + self.df2_unq_rows.count(), + self.intersect_rows.count() - self.count_matching_rows(), + self.count_matching_rows(), + self.df1_name, + self.df2_name, + "Yes" if self._any_dupes else "No", + ) + + # Column Matching + report += render( + "column_comparison.txt", + len([col for col in self.column_stats if col["unequal_cnt"] > 0]), + len([col for col in self.column_stats if col["unequal_cnt"] == 0]), + sum(col["unequal_cnt"] for col in self.column_stats), + ) + + match_stats = [] + match_sample = [] + any_mismatch = False + for column in self.column_stats: + if not column["all_match"]: + any_mismatch = True + match_stats.append( + { + "Column": column["column"], + f"{self.df1_name} dtype": column["dtype1"], + f"{self.df2_name} dtype": column["dtype2"], + "# Unequal": column["unequal_cnt"], + "Max Diff": column["max_diff"], + "# Null Diff": column["null_diff"], + } + ) + if column["unequal_cnt"] > 0: + match_sample.append( + self.sample_mismatch( + column["column"], sample_count, for_display=True + ) + ) + + if any_mismatch: + report += "Columns with Unequal Values or Types\n" + report += "------------------------------------\n" + report += "\n" + df_match_stats = pd.DataFrame(match_stats) + df_match_stats.sort_values("Column", inplace=True) + # Have to specify again for sorting + report += df_match_stats[ + [ + "Column", + f"{self.df1_name} dtype", + f"{self.df2_name} dtype", + "# Unequal", + "Max Diff", + "# Null Diff", + ] + ].to_string() + report += "\n\n" + + if sample_count > 0: + report += "Sample Rows with Unequal Values\n" + report += "-------------------------------\n" + report += "\n" + for sample in match_sample: + report += sample.toPandas().to_string() + report += "\n\n" + + if min(sample_count, self.df1_unq_rows.count()) > 0: + report += ( + f"Sample Rows Only in {self.df1_name} (First {column_count} Columns)\n" + ) + report += ( + f"---------------------------------------{'-' * len(self.df1_name)}\n" + ) + report += "\n" + columns = self.df1_unq_rows.columns[:column_count] + unq_count = min(sample_count, self.df1_unq_rows.count()) + report += ( + self.df1_unq_rows.limit(unq_count) + .select(columns) + .toPandas() + .to_string() + ) + report += "\n\n" + + if min(sample_count, self.df2_unq_rows.count()) > 0: + report += ( + f"Sample Rows Only in {self.df2_name} (First {column_count} Columns)\n" + ) + report += ( + f"---------------------------------------{'-' * len(self.df2_name)}\n" + ) + report += "\n" + columns = self.df2_unq_rows.columns[:column_count] + unq_count = min(sample_count, self.df2_unq_rows.count()) + report += ( + self.df2_unq_rows.limit(unq_count) + .select(columns) + .toPandas() + .to_string() + ) + report += "\n\n" + + if html_file: + html_report = report.replace("\n", "
").replace(" ", " ") + html_report = f"
{html_report}
" + with open(html_file, "w") as f: + f.write(html_report) + + return report + + +def render(filename: str, *fields: Union[int, float, str]) -> str: + """Render out an individual template. + + This basically just reads in a + template file, and applies ``.format()`` on the fields. + + Parameters + ---------- + filename : str + The file that contains the template. Will automagically prepend the + templates directory before opening + fields : list + Fields to be rendered out in the template + + Returns + ------- + str + The fully rendered out file. + """ + this_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(this_dir, "templates", filename)) as file_open: + return file_open.read().format(*fields) + + +def columns_equal( + dataframe: "sp.DataFrame", + col_1: str, + col_2: str, + col_match: str, + rel_tol: float = 0, + abs_tol: float = 0, + ignore_spaces: bool = False, +) -> "sp.DataFrame": + """Compare two columns from a dataframe. + + Returns a True/False series with the same index as column 1. + + - Two nulls (np.nan) will evaluate to True. + - A null and a non-null value will evaluate to False. + - Numeric values will use the relative and absolute tolerances. + - Decimal values (decimal.Decimal) will attempt to be converted to floats + before comparing + - Non-numeric values (i.e. where np.isclose can't be used) will just + trigger True on two nulls or exact matches. + + Parameters + ---------- + dataframe: sp.DataFrame + DataFrame to do comparison on + col_1 : str + The first column to look at + col_2 : str + The second column + col_match : str + The matching column denoting if the compare was a match or not + rel_tol : float, optional + Relative tolerance + abs_tol : float, optional + Absolute tolerance + ignore_spaces : bool, optional + Flag to strip whitespace (including newlines) from string columns + + Returns + ------- + sp.DataFrame + A column of boolean values are added. True == the values match, False == the + values don't match. + """ + base_dtype, compare_dtype = _get_column_dtypes(dataframe, col_1, col_2) + if _is_comparable(base_dtype, compare_dtype): + if (base_dtype in NUMERIC_SNOWPARK_TYPES) and ( + compare_dtype in NUMERIC_SNOWPARK_TYPES + ): # numeric tolerance comparison + dataframe = dataframe.withColumn( + col_match, + when( + (col(col_1).eqNullSafe(col(col_2))) + | ( + abs(col(col_1) - col(col_2)) + <= lit(abs_tol) + (lit(rel_tol) * abs(col(col_2))) + ), + # corner case of col1 != NaN and col2 == Nan returns True incorrectly + when( + (is_null(col(col_1)) == False) # noqa: E712 + & (is_null(col(col_2)) == True), # noqa: E712 + lit(False), + ).otherwise(lit(True)), + ).otherwise(lit(False)), + ) + else: # non-numeric comparison + if ignore_spaces: + when_clause = trim(col(col_1)).eqNullSafe(trim(col(col_2))) + else: + when_clause = col(col_1).eqNullSafe(col(col_2)) + + dataframe = dataframe.withColumn( + col_match, + when(when_clause, lit(True)).otherwise(lit(False)), + ) + else: + LOG.debug( + f"Skipping {col_1}({base_dtype}) and {col_2}({compare_dtype}), columns are not comparable" + ) + dataframe = dataframe.withColumn(col_match, lit(False)) + return dataframe + + +def get_merged_columns( + original_df: "sp.DataFrame", merged_df: "sp.DataFrame", suffix: str +) -> List[str]: + """Get the columns from an original dataframe, in the new merged dataframe. + + Parameters + ---------- + original_df : sp.DataFrame + The original, pre-merge dataframe + merged_df : sp.DataFrame + Post-merge with another dataframe, with suffixes added in. + suffix : str + What suffix was used to distinguish when the original dataframe was + overlapping with the other merged dataframe. + + Returns + ------- + List[str] + Column list of the original dataframe pre suffix + """ + columns = [] + for column in original_df.columns: + if column in merged_df.columns: + columns.append(column) + elif f"{column}_{suffix}" in merged_df.columns: + columns.append(f"{column}_{suffix}") + else: + raise ValueError("Column not found: %s", column) + return columns + + +def calculate_max_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> float: + """Get a maximum difference between two columns. + + Parameters + ---------- + dataframe: sp.DataFrame + DataFrame to do comparison on + col_1 : str + The first column to look at + col_2 : str + The second column + + Returns + ------- + float + max diff + """ + # Attempting to coalesce maximum diff for non-numeric results in error, if error return 0 max diff. + try: + diff = dataframe.select( + (col(col_1).astype("float") - col(col_2).astype("float")).alias("diff") + ) + abs_diff = diff.select(abs(col("diff")).alias("abs_diff")) + max_diff: float = ( + abs_diff.where(is_null(col("abs_diff")) == False) # noqa: E712 + .agg({"abs_diff": "max"}) + .collect()[0][0] + ) + except SnowparkSQLException: + return None + + if pd.isna(max_diff) or pd.isnull(max_diff) or max_diff is None: + return 0 + else: + return max_diff + + +def calculate_null_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> int: + """Get the null differences between two columns. + + Parameters + ---------- + dataframe: sp.DataFrame + DataFrame to do comparison on + col_1 : str + The first column to look at + col_2 : str + The second column + + Returns + ------- + int + null diff + """ + nulls_df = dataframe.withColumn( + "col_1_null", + when(col(col_1).isNull() == True, lit(True)).otherwise( # noqa: E712 + lit(False) + ), + ) + nulls_df = nulls_df.withColumn( + "col_2_null", + when(col(col_2).isNull() == True, lit(True)).otherwise( # noqa: E712 + lit(False) + ), + ).select(["col_1_null", "col_2_null"]) + + # (not a and b) or (a and not b) + null_diff = nulls_df.where( + ((col("col_1_null") == False) & (col("col_2_null") == True)) # noqa: E712 + | ((col("col_1_null") == True) & (col("col_2_null") == False)) # noqa: E712 + ).count() + + if pd.isna(null_diff) or pd.isnull(null_diff) or null_diff is None: + return 0 + else: + return null_diff + + +def _generate_id_within_group( + dataframe: "sp.DataFrame", join_columns: List[str], order_column_name: str +) -> "sp.DataFrame": + """Generate an ID column that can be used to deduplicate identical rows. + + The series generated + is the order within a unique group, and it handles nulls. Requires a ``__index`` column. + + Parameters + ---------- + dataframe : sp.DataFrame + The dataframe to operate on + join_columns : list + List of strings which are the join columns + order_column_name: str + The name of the ``row_number`` column name + + Returns + ------- + sp.DataFrame + Original dataframe with the ID column that's unique in each group + """ + default_value = "DATACOMPY_NULL" + null_check = False + default_check = False + for c in join_columns: + if dataframe.where(col(c).isNull()).limit(1).collect(): + null_check = True + break + for c in [ + column for column, type in dataframe[join_columns].dtypes if "string" in type + ]: + if dataframe.where(col(c).isin(default_value)).limit(1).collect(): + default_check = True + break + + if null_check: + if default_check: + raise ValueError(f"{default_value} was found in your join columns") + + return ( + dataframe.select( + *(col(c).cast("string").alias(c) for c in join_columns + ["__index"]) # noqa: RUF005 + ) + .fillna(default_value) + .withColumn( + order_column_name, + row_number().over(Window.orderBy("__index").partitionBy(join_columns)) + - 1, + ) + .select(["__index", order_column_name]) + ) + else: + return ( + dataframe.select(join_columns + ["__index"]) # noqa: RUF005 + .withColumn( + order_column_name, + row_number().over(Window.orderBy("__index").partitionBy(join_columns)) + - 1, + ) + .select(["__index", order_column_name]) + ) + + +def _get_column_dtypes( + dataframe: "sp.DataFrame", col_1: "str", col_2: "str" +) -> tuple[str, str]: + """Get the dtypes of two columns. + + Parameters + ---------- + dataframe: sp.DataFrame + DataFrame to do comparison on + col_1 : str + The first column to look at + col_2 : str + The second column + + Returns + ------- + Tuple(str, str) + Tuple of base and compare datatype + """ + base_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_1) + compare_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_2) + return base_dtype, compare_dtype + + +def _is_comparable(type1: str, type2: str) -> bool: + """Check if two SnowPark data types can be safely compared. + + Two data types are considered comparable if any of the following apply: + 1. Both data types are the same + 2. Both data types are numeric + + Parameters + ---------- + type1 : str + A string representation of a Snowpark data type + type2 : str + A string representation of a Snowpark data type + + Returns + ------- + bool + True if both data types are comparable + """ + return ( + type1 == type2 + or (type1 in NUMERIC_SNOWPARK_TYPES and type2 in NUMERIC_SNOWPARK_TYPES) + or ("string" in type1 and type2 == "date") + or (type1 == "date" and "string" in type2) + or ("string" in type1 and type2 == "timestamp") + or (type1 == "timestamp" and "string" in type2) + ) + + +def temp_column_name(*dataframes) -> str: + """Get a temp column name that isn't included in columns of any dataframes. + + Parameters + ---------- + dataframes : list of DataFrames + The DataFrames to create a temporary column name for + + Returns + ------- + str + String column name that looks like '_temp_x' for some integer x + """ + i = 0 + columns = [] + for dataframe in dataframes: + columns = columns + list(dataframe.columns) + columns = set(columns) + + while True: + temp_column = f"_TEMP_{i}" + unique = True + + if temp_column in columns: + i += 1 + unique = False + if unique: + return temp_column diff --git a/docs/source/index.rst b/docs/source/index.rst index e6f77d96..b1cb4c4a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,6 +11,7 @@ Contents Installation Pandas Usage Spark Usage + Snowflake Usage Polars Usage Fugue Usage Benchmarks @@ -28,4 +29,4 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` -* :ref:`search` \ No newline at end of file +* :ref:`search` diff --git a/docs/source/snowflake_usage.rst b/docs/source/snowflake_usage.rst new file mode 100644 index 00000000..3c2687e3 --- /dev/null +++ b/docs/source/snowflake_usage.rst @@ -0,0 +1,268 @@ +Snowpark/Snowflake Usage +======================== + +For ``SnowflakeCompare`` + +- ``on_index`` is not supported. +- Joining is done using ``EQUAL_NULL`` which is the equality test that is safe for null values. +- Compares ``snowflake.snowpark.DataFrame``, which can be provided as either raw Snowflake dataframes +or the as the names of full names of valid snowflake tables, which we will process into Snowpark dataframes. + + +SnowflakeCompare Object Setup +--------------------------------------------------- +There are two ways to specify input dataframes for ``SnowflakeCompare`` + +Provide Snowpark dataframes +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from snowflake.snowpark import Session + from snowflake.snowpark import Row + import datetime + import datacompy.snowflake as sp + + connection_parameters = { + ... + } + session = Session.builder.configs(connection_parameters).create() + + data1 = [ + Row(acct_id=10000001234, dollar_amt=123.45, name='George Maharis', float_fld=14530.1555, + date_fld=datetime.date(2017, 1, 1)), + Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=1.0, + date_fld=datetime.date(2017, 1, 1)), + Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=None, + date_fld=datetime.date(2017, 1, 1)), + Row(acct_id=10000001237, dollar_amt=123456.0, name='Bob Loblaw', float_fld=345.12, + date_fld=datetime.date(2017, 1, 1)), + Row(acct_id=10000001239, dollar_amt=1.05, name='Lucille Bluth', float_fld=None, + date_fld=datetime.date(2017, 1, 1)), + ] + + data2 = [ + Row(acct_id=10000001234, dollar_amt=123.4, name='George Michael Bluth', float_fld=14530.155), + Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=None), + Row(acct_id=None, dollar_amt=1345.0, name='George Bluth', float_fld=1.0), + Row(acct_id=10000001237, dollar_amt=123456.0, name='Robert Loblaw', float_fld=345.12), + Row(acct_id=10000001238, dollar_amt=1.05, name='Loose Seal Bluth', float_fld=111.0), + ] + + df_1 = session.createDataFrame(data1) + df_2 = session.createDataFrame(data2) + + compare = sp.SnowflakeCompare( + session, + df_1, + df_2, + join_columns=['acct_id'], + rel_tol=1e-03, + abs_tol=1e-04, + ) + compare.matches(ignore_extra_columns=False) + + # This method prints out a human-readable report summarizing and sampling differences + print(compare.report()) + + +Provide the full name (``{db}.{schema}.{table_name}``) of valid Snowflake tables +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Given the dataframes from the prior examples... + +.. code-block:: python + df_1.write.mode("overwrite").save_as_table("toy_table_1") + df_2.write.mode("overwrite").save_as_table("toy_table_2") + + compare = sp.SnowflakeCompare( + session, + f"{db}.{schema}.toy_table_1", + f"{db}.{schema}.toy_table_2", + join_columns=['acct_id'], + rel_tol=1e-03, + abs_tol=1e-04, + ) + compare.matches(ignore_extra_columns=False) + + # This method prints out a human-readable report summarizing and sampling differences + print(compare.report()) + +Reports +------- + +A report is generated by calling ``report()``, which returns a string. +Here is a sample report generated by ``datacompy`` for the two tables above, +joined on ``acct_id`` (Note: the names for your dataframes are extracted from +the name of the provided Snowflake table. If you chose to directly use Snowpark +dataframes, then the names will default to ``DF1`` and ``DF2``.):: + + DataComPy Comparison + -------------------- + + DataFrame Summary + ----------------- + + DataFrame Columns Rows + 0 DF1 5 5 + 1 DF2 4 5 + + Column Summary + -------------- + + Number of columns in common: 4 + Number of columns in DF1 but not in DF2: 1 + Number of columns in DF2 but not in DF1: 0 + + Row Summary + ----------- + + Matched on: ACCT_ID + Any duplicates on match values: No + Absolute Tolerance: 0 + Relative Tolerance: 0 + Number of rows in common: 4 + Number of rows in DF1 but not in DF2: 1 + Number of rows in DF2 but not in DF1: 1 + + Number of rows with some compared columns unequal: 4 + Number of rows with all compared columns equal: 0 + + Column Comparison + ----------------- + + Number of columns compared with some values unequal: 3 + Number of columns compared with all values equal: 1 + Total number of values which compare unequal: 6 + + Columns with Unequal Values or Types + ------------------------------------ + + Column DF1 dtype DF2 dtype # Unequal Max Diff # Null Diff + 0 DOLLAR_AMT double double 1 0.0500 0 + 2 FLOAT_FLD double double 3 0.0005 2 + 1 NAME string(16777216) string(16777216) 2 NaN 0 + + Sample Rows with Unequal Values + ------------------------------- + + ACCT_ID DOLLAR_AMT (DF1) DOLLAR_AMT (DF2) + 0 10000001234 123.45 123.4 + + ACCT_ID NAME (DF1) NAME (DF2) + 0 10000001234 George Maharis George Michael Bluth + 1 10000001237 Bob Loblaw Robert Loblaw + + ACCT_ID FLOAT_FLD (DF1) FLOAT_FLD (DF2) + 0 10000001234 14530.1555 14530.155 + 1 10000001235 1.0000 NaN + 2 10000001236 NaN 1.000 + + Sample Rows Only in DF1 (First 10 Columns) + ------------------------------------------ + + ACCT_ID_DF1 DOLLAR_AMT_DF1 NAME_DF1 FLOAT_FLD_DF1 DATE_FLD_DF1 + 0 10000001239 1.05 Lucille Bluth NaN 2017-01-01 + + Sample Rows Only in DF2 (First 10 Columns) + ------------------------------------------ + + ACCT_ID_DF2 DOLLAR_AMT_DF2 NAME_DF2 FLOAT_FLD_DF2 + 0 10000001238 1.05 Loose Seal Bluth 111.0 + + +Convenience Methods +------------------- + +There are a few convenience methods and attributes available after the comparison has been run: + +.. code-block:: python + + compare.intersect_rows[['name_df1', 'name_df2', 'name_match']].show() + # -------------------------------------------------------- + # |"NAME_DF1" |"NAME_DF2" |"NAME_MATCH" | + # -------------------------------------------------------- + # |George Maharis |George Michael Bluth |False | + # |Michael Bluth |Michael Bluth |True | + # |George Bluth |George Bluth |True | + # |Bob Loblaw |Robert Loblaw |False | + # -------------------------------------------------------- + + compare.df1_unq_rows.show() + # --------------------------------------------------------------------------------------- + # |"ACCT_ID_DF1" |"DOLLAR_AMT_DF1" |"NAME_DF1" |"FLOAT_FLD_DF1" |"DATE_FLD_DF1" | + # --------------------------------------------------------------------------------------- + # |10000001239 |1.05 |Lucille Bluth |NULL |2017-01-01 | + # --------------------------------------------------------------------------------------- + + compare.df2_unq_rows.show() + # ------------------------------------------------------------------------- + # |"ACCT_ID_DF2" |"DOLLAR_AMT_DF2" |"NAME_DF2" |"FLOAT_FLD_DF2" | + # ------------------------------------------------------------------------- + # |10000001238 |1.05 |Loose Seal Bluth |111.0 | + # ------------------------------------------------------------------------- + + print(compare.intersect_columns()) + # OrderedSet(['acct_id', 'dollar_amt', 'name', 'float_fld']) + + print(compare.df1_unq_columns()) + # OrderedSet(['date_fld']) + + print(compare.df2_unq_columns()) + # OrderedSet() + +Duplicate rows +-------------- + +Datacompy will try to handle rows that are duplicate in the join columns. It does this behind the +scenes by generating a unique ID within each unique group of the join columns. For example, if you +have two dataframes you're trying to join on acct_id: + +=========== ================ +acct_id name +=========== ================ +1 George Maharis +1 Michael Bluth +2 George Bluth +=========== ================ + +=========== ================ +acct_id name +=========== ================ +1 George Maharis +1 Michael Bluth +1 Tony Wonder +2 George Bluth +=========== ================ + +Datacompy will generate a unique temporary ID for joining: + +=========== ================ ======== +acct_id name temp_id +=========== ================ ======== +1 George Maharis 0 +1 Michael Bluth 1 +2 George Bluth 0 +=========== ================ ======== + +=========== ================ ======== +acct_id name temp_id +=========== ================ ======== +1 George Maharis 0 +1 Michael Bluth 1 +1 Tony Wonder 2 +2 George Bluth 0 +=========== ================ ======== + +And then merge the two dataframes on a combination of the join_columns you specified and the temporary +ID, before dropping the temp_id again. So the first two rows in the first dataframe will match the +first two rows in the second dataframe, and the third row in the second dataframe will be recognized +as uniquely in the second. + +Additional considerations +------------------------- +- It is strongly recommended against joining on float columns (or any column with floating point precision). +Columns joining tables are compared on the basis of an exact comparison, therefore if the values comparing +your float columns are not exact, you will likely get unexpected results. +- Case-sensitive columns are only partially supported. We essentially treat case-sensitive +columns as if they are case-insensitive. Therefore you may use case-sensitive columns as long as +you don't have several columns with the same name differentiated only be case sensitivity. diff --git a/pyproject.toml b/pyproject.toml index ed9c0f9b..9c86c82f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,16 +56,19 @@ python-tag = "py3" [project.optional-dependencies] duckdb = ["fugue[duckdb]"] spark = ["pyspark[connect]>=3.1.1; python_version < \"3.11\"", "pyspark[connect]>=3.4; python_version >= \"3.11\""] +snowflake = ["snowflake-connector-python", "snowflake-snowpark-python"] dask = ["fugue[dask]"] ray = ["fugue[ray]"] docs = ["sphinx", "furo", "myst-parser"] tests = ["pytest", "pytest-cov"] tests-spark = ["pytest", "pytest-cov", "pytest-spark"] +tests-snowflake = ["snowflake-snowpark-python[localtest]"] qa = ["pre-commit", "ruff==0.5.7", "mypy", "pandas-stubs"] build = ["build", "twine", "wheel"] edgetest = ["edgetest", "edgetest-conda"] -dev = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"] +dev_no_snowflake = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"] +dev = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[snowflake]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[tests-snowflake]", "datacompy[qa]", "datacompy[build]"] # Linters, formatters and type checkers [tool.ruff] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..54532b4b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +"""Testing configuration file, currently used for generating a Snowpark local session for testing.""" + +import os + +import pytest + +try: + from snowflake.snowpark.session import Session +except ModuleNotFoundError: + pass + +CONNECTION_PARAMETERS = { + "account": os.environ.get("SF_ACCOUNT"), + "user": os.environ.get("SF_UID"), + "password": os.environ.get("SF_PWD"), + "warehouse": os.environ.get("SF_WAREHOUSE"), + "database": os.environ.get("SF_DATABASE"), + "schema": os.environ.get("SF_SCHEMA"), +} + + +@pytest.fixture(scope="module") +def snowpark_session() -> "Session": + return Session.builder.configs(CONNECTION_PARAMETERS).create() diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py new file mode 100644 index 00000000..548f6043 --- /dev/null +++ b/tests/test_snowflake.py @@ -0,0 +1,1326 @@ +# +# Copyright 2024 Capital One Services, LLC +# +# Licensed under the Apache License, Version 2.0 (the "LICENSE"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Testing out the datacompy functionality +""" + +import io +import logging +import sys +from datetime import datetime +from decimal import Decimal +from io import StringIO +from unittest import mock + +import numpy as np +import pandas as pd +import pytest +from pytest import raises + +pytest.importorskip("pyspark") + + +from datacompy.snowflake import ( + SnowflakeCompare, + _generate_id_within_group, + calculate_max_diff, + columns_equal, + temp_column_name, +) +from pandas.testing import assert_series_equal +from snowflake.snowpark.exceptions import SnowparkSQLException + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + +pd.DataFrame.iteritems = pd.DataFrame.items # Pandas 2+ compatability +np.bool = np.bool_ # Numpy 1.24.3+ comptability + + +def test_numeric_columns_equal_abs(snowpark_session): + data = """A|B|EXPECTED +1|1|True +2|2.1|True +3|4|False +4|NULL|False +NULL|4|False +NULL|NULL|True""" + + df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.2).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_numeric_columns_equal_rel(snowpark_session): + data = """A|B|EXPECTED +1|1|True +2|2.1|True +3|4|False +4|NULL|False +NULL|4|False +NULL|NULL|True""" + df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_string_columns_equal(snowpark_session): + data = """A|B|EXPECTED +Hi|Hi|True +Yo|Yo|True +Hey|Hey |False +résumé|resume|False +résumé|résumé|True +💩|💩|True +💩|🤔|False + | |True + | |False +datacompy|DataComPy|False +something||False +|something|False +||True""" + df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_string_columns_equal_with_ignore_spaces(snowpark_session): + data = """A|B|EXPECTED +Hi|Hi|True +Yo|Yo|True +Hey|Hey |True +résumé|resume|False +résumé|résumé|True +💩|💩|True +💩|🤔|False + | |True + | |True +datacompy|DataComPy|False +something||False +|something|False +||True""" + df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal( + df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True + ).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_date_columns_equal(snowpark_session): + data = """A|B|EXPECTED +2017-01-01|2017-01-01|True +2017-01-02|2017-01-02|True +2017-10-01|2017-10-10|False +2017-01-01||False +|2017-01-01|False +||True""" + pdf = pd.read_csv(io.StringIO(data), sep="|") + df = snowpark_session.createDataFrame(pdf) + # First compare just the strings + actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + # Then compare converted to datetime objects + pdf["A"] = pd.to_datetime(pdf["A"]) + pdf["B"] = pd.to_datetime(pdf["B"]) + df = snowpark_session.createDataFrame(pdf) + actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + # and reverse + actual_out_rev = columns_equal(df, "B", "A", "ACTUAL", rel_tol=0.2).toPandas()[ + "ACTUAL" + ] + assert_series_equal(expect_out, actual_out_rev, check_names=False) + + +def test_date_columns_equal_with_ignore_spaces(snowpark_session): + data = """A|B|EXPECTED +2017-01-01|2017-01-01 |True +2017-01-02 |2017-01-02|True +2017-10-01 |2017-10-10 |False +2017-01-01||False +|2017-01-01|False +||True""" + pdf = pd.read_csv(io.StringIO(data), sep="|") + df = snowpark_session.createDataFrame(pdf) + # First compare just the strings + actual_out = columns_equal( + df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True + ).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + # Then compare converted to datetime objects + try: # pandas 2 + pdf["A"] = pd.to_datetime(pdf["A"], format="mixed") + pdf["B"] = pd.to_datetime(pdf["B"], format="mixed") + except ValueError: # pandas 1.5 + pdf["A"] = pd.to_datetime(pdf["A"]) + pdf["B"] = pd.to_datetime(pdf["B"]) + df = snowpark_session.createDataFrame(pdf) + actual_out = columns_equal( + df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True + ).toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + # and reverse + actual_out_rev = columns_equal( + df, "B", "A", "ACTUAL", rel_tol=0.2, ignore_spaces=True + ).toPandas()["ACTUAL"] + assert_series_equal(expect_out, actual_out_rev, check_names=False) + + +def test_date_columns_unequal(snowpark_session): + """I want datetime fields to match with dates stored as strings""" + data = [{"A": "2017-01-01", "B": "2017-01-02"}, {"A": "2017-01-01"}] + pdf = pd.DataFrame(data) + pdf["A_DT"] = pd.to_datetime(pdf["A"]) + pdf["B_DT"] = pd.to_datetime(pdf["B"]) + df = snowpark_session.createDataFrame(pdf) + assert columns_equal(df, "A", "A_DT", "ACTUAL").toPandas()["ACTUAL"].all() + assert columns_equal(df, "B", "B_DT", "ACTUAL").toPandas()["ACTUAL"].all() + assert columns_equal(df, "A_DT", "A", "ACTUAL").toPandas()["ACTUAL"].all() + assert columns_equal(df, "B_DT", "B", "ACTUAL").toPandas()["ACTUAL"].all() + assert not columns_equal(df, "B_DT", "A", "ACTUAL").toPandas()["ACTUAL"].any() + assert not columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].any() + assert not columns_equal(df, "A", "B_DT", "ACTUAL").toPandas()["ACTUAL"].any() + assert not columns_equal(df, "B", "A_DT", "ACTUAL").toPandas()["ACTUAL"].any() + + +def test_bad_date_columns(snowpark_session): + """If strings can't be coerced into dates then it should be false for the + whole column. + """ + data = [ + {"A": "2017-01-01", "B": "2017-01-01"}, + {"A": "2017-01-01", "B": "217-01-01"}, + ] + pdf = pd.DataFrame(data) + pdf["A_DT"] = pd.to_datetime(pdf["A"]) + df = snowpark_session.createDataFrame(pdf) + assert not columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].all() + assert columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].any() + + +def test_rounded_date_columns(snowpark_session): + """If strings can't be coerced into dates then it should be false for the + whole column. + """ + data = [ + {"A": "2017-01-01", "B": "2017-01-01 00:00:00.000000", "EXP": True}, + {"A": "2017-01-01", "B": "2017-01-01 00:00:00.123456", "EXP": False}, + {"A": "2017-01-01", "B": "2017-01-01 00:00:01.000000", "EXP": False}, + {"A": "2017-01-01", "B": "2017-01-01 00:00:00", "EXP": True}, + ] + pdf = pd.DataFrame(data) + pdf["A_DT"] = pd.to_datetime(pdf["A"]) + df = snowpark_session.createDataFrame(pdf) + actual = columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"] + expected = df.select("EXP").toPandas()["EXP"] + assert_series_equal(actual, expected, check_names=False) + + +def test_decimal_float_columns_equal(snowpark_session): + data = [ + {"A": Decimal("1"), "B": 1, "EXPECTED": True}, + {"A": Decimal("1.3"), "B": 1.3, "EXPECTED": True}, + {"A": Decimal("1.000003"), "B": 1.000003, "EXPECTED": True}, + {"A": Decimal("1.000000004"), "B": 1.000000003, "EXPECTED": False}, + {"A": Decimal("1.3"), "B": 1.2, "EXPECTED": False}, + {"A": np.nan, "B": np.nan, "EXPECTED": True}, + {"A": np.nan, "B": 1, "EXPECTED": False}, + {"A": Decimal("1"), "B": np.nan, "EXPECTED": False}, + ] + pdf = pd.DataFrame(data) + df = snowpark_session.createDataFrame(pdf) + actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_decimal_float_columns_equal_rel(snowpark_session): + data = [ + {"A": Decimal("1"), "B": 1, "EXPECTED": True}, + {"A": Decimal("1.3"), "B": 1.3, "EXPECTED": True}, + {"A": Decimal("1.000003"), "B": 1.000003, "EXPECTED": True}, + {"A": Decimal("1.000000004"), "B": 1.000000003, "EXPECTED": True}, + {"A": Decimal("1.3"), "B": 1.2, "EXPECTED": False}, + {"A": np.nan, "B": np.nan, "EXPECTED": True}, + {"A": np.nan, "B": 1, "EXPECTED": False}, + {"A": Decimal("1"), "B": np.nan, "EXPECTED": False}, + ] + pdf = pd.DataFrame(data) + df = snowpark_session.createDataFrame(pdf) + actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.001).toPandas()[ + "ACTUAL" + ] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_decimal_columns_equal(snowpark_session): + data = [ + {"A": Decimal("1"), "B": Decimal("1"), "EXPECTED": True}, + {"A": Decimal("1.3"), "B": Decimal("1.3"), "EXPECTED": True}, + {"A": Decimal("1.000003"), "B": Decimal("1.000003"), "EXPECTED": True}, + { + "A": Decimal("1.000000004"), + "B": Decimal("1.000000003"), + "EXPECTED": False, + }, + {"A": Decimal("1.3"), "B": Decimal("1.2"), "EXPECTED": False}, + {"A": np.nan, "B": np.nan, "EXPECTED": True}, + {"A": np.nan, "B": Decimal("1"), "EXPECTED": False}, + {"A": Decimal("1"), "B": np.nan, "EXPECTED": False}, + ] + pdf = pd.DataFrame(data) + df = snowpark_session.createDataFrame(pdf) + actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_decimal_columns_equal_rel(snowpark_session): + data = [ + {"A": Decimal("1"), "B": Decimal("1"), "EXPECTED": True}, + {"A": Decimal("1.3"), "B": Decimal("1.3"), "EXPECTED": True}, + {"A": Decimal("1.000003"), "B": Decimal("1.000003"), "EXPECTED": True}, + { + "A": Decimal("1.000000004"), + "B": Decimal("1.000000003"), + "EXPECTED": True, + }, + {"A": Decimal("1.3"), "B": Decimal("1.2"), "EXPECTED": False}, + {"A": np.nan, "B": np.nan, "EXPECTED": True}, + {"A": np.nan, "B": Decimal("1"), "EXPECTED": False}, + {"A": Decimal("1"), "B": np.nan, "EXPECTED": False}, + ] + pdf = pd.DataFrame(data) + df = snowpark_session.createDataFrame(pdf) + actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.001).toPandas()[ + "ACTUAL" + ] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_infinity_and_beyond(snowpark_session): + # https://spark.apache.org/docs/latest/sql-ref-datatypes.html#positivenegative-infinity-semantics + # Positive/negative infinity multiplied by 0 returns NaN. + # Positive infinity sorts lower than NaN and higher than any other values. + # Negative infinity sorts lower than any other values. + data = [ + {"A": np.inf, "B": np.inf, "EXPECTED": True}, + {"A": -np.inf, "B": -np.inf, "EXPECTED": True}, + {"A": -np.inf, "B": np.inf, "EXPECTED": True}, + {"A": np.inf, "B": -np.inf, "EXPECTED": True}, + {"A": 1, "B": 1, "EXPECTED": True}, + {"A": 1, "B": 0, "EXPECTED": False}, + ] + pdf = pd.DataFrame(data) + df = snowpark_session.createDataFrame(pdf) + actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"] + expect_out = df.select("EXPECTED").toPandas()["EXPECTED"] + assert_series_equal(expect_out, actual_out, check_names=False) + + +def test_compare_table_setter_bad(snowpark_session): + # Invalid table name + with raises(ValueError, match="invalid_table_name_1 is not a valid table name."): + SnowflakeCompare( + snowpark_session, "invalid_table_name_1", "invalid_table_name_2", ["A"] + ) + # Valid table name but table does not exist + with raises(SnowparkSQLException): + SnowflakeCompare( + snowpark_session, "non.existant.table_1", "non.existant.table_2", ["A"] + ) + + +def test_compare_table_setter_good(snowpark_session): + data = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.4,George Michael Bluth,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Robert Loblaw,345.12, + 10000001238,1.05,Loose Seal Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df = pd.read_csv(StringIO(data), sep=",") + database = snowpark_session.get_current_database().replace('"', "") + schema = snowpark_session.get_current_schema().replace('"', "") + full_table_name = f"{database}.{schema}" + toy_table_name_1 = "DC_TOY_TABLE_1" + toy_table_name_2 = "DC_TOY_TABLE_2" + full_toy_table_name_1 = f"{full_table_name}.{toy_table_name_1}" + full_toy_table_name_2 = f"{full_table_name}.{toy_table_name_2}" + + snowpark_session.write_pandas( + df, toy_table_name_1, table_type="temp", auto_create_table=True, overwrite=True + ) + snowpark_session.write_pandas( + df, toy_table_name_2, table_type="temp", auto_create_table=True, overwrite=True + ) + + compare = SnowflakeCompare( + snowpark_session, + full_toy_table_name_1, + full_toy_table_name_2, + join_columns=["ACCT_ID"], + ) + assert compare.df1.toPandas().equals(df) + assert compare.join_columns == ["ACCT_ID"] + + +def test_compare_df_setter_bad(snowpark_session): + pdf = pd.DataFrame([{"A": 1, "C": 2}, {"A": 2, "C": 2}]) + df = snowpark_session.createDataFrame(pdf) + with raises(TypeError, match="DF1 must be a valid sp.Dataframe"): + SnowflakeCompare(snowpark_session, 3, 2, ["A"]) + with raises(ValueError, match="DF1 must have all columns from join_columns"): + SnowflakeCompare(snowpark_session, df, df.select("*"), ["B"]) + pdf = pd.DataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 3}]) + df_dupe = snowpark_session.createDataFrame(pdf) + pd.testing.assert_frame_equal( + SnowflakeCompare( + snowpark_session, df_dupe, df_dupe.select("*"), ["A", "B"] + ).df1.toPandas(), + pdf, + check_dtype=False, + ) + + +def test_compare_df_setter_good(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert compare.df1.toPandas().equals(df1.toPandas()) + assert compare.join_columns == ["A"] + compare = SnowflakeCompare(snowpark_session, df1, df2, ["A", "B"]) + assert compare.df1.toPandas().equals(df1.toPandas()) + assert compare.join_columns == ["A", "B"] + + +def test_compare_df_setter_different_cases(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert compare.df1.toPandas().equals(df1.toPandas()) + + +def test_columns_overlap(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert compare.df1_unq_columns() == set() + assert compare.df2_unq_columns() == set() + assert compare.intersect_columns() == {"A", "B"} + + +def test_columns_no_overlap(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}] + ) + df2 = snowpark_session.createDataFrame( + [{"A": 1, "B": 2, "D": "OH"}, {"A": 2, "B": 3, "D": "YA"}] + ) + compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert compare.df1_unq_columns() == {"C"} + assert compare.df2_unq_columns() == {"D"} + assert compare.intersect_columns() == {"A", "B"} + + +def test_columns_maintain_order_through_set_operations(snowpark_session): + pdf1 = pd.DataFrame( + { + "JOIN": ["A", "B"], + "F": [0, 0], + "G": [1, 2], + "B": [2, 2], + "H": [3, 3], + "A": [4, 4], + "C": [-2, -3], + } + ) + pdf2 = pd.DataFrame( + { + "JOIN": ["A", "B"], + "E": [0, 1], + "H": [1, 2], + "B": [2, 3], + "A": [-1, -1], + "G": [4, 4], + "D": [-3, -2], + } + ) + df1 = snowpark_session.createDataFrame(pdf1) + df2 = snowpark_session.createDataFrame(pdf2) + compare = SnowflakeCompare(snowpark_session, df1, df2, ["JOIN"]) + assert list(compare.df1_unq_columns()) == ["F", "C"] + assert list(compare.df2_unq_columns()) == ["E", "D"] + assert list(compare.intersect_columns()) == ["JOIN", "G", "B", "H", "A"] + + +def test_10k_rows(snowpark_session): + rng = np.random.default_rng() + pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["B", "C"]) + pdf.reset_index(inplace=True) + pdf.columns = ["A", "B", "C"] + pdf2 = pdf.copy() + pdf2["B"] = pdf2["B"] + 0.1 + df1 = snowpark_session.createDataFrame(pdf) + df2 = snowpark_session.createDataFrame(pdf2) + compare_tol = SnowflakeCompare(snowpark_session, df1, df2, ["A"], abs_tol=0.2) + assert compare_tol.matches() + assert compare_tol.df1_unq_rows.count() == 0 + assert compare_tol.df2_unq_rows.count() == 0 + assert compare_tol.intersect_columns() == {"A", "B", "C"} + assert compare_tol.all_columns_match() + assert compare_tol.all_rows_overlap() + assert compare_tol.intersect_rows_match() + + compare_no_tol = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert not compare_no_tol.matches() + assert compare_no_tol.df1_unq_rows.count() == 0 + assert compare_no_tol.df2_unq_rows.count() == 0 + assert compare_no_tol.intersect_columns() == {"A", "B", "C"} + assert compare_no_tol.all_columns_match() + assert compare_no_tol.all_rows_overlap() + assert not compare_no_tol.intersect_rows_match() + + +def test_subset(snowpark_session, caplog): + caplog.set_level(logging.DEBUG) + df1 = snowpark_session.createDataFrame( + [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}] + ) + df2 = snowpark_session.createDataFrame([{"A": 1, "C": "HI"}]) + comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert comp.subset() + + +def test_not_subset(snowpark_session, caplog): + caplog.set_level(logging.INFO) + df1 = snowpark_session.createDataFrame( + [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}] + ) + df2 = snowpark_session.createDataFrame( + [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "GREAT"}] + ) + comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert not comp.subset() + assert "C: 1 / 2 (50.00%) match" in caplog.text + + +def test_large_subset(snowpark_session): + rng = np.random.default_rng() + pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["B", "C"]) + pdf.reset_index(inplace=True) + pdf.columns = ["A", "B", "C"] + pdf2 = pdf[["A", "B"]].head(50).copy() + df1 = snowpark_session.createDataFrame(pdf) + df2 = snowpark_session.createDataFrame(pdf2) + comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"]) + assert not comp.matches() + assert comp.subset() + + +def test_string_joiner(snowpark_session): + df1 = snowpark_session.createDataFrame([{"AB": 1, "BC": 2}, {"AB": 2, "BC": 2}]) + df2 = snowpark_session.createDataFrame([{"AB": 1, "BC": 2}, {"AB": 2, "BC": 2}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, "AB") + assert compare.matches() + + +def test_decimal_with_joins(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"A": Decimal("1"), "B": 2}, {"A": Decimal("2"), "B": 2}] + ) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, "A") + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_decimal_with_nulls(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"A": 1, "B": Decimal("2")}, {"A": 2, "B": Decimal("2")}] + ) + df2 = snowpark_session.createDataFrame( + [{"A": 1, "B": 2}, {"A": 2, "B": 2}, {"A": 3, "B": 2}] + ) + compare = SnowflakeCompare(snowpark_session, df1, df2, "A") + assert not compare.matches() + assert compare.all_columns_match() + assert not compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_strings_with_joins(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}]) + df2 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, "A") + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_temp_column_name(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}]) + df2 = snowpark_session.createDataFrame( + [{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}, {"A": "back fo mo", "B": 3}] + ) + actual = temp_column_name(df1, df2) + assert actual == "_TEMP_0" + + +def test_temp_column_name_one_has(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}] + ) + df2 = snowpark_session.createDataFrame( + [{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}, {"A": "back fo mo", "B": 3}] + ) + actual = temp_column_name(df1, df2) + assert actual == "_TEMP_1" + + +def test_temp_column_name_both_have_temp_1(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}] + ) + df2 = snowpark_session.createDataFrame( + [ + {"_TEMP_0": "HI", "B": 2}, + {"_TEMP_0": "BYE", "B": 2}, + {"A": "back fo mo", "B": 3}, + ] + ) + actual = temp_column_name(df1, df2) + assert actual == "_TEMP_1" + + +def test_temp_column_name_both_have_temp_2(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}] + ) + df2 = snowpark_session.createDataFrame( + [ + {"_TEMP_0": "HI", "B": 2}, + {"_TEMP_1": "BYE", "B": 2}, + {"A": "back fo mo", "B": 3}, + ] + ) + actual = temp_column_name(df1, df2) + assert actual == "_TEMP_2" + + +def test_temp_column_name_one_already(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"_TEMP_1": "HI", "B": 2}, {"_TEMP_1": "BYE", "B": 2}] + ) + df2 = snowpark_session.createDataFrame( + [ + {"_TEMP_1": "HI", "B": 2}, + {"_TEMP_1": "BYE", "B": 2}, + {"A": "back fo mo", "B": 3}, + ] + ) + actual = temp_column_name(df1, df2) + assert actual == "_TEMP_0" + + +# Duplicate testing! + + +def test_simple_dupes_one_field(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"]) + assert compare.matches() + # Just render the report to make sure it renders. + compare.report() + + +def test_simple_dupes_two_fields(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2, "C": 2}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2, "C": 2}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A", "B"]) + assert compare.matches() + # Just render the report to make sure it renders. + compare.report() + + +def test_simple_dupes_one_field_two_vals_1(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"]) + assert compare.matches() + # Just render the report to make sure it renders. + compare.report() + + +def test_simple_dupes_one_field_two_vals_2(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 0}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"]) + assert not compare.matches() + assert compare.df1_unq_rows.count() == 1 + assert compare.df2_unq_rows.count() == 1 + assert compare.intersect_rows.count() == 1 + # Just render the report to make sure it renders. + compare.report() + + +def test_simple_dupes_one_field_three_to_two_vals(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"A": 1, "B": 2}, {"A": 1, "B": 0}, {"A": 1, "B": 0}] + ) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"]) + assert not compare.matches() + assert compare.df1_unq_rows.count() == 1 + assert compare.df2_unq_rows.count() == 0 + assert compare.intersect_rows.count() == 2 + # Just render the report to make sure it renders. + compare.report() + assert "(First 1 Columns)" in compare.report(column_count=1) + assert "(First 2 Columns)" in compare.report(column_count=2) + + +def test_dupes_from_real_data(snowpark_session): + data = """ACCT_ID,ACCT_SFX_NUM,TRXN_POST_DT,TRXN_POST_SEQ_NUM,TRXN_AMT,TRXN_DT,DEBIT_CR_CD,CASH_ADV_TRXN_COMN_CNTRY_CD,MRCH_CATG_CD,MRCH_PSTL_CD,VISA_MAIL_PHN_CD,VISA_RQSTD_PMT_SVC_CD,MC_PMT_FACILITATOR_IDN_NUM +100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0 +200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1, +100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0 +200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A, +100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0 +200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,, +100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0 +200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G, +100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0 +200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G, +100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0 +200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,, +100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0 +200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F, +100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0 +200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F, +100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0 +200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A, +100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0 +200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,""" + df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep=",")) + df2 = df1.select("*") + compare_acct = SnowflakeCompare( + snowpark_session, df1, df2, join_columns=["ACCT_ID"] + ) + assert compare_acct.matches() + compare_acct.report() + + compare_unq = SnowflakeCompare( + snowpark_session, + df1, + df2, + join_columns=["ACCT_ID", "ACCT_SFX_NUM", "TRXN_POST_DT", "TRXN_POST_SEQ_NUM"], + ) + assert compare_unq.matches() + compare_unq.report() + + +def test_table_compare_from_real_data(snowpark_session): + data = """ACCT_ID,ACCT_SFX_NUM,TRXN_POST_DT,TRXN_POST_SEQ_NUM,TRXN_AMT,TRXN_DT,DEBIT_CR_CD,CASH_ADV_TRXN_COMN_CNTRY_CD,MRCH_CATG_CD,MRCH_PSTL_CD,VISA_MAIL_PHN_CD,VISA_RQSTD_PMT_SVC_CD,MC_PMT_FACILITATOR_IDN_NUM +100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0 +200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1, +100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0 +200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A, +100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0 +200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,, +100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0 +200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G, +100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0 +200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G, +100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0 +200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,, +100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0 +200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F, +100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0 +200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F, +100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0 +200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A, +100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0 +200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,""" + df = pd.read_csv(StringIO(data), sep=",") + database = snowpark_session.get_current_database().replace('"', "") + schema = snowpark_session.get_current_schema().replace('"', "") + full_table_name = f"{database}.{schema}" + toy_table_name_1 = "DC_TOY_TABLE_1" + toy_table_name_2 = "DC_TOY_TABLE_2" + full_toy_table_name_1 = f"{full_table_name}.{toy_table_name_1}" + full_toy_table_name_2 = f"{full_table_name}.{toy_table_name_2}" + + snowpark_session.write_pandas( + df, toy_table_name_1, table_type="temp", auto_create_table=True, overwrite=True + ) + snowpark_session.write_pandas( + df, toy_table_name_2, table_type="temp", auto_create_table=True, overwrite=True + ) + + compare_acct = SnowflakeCompare( + snowpark_session, + full_toy_table_name_1, + full_toy_table_name_2, + join_columns=["ACCT_ID"], + ) + assert compare_acct.matches() + compare_acct.report() + + compare_unq = SnowflakeCompare( + snowpark_session, + full_toy_table_name_1, + full_toy_table_name_2, + join_columns=["ACCT_ID", "ACCT_SFX_NUM", "TRXN_POST_DT", "TRXN_POST_SEQ_NUM"], + ) + assert compare_unq.matches() + compare_unq.report() + + +def test_strings_with_joins_with_ignore_spaces(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"A": "HI", "B": " A"}, {"A": "BYE", "B": "A"}] + ) + df2 = snowpark_session.createDataFrame( + [{"A": "HI", "B": "A"}, {"A": "BYE", "B": "A "}] + ) + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False) + assert not compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert not compare.intersect_rows_match() + + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_decimal_with_joins_with_ignore_spaces(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": " A"}, {"A": 2, "B": "A"}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A "}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False) + assert not compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert not compare.intersect_rows_match() + + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_joins_with_ignore_spaces(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": " A"}, {"A": 2, "B": "A"}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A "}]) + + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_joins_with_insensitive_lowercase_cols(snowpark_session): + df1 = snowpark_session.createDataFrame([{"a": 1, "B": "A"}, {"a": 2, "B": "A"}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}]) + + compare = SnowflakeCompare(snowpark_session, df1, df2, "A") + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + df1 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}]) + + compare = SnowflakeCompare(snowpark_session, df1, df2, "a") + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_joins_with_sensitive_lowercase_cols(snowpark_session): + df1 = snowpark_session.createDataFrame([{'"a"': 1, "B": "A"}, {'"a"': 2, "B": "A"}]) + df2 = snowpark_session.createDataFrame([{'"a"': 1, "B": "A"}, {'"a"': 2, "B": "A"}]) + + compare = SnowflakeCompare(snowpark_session, df1, df2, '"a"') + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +def test_strings_with_ignore_spaces_and_join_columns(snowpark_session): + df1 = snowpark_session.createDataFrame( + [{"A": "HI", "B": "A"}, {"A": "BYE", "B": "A"}] + ) + df2 = snowpark_session.createDataFrame( + [{"A": " HI ", "B": "A"}, {"A": " BYE ", "B": "A"}] + ) + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False) + assert not compare.matches() + assert compare.all_columns_match() + assert not compare.all_rows_overlap() + assert compare.count_matching_rows() == 0 + + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + assert compare.count_matching_rows() == 2 + + +def test_integers_with_ignore_spaces_and_join_columns(snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + assert compare.count_matching_rows() == 2 + + compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + assert compare.count_matching_rows() == 2 + + +def test_sample_mismatch(snowpark_session): + data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + + data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.4,George Michael Bluth,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Robert Loblaw,345.12, + 10000001238,1.05,Loose Seal Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + + df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=",")) + df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=",")) + + compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID") + + output = compare.sample_mismatch(column="NAME", sample_count=1).toPandas() + assert output.shape[0] == 1 + assert (output.NAME_DF1 != output.NAME_DF2).all() + + output = compare.sample_mismatch(column="NAME", sample_count=2).toPandas() + assert output.shape[0] == 2 + assert (output.NAME_DF1 != output.NAME_DF2).all() + + output = compare.sample_mismatch(column="NAME", sample_count=3).toPandas() + assert output.shape[0] == 2 + assert (output.NAME_DF1 != output.NAME_DF2).all() + + +def test_all_mismatch_not_ignore_matching_cols_no_cols_matching(snowpark_session): + data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + + data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.4,George Michael Bluth,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Robert Loblaw,345.12, + 10000001238,1.05,Loose Seal Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=",")) + df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=",")) + compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID") + + output = compare.all_mismatch().toPandas() + assert output.shape[0] == 4 + assert output.shape[1] == 9 + + assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 2 + assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 2 + + assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 1 + assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 3 + + assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3 + assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1 + + assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4 + assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0 + + +def test_all_mismatch_not_ignore_matching_cols_some_cols_matching(snowpark_session): + # Columns dollar_amt and name are matching + data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + + data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Bob Loblaw,345.12, + 10000001238,1.05,Lucille Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=",")) + df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=",")) + compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID") + + output = compare.all_mismatch().toPandas() + assert output.shape[0] == 4 + assert output.shape[1] == 9 + + assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 0 + assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 4 + + assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 0 + assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 4 + + assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3 + assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1 + + assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4 + assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0 + + +def test_all_mismatch_ignore_matching_cols_some_cols_matching_diff_rows( + snowpark_session, +): + # Case where there are rows on either dataset which don't match up. + # Columns dollar_amt and name are matching + data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + 10000001241,1111.05,Lucille Bluth, + """ + + data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Bob Loblaw,345.12, + 10000001238,1.05,Lucille Bluth,111, + """ + df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=",")) + df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=",")) + compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID") + + output = compare.all_mismatch(ignore_matching_cols=True).toPandas() + + assert output.shape[0] == 4 + assert output.shape[1] == 5 + + assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3 + assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1 + + assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4 + assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0 + + assert not ("NAME_DF1" in output and "NAME_DF2" in output) + assert not ("DOLLAR_AMT_DF1" in output and "DOLLAR_AMT_DF1" in output) + + +def test_all_mismatch_ignore_matching_cols_some_cols_matching(snowpark_session): + # Columns dollar_amt and name are matching + data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + + data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Bob Loblaw,345.12, + 10000001238,1.05,Lucille Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=",")) + df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=",")) + compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID") + + output = compare.all_mismatch(ignore_matching_cols=True).toPandas() + + assert output.shape[0] == 4 + assert output.shape[1] == 5 + + assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3 + assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1 + + assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4 + assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0 + + assert not ("NAME_DF1" in output and "NAME_DF2" in output) + assert not ("DOLLAR_AMT_DF1" in output and "DOLLAR_AMT_DF1" in output) + + +def test_all_mismatch_ignore_matching_cols_no_cols_matching(snowpark_session): + data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + + data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD + 10000001234,123.4,George Michael Bluth,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Robert Loblaw,345.12, + 10000001238,1.05,Loose Seal Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=",")) + df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=",")) + compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID") + + output = compare.all_mismatch().toPandas() + assert output.shape[0] == 4 + assert output.shape[1] == 9 + + assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 2 + assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 2 + + assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 1 + assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 3 + + assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3 + assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1 + + assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4 + assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0 + + +@pytest.mark.parametrize( + "column, expected", + [ + ("BASE", 0), + ("FLOATS", 0.2), + ("DECIMALS", 0.1), + ("NULL_FLOATS", 0.1), + ("STRINGS", 0.1), + ("INFINITY", np.inf), + ], +) +def test_calculate_max_diff(snowpark_session, column, expected): + pdf = pd.DataFrame( + { + "BASE": [1, 1, 1, 1, 1], + "FLOATS": [1.1, 1.1, 1.1, 1.2, 0.9], + "DECIMALS": [ + Decimal("1.1"), + Decimal("1.1"), + Decimal("1.1"), + Decimal("1.1"), + Decimal("1.1"), + ], + "NULL_FLOATS": [np.nan, 1.1, 1, 1, 1], + "STRINGS": ["1", "1", "1", "1.1", "1"], + "INFINITY": [1, 1, 1, 1, np.inf], + } + ) + MAX_DIFF_DF = snowpark_session.createDataFrame(pdf) + assert np.isclose( + calculate_max_diff(MAX_DIFF_DF, "BASE", column), + expected, + ) + + +def test_dupes_with_nulls_strings(snowpark_session): + pdf1 = pd.DataFrame( + { + "FLD_1": [1, 2, 2, 3, 3, 4, 5, 5], + "FLD_2": ["A", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + "FLD_3": [1, 2, 2, 3, 3, 4, 5, 5], + } + ) + pdf2 = pd.DataFrame( + { + "FLD_1": [1, 2, 3, 4, 5], + "FLD_2": ["A", np.nan, np.nan, np.nan, np.nan], + "FLD_3": [1, 2, 3, 4, 5], + } + ) + df1 = snowpark_session.createDataFrame(pdf1) + df2 = snowpark_session.createDataFrame(pdf2) + comp = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["FLD_1", "FLD_2"]) + assert comp.subset() + + +def test_dupes_with_nulls_ints(snowpark_session): + pdf1 = pd.DataFrame( + { + "FLD_1": [1, 2, 2, 3, 3, 4, 5, 5], + "FLD_2": [1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + "FLD_3": [1, 2, 2, 3, 3, 4, 5, 5], + } + ) + pdf2 = pd.DataFrame( + { + "FLD_1": [1, 2, 3, 4, 5], + "FLD_2": [1, np.nan, np.nan, np.nan, np.nan], + "FLD_3": [1, 2, 3, 4, 5], + } + ) + df1 = snowpark_session.createDataFrame(pdf1) + df2 = snowpark_session.createDataFrame(pdf2) + comp = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["FLD_1", "FLD_2"]) + assert comp.subset() + + +def test_generate_id_within_group(snowpark_session): + matrix = [ + ( + pd.DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "__INDEX": [1, 2, 3]}), + pd.Series([0, 0, 0]), + ), + ( + pd.DataFrame( + { + "A": ["A", "A", "DATACOMPY_NULL"], + "B": [1, 1, 2], + "__INDEX": [1, 2, 3], + } + ), + pd.Series([0, 1, 0]), + ), + ( + pd.DataFrame({"A": [-999, 2, 3], "B": [1, 2, 3], "__INDEX": [1, 2, 3]}), + pd.Series([0, 0, 0]), + ), + ( + pd.DataFrame( + {"A": [1, np.nan, np.nan], "B": [1, 2, 2], "__INDEX": [1, 2, 3]} + ), + pd.Series([0, 0, 1]), + ), + ( + pd.DataFrame( + {"A": ["1", np.nan, np.nan], "B": ["1", "2", "2"], "__INDEX": [1, 2, 3]} + ), + pd.Series([0, 0, 1]), + ), + ( + pd.DataFrame( + { + "A": [datetime(2018, 1, 1), np.nan, np.nan], + "B": ["1", "2", "2"], + "__INDEX": [1, 2, 3], + } + ), + pd.Series([0, 0, 1]), + ), + ] + for i in matrix: + dataframe = i[0] + expected = i[1] + actual = ( + _generate_id_within_group( + snowpark_session.createDataFrame(dataframe), ["A", "B"], "_TEMP_0" + ) + .orderBy("__INDEX") + .select("_TEMP_0") + .toPandas() + ) + assert (actual["_TEMP_0"] == expected).all() + + +def test_generate_id_within_group_single_join(snowpark_session): + dataframe = snowpark_session.createDataFrame( + [{"A": 1, "B": 2, "__INDEX": 1}, {"A": 1, "B": 2, "__INDEX": 2}] + ) + expected = pd.Series([0, 1]) + actual = ( + _generate_id_within_group(dataframe, ["A"], "_TEMP_0") + .orderBy("__INDEX") + .select("_TEMP_0") + ).toPandas() + assert (actual["_TEMP_0"] == expected).all() + + +@mock.patch("datacompy.snowflake.render") +def test_save_html(mock_render, snowpark_session): + df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}]) + df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}]) + compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"]) + + m = mock.mock_open() + with mock.patch("datacompy.snowflake.open", m, create=True): + # assert without HTML call + compare.report() + assert mock_render.call_count == 4 + m.assert_not_called() + + mock_render.reset_mock() + m = mock.mock_open() + with mock.patch("datacompy.snowflake.open", m, create=True): + # assert with HTML call + compare.report(html_file="test.html") + assert mock_render.call_count == 4 + m.assert_called_with("test.html", "w")