This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b2de91cb6c11 [SPARK-45918][PS] Optimize 
`MultiIndex.symmetric_difference`
b2de91cb6c11 is described below

commit b2de91cb6c117ac7deb099c89fafa7e6fccb34b3
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Nov 15 11:16:11 2023 +0900

    [SPARK-45918][PS] Optimize `MultiIndex.symmetric_difference`
    
    ### What changes were proposed in this pull request?
    Optimize `MultiIndex.symmetric_difference`
    
    ### Why are the changes needed?
    currently, the `XOR` operation `a.union(b).subtract(a.intersect(b))` is not 
optimum:
    
    ```
            >>> midx1 = pd.MultiIndex([['lama', 'cow', 'falcon'],
            ...                        ['speed', 'weight', 'length']],
            ...                       [[0, 0, 0, 1, 1, 1, 2, 2, 2],
            ...                        [0, 0, 0, 0, 1, 2, 0, 1, 2]])
            >>> midx2 = pd.MultiIndex([['pandas-on-Spark', 'cow', 'falcon'],
            ...                        ['speed', 'weight', 'length']],
            ...                       [[0, 0, 0, 1, 1, 1, 2, 2, 2],
            ...                        [0, 0, 0, 0, 1, 2, 0, 1, 2]])
            >>> s1 = ps.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3],
            ...                index=midx1)
            >>> s2 = ps.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3],
            ...              index=midx2)
            >>> 
s1.index.symmetric_difference(s2.index)._internal.spark_frame.explain("extended")
    
    ```
    
    before this PR:
    ```
    == Optimized Logical Plan ==
    Aggregate [__index_level_0__#0, __index_level_1__#1], [__index_level_0__#0, 
__index_level_1__#1, monotonically_increasing_id() AS __natural_order__#161L]
    +- Union false, false
       :- Join LeftAnti, ((__index_level_0__#0 <=> __index_level_0__#145) AND 
(__index_level_1__#1 <=> __index_level_1__#146))
       :  :- Project [__index_level_0__#0, __index_level_1__#1]
       :  :  +- LogicalRDD [__index_level_0__#0, __index_level_1__#1, 0#2], 
false
       :  +- Aggregate [__index_level_0__#145, __index_level_1__#146], 
[__index_level_0__#145, __index_level_1__#146]
       :     +- Join LeftSemi, ((__index_level_0__#145 <=> 
__index_level_0__#149) AND (__index_level_1__#146 <=> __index_level_1__#150))
       :        :- Project [__index_level_0__#145, __index_level_1__#146]
       :        :  +- LogicalRDD [__index_level_0__#145, __index_level_1__#146, 
0#147], false
       :        +- Project [__index_level_0__#149, __index_level_1__#150]
       :           +- LogicalRDD [__index_level_0__#149, __index_level_1__#150, 
0#151], false
       +- Join LeftAnti, ((__index_level_0__#11 <=> __index_level_0__#145) AND 
(__index_level_1__#12 <=> __index_level_1__#146))
          :- Project [__index_level_0__#11, __index_level_1__#12]
          :  +- LogicalRDD [__index_level_0__#11, __index_level_1__#12, 0#13], 
false
          +- Aggregate [__index_level_0__#145, __index_level_1__#146], 
[__index_level_0__#145, __index_level_1__#146]
             +- Join LeftSemi, ((__index_level_0__#145 <=> 
__index_level_0__#149) AND (__index_level_1__#146 <=> __index_level_1__#150))
                :- Project [__index_level_0__#145, __index_level_1__#146]
                :  +- LogicalRDD [__index_level_0__#145, __index_level_1__#146, 
0#147], false
                +- Project [__index_level_0__#149, __index_level_1__#150]
                   +- LogicalRDD [__index_level_0__#149, __index_level_1__#150, 
0#151], false
    
    ```
    
    after this PR:
    ```
    == Optimized Logical Plan ==
    Project [__index_level_0__#0, __index_level_1__#1, 
monotonically_increasing_id() AS __natural_order__#64L]
    +- Filter ((isnotnull(__multi_index_min_tag__#46) AND 
isnotnull(__multi_index_max_tag__#47)) AND (__multi_index_min_tag__#46 = 
__multi_index_max_tag__#47))
       +- Aggregate [__index_level_0__#0, __index_level_1__#1], 
[__index_level_0__#0, __index_level_1__#1, min(__multi_index_tag__#30) AS 
__multi_index_min_tag__#46, max(__multi_index_tag__#30) AS 
__multi_index_max_tag__#47]
          +- Union false, false
             :- Project [__index_level_0__#0, __index_level_1__#1, 0 AS 
__multi_index_tag__#30]
             :  +- LogicalRDD [__index_level_0__#0, __index_level_1__#1, 0#2], 
false
             +- Project [__index_level_0__#11, __index_level_1__#12, 1 AS 
__multi_index_tag__#34]
                +- LogicalRDD [__index_level_0__#11, __index_level_1__#12, 
0#13], false
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43795 from zhengruifeng/ps_multi_index_opt.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/indexes/multi.py | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/pandas/indexes/multi.py 
b/python/pyspark/pandas/indexes/multi.py
index 9fbc608c12a4..62b42c1fcd02 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -810,7 +810,18 @@ class MultiIndex(Index):
         sdf_self = 
self._psdf._internal.spark_frame.select(self._internal.index_spark_columns)
         sdf_other = 
other._psdf._internal.spark_frame.select(other._internal.index_spark_columns)
 
-        sdf_symdiff = 
sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other))
+        tmp_tag_col = verify_temp_column_name(sdf_self, "__multi_index_tag__")
+        tmp_max_col = verify_temp_column_name(sdf_self, 
"__multi_index_max_tag__")
+        tmp_min_col = verify_temp_column_name(sdf_self, 
"__multi_index_min_tag__")
+
+        sdf_symdiff = (
+            sdf_self.withColumn(tmp_tag_col, F.lit(0))
+            .union(sdf_other.withColumn(tmp_tag_col, F.lit(1)))
+            .groupBy(*self._internal.index_spark_column_names)
+            .agg(F.min(tmp_tag_col).alias(tmp_min_col), 
F.max(tmp_tag_col).alias(tmp_max_col))
+            .where(F.col(tmp_min_col) == F.col(tmp_max_col))
+            .select(*self._internal.index_spark_column_names)
+        )
 
         if sort:
             sdf_symdiff = 
sdf_symdiff.sort(*self._internal.index_spark_column_names)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to