This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2bb36fb271b [SPARK-45936][PS] Optimize `Index.symmetric_difference` 2bb36fb271b is described below commit 2bb36fb271b60dda68567b92613a3664a7aae2b8 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Nov 16 10:40:05 2023 +0900 [SPARK-45936][PS] Optimize `Index.symmetric_difference` ### What changes were proposed in this pull request? Add a helper function for `XOR`, and use it to optimize `Index.symmetric_difference` ### Why are the changes needed? the old plan is too complex ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43816 from zhengruifeng/ps_base_diff. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/pandas/indexes/base.py | 4 ++-- python/pyspark/pandas/indexes/multi.py | 15 ++------------- python/pyspark/pandas/utils.py | 17 +++++++++++++++++ 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index 6c6ee9ae0d7..a515a79dcd7 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -72,6 +72,7 @@ from pyspark.pandas.utils import ( validate_index_loc, ERROR_MESSAGE_CANNOT_COMBINE, log_advice, + xor, ) from pyspark.pandas.internal import ( InternalField, @@ -1468,8 +1469,7 @@ class Index(IndexOpsMixin): sdf_self = self._psdf._internal.spark_frame.select(self._internal.index_spark_columns) sdf_other = other._psdf._internal.spark_frame.select(other._internal.index_spark_columns) - - sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other)) + sdf_symdiff = xor(sdf_self, sdf_other) if sort: sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names) diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 62b42c1fcd0..7d2712cbb53 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -38,6 +38,7 @@ from pyspark.pandas.utils import ( scol_for, verify_temp_column_name, validate_index_loc, + xor, ) from pyspark.pandas.internal import ( InternalField, @@ -809,19 +810,7 @@ class MultiIndex(Index): sdf_self = self._psdf._internal.spark_frame.select(self._internal.index_spark_columns) sdf_other = other._psdf._internal.spark_frame.select(other._internal.index_spark_columns) - - tmp_tag_col = verify_temp_column_name(sdf_self, "__multi_index_tag__") - tmp_max_col = verify_temp_column_name(sdf_self, "__multi_index_max_tag__") - tmp_min_col = verify_temp_column_name(sdf_self, "__multi_index_min_tag__") - - sdf_symdiff = ( - sdf_self.withColumn(tmp_tag_col, F.lit(0)) - .union(sdf_other.withColumn(tmp_tag_col, F.lit(1))) - .groupBy(*self._internal.index_spark_column_names) - .agg(F.min(tmp_tag_col).alias(tmp_min_col), F.max(tmp_tag_col).alias(tmp_max_col)) - .where(F.col(tmp_min_col) == F.col(tmp_max_col)) - .select(*self._internal.index_spark_column_names) - ) + sdf_symdiff = xor(sdf_self, sdf_other) if sort: sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names) diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 9f372a53079..57c1ddbe6ae 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -1033,6 +1033,23 @@ def validate_index_loc(index: "Index", loc: int) -> None: ) +def xor(df1: PySparkDataFrame, df2: PySparkDataFrame) -> PySparkDataFrame: + colNames = df1.columns + + tmp_tag_col = verify_temp_column_name(df1, "__temporary_tag__") + tmp_max_col = verify_temp_column_name(df1, "__temporary_max_tag__") + tmp_min_col = verify_temp_column_name(df1, "__temporary_min_tag__") + + return ( + df1.withColumn(tmp_tag_col, F.lit(0)) + .union(df2.withColumn(tmp_tag_col, F.lit(1))) + .groupBy(*colNames) + .agg(F.min(tmp_tag_col).alias(tmp_min_col), F.max(tmp_tag_col).alias(tmp_max_col)) + .where(F.col(tmp_min_col) == F.col(tmp_max_col)) + .select(*colNames) + ) + + def _test() -> None: import os import doctest --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org