diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml
index 26b18def..dc50777d 100644
--- a/.github/workflows/test-package.yml
+++ b/.github/workflows/test-package.yml
@@ -64,17 +64,27 @@ jobs:
java-version: '8'
distribution: 'adopt'
- - name: Install Spark and datacompy
+ - name: Install Spark, Pandas, and Numpy
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-spark pypandoc
python -m pip install pyspark[connect]==${{ matrix.spark-version }}
python -m pip install pandas==${{ matrix.pandas-version }}
python -m pip install numpy==${{ matrix.numpy-version }}
+
+ - name: Install Datacompy without Snowflake/Snowpark if Python 3.12
+ if: ${{ matrix.python-version == '3.12' }}
+ run: |
+ python -m pip install .[dev_no_snowflake]
+
+ - name: Install Datacompy with all dev dependencies if Python 3.9, 3.10, or 3.11
+ if: ${{ matrix.python-version != '3.12' }}
+ run: |
python -m pip install .[dev]
+
- name: Test with pytest
run: |
- python -m pytest tests/
+ python -m pytest tests/ --ignore=tests/test_snowflake.py
test-bare-install:
@@ -101,7 +111,7 @@ jobs:
python -m pip install .[tests]
- name: Test with pytest
run: |
- python -m pytest tests/
+ python -m pytest tests/ --ignore=tests/test_snowflake.py
test-fugue-install-no-spark:
@@ -127,4 +137,4 @@ jobs:
python -m pip install .[tests,duckdb,polars,dask,ray]
- name: Test with pytest
run: |
- python -m pytest tests/
+ python -m pytest tests/ --ignore=tests/test_snowflake.py
diff --git a/README.md b/README.md
index b60457ef..cf6070df 100644
--- a/README.md
+++ b/README.md
@@ -34,6 +34,7 @@ pip install datacompy[spark]
pip install datacompy[dask]
pip install datacompy[duckdb]
pip install datacompy[ray]
+pip install datacompy[snowflake]
```
@@ -95,6 +96,7 @@ with the Pandas on Spark implementation. Spark plans to support Pandas 2 in [Spa
- Pandas: ([See documentation](https://capitalone.github.io/datacompy/pandas_usage.html))
- Spark: ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html))
- Polars: ([See documentation](https://capitalone.github.io/datacompy/polars_usage.html))
+- Snowflake/Snowpark: ([See documentation](https://capitalone.github.io/datacompy/snowflake_usage.html))
- Fugue is a Python library that provides a unified interface for data processing on Pandas, DuckDB, Polars, Arrow,
Spark, Dask, Ray, and many other backends. DataComPy integrates with Fugue to provide a simple way to compare data
across these backends. Please note that Fugue will use the Pandas (Native) logic at its lowest level
diff --git a/datacompy/__init__.py b/datacompy/__init__.py
index a6d331e8..74154839 100644
--- a/datacompy/__init__.py
+++ b/datacompy/__init__.py
@@ -43,12 +43,14 @@
unq_columns,
)
from datacompy.polars import PolarsCompare
+from datacompy.snowflake import SnowflakeCompare
from datacompy.spark.sql import SparkSQLCompare
__all__ = [
"BaseCompare",
"Compare",
"PolarsCompare",
+ "SnowflakeCompare",
"SparkSQLCompare",
"all_columns_match",
"all_rows_overlap",
diff --git a/datacompy/snowflake.py b/datacompy/snowflake.py
new file mode 100644
index 00000000..19f63978
--- /dev/null
+++ b/datacompy/snowflake.py
@@ -0,0 +1,1202 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Compare two Snowpark SQL DataFrames and Snowflake tables.
+
+Originally this package was meant to provide similar functionality to
+PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
+two dataframes.
+"""
+
+import logging
+import os
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Union, cast
+
+import pandas as pd
+from ordered_set import OrderedSet
+
+try:
+ import snowflake.snowpark as sp
+ from snowflake.snowpark import Window
+ from snowflake.snowpark.exceptions import SnowparkSQLException
+ from snowflake.snowpark.functions import (
+ abs,
+ col,
+ concat,
+ contains,
+ is_null,
+ lit,
+ monotonically_increasing_id,
+ row_number,
+ trim,
+ when,
+ )
+except ImportError:
+ pass # for non-snowflake users
+from datacompy.base import BaseCompare
+from datacompy.spark.sql import decimal_comparator
+
+LOG = logging.getLogger(__name__)
+
+
+NUMERIC_SNOWPARK_TYPES = [
+ "tinyint",
+ "smallint",
+ "int",
+ "bigint",
+ "float",
+ "double",
+ decimal_comparator(),
+]
+
+
+class SnowflakeCompare(BaseCompare):
+ """Comparison class to be used to compare whether two Snowpark dataframes are equal.
+
+ df1 and df2 can refer to either a Snowpark dataframe or the name of a valid Snowflake table.
+ The data structures which df1 and df2 represent must contain all of the join_columns,
+ with unique column names. Differences between values are compared to
+ abs_tol + rel_tol * abs(df2['value']).
+
+ Parameters
+ ----------
+ session: snowflake.snowpark.session
+ Session with the required connection session info for user and targeted tables
+ df1 : Union[str, sp.Dataframe]
+ First table to check, provided either as the table's name or as a Snowpark DF.
+ df2 : Union[str, sp.Dataframe]
+ Second table to check, provided either as the table's name or as a Snowpark DF.
+ join_columns : list or str, optional
+ Column(s) to join dataframes on. If a string is passed in, that one
+ column will be used.
+ abs_tol : float, optional
+ Absolute tolerance between two values.
+ rel_tol : float, optional
+ Relative tolerance between two values.
+ df1_name : str, optional
+ A string name for the first dataframe. If used alongside a snowflake table,
+ overrides the default convention of naming the dataframe after the table.
+ df2_name : str, optional
+ A string name for the second dataframe.
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns (including any join
+ columns).
+
+ Attributes
+ ----------
+ df1_unq_rows : sp.DataFrame
+ All records that are only in df1 (based on a join on join_columns)
+ df2_unq_rows : sp.DataFrame
+ All records that are only in df2 (based on a join on join_columns)
+ """
+
+ def __init__(
+ self,
+ session: "sp.Session",
+ df1: Union[str, "sp.DataFrame"],
+ df2: Union[str, "sp.DataFrame"],
+ join_columns: Optional[Union[List[str], str]],
+ abs_tol: float = 0,
+ rel_tol: float = 0,
+ df1_name: Optional[str] = None,
+ df2_name: Optional[str] = None,
+ ignore_spaces: bool = False,
+ ) -> None:
+ if join_columns is None:
+ errmsg = "join_columns cannot be None"
+ raise ValueError(errmsg)
+ elif not join_columns:
+ errmsg = "join_columns is empty"
+ raise ValueError(errmsg)
+ elif isinstance(join_columns, (str, int, float)):
+ self.join_columns = [str(join_columns).replace('"', "").upper()]
+ else:
+ self.join_columns = [
+ str(col).replace('"', "").upper()
+ for col in cast(List[str], join_columns)
+ ]
+
+ self._any_dupes: bool = False
+ self.session = session
+ self.df1 = (df1, df1_name)
+ self.df2 = (df2, df2_name)
+ self.abs_tol = abs_tol
+ self.rel_tol = rel_tol
+ self.ignore_spaces = ignore_spaces
+ self.df1_unq_rows: sp.DataFrame
+ self.df2_unq_rows: sp.DataFrame
+ self.intersect_rows: sp.DataFrame
+ self.column_stats: List[Dict[str, Any]] = []
+ self._compare(ignore_spaces=ignore_spaces)
+
+ @property
+ def df1(self) -> "sp.DataFrame":
+ """Get the first dataframe."""
+ return self._df1
+
+ @df1.setter
+ def df1(self, df1: tuple[Union[str, "sp.DataFrame"], Optional[str]]) -> None:
+ """Check that df1 is either a Snowpark DF or the name of a valid Snowflake table."""
+ (df, df_name) = df1
+ if isinstance(df, str):
+ table_name = [table_comp.upper() for table_comp in df.split(".")]
+ if len(table_name) != 3:
+ errmsg = f"{df} is not a valid table name. Be sure to include the target db and schema."
+ raise ValueError(errmsg)
+ self.df1_name = df_name.upper() if df_name else table_name[2]
+ self._df1 = self.session.table(df)
+ else:
+ self._df1 = df
+ self.df1_name = df_name.upper() if df_name else "DF1"
+ self._validate_dataframe(self.df1_name, "df1")
+
+ @property
+ def df2(self) -> "sp.DataFrame":
+ """Get the second dataframe."""
+ return self._df2
+
+ @df2.setter
+ def df2(self, df2: tuple[Union[str, "sp.DataFrame"], Optional[str]]) -> None:
+ """Check that df2 is either a Snowpark DF or the name of a valid Snowflake table."""
+ (df, df_name) = df2
+ if isinstance(df, str):
+ table_name = [table_comp.upper() for table_comp in df.split(".")]
+ if len(table_name) != 3:
+ errmsg = f"{df} is not a valid table name. Be sure to include the target db and schema."
+ raise ValueError(errmsg)
+ self.df2_name = df_name.upper() if df_name else table_name[2]
+ self._df2 = self.session.table(df)
+ else:
+ self._df2 = df
+ self.df2_name = df_name.upper() if df_name else "DF2"
+ self._validate_dataframe(self.df2_name, "df2")
+
+ def _validate_dataframe(self, df_name: str, index: str) -> None:
+ """Validate the provided Snowpark dataframe.
+
+ The dataframe can either be a standalone Snowpark dataframe or a representative
+ of a Snowflake table - in the latter case we check that the table it represents
+ is a valid table by forcing a collection.
+
+ Parameters
+ ----------
+ df_name : str
+ Name of the Snowflake table / Snowpark dataframe
+ index : str
+ The "index" of the dataframe - df1 or df2.
+ """
+ df = getattr(self, index)
+ if not isinstance(df, "sp.DataFrame"):
+ raise TypeError(f"{df_name} must be a valid sp.Dataframe")
+
+ # force all columns to be non-case-sensitive
+ if index == "df1":
+ col_map = dict(
+ zip(
+ self._df1.columns,
+ [str(c).replace('"', "").upper() for c in self._df1.columns],
+ )
+ )
+ self._df1 = self._df1.rename(col_map)
+ if index == "df2":
+ col_map = dict(
+ zip(
+ self._df2.columns,
+ [str(c).replace('"', "").upper() for c in self._df2.columns],
+ )
+ )
+ self._df2 = self._df2.rename(dict(col_map))
+
+ df = getattr(self, index) # refresh
+ if not set(self.join_columns).issubset(set(df.columns)):
+ raise ValueError(f"{df_name} must have all columns from join_columns")
+ if len(set(df.columns)) < len(df.columns):
+ raise ValueError(f"{df_name} must have unique column names")
+
+ if df.drop_duplicates(self.join_columns).count() < df.count():
+ self._any_dupes = True
+
+ def _compare(self, ignore_spaces: bool) -> None:
+ """Actually run the comparison.
+
+ This method will log out information about what is different between
+ the two dataframes.
+ """
+ LOG.info(f"Number of columns in common: {len(self.intersect_columns())}")
+ LOG.debug("Checking column overlap")
+ for column in self.df1_unq_columns():
+ LOG.info(f"Column in df1 and not in df2: {column}")
+ LOG.info(
+ f"Number of columns in df1 and not in df2: {len(self.df1_unq_columns())}"
+ )
+ for column in self.df2_unq_columns():
+ LOG.info(f"Column in df2 and not in df1: {column}")
+ LOG.info(
+ f"Number of columns in df2 and not in df1: {len(self.df2_unq_columns())}"
+ )
+ LOG.debug("Merging dataframes")
+ self._dataframe_merge(ignore_spaces)
+ self._intersect_compare(ignore_spaces)
+ if self.matches():
+ LOG.info("df1 matches df2")
+ else:
+ LOG.info("df1 does not match df2")
+
+ def df1_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df1."""
+ return cast(
+ OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
+ )
+
+ def df2_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df2."""
+ return cast(
+ OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
+ )
+
+ def intersect_columns(self) -> OrderedSet[str]:
+ """Get columns that are shared between the two dataframes."""
+ return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns)
+
+ def _dataframe_merge(self, ignore_spaces: bool) -> None:
+ """Merge df1 to df2 on the join columns.
+
+ Gets df1 - df2, df2 - df1, and df1 & df2
+ joining on the ``join_columns``.
+ """
+ LOG.debug("Outer joining")
+
+ df1 = self.df1
+ df2 = self.df2
+ temp_join_columns = deepcopy(self.join_columns)
+
+ if self._any_dupes:
+ LOG.debug("Duplicate rows found, deduping by order of remaining fields")
+ # setting internal index
+ LOG.info("Adding internal index to dataframes")
+ df1 = df1.withColumn("__index", monotonically_increasing_id())
+ df2 = df2.withColumn("__index", monotonically_increasing_id())
+
+ # Create order column for uniqueness of match
+ order_column = temp_column_name(df1, df2)
+ df1 = df1.join(
+ _generate_id_within_group(df1, temp_join_columns, order_column),
+ on="__index",
+ how="inner",
+ ).drop("__index")
+ df2 = df2.join(
+ _generate_id_within_group(df2, temp_join_columns, order_column),
+ on="__index",
+ how="inner",
+ ).drop("__index")
+ temp_join_columns.append(order_column)
+
+ # drop index
+ LOG.info("Dropping internal index")
+ df1 = df1.drop("__index")
+ df2 = df2.drop("__index")
+
+ if ignore_spaces:
+ for column in self.join_columns:
+ if "string" in next(
+ dtype for name, dtype in df1.dtypes if name == column
+ ):
+ df1 = df1.withColumn(column, trim(col(column)))
+ if "string" in next(
+ dtype for name, dtype in df2.dtypes if name == column
+ ):
+ df2 = df2.withColumn(column, trim(col(column)))
+
+ df1 = df1.withColumn("merge", lit(True))
+ df2 = df2.withColumn("merge", lit(True))
+
+ for c in df1.columns:
+ df1 = df1.withColumnRenamed(c, c + "_" + self.df1_name)
+ for c in df2.columns:
+ df2 = df2.withColumnRenamed(c, c + "_" + self.df2_name)
+
+ # NULL SAFE Outer join, not possible with Snowpark Dataframe join
+ df1.createOrReplaceTempView("df1")
+ df2.createOrReplaceTempView("df2")
+ on = " and ".join(
+ [
+ f"EQUAL_NULL(df1.{c}_{self.df1_name}, df2.{c}_{self.df2_name})"
+ for c in temp_join_columns
+ ]
+ )
+ outer_join = self.session.sql(
+ """
+ SELECT * FROM
+ df1 FULL OUTER JOIN df2
+ ON
+ """
+ + on
+ )
+ # Create join indicator
+ outer_join = outer_join.withColumn(
+ "_merge",
+ when(
+ outer_join[f"MERGE_{self.df1_name}"]
+ & outer_join[f"MERGE_{self.df2_name}"],
+ lit("BOTH"),
+ )
+ .when(
+ outer_join[f"MERGE_{self.df1_name}"]
+ & outer_join[f"MERGE_{self.df2_name}"].is_null(),
+ lit("LEFT_ONLY"),
+ )
+ .when(
+ outer_join[f"MERGE_{self.df1_name}"].is_null()
+ & outer_join[f"MERGE_{self.df2_name}"],
+ lit("RIGHT_ONLY"),
+ ),
+ )
+
+ df1 = df1.drop(f"MERGE_{self.df1_name}")
+ df2 = df2.drop(f"MERGE_{self.df2_name}")
+
+ # Clean up temp columns for duplicate row matching
+ if self._any_dupes:
+ outer_join = outer_join.select_expr(
+ f"* EXCLUDE ({order_column}_{self.df1_name}, {order_column}_{self.df2_name})"
+ )
+ df1 = df1.drop(f"{order_column}_{self.df1_name}")
+ df2 = df2.drop(f"{order_column}_{self.df2_name}")
+
+ # Capitalization required - clean up
+ df1_cols = get_merged_columns(df1, outer_join, self.df1_name)
+ df2_cols = get_merged_columns(df2, outer_join, self.df2_name)
+
+ LOG.debug("Selecting df1 unique rows")
+ self.df1_unq_rows = outer_join[outer_join["_merge"] == "LEFT_ONLY"][df1_cols]
+
+ LOG.debug("Selecting df2 unique rows")
+ self.df2_unq_rows = outer_join[outer_join["_merge"] == "RIGHT_ONLY"][df2_cols]
+ LOG.info(f"Number of rows in df1 and not in df2: {self.df1_unq_rows.count()}")
+ LOG.info(f"Number of rows in df2 and not in df1: {self.df2_unq_rows.count()}")
+
+ LOG.debug("Selecting intersecting rows")
+ self.intersect_rows = outer_join[outer_join["_merge"] == "BOTH"]
+ LOG.info(
+ f"Number of rows in df1 and df2 (not necessarily equal): {self.intersect_rows.count()}"
+ )
+ self.intersect_rows = self.intersect_rows.cache_result()
+
+ def _intersect_compare(self, ignore_spaces: bool) -> None:
+ """Run the comparison on the intersect dataframe.
+
+ This loops through all columns that are shared between df1 and df2, and
+ creates a column column_match which is True for matches, False
+ otherwise.
+ """
+ LOG.debug("Comparing intersection")
+ max_diff: float
+ null_diff: int
+ row_cnt = self.intersect_rows.count()
+ for column in self.intersect_columns():
+ if column in self.join_columns:
+ match_cnt = row_cnt
+ col_match = ""
+ max_diff = 0
+ null_diff = 0
+ else:
+ col_1 = column + "_" + self.df1_name
+ col_2 = column + "_" + self.df2_name
+ col_match = column + "_MATCH"
+ self.intersect_rows = columns_equal(
+ self.intersect_rows,
+ col_1,
+ col_2,
+ col_match,
+ self.rel_tol,
+ self.abs_tol,
+ ignore_spaces,
+ )
+ match_cnt = (
+ self.intersect_rows.select(col_match)
+ .where(col(col_match) == True) # noqa: E712
+ .count()
+ )
+ max_diff = calculate_max_diff(
+ self.intersect_rows,
+ col_1,
+ col_2,
+ )
+ null_diff = calculate_null_diff(self.intersect_rows, col_1, col_2)
+
+ if row_cnt > 0:
+ match_rate = float(match_cnt) / row_cnt
+ else:
+ match_rate = 0
+ LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match")
+
+ col1_dtype, _ = _get_column_dtypes(self.df1, column, column)
+ col2_dtype, _ = _get_column_dtypes(self.df2, column, column)
+
+ self.column_stats.append(
+ {
+ "column": column,
+ "match_column": col_match,
+ "match_cnt": match_cnt,
+ "unequal_cnt": row_cnt - match_cnt,
+ "dtype1": str(col1_dtype),
+ "dtype2": str(col2_dtype),
+ "all_match": all(
+ (
+ col1_dtype == col2_dtype,
+ row_cnt == match_cnt,
+ )
+ ),
+ "max_diff": max_diff,
+ "null_diff": null_diff,
+ }
+ )
+
+ def all_columns_match(self) -> bool:
+ """Whether the columns all match in the dataframes.
+
+ Returns
+ -------
+ bool
+ True if all columns in df1 are in df2 and vice versa
+ """
+ return self.df1_unq_columns() == self.df2_unq_columns() == set()
+
+ def all_rows_overlap(self) -> bool:
+ """Whether the rows are all present in both dataframes.
+
+ Returns
+ -------
+ bool
+ True if all rows in df1 are in df2 and vice versa (based on
+ existence for join option)
+ """
+ return self.df1_unq_rows.count() == self.df2_unq_rows.count() == 0
+
+ def count_matching_rows(self) -> int:
+ """Count the number of rows match (on overlapping fields).
+
+ Returns
+ -------
+ int
+ Number of matching rows
+ """
+ conditions = []
+ match_columns = []
+ for column in self.intersect_columns():
+ if column not in self.join_columns:
+ match_columns.append(column + "_MATCH")
+ conditions.append(f"{column}_MATCH = True")
+ if len(conditions) > 0:
+ match_columns_count = self.intersect_rows.filter(
+ " and ".join(conditions)
+ ).count()
+ else:
+ match_columns_count = 0
+ return match_columns_count
+
+ def intersect_rows_match(self) -> bool:
+ """Check whether the intersect rows all match."""
+ actual_length = self.intersect_rows.count()
+ return self.count_matching_rows() == actual_length
+
+ def matches(self, ignore_extra_columns: bool = False) -> bool:
+ """Return True or False if the dataframes match.
+
+ Parameters
+ ----------
+ ignore_extra_columns : bool
+ Ignores any columns in one dataframe and not in the other.
+
+ Returns
+ -------
+ bool
+ True or False if the dataframes match.
+ """
+ return not (
+ (not ignore_extra_columns and not self.all_columns_match())
+ or not self.all_rows_overlap()
+ or not self.intersect_rows_match()
+ )
+
+ def subset(self) -> bool:
+ """Return True if dataframe 2 is a subset of dataframe 1.
+
+ Dataframe 2 is considered a subset if all of its columns are in
+ dataframe 1, and all of its rows match rows in dataframe 1 for the
+ shared columns.
+
+ Returns
+ -------
+ bool
+ True if dataframe 2 is a subset of dataframe 1.
+ """
+ return not (
+ self.df2_unq_columns() != set()
+ or self.df2_unq_rows.count() != 0
+ or not self.intersect_rows_match()
+ )
+
+ def sample_mismatch(
+ self, column: str, sample_count: int = 10, for_display: bool = False
+ ) -> "sp.DataFrame":
+ """Return sample mismatches.
+
+ Gets a sub-dataframe which contains the identifying
+ columns, and df1 and df2 versions of the column.
+
+ Parameters
+ ----------
+ column : str
+ The raw column name (i.e. without ``_df1`` appended)
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+ for_display : bool, optional
+ Whether this is just going to be used for display (overwrite the
+ column names)
+
+ Returns
+ -------
+ sp.DataFrame
+ A sample of the intersection dataframe, containing only the
+ "pertinent" columns, for rows that don't match on the provided
+ column.
+ """
+ row_cnt = self.intersect_rows.count()
+ col_match = self.intersect_rows.select(column + "_MATCH")
+ match_cnt = col_match.where(
+ col(column + "_MATCH") == True # noqa: E712
+ ).count()
+ sample_count = min(sample_count, row_cnt - match_cnt)
+ sample = (
+ self.intersect_rows.where(col(column + "_MATCH") == False) # noqa: E712
+ .drop(column + "_MATCH")
+ .limit(sample_count)
+ )
+
+ for c in self.join_columns:
+ sample = sample.withColumnRenamed(c + "_" + self.df1_name, c)
+
+ return_cols = [
+ *self.join_columns,
+ column + "_" + self.df1_name,
+ column + "_" + self.df2_name,
+ ]
+ to_return = sample.select(return_cols)
+
+ if for_display:
+ return to_return.toDF(
+ *[
+ *self.join_columns,
+ column + " (" + self.df1_name + ")",
+ column + " (" + self.df2_name + ")",
+ ]
+ )
+ return to_return
+
+ def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame":
+ """Get all rows with any columns that have a mismatch.
+
+ Returns all df1 and df2 versions of the columns and join
+ columns.
+
+ Parameters
+ ----------
+ ignore_matching_cols : bool, optional
+ Whether showing the matching columns in the output or not. The default is False.
+
+ Returns
+ -------
+ sp.DataFrame
+ All rows of the intersection dataframe, containing any columns, that don't match.
+ """
+ match_list = []
+ return_list = []
+ for c in self.intersect_rows.columns:
+ if c.endswith("_MATCH"):
+ orig_col_name = c[:-6]
+
+ col_comparison = columns_equal(
+ self.intersect_rows,
+ orig_col_name + "_" + self.df1_name,
+ orig_col_name + "_" + self.df2_name,
+ c,
+ self.rel_tol,
+ self.abs_tol,
+ self.ignore_spaces,
+ )
+
+ if not ignore_matching_cols or (
+ ignore_matching_cols
+ and col_comparison.select(c)
+ .where(col(c) == False) # noqa: E712
+ .count()
+ > 0
+ ):
+ LOG.debug(f"Adding column {orig_col_name} to the result.")
+ match_list.append(c)
+ return_list.extend(
+ [
+ orig_col_name + "_" + self.df1_name,
+ orig_col_name + "_" + self.df2_name,
+ ]
+ )
+ elif ignore_matching_cols:
+ LOG.debug(
+ f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
+ )
+
+ mm_rows = self.intersect_rows.withColumn(
+ "match_array", concat(*match_list)
+ ).where(contains(col("match_array"), lit("false")))
+
+ for c in self.join_columns:
+ mm_rows = mm_rows.withColumnRenamed(c + "_" + self.df1_name, c)
+
+ return mm_rows.select(self.join_columns + return_list)
+
+ def report(
+ self,
+ sample_count: int = 10,
+ column_count: int = 10,
+ html_file: Optional[str] = None,
+ ) -> str:
+ """Return a string representation of a report.
+
+ The representation can
+ then be printed or saved to a file.
+
+ Parameters
+ ----------
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+
+ column_count : int, optional
+ The number of columns to display in the sample records output. Defaults to 10.
+
+ html_file : str, optional
+ HTML file name to save report output to. If ``None`` the file creation will be skipped.
+
+ Returns
+ -------
+ str
+ The report, formatted kinda nicely.
+ """
+ # Header
+ report = render("header.txt")
+ df_header = pd.DataFrame(
+ {
+ "DataFrame": [self.df1_name, self.df2_name],
+ "Columns": [len(self.df1.columns), len(self.df2.columns)],
+ "Rows": [self.df1.count(), self.df2.count()],
+ }
+ )
+ report += df_header[["DataFrame", "Columns", "Rows"]].to_string()
+ report += "\n\n"
+
+ # Column Summary
+ report += render(
+ "column_summary.txt",
+ len(self.intersect_columns()),
+ len(self.df1_unq_columns()),
+ len(self.df2_unq_columns()),
+ self.df1_name,
+ self.df2_name,
+ )
+
+ # Row Summary
+ match_on = ", ".join(self.join_columns)
+ report += render(
+ "row_summary.txt",
+ match_on,
+ self.abs_tol,
+ self.rel_tol,
+ self.intersect_rows.count(),
+ self.df1_unq_rows.count(),
+ self.df2_unq_rows.count(),
+ self.intersect_rows.count() - self.count_matching_rows(),
+ self.count_matching_rows(),
+ self.df1_name,
+ self.df2_name,
+ "Yes" if self._any_dupes else "No",
+ )
+
+ # Column Matching
+ report += render(
+ "column_comparison.txt",
+ len([col for col in self.column_stats if col["unequal_cnt"] > 0]),
+ len([col for col in self.column_stats if col["unequal_cnt"] == 0]),
+ sum(col["unequal_cnt"] for col in self.column_stats),
+ )
+
+ match_stats = []
+ match_sample = []
+ any_mismatch = False
+ for column in self.column_stats:
+ if not column["all_match"]:
+ any_mismatch = True
+ match_stats.append(
+ {
+ "Column": column["column"],
+ f"{self.df1_name} dtype": column["dtype1"],
+ f"{self.df2_name} dtype": column["dtype2"],
+ "# Unequal": column["unequal_cnt"],
+ "Max Diff": column["max_diff"],
+ "# Null Diff": column["null_diff"],
+ }
+ )
+ if column["unequal_cnt"] > 0:
+ match_sample.append(
+ self.sample_mismatch(
+ column["column"], sample_count, for_display=True
+ )
+ )
+
+ if any_mismatch:
+ report += "Columns with Unequal Values or Types\n"
+ report += "------------------------------------\n"
+ report += "\n"
+ df_match_stats = pd.DataFrame(match_stats)
+ df_match_stats.sort_values("Column", inplace=True)
+ # Have to specify again for sorting
+ report += df_match_stats[
+ [
+ "Column",
+ f"{self.df1_name} dtype",
+ f"{self.df2_name} dtype",
+ "# Unequal",
+ "Max Diff",
+ "# Null Diff",
+ ]
+ ].to_string()
+ report += "\n\n"
+
+ if sample_count > 0:
+ report += "Sample Rows with Unequal Values\n"
+ report += "-------------------------------\n"
+ report += "\n"
+ for sample in match_sample:
+ report += sample.toPandas().to_string()
+ report += "\n\n"
+
+ if min(sample_count, self.df1_unq_rows.count()) > 0:
+ report += (
+ f"Sample Rows Only in {self.df1_name} (First {column_count} Columns)\n"
+ )
+ report += (
+ f"---------------------------------------{'-' * len(self.df1_name)}\n"
+ )
+ report += "\n"
+ columns = self.df1_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df1_unq_rows.count())
+ report += (
+ self.df1_unq_rows.limit(unq_count)
+ .select(columns)
+ .toPandas()
+ .to_string()
+ )
+ report += "\n\n"
+
+ if min(sample_count, self.df2_unq_rows.count()) > 0:
+ report += (
+ f"Sample Rows Only in {self.df2_name} (First {column_count} Columns)\n"
+ )
+ report += (
+ f"---------------------------------------{'-' * len(self.df2_name)}\n"
+ )
+ report += "\n"
+ columns = self.df2_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df2_unq_rows.count())
+ report += (
+ self.df2_unq_rows.limit(unq_count)
+ .select(columns)
+ .toPandas()
+ .to_string()
+ )
+ report += "\n\n"
+
+ if html_file:
+ html_report = report.replace("\n", "
").replace(" ", " ")
+ html_report = f"
{html_report}
"
+ with open(html_file, "w") as f:
+ f.write(html_report)
+
+ return report
+
+
+def render(filename: str, *fields: Union[int, float, str]) -> str:
+ """Render out an individual template.
+
+ This basically just reads in a
+ template file, and applies ``.format()`` on the fields.
+
+ Parameters
+ ----------
+ filename : str
+ The file that contains the template. Will automagically prepend the
+ templates directory before opening
+ fields : list
+ Fields to be rendered out in the template
+
+ Returns
+ -------
+ str
+ The fully rendered out file.
+ """
+ this_dir = os.path.dirname(os.path.realpath(__file__))
+ with open(os.path.join(this_dir, "templates", filename)) as file_open:
+ return file_open.read().format(*fields)
+
+
+def columns_equal(
+ dataframe: "sp.DataFrame",
+ col_1: str,
+ col_2: str,
+ col_match: str,
+ rel_tol: float = 0,
+ abs_tol: float = 0,
+ ignore_spaces: bool = False,
+) -> "sp.DataFrame":
+ """Compare two columns from a dataframe.
+
+ Returns a True/False series with the same index as column 1.
+
+ - Two nulls (np.nan) will evaluate to True.
+ - A null and a non-null value will evaluate to False.
+ - Numeric values will use the relative and absolute tolerances.
+ - Decimal values (decimal.Decimal) will attempt to be converted to floats
+ before comparing
+ - Non-numeric values (i.e. where np.isclose can't be used) will just
+ trigger True on two nulls or exact matches.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+ col_match : str
+ The matching column denoting if the compare was a match or not
+ rel_tol : float, optional
+ Relative tolerance
+ abs_tol : float, optional
+ Absolute tolerance
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns
+
+ Returns
+ -------
+ sp.DataFrame
+ A column of boolean values are added. True == the values match, False == the
+ values don't match.
+ """
+ base_dtype, compare_dtype = _get_column_dtypes(dataframe, col_1, col_2)
+ if _is_comparable(base_dtype, compare_dtype):
+ if (base_dtype in NUMERIC_SNOWPARK_TYPES) and (
+ compare_dtype in NUMERIC_SNOWPARK_TYPES
+ ): # numeric tolerance comparison
+ dataframe = dataframe.withColumn(
+ col_match,
+ when(
+ (col(col_1).eqNullSafe(col(col_2)))
+ | (
+ abs(col(col_1) - col(col_2))
+ <= lit(abs_tol) + (lit(rel_tol) * abs(col(col_2)))
+ ),
+ # corner case of col1 != NaN and col2 == Nan returns True incorrectly
+ when(
+ (is_null(col(col_1)) == False) # noqa: E712
+ & (is_null(col(col_2)) == True), # noqa: E712
+ lit(False),
+ ).otherwise(lit(True)),
+ ).otherwise(lit(False)),
+ )
+ else: # non-numeric comparison
+ if ignore_spaces:
+ when_clause = trim(col(col_1)).eqNullSafe(trim(col(col_2)))
+ else:
+ when_clause = col(col_1).eqNullSafe(col(col_2))
+
+ dataframe = dataframe.withColumn(
+ col_match,
+ when(when_clause, lit(True)).otherwise(lit(False)),
+ )
+ else:
+ LOG.debug(
+ f"Skipping {col_1}({base_dtype}) and {col_2}({compare_dtype}), columns are not comparable"
+ )
+ dataframe = dataframe.withColumn(col_match, lit(False))
+ return dataframe
+
+
+def get_merged_columns(
+ original_df: "sp.DataFrame", merged_df: "sp.DataFrame", suffix: str
+) -> List[str]:
+ """Get the columns from an original dataframe, in the new merged dataframe.
+
+ Parameters
+ ----------
+ original_df : sp.DataFrame
+ The original, pre-merge dataframe
+ merged_df : sp.DataFrame
+ Post-merge with another dataframe, with suffixes added in.
+ suffix : str
+ What suffix was used to distinguish when the original dataframe was
+ overlapping with the other merged dataframe.
+
+ Returns
+ -------
+ List[str]
+ Column list of the original dataframe pre suffix
+ """
+ columns = []
+ for column in original_df.columns:
+ if column in merged_df.columns:
+ columns.append(column)
+ elif f"{column}_{suffix}" in merged_df.columns:
+ columns.append(f"{column}_{suffix}")
+ else:
+ raise ValueError("Column not found: %s", column)
+ return columns
+
+
+def calculate_max_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> float:
+ """Get a maximum difference between two columns.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+
+ Returns
+ -------
+ float
+ max diff
+ """
+ # Attempting to coalesce maximum diff for non-numeric results in error, if error return 0 max diff.
+ try:
+ diff = dataframe.select(
+ (col(col_1).astype("float") - col(col_2).astype("float")).alias("diff")
+ )
+ abs_diff = diff.select(abs(col("diff")).alias("abs_diff"))
+ max_diff: float = (
+ abs_diff.where(is_null(col("abs_diff")) == False) # noqa: E712
+ .agg({"abs_diff": "max"})
+ .collect()[0][0]
+ )
+ except SnowparkSQLException:
+ return None
+
+ if pd.isna(max_diff) or pd.isnull(max_diff) or max_diff is None:
+ return 0
+ else:
+ return max_diff
+
+
+def calculate_null_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> int:
+ """Get the null differences between two columns.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+
+ Returns
+ -------
+ int
+ null diff
+ """
+ nulls_df = dataframe.withColumn(
+ "col_1_null",
+ when(col(col_1).isNull() == True, lit(True)).otherwise( # noqa: E712
+ lit(False)
+ ),
+ )
+ nulls_df = nulls_df.withColumn(
+ "col_2_null",
+ when(col(col_2).isNull() == True, lit(True)).otherwise( # noqa: E712
+ lit(False)
+ ),
+ ).select(["col_1_null", "col_2_null"])
+
+ # (not a and b) or (a and not b)
+ null_diff = nulls_df.where(
+ ((col("col_1_null") == False) & (col("col_2_null") == True)) # noqa: E712
+ | ((col("col_1_null") == True) & (col("col_2_null") == False)) # noqa: E712
+ ).count()
+
+ if pd.isna(null_diff) or pd.isnull(null_diff) or null_diff is None:
+ return 0
+ else:
+ return null_diff
+
+
+def _generate_id_within_group(
+ dataframe: "sp.DataFrame", join_columns: List[str], order_column_name: str
+) -> "sp.DataFrame":
+ """Generate an ID column that can be used to deduplicate identical rows.
+
+ The series generated
+ is the order within a unique group, and it handles nulls. Requires a ``__index`` column.
+
+ Parameters
+ ----------
+ dataframe : sp.DataFrame
+ The dataframe to operate on
+ join_columns : list
+ List of strings which are the join columns
+ order_column_name: str
+ The name of the ``row_number`` column name
+
+ Returns
+ -------
+ sp.DataFrame
+ Original dataframe with the ID column that's unique in each group
+ """
+ default_value = "DATACOMPY_NULL"
+ null_check = False
+ default_check = False
+ for c in join_columns:
+ if dataframe.where(col(c).isNull()).limit(1).collect():
+ null_check = True
+ break
+ for c in [
+ column for column, type in dataframe[join_columns].dtypes if "string" in type
+ ]:
+ if dataframe.where(col(c).isin(default_value)).limit(1).collect():
+ default_check = True
+ break
+
+ if null_check:
+ if default_check:
+ raise ValueError(f"{default_value} was found in your join columns")
+
+ return (
+ dataframe.select(
+ *(col(c).cast("string").alias(c) for c in join_columns + ["__index"]) # noqa: RUF005
+ )
+ .fillna(default_value)
+ .withColumn(
+ order_column_name,
+ row_number().over(Window.orderBy("__index").partitionBy(join_columns))
+ - 1,
+ )
+ .select(["__index", order_column_name])
+ )
+ else:
+ return (
+ dataframe.select(join_columns + ["__index"]) # noqa: RUF005
+ .withColumn(
+ order_column_name,
+ row_number().over(Window.orderBy("__index").partitionBy(join_columns))
+ - 1,
+ )
+ .select(["__index", order_column_name])
+ )
+
+
+def _get_column_dtypes(
+ dataframe: "sp.DataFrame", col_1: "str", col_2: "str"
+) -> tuple[str, str]:
+ """Get the dtypes of two columns.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+
+ Returns
+ -------
+ Tuple(str, str)
+ Tuple of base and compare datatype
+ """
+ base_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_1)
+ compare_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_2)
+ return base_dtype, compare_dtype
+
+
+def _is_comparable(type1: str, type2: str) -> bool:
+ """Check if two SnowPark data types can be safely compared.
+
+ Two data types are considered comparable if any of the following apply:
+ 1. Both data types are the same
+ 2. Both data types are numeric
+
+ Parameters
+ ----------
+ type1 : str
+ A string representation of a Snowpark data type
+ type2 : str
+ A string representation of a Snowpark data type
+
+ Returns
+ -------
+ bool
+ True if both data types are comparable
+ """
+ return (
+ type1 == type2
+ or (type1 in NUMERIC_SNOWPARK_TYPES and type2 in NUMERIC_SNOWPARK_TYPES)
+ or ("string" in type1 and type2 == "date")
+ or (type1 == "date" and "string" in type2)
+ or ("string" in type1 and type2 == "timestamp")
+ or (type1 == "timestamp" and "string" in type2)
+ )
+
+
+def temp_column_name(*dataframes) -> str:
+ """Get a temp column name that isn't included in columns of any dataframes.
+
+ Parameters
+ ----------
+ dataframes : list of DataFrames
+ The DataFrames to create a temporary column name for
+
+ Returns
+ -------
+ str
+ String column name that looks like '_temp_x' for some integer x
+ """
+ i = 0
+ columns = []
+ for dataframe in dataframes:
+ columns = columns + list(dataframe.columns)
+ columns = set(columns)
+
+ while True:
+ temp_column = f"_TEMP_{i}"
+ unique = True
+
+ if temp_column in columns:
+ i += 1
+ unique = False
+ if unique:
+ return temp_column
diff --git a/docs/source/index.rst b/docs/source/index.rst
index e6f77d96..b1cb4c4a 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -11,6 +11,7 @@ Contents
Installation
Pandas Usage
Spark Usage
+ Snowflake Usage
Polars Usage
Fugue Usage
Benchmarks
@@ -28,4 +29,4 @@ Indices and tables
* :ref:`genindex`
* :ref:`modindex`
-* :ref:`search`
\ No newline at end of file
+* :ref:`search`
diff --git a/docs/source/snowflake_usage.rst b/docs/source/snowflake_usage.rst
new file mode 100644
index 00000000..3c2687e3
--- /dev/null
+++ b/docs/source/snowflake_usage.rst
@@ -0,0 +1,268 @@
+Snowpark/Snowflake Usage
+========================
+
+For ``SnowflakeCompare``
+
+- ``on_index`` is not supported.
+- Joining is done using ``EQUAL_NULL`` which is the equality test that is safe for null values.
+- Compares ``snowflake.snowpark.DataFrame``, which can be provided as either raw Snowflake dataframes
+or the as the names of full names of valid snowflake tables, which we will process into Snowpark dataframes.
+
+
+SnowflakeCompare Object Setup
+---------------------------------------------------
+There are two ways to specify input dataframes for ``SnowflakeCompare``
+
+Provide Snowpark dataframes
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: python
+
+ from snowflake.snowpark import Session
+ from snowflake.snowpark import Row
+ import datetime
+ import datacompy.snowflake as sp
+
+ connection_parameters = {
+ ...
+ }
+ session = Session.builder.configs(connection_parameters).create()
+
+ data1 = [
+ Row(acct_id=10000001234, dollar_amt=123.45, name='George Maharis', float_fld=14530.1555,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=1.0,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=None,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001237, dollar_amt=123456.0, name='Bob Loblaw', float_fld=345.12,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001239, dollar_amt=1.05, name='Lucille Bluth', float_fld=None,
+ date_fld=datetime.date(2017, 1, 1)),
+ ]
+
+ data2 = [
+ Row(acct_id=10000001234, dollar_amt=123.4, name='George Michael Bluth', float_fld=14530.155),
+ Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=None),
+ Row(acct_id=None, dollar_amt=1345.0, name='George Bluth', float_fld=1.0),
+ Row(acct_id=10000001237, dollar_amt=123456.0, name='Robert Loblaw', float_fld=345.12),
+ Row(acct_id=10000001238, dollar_amt=1.05, name='Loose Seal Bluth', float_fld=111.0),
+ ]
+
+ df_1 = session.createDataFrame(data1)
+ df_2 = session.createDataFrame(data2)
+
+ compare = sp.SnowflakeCompare(
+ session,
+ df_1,
+ df_2,
+ join_columns=['acct_id'],
+ rel_tol=1e-03,
+ abs_tol=1e-04,
+ )
+ compare.matches(ignore_extra_columns=False)
+
+ # This method prints out a human-readable report summarizing and sampling differences
+ print(compare.report())
+
+
+Provide the full name (``{db}.{schema}.{table_name}``) of valid Snowflake tables
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Given the dataframes from the prior examples...
+
+.. code-block:: python
+ df_1.write.mode("overwrite").save_as_table("toy_table_1")
+ df_2.write.mode("overwrite").save_as_table("toy_table_2")
+
+ compare = sp.SnowflakeCompare(
+ session,
+ f"{db}.{schema}.toy_table_1",
+ f"{db}.{schema}.toy_table_2",
+ join_columns=['acct_id'],
+ rel_tol=1e-03,
+ abs_tol=1e-04,
+ )
+ compare.matches(ignore_extra_columns=False)
+
+ # This method prints out a human-readable report summarizing and sampling differences
+ print(compare.report())
+
+Reports
+-------
+
+A report is generated by calling ``report()``, which returns a string.
+Here is a sample report generated by ``datacompy`` for the two tables above,
+joined on ``acct_id`` (Note: the names for your dataframes are extracted from
+the name of the provided Snowflake table. If you chose to directly use Snowpark
+dataframes, then the names will default to ``DF1`` and ``DF2``.)::
+
+ DataComPy Comparison
+ --------------------
+
+ DataFrame Summary
+ -----------------
+
+ DataFrame Columns Rows
+ 0 DF1 5 5
+ 1 DF2 4 5
+
+ Column Summary
+ --------------
+
+ Number of columns in common: 4
+ Number of columns in DF1 but not in DF2: 1
+ Number of columns in DF2 but not in DF1: 0
+
+ Row Summary
+ -----------
+
+ Matched on: ACCT_ID
+ Any duplicates on match values: No
+ Absolute Tolerance: 0
+ Relative Tolerance: 0
+ Number of rows in common: 4
+ Number of rows in DF1 but not in DF2: 1
+ Number of rows in DF2 but not in DF1: 1
+
+ Number of rows with some compared columns unequal: 4
+ Number of rows with all compared columns equal: 0
+
+ Column Comparison
+ -----------------
+
+ Number of columns compared with some values unequal: 3
+ Number of columns compared with all values equal: 1
+ Total number of values which compare unequal: 6
+
+ Columns with Unequal Values or Types
+ ------------------------------------
+
+ Column DF1 dtype DF2 dtype # Unequal Max Diff # Null Diff
+ 0 DOLLAR_AMT double double 1 0.0500 0
+ 2 FLOAT_FLD double double 3 0.0005 2
+ 1 NAME string(16777216) string(16777216) 2 NaN 0
+
+ Sample Rows with Unequal Values
+ -------------------------------
+
+ ACCT_ID DOLLAR_AMT (DF1) DOLLAR_AMT (DF2)
+ 0 10000001234 123.45 123.4
+
+ ACCT_ID NAME (DF1) NAME (DF2)
+ 0 10000001234 George Maharis George Michael Bluth
+ 1 10000001237 Bob Loblaw Robert Loblaw
+
+ ACCT_ID FLOAT_FLD (DF1) FLOAT_FLD (DF2)
+ 0 10000001234 14530.1555 14530.155
+ 1 10000001235 1.0000 NaN
+ 2 10000001236 NaN 1.000
+
+ Sample Rows Only in DF1 (First 10 Columns)
+ ------------------------------------------
+
+ ACCT_ID_DF1 DOLLAR_AMT_DF1 NAME_DF1 FLOAT_FLD_DF1 DATE_FLD_DF1
+ 0 10000001239 1.05 Lucille Bluth NaN 2017-01-01
+
+ Sample Rows Only in DF2 (First 10 Columns)
+ ------------------------------------------
+
+ ACCT_ID_DF2 DOLLAR_AMT_DF2 NAME_DF2 FLOAT_FLD_DF2
+ 0 10000001238 1.05 Loose Seal Bluth 111.0
+
+
+Convenience Methods
+-------------------
+
+There are a few convenience methods and attributes available after the comparison has been run:
+
+.. code-block:: python
+
+ compare.intersect_rows[['name_df1', 'name_df2', 'name_match']].show()
+ # --------------------------------------------------------
+ # |"NAME_DF1" |"NAME_DF2" |"NAME_MATCH" |
+ # --------------------------------------------------------
+ # |George Maharis |George Michael Bluth |False |
+ # |Michael Bluth |Michael Bluth |True |
+ # |George Bluth |George Bluth |True |
+ # |Bob Loblaw |Robert Loblaw |False |
+ # --------------------------------------------------------
+
+ compare.df1_unq_rows.show()
+ # ---------------------------------------------------------------------------------------
+ # |"ACCT_ID_DF1" |"DOLLAR_AMT_DF1" |"NAME_DF1" |"FLOAT_FLD_DF1" |"DATE_FLD_DF1" |
+ # ---------------------------------------------------------------------------------------
+ # |10000001239 |1.05 |Lucille Bluth |NULL |2017-01-01 |
+ # ---------------------------------------------------------------------------------------
+
+ compare.df2_unq_rows.show()
+ # -------------------------------------------------------------------------
+ # |"ACCT_ID_DF2" |"DOLLAR_AMT_DF2" |"NAME_DF2" |"FLOAT_FLD_DF2" |
+ # -------------------------------------------------------------------------
+ # |10000001238 |1.05 |Loose Seal Bluth |111.0 |
+ # -------------------------------------------------------------------------
+
+ print(compare.intersect_columns())
+ # OrderedSet(['acct_id', 'dollar_amt', 'name', 'float_fld'])
+
+ print(compare.df1_unq_columns())
+ # OrderedSet(['date_fld'])
+
+ print(compare.df2_unq_columns())
+ # OrderedSet()
+
+Duplicate rows
+--------------
+
+Datacompy will try to handle rows that are duplicate in the join columns. It does this behind the
+scenes by generating a unique ID within each unique group of the join columns. For example, if you
+have two dataframes you're trying to join on acct_id:
+
+=========== ================
+acct_id name
+=========== ================
+1 George Maharis
+1 Michael Bluth
+2 George Bluth
+=========== ================
+
+=========== ================
+acct_id name
+=========== ================
+1 George Maharis
+1 Michael Bluth
+1 Tony Wonder
+2 George Bluth
+=========== ================
+
+Datacompy will generate a unique temporary ID for joining:
+
+=========== ================ ========
+acct_id name temp_id
+=========== ================ ========
+1 George Maharis 0
+1 Michael Bluth 1
+2 George Bluth 0
+=========== ================ ========
+
+=========== ================ ========
+acct_id name temp_id
+=========== ================ ========
+1 George Maharis 0
+1 Michael Bluth 1
+1 Tony Wonder 2
+2 George Bluth 0
+=========== ================ ========
+
+And then merge the two dataframes on a combination of the join_columns you specified and the temporary
+ID, before dropping the temp_id again. So the first two rows in the first dataframe will match the
+first two rows in the second dataframe, and the third row in the second dataframe will be recognized
+as uniquely in the second.
+
+Additional considerations
+-------------------------
+- It is strongly recommended against joining on float columns (or any column with floating point precision).
+Columns joining tables are compared on the basis of an exact comparison, therefore if the values comparing
+your float columns are not exact, you will likely get unexpected results.
+- Case-sensitive columns are only partially supported. We essentially treat case-sensitive
+columns as if they are case-insensitive. Therefore you may use case-sensitive columns as long as
+you don't have several columns with the same name differentiated only be case sensitivity.
diff --git a/pyproject.toml b/pyproject.toml
index ed9c0f9b..9c86c82f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,16 +56,19 @@ python-tag = "py3"
[project.optional-dependencies]
duckdb = ["fugue[duckdb]"]
spark = ["pyspark[connect]>=3.1.1; python_version < \"3.11\"", "pyspark[connect]>=3.4; python_version >= \"3.11\""]
+snowflake = ["snowflake-connector-python", "snowflake-snowpark-python"]
dask = ["fugue[dask]"]
ray = ["fugue[ray]"]
docs = ["sphinx", "furo", "myst-parser"]
tests = ["pytest", "pytest-cov"]
tests-spark = ["pytest", "pytest-cov", "pytest-spark"]
+tests-snowflake = ["snowflake-snowpark-python[localtest]"]
qa = ["pre-commit", "ruff==0.5.7", "mypy", "pandas-stubs"]
build = ["build", "twine", "wheel"]
edgetest = ["edgetest", "edgetest-conda"]
-dev = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"]
+dev_no_snowflake = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"]
+dev = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[snowflake]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[tests-snowflake]", "datacompy[qa]", "datacompy[build]"]
# Linters, formatters and type checkers
[tool.ruff]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..54532b4b
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,24 @@
+"""Testing configuration file, currently used for generating a Snowpark local session for testing."""
+
+import os
+
+import pytest
+
+try:
+ from snowflake.snowpark.session import Session
+except ModuleNotFoundError:
+ pass
+
+CONNECTION_PARAMETERS = {
+ "account": os.environ.get("SF_ACCOUNT"),
+ "user": os.environ.get("SF_UID"),
+ "password": os.environ.get("SF_PWD"),
+ "warehouse": os.environ.get("SF_WAREHOUSE"),
+ "database": os.environ.get("SF_DATABASE"),
+ "schema": os.environ.get("SF_SCHEMA"),
+}
+
+
+@pytest.fixture(scope="module")
+def snowpark_session() -> "Session":
+ return Session.builder.configs(CONNECTION_PARAMETERS).create()
diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py
new file mode 100644
index 00000000..548f6043
--- /dev/null
+++ b/tests/test_snowflake.py
@@ -0,0 +1,1326 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "LICENSE");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Testing out the datacompy functionality
+"""
+
+import io
+import logging
+import sys
+from datetime import datetime
+from decimal import Decimal
+from io import StringIO
+from unittest import mock
+
+import numpy as np
+import pandas as pd
+import pytest
+from pytest import raises
+
+pytest.importorskip("pyspark")
+
+
+from datacompy.snowflake import (
+ SnowflakeCompare,
+ _generate_id_within_group,
+ calculate_max_diff,
+ columns_equal,
+ temp_column_name,
+)
+from pandas.testing import assert_series_equal
+from snowflake.snowpark.exceptions import SnowparkSQLException
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
+
+pd.DataFrame.iteritems = pd.DataFrame.items # Pandas 2+ compatability
+np.bool = np.bool_ # Numpy 1.24.3+ comptability
+
+
+def test_numeric_columns_equal_abs(snowpark_session):
+ data = """A|B|EXPECTED
+1|1|True
+2|2.1|True
+3|4|False
+4|NULL|False
+NULL|4|False
+NULL|NULL|True"""
+
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_numeric_columns_equal_rel(snowpark_session):
+ data = """A|B|EXPECTED
+1|1|True
+2|2.1|True
+3|4|False
+4|NULL|False
+NULL|4|False
+NULL|NULL|True"""
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_string_columns_equal(snowpark_session):
+ data = """A|B|EXPECTED
+Hi|Hi|True
+Yo|Yo|True
+Hey|Hey |False
+résumé|resume|False
+résumé|résumé|True
+💩|💩|True
+💩|🤔|False
+ | |True
+ | |False
+datacompy|DataComPy|False
+something||False
+|something|False
+||True"""
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_string_columns_equal_with_ignore_spaces(snowpark_session):
+ data = """A|B|EXPECTED
+Hi|Hi|True
+Yo|Yo|True
+Hey|Hey |True
+résumé|resume|False
+résumé|résumé|True
+💩|💩|True
+💩|🤔|False
+ | |True
+ | |True
+datacompy|DataComPy|False
+something||False
+|something|False
+||True"""
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(
+ df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_date_columns_equal(snowpark_session):
+ data = """A|B|EXPECTED
+2017-01-01|2017-01-01|True
+2017-01-02|2017-01-02|True
+2017-10-01|2017-10-10|False
+2017-01-01||False
+|2017-01-01|False
+||True"""
+ pdf = pd.read_csv(io.StringIO(data), sep="|")
+ df = snowpark_session.createDataFrame(pdf)
+ # First compare just the strings
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+ # Then compare converted to datetime objects
+ pdf["A"] = pd.to_datetime(pdf["A"])
+ pdf["B"] = pd.to_datetime(pdf["B"])
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+ # and reverse
+ actual_out_rev = columns_equal(df, "B", "A", "ACTUAL", rel_tol=0.2).toPandas()[
+ "ACTUAL"
+ ]
+ assert_series_equal(expect_out, actual_out_rev, check_names=False)
+
+
+def test_date_columns_equal_with_ignore_spaces(snowpark_session):
+ data = """A|B|EXPECTED
+2017-01-01|2017-01-01 |True
+2017-01-02 |2017-01-02|True
+2017-10-01 |2017-10-10 |False
+2017-01-01||False
+|2017-01-01|False
+||True"""
+ pdf = pd.read_csv(io.StringIO(data), sep="|")
+ df = snowpark_session.createDataFrame(pdf)
+ # First compare just the strings
+ actual_out = columns_equal(
+ df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+ # Then compare converted to datetime objects
+ try: # pandas 2
+ pdf["A"] = pd.to_datetime(pdf["A"], format="mixed")
+ pdf["B"] = pd.to_datetime(pdf["B"], format="mixed")
+ except ValueError: # pandas 1.5
+ pdf["A"] = pd.to_datetime(pdf["A"])
+ pdf["B"] = pd.to_datetime(pdf["B"])
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(
+ df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+ # and reverse
+ actual_out_rev = columns_equal(
+ df, "B", "A", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ assert_series_equal(expect_out, actual_out_rev, check_names=False)
+
+
+def test_date_columns_unequal(snowpark_session):
+ """I want datetime fields to match with dates stored as strings"""
+ data = [{"A": "2017-01-01", "B": "2017-01-02"}, {"A": "2017-01-01"}]
+ pdf = pd.DataFrame(data)
+ pdf["A_DT"] = pd.to_datetime(pdf["A"])
+ pdf["B_DT"] = pd.to_datetime(pdf["B"])
+ df = snowpark_session.createDataFrame(pdf)
+ assert columns_equal(df, "A", "A_DT", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "B", "B_DT", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "A_DT", "A", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "B_DT", "B", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert not columns_equal(df, "B_DT", "A", "ACTUAL").toPandas()["ACTUAL"].any()
+ assert not columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].any()
+ assert not columns_equal(df, "A", "B_DT", "ACTUAL").toPandas()["ACTUAL"].any()
+ assert not columns_equal(df, "B", "A_DT", "ACTUAL").toPandas()["ACTUAL"].any()
+
+
+def test_bad_date_columns(snowpark_session):
+ """If strings can't be coerced into dates then it should be false for the
+ whole column.
+ """
+ data = [
+ {"A": "2017-01-01", "B": "2017-01-01"},
+ {"A": "2017-01-01", "B": "217-01-01"},
+ ]
+ pdf = pd.DataFrame(data)
+ pdf["A_DT"] = pd.to_datetime(pdf["A"])
+ df = snowpark_session.createDataFrame(pdf)
+ assert not columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].any()
+
+
+def test_rounded_date_columns(snowpark_session):
+ """If strings can't be coerced into dates then it should be false for the
+ whole column.
+ """
+ data = [
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:00.000000", "EXP": True},
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:00.123456", "EXP": False},
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:01.000000", "EXP": False},
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:00", "EXP": True},
+ ]
+ pdf = pd.DataFrame(data)
+ pdf["A_DT"] = pd.to_datetime(pdf["A"])
+ df = snowpark_session.createDataFrame(pdf)
+ actual = columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expected = df.select("EXP").toPandas()["EXP"]
+ assert_series_equal(actual, expected, check_names=False)
+
+
+def test_decimal_float_columns_equal(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": 1, "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": 1.3, "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": 1.000003, "EXPECTED": True},
+ {"A": Decimal("1.000000004"), "B": 1.000000003, "EXPECTED": False},
+ {"A": Decimal("1.3"), "B": 1.2, "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": 1, "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_decimal_float_columns_equal_rel(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": 1, "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": 1.3, "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": 1.000003, "EXPECTED": True},
+ {"A": Decimal("1.000000004"), "B": 1.000000003, "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": 1.2, "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": 1, "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.001).toPandas()[
+ "ACTUAL"
+ ]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_decimal_columns_equal(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": Decimal("1"), "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": Decimal("1.3"), "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": Decimal("1.000003"), "EXPECTED": True},
+ {
+ "A": Decimal("1.000000004"),
+ "B": Decimal("1.000000003"),
+ "EXPECTED": False,
+ },
+ {"A": Decimal("1.3"), "B": Decimal("1.2"), "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": Decimal("1"), "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_decimal_columns_equal_rel(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": Decimal("1"), "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": Decimal("1.3"), "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": Decimal("1.000003"), "EXPECTED": True},
+ {
+ "A": Decimal("1.000000004"),
+ "B": Decimal("1.000000003"),
+ "EXPECTED": True,
+ },
+ {"A": Decimal("1.3"), "B": Decimal("1.2"), "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": Decimal("1"), "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.001).toPandas()[
+ "ACTUAL"
+ ]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_infinity_and_beyond(snowpark_session):
+ # https://spark.apache.org/docs/latest/sql-ref-datatypes.html#positivenegative-infinity-semantics
+ # Positive/negative infinity multiplied by 0 returns NaN.
+ # Positive infinity sorts lower than NaN and higher than any other values.
+ # Negative infinity sorts lower than any other values.
+ data = [
+ {"A": np.inf, "B": np.inf, "EXPECTED": True},
+ {"A": -np.inf, "B": -np.inf, "EXPECTED": True},
+ {"A": -np.inf, "B": np.inf, "EXPECTED": True},
+ {"A": np.inf, "B": -np.inf, "EXPECTED": True},
+ {"A": 1, "B": 1, "EXPECTED": True},
+ {"A": 1, "B": 0, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_compare_table_setter_bad(snowpark_session):
+ # Invalid table name
+ with raises(ValueError, match="invalid_table_name_1 is not a valid table name."):
+ SnowflakeCompare(
+ snowpark_session, "invalid_table_name_1", "invalid_table_name_2", ["A"]
+ )
+ # Valid table name but table does not exist
+ with raises(SnowparkSQLException):
+ SnowflakeCompare(
+ snowpark_session, "non.existant.table_1", "non.existant.table_2", ["A"]
+ )
+
+
+def test_compare_table_setter_good(snowpark_session):
+ data = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df = pd.read_csv(StringIO(data), sep=",")
+ database = snowpark_session.get_current_database().replace('"', "")
+ schema = snowpark_session.get_current_schema().replace('"', "")
+ full_table_name = f"{database}.{schema}"
+ toy_table_name_1 = "DC_TOY_TABLE_1"
+ toy_table_name_2 = "DC_TOY_TABLE_2"
+ full_toy_table_name_1 = f"{full_table_name}.{toy_table_name_1}"
+ full_toy_table_name_2 = f"{full_table_name}.{toy_table_name_2}"
+
+ snowpark_session.write_pandas(
+ df, toy_table_name_1, table_type="temp", auto_create_table=True, overwrite=True
+ )
+ snowpark_session.write_pandas(
+ df, toy_table_name_2, table_type="temp", auto_create_table=True, overwrite=True
+ )
+
+ compare = SnowflakeCompare(
+ snowpark_session,
+ full_toy_table_name_1,
+ full_toy_table_name_2,
+ join_columns=["ACCT_ID"],
+ )
+ assert compare.df1.toPandas().equals(df)
+ assert compare.join_columns == ["ACCT_ID"]
+
+
+def test_compare_df_setter_bad(snowpark_session):
+ pdf = pd.DataFrame([{"A": 1, "C": 2}, {"A": 2, "C": 2}])
+ df = snowpark_session.createDataFrame(pdf)
+ with raises(TypeError, match="DF1 must be a valid sp.Dataframe"):
+ SnowflakeCompare(snowpark_session, 3, 2, ["A"])
+ with raises(ValueError, match="DF1 must have all columns from join_columns"):
+ SnowflakeCompare(snowpark_session, df, df.select("*"), ["B"])
+ pdf = pd.DataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 3}])
+ df_dupe = snowpark_session.createDataFrame(pdf)
+ pd.testing.assert_frame_equal(
+ SnowflakeCompare(
+ snowpark_session, df_dupe, df_dupe.select("*"), ["A", "B"]
+ ).df1.toPandas(),
+ pdf,
+ check_dtype=False,
+ )
+
+
+def test_compare_df_setter_good(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1.toPandas().equals(df1.toPandas())
+ assert compare.join_columns == ["A"]
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A", "B"])
+ assert compare.df1.toPandas().equals(df1.toPandas())
+ assert compare.join_columns == ["A", "B"]
+
+
+def test_compare_df_setter_different_cases(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1.toPandas().equals(df1.toPandas())
+
+
+def test_columns_overlap(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1_unq_columns() == set()
+ assert compare.df2_unq_columns() == set()
+ assert compare.intersect_columns() == {"A", "B"}
+
+
+def test_columns_no_overlap(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "D": "OH"}, {"A": 2, "B": 3, "D": "YA"}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1_unq_columns() == {"C"}
+ assert compare.df2_unq_columns() == {"D"}
+ assert compare.intersect_columns() == {"A", "B"}
+
+
+def test_columns_maintain_order_through_set_operations(snowpark_session):
+ pdf1 = pd.DataFrame(
+ {
+ "JOIN": ["A", "B"],
+ "F": [0, 0],
+ "G": [1, 2],
+ "B": [2, 2],
+ "H": [3, 3],
+ "A": [4, 4],
+ "C": [-2, -3],
+ }
+ )
+ pdf2 = pd.DataFrame(
+ {
+ "JOIN": ["A", "B"],
+ "E": [0, 1],
+ "H": [1, 2],
+ "B": [2, 3],
+ "A": [-1, -1],
+ "G": [4, 4],
+ "D": [-3, -2],
+ }
+ )
+ df1 = snowpark_session.createDataFrame(pdf1)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["JOIN"])
+ assert list(compare.df1_unq_columns()) == ["F", "C"]
+ assert list(compare.df2_unq_columns()) == ["E", "D"]
+ assert list(compare.intersect_columns()) == ["JOIN", "G", "B", "H", "A"]
+
+
+def test_10k_rows(snowpark_session):
+ rng = np.random.default_rng()
+ pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["B", "C"])
+ pdf.reset_index(inplace=True)
+ pdf.columns = ["A", "B", "C"]
+ pdf2 = pdf.copy()
+ pdf2["B"] = pdf2["B"] + 0.1
+ df1 = snowpark_session.createDataFrame(pdf)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ compare_tol = SnowflakeCompare(snowpark_session, df1, df2, ["A"], abs_tol=0.2)
+ assert compare_tol.matches()
+ assert compare_tol.df1_unq_rows.count() == 0
+ assert compare_tol.df2_unq_rows.count() == 0
+ assert compare_tol.intersect_columns() == {"A", "B", "C"}
+ assert compare_tol.all_columns_match()
+ assert compare_tol.all_rows_overlap()
+ assert compare_tol.intersect_rows_match()
+
+ compare_no_tol = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert not compare_no_tol.matches()
+ assert compare_no_tol.df1_unq_rows.count() == 0
+ assert compare_no_tol.df2_unq_rows.count() == 0
+ assert compare_no_tol.intersect_columns() == {"A", "B", "C"}
+ assert compare_no_tol.all_columns_match()
+ assert compare_no_tol.all_rows_overlap()
+ assert not compare_no_tol.intersect_rows_match()
+
+
+def test_subset(snowpark_session, caplog):
+ caplog.set_level(logging.DEBUG)
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}]
+ )
+ df2 = snowpark_session.createDataFrame([{"A": 1, "C": "HI"}])
+ comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert comp.subset()
+
+
+def test_not_subset(snowpark_session, caplog):
+ caplog.set_level(logging.INFO)
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "GREAT"}]
+ )
+ comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert not comp.subset()
+ assert "C: 1 / 2 (50.00%) match" in caplog.text
+
+
+def test_large_subset(snowpark_session):
+ rng = np.random.default_rng()
+ pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["B", "C"])
+ pdf.reset_index(inplace=True)
+ pdf.columns = ["A", "B", "C"]
+ pdf2 = pdf[["A", "B"]].head(50).copy()
+ df1 = snowpark_session.createDataFrame(pdf)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert not comp.matches()
+ assert comp.subset()
+
+
+def test_string_joiner(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"AB": 1, "BC": 2}, {"AB": 2, "BC": 2}])
+ df2 = snowpark_session.createDataFrame([{"AB": 1, "BC": 2}, {"AB": 2, "BC": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "AB")
+ assert compare.matches()
+
+
+def test_decimal_with_joins(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": Decimal("1"), "B": 2}, {"A": Decimal("2"), "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_decimal_with_nulls(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": Decimal("2")}, {"A": 2, "B": Decimal("2")}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2}, {"A": 2, "B": 2}, {"A": 3, "B": 2}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert not compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_strings_with_joins(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_temp_column_name(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}])
+ df2 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}, {"A": "back fo mo", "B": 3}]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_0"
+
+
+def test_temp_column_name_one_has(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}, {"A": "back fo mo", "B": 3}]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_1"
+
+
+def test_temp_column_name_both_have_temp_1(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [
+ {"_TEMP_0": "HI", "B": 2},
+ {"_TEMP_0": "BYE", "B": 2},
+ {"A": "back fo mo", "B": 3},
+ ]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_1"
+
+
+def test_temp_column_name_both_have_temp_2(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [
+ {"_TEMP_0": "HI", "B": 2},
+ {"_TEMP_1": "BYE", "B": 2},
+ {"A": "back fo mo", "B": 3},
+ ]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_2"
+
+
+def test_temp_column_name_one_already(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_1": "HI", "B": 2}, {"_TEMP_1": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [
+ {"_TEMP_1": "HI", "B": 2},
+ {"_TEMP_1": "BYE", "B": 2},
+ {"A": "back fo mo", "B": 3},
+ ]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_0"
+
+
+# Duplicate testing!
+
+
+def test_simple_dupes_one_field(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_two_fields(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2, "C": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2, "C": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A", "B"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_one_field_two_vals_1(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_one_field_two_vals_2(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 0}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert not compare.matches()
+ assert compare.df1_unq_rows.count() == 1
+ assert compare.df2_unq_rows.count() == 1
+ assert compare.intersect_rows.count() == 1
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_one_field_three_to_two_vals(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2}, {"A": 1, "B": 0}, {"A": 1, "B": 0}]
+ )
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert not compare.matches()
+ assert compare.df1_unq_rows.count() == 1
+ assert compare.df2_unq_rows.count() == 0
+ assert compare.intersect_rows.count() == 2
+ # Just render the report to make sure it renders.
+ compare.report()
+ assert "(First 1 Columns)" in compare.report(column_count=1)
+ assert "(First 2 Columns)" in compare.report(column_count=2)
+
+
+def test_dupes_from_real_data(snowpark_session):
+ data = """ACCT_ID,ACCT_SFX_NUM,TRXN_POST_DT,TRXN_POST_SEQ_NUM,TRXN_AMT,TRXN_DT,DEBIT_CR_CD,CASH_ADV_TRXN_COMN_CNTRY_CD,MRCH_CATG_CD,MRCH_PSTL_CD,VISA_MAIL_PHN_CD,VISA_RQSTD_PMT_SVC_CD,MC_PMT_FACILITATOR_IDN_NUM
+100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0
+200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1,
+100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A,
+100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0
+200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,,
+100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0
+200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0
+200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,,
+100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0
+200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F,
+100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0
+200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F,
+100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0
+200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A,
+100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0
+200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,"""
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep=","))
+ df2 = df1.select("*")
+ compare_acct = SnowflakeCompare(
+ snowpark_session, df1, df2, join_columns=["ACCT_ID"]
+ )
+ assert compare_acct.matches()
+ compare_acct.report()
+
+ compare_unq = SnowflakeCompare(
+ snowpark_session,
+ df1,
+ df2,
+ join_columns=["ACCT_ID", "ACCT_SFX_NUM", "TRXN_POST_DT", "TRXN_POST_SEQ_NUM"],
+ )
+ assert compare_unq.matches()
+ compare_unq.report()
+
+
+def test_table_compare_from_real_data(snowpark_session):
+ data = """ACCT_ID,ACCT_SFX_NUM,TRXN_POST_DT,TRXN_POST_SEQ_NUM,TRXN_AMT,TRXN_DT,DEBIT_CR_CD,CASH_ADV_TRXN_COMN_CNTRY_CD,MRCH_CATG_CD,MRCH_PSTL_CD,VISA_MAIL_PHN_CD,VISA_RQSTD_PMT_SVC_CD,MC_PMT_FACILITATOR_IDN_NUM
+100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0
+200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1,
+100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A,
+100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0
+200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,,
+100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0
+200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0
+200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,,
+100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0
+200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F,
+100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0
+200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F,
+100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0
+200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A,
+100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0
+200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,"""
+ df = pd.read_csv(StringIO(data), sep=",")
+ database = snowpark_session.get_current_database().replace('"', "")
+ schema = snowpark_session.get_current_schema().replace('"', "")
+ full_table_name = f"{database}.{schema}"
+ toy_table_name_1 = "DC_TOY_TABLE_1"
+ toy_table_name_2 = "DC_TOY_TABLE_2"
+ full_toy_table_name_1 = f"{full_table_name}.{toy_table_name_1}"
+ full_toy_table_name_2 = f"{full_table_name}.{toy_table_name_2}"
+
+ snowpark_session.write_pandas(
+ df, toy_table_name_1, table_type="temp", auto_create_table=True, overwrite=True
+ )
+ snowpark_session.write_pandas(
+ df, toy_table_name_2, table_type="temp", auto_create_table=True, overwrite=True
+ )
+
+ compare_acct = SnowflakeCompare(
+ snowpark_session,
+ full_toy_table_name_1,
+ full_toy_table_name_2,
+ join_columns=["ACCT_ID"],
+ )
+ assert compare_acct.matches()
+ compare_acct.report()
+
+ compare_unq = SnowflakeCompare(
+ snowpark_session,
+ full_toy_table_name_1,
+ full_toy_table_name_2,
+ join_columns=["ACCT_ID", "ACCT_SFX_NUM", "TRXN_POST_DT", "TRXN_POST_SEQ_NUM"],
+ )
+ assert compare_unq.matches()
+ compare_unq.report()
+
+
+def test_strings_with_joins_with_ignore_spaces(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": " A"}, {"A": "BYE", "B": "A"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": "A"}, {"A": "BYE", "B": "A "}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_decimal_with_joins_with_ignore_spaces(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": " A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A "}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_joins_with_ignore_spaces(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": " A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A "}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_joins_with_insensitive_lowercase_cols(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"a": 1, "B": "A"}, {"a": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "a")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_joins_with_sensitive_lowercase_cols(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{'"a"': 1, "B": "A"}, {'"a"': 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{'"a"': 1, "B": "A"}, {'"a"': 2, "B": "A"}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, '"a"')
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_strings_with_ignore_spaces_and_join_columns(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": "A"}, {"A": "BYE", "B": "A"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": " HI ", "B": "A"}, {"A": " BYE ", "B": "A"}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert not compare.all_rows_overlap()
+ assert compare.count_matching_rows() == 0
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+
+def test_integers_with_ignore_spaces_and_join_columns(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+
+def test_sample_mismatch(snowpark_session):
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.sample_mismatch(column="NAME", sample_count=1).toPandas()
+ assert output.shape[0] == 1
+ assert (output.NAME_DF1 != output.NAME_DF2).all()
+
+ output = compare.sample_mismatch(column="NAME", sample_count=2).toPandas()
+ assert output.shape[0] == 2
+ assert (output.NAME_DF1 != output.NAME_DF2).all()
+
+ output = compare.sample_mismatch(column="NAME", sample_count=3).toPandas()
+ assert output.shape[0] == 2
+ assert (output.NAME_DF1 != output.NAME_DF2).all()
+
+
+def test_all_mismatch_not_ignore_matching_cols_no_cols_matching(snowpark_session):
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch().toPandas()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 9
+
+ assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 2
+ assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 2
+
+ assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 1
+ assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 3
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+
+def test_all_mismatch_not_ignore_matching_cols_some_cols_matching(snowpark_session):
+ # Columns dollar_amt and name are matching
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch().toPandas()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 9
+
+ assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 0
+ assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 4
+
+ assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 0
+ assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 4
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+
+def test_all_mismatch_ignore_matching_cols_some_cols_matching_diff_rows(
+ snowpark_session,
+):
+ # Case where there are rows on either dataset which don't match up.
+ # Columns dollar_amt and name are matching
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ 10000001241,1111.05,Lucille Bluth,
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch(ignore_matching_cols=True).toPandas()
+
+ assert output.shape[0] == 4
+ assert output.shape[1] == 5
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+ assert not ("NAME_DF1" in output and "NAME_DF2" in output)
+ assert not ("DOLLAR_AMT_DF1" in output and "DOLLAR_AMT_DF1" in output)
+
+
+def test_all_mismatch_ignore_matching_cols_some_cols_matching(snowpark_session):
+ # Columns dollar_amt and name are matching
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch(ignore_matching_cols=True).toPandas()
+
+ assert output.shape[0] == 4
+ assert output.shape[1] == 5
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+ assert not ("NAME_DF1" in output and "NAME_DF2" in output)
+ assert not ("DOLLAR_AMT_DF1" in output and "DOLLAR_AMT_DF1" in output)
+
+
+def test_all_mismatch_ignore_matching_cols_no_cols_matching(snowpark_session):
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch().toPandas()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 9
+
+ assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 2
+ assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 2
+
+ assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 1
+ assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 3
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+
+@pytest.mark.parametrize(
+ "column, expected",
+ [
+ ("BASE", 0),
+ ("FLOATS", 0.2),
+ ("DECIMALS", 0.1),
+ ("NULL_FLOATS", 0.1),
+ ("STRINGS", 0.1),
+ ("INFINITY", np.inf),
+ ],
+)
+def test_calculate_max_diff(snowpark_session, column, expected):
+ pdf = pd.DataFrame(
+ {
+ "BASE": [1, 1, 1, 1, 1],
+ "FLOATS": [1.1, 1.1, 1.1, 1.2, 0.9],
+ "DECIMALS": [
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ ],
+ "NULL_FLOATS": [np.nan, 1.1, 1, 1, 1],
+ "STRINGS": ["1", "1", "1", "1.1", "1"],
+ "INFINITY": [1, 1, 1, 1, np.inf],
+ }
+ )
+ MAX_DIFF_DF = snowpark_session.createDataFrame(pdf)
+ assert np.isclose(
+ calculate_max_diff(MAX_DIFF_DF, "BASE", column),
+ expected,
+ )
+
+
+def test_dupes_with_nulls_strings(snowpark_session):
+ pdf1 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 2, 3, 3, 4, 5, 5],
+ "FLD_2": ["A", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 2, 3, 3, 4, 5, 5],
+ }
+ )
+ pdf2 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 3, 4, 5],
+ "FLD_2": ["A", np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 3, 4, 5],
+ }
+ )
+ df1 = snowpark_session.createDataFrame(pdf1)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ comp = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["FLD_1", "FLD_2"])
+ assert comp.subset()
+
+
+def test_dupes_with_nulls_ints(snowpark_session):
+ pdf1 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 2, 3, 3, 4, 5, 5],
+ "FLD_2": [1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 2, 3, 3, 4, 5, 5],
+ }
+ )
+ pdf2 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 3, 4, 5],
+ "FLD_2": [1, np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 3, 4, 5],
+ }
+ )
+ df1 = snowpark_session.createDataFrame(pdf1)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ comp = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["FLD_1", "FLD_2"])
+ assert comp.subset()
+
+
+def test_generate_id_within_group(snowpark_session):
+ matrix = [
+ (
+ pd.DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "__INDEX": [1, 2, 3]}),
+ pd.Series([0, 0, 0]),
+ ),
+ (
+ pd.DataFrame(
+ {
+ "A": ["A", "A", "DATACOMPY_NULL"],
+ "B": [1, 1, 2],
+ "__INDEX": [1, 2, 3],
+ }
+ ),
+ pd.Series([0, 1, 0]),
+ ),
+ (
+ pd.DataFrame({"A": [-999, 2, 3], "B": [1, 2, 3], "__INDEX": [1, 2, 3]}),
+ pd.Series([0, 0, 0]),
+ ),
+ (
+ pd.DataFrame(
+ {"A": [1, np.nan, np.nan], "B": [1, 2, 2], "__INDEX": [1, 2, 3]}
+ ),
+ pd.Series([0, 0, 1]),
+ ),
+ (
+ pd.DataFrame(
+ {"A": ["1", np.nan, np.nan], "B": ["1", "2", "2"], "__INDEX": [1, 2, 3]}
+ ),
+ pd.Series([0, 0, 1]),
+ ),
+ (
+ pd.DataFrame(
+ {
+ "A": [datetime(2018, 1, 1), np.nan, np.nan],
+ "B": ["1", "2", "2"],
+ "__INDEX": [1, 2, 3],
+ }
+ ),
+ pd.Series([0, 0, 1]),
+ ),
+ ]
+ for i in matrix:
+ dataframe = i[0]
+ expected = i[1]
+ actual = (
+ _generate_id_within_group(
+ snowpark_session.createDataFrame(dataframe), ["A", "B"], "_TEMP_0"
+ )
+ .orderBy("__INDEX")
+ .select("_TEMP_0")
+ .toPandas()
+ )
+ assert (actual["_TEMP_0"] == expected).all()
+
+
+def test_generate_id_within_group_single_join(snowpark_session):
+ dataframe = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "__INDEX": 1}, {"A": 1, "B": 2, "__INDEX": 2}]
+ )
+ expected = pd.Series([0, 1])
+ actual = (
+ _generate_id_within_group(dataframe, ["A"], "_TEMP_0")
+ .orderBy("__INDEX")
+ .select("_TEMP_0")
+ ).toPandas()
+ assert (actual["_TEMP_0"] == expected).all()
+
+
+@mock.patch("datacompy.snowflake.render")
+def test_save_html(mock_render, snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+
+ m = mock.mock_open()
+ with mock.patch("datacompy.snowflake.open", m, create=True):
+ # assert without HTML call
+ compare.report()
+ assert mock_render.call_count == 4
+ m.assert_not_called()
+
+ mock_render.reset_mock()
+ m = mock.mock_open()
+ with mock.patch("datacompy.snowflake.open", m, create=True):
+ # assert with HTML call
+ compare.report(html_file="test.html")
+ assert mock_render.call_count == 4
+ m.assert_called_with("test.html", "w")