From 7ee85ed611d125cb4c1622b033071d9f663e2fd6 Mon Sep 17 00:00:00 2001 From: Faisal Dosani Date: Thu, 13 Jun 2024 19:23:44 -0400 Subject: [PATCH] feedback from review, switch to monotonic and simplify checks --- datacompy/spark/sql.py | 34 ++++++++-------------------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/datacompy/spark/sql.py b/datacompy/spark/sql.py index cc67d777..cdc3fcdc 100644 --- a/datacompy/spark/sql.py +++ b/datacompy/spark/sql.py @@ -291,14 +291,8 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: LOG.debug("Duplicate rows found, deduping by order of remaining fields") # setting internal index LOG.info("Adding internal index to dataframes") - df1 = df1.withColumn( - "__index", - row_number().over(Window.orderBy(monotonically_increasing_id())) - 1, - ) - df2 = df2.withColumn( - "__index", - row_number().over(Window.orderBy(monotonically_increasing_id())) - 1, - ) + df1 = df1.withColumn("__index", monotonically_increasing_id()) + df2 = df2.withColumn("__index", monotonically_increasing_id()) # Create order column for uniqueness of match order_column = temp_column_name(df1, df2) @@ -1113,26 +1107,14 @@ def _generate_id_within_group( Original dataframe with the ID column that's unique in each group """ default_value = "DATACOMPY_NULL" + null_cols = [f"any(isnull({c}))" for c in join_columns] + default_cols = [f"any({c} == '{default_value}')" for c in join_columns] - if len(join_columns) > 1: - isnull_check = dataframe.select( - greatest(*[isnull(c) for c in join_columns]).alias("isnull") - ).filter("isnull == True") - isdefault_check = dataframe.select( - greatest(*[col(c) == default_value for c in join_columns]).alias( - "isdefault" - ) - ).filter("isdefault == True") - else: # greatest doesn't work for single joincolumns - isnull_check = dataframe.select(isnull(*join_columns).alias("isnull")).filter( - "isnull == True" - ) - isdefault_check = dataframe.select( - (col(*join_columns) == default_value).alias("isdefault") - ).filter("isdefault == True") + null_check = any(list(dataframe.selectExpr(null_cols).first())) + default_check = any(list(dataframe.selectExpr(default_cols).first())) - if isnull_check.count() > 0: - if isdefault_check.count() > 0: + if null_check: + if default_check: raise ValueError(f"{default_value} was found in your join columns") return (