This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bb36fb271b [SPARK-45936][PS] Optimize `Index.symmetric_difference`
2bb36fb271b is described below

commit 2bb36fb271b60dda68567b92613a3664a7aae2b8
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Nov 16 10:40:05 2023 +0900

    [SPARK-45936][PS] Optimize `Index.symmetric_difference`
    
    ### What changes were proposed in this pull request?
    Add a helper function for `XOR`, and use it to optimize 
`Index.symmetric_difference`
    
    ### Why are the changes needed?
    the old plan is too complex
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43816 from zhengruifeng/ps_base_diff.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/indexes/base.py  |  4 ++--
 python/pyspark/pandas/indexes/multi.py | 15 ++-------------
 python/pyspark/pandas/utils.py         | 17 +++++++++++++++++
 3 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/python/pyspark/pandas/indexes/base.py 
b/python/pyspark/pandas/indexes/base.py
index 6c6ee9ae0d7..a515a79dcd7 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -72,6 +72,7 @@ from pyspark.pandas.utils import (
     validate_index_loc,
     ERROR_MESSAGE_CANNOT_COMBINE,
     log_advice,
+    xor,
 )
 from pyspark.pandas.internal import (
     InternalField,
@@ -1468,8 +1469,7 @@ class Index(IndexOpsMixin):
 
         sdf_self = 
self._psdf._internal.spark_frame.select(self._internal.index_spark_columns)
         sdf_other = 
other._psdf._internal.spark_frame.select(other._internal.index_spark_columns)
-
-        sdf_symdiff = 
sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other))
+        sdf_symdiff = xor(sdf_self, sdf_other)
 
         if sort:
             sdf_symdiff = 
sdf_symdiff.sort(*self._internal.index_spark_column_names)
diff --git a/python/pyspark/pandas/indexes/multi.py 
b/python/pyspark/pandas/indexes/multi.py
index 62b42c1fcd0..7d2712cbb53 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -38,6 +38,7 @@ from pyspark.pandas.utils import (
     scol_for,
     verify_temp_column_name,
     validate_index_loc,
+    xor,
 )
 from pyspark.pandas.internal import (
     InternalField,
@@ -809,19 +810,7 @@ class MultiIndex(Index):
 
         sdf_self = 
self._psdf._internal.spark_frame.select(self._internal.index_spark_columns)
         sdf_other = 
other._psdf._internal.spark_frame.select(other._internal.index_spark_columns)
-
-        tmp_tag_col = verify_temp_column_name(sdf_self, "__multi_index_tag__")
-        tmp_max_col = verify_temp_column_name(sdf_self, 
"__multi_index_max_tag__")
-        tmp_min_col = verify_temp_column_name(sdf_self, 
"__multi_index_min_tag__")
-
-        sdf_symdiff = (
-            sdf_self.withColumn(tmp_tag_col, F.lit(0))
-            .union(sdf_other.withColumn(tmp_tag_col, F.lit(1)))
-            .groupBy(*self._internal.index_spark_column_names)
-            .agg(F.min(tmp_tag_col).alias(tmp_min_col), 
F.max(tmp_tag_col).alias(tmp_max_col))
-            .where(F.col(tmp_min_col) == F.col(tmp_max_col))
-            .select(*self._internal.index_spark_column_names)
-        )
+        sdf_symdiff = xor(sdf_self, sdf_other)
 
         if sort:
             sdf_symdiff = 
sdf_symdiff.sort(*self._internal.index_spark_column_names)
diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py
index 9f372a53079..57c1ddbe6ae 100644
--- a/python/pyspark/pandas/utils.py
+++ b/python/pyspark/pandas/utils.py
@@ -1033,6 +1033,23 @@ def validate_index_loc(index: "Index", loc: int) -> None:
             )
 
 
+def xor(df1: PySparkDataFrame, df2: PySparkDataFrame) -> PySparkDataFrame:
+    colNames = df1.columns
+
+    tmp_tag_col = verify_temp_column_name(df1, "__temporary_tag__")
+    tmp_max_col = verify_temp_column_name(df1, "__temporary_max_tag__")
+    tmp_min_col = verify_temp_column_name(df1, "__temporary_min_tag__")
+
+    return (
+        df1.withColumn(tmp_tag_col, F.lit(0))
+        .union(df2.withColumn(tmp_tag_col, F.lit(1)))
+        .groupBy(*colNames)
+        .agg(F.min(tmp_tag_col).alias(tmp_min_col), 
F.max(tmp_tag_col).alias(tmp_max_col))
+        .where(F.col(tmp_min_col) == F.col(tmp_max_col))
+        .select(*colNames)
+    )
+
+
 def _test() -> None:
     import os
     import doctest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to