This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 11abc64a731d [SPARK-47094][SQL] SPJ : Dynamically rebalance number of 
buckets when they are not equal
11abc64a731d is described below

commit 11abc64a731d0e75d837994183396e6da9c45310
Author: Szehon Ho <szehon.apa...@gmail.com>
AuthorDate: Fri Apr 5 20:11:54 2024 -0700

    [SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when they 
are not equal
    
    ### What changes were proposed in this pull request?
    -- Allow SPJ between 'compatible' bucket funtions
    -- Add a mechanism to define 'reducible' functions, one function whose 
output can be 'reduced' to another for all inputs.
    
      ### Why are the changes needed?
    -- SPJ currently applies only if the partition transform expressions on 
both sides are identifical.
    
      ### Does this PR introduce _any_ user-facing change?
    No
    
      ### How was this patch tested?
    Added new tests in KeyGroupedPartitioningSuite
    
      ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45267 from szehon-ho/spj-uneven-buckets.
    
    Authored-by: Szehon Ho <szehon.apa...@gmail.com>
    Signed-off-by: Chao Sun <c...@openai.com>
---
 .../sql/connector/catalog/functions/Reducer.java   |  42 ++
 .../catalog/functions/ReducibleFunction.java       | 106 +++++
 .../catalyst/expressions/TransformExpression.scala |  57 ++-
 .../sql/catalyst/plans/physical/partitioning.scala |  50 ++-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  15 +
 .../execution/datasources/v2/BatchScanExec.scala   |  20 +-
 .../execution/exchange/EnsureRequirements.scala    |  50 ++-
 .../connector/KeyGroupedPartitioningSuite.scala    | 474 +++++++++++++++++++++
 .../catalog/functions/transformFunctions.scala     |  22 +-
 9 files changed, 821 insertions(+), 15 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java
new file mode 100644
index 000000000000..561d66092d64
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A 'reducer' for output of user-defined functions.
+ *
+ * @see ReducibleFunction
+ *
+ * A user defined function f_source(x) is 'reducible' on another user_defined 
function
+ * f_target(x) if
+ * <ul>
+ *   <li> There exists a reducer function r(x) such that r(f_source(x)) = 
f_target(x) for
+ *        all input x, or </li>
+ *   <li> More generally, there exists reducer functions r1(x) and r2(x) such 
that
+ *        r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
+ * </ul>
+ *
+ * @param <I> reducer input type
+ * @param <O> reducer output type
+ * @since 4.0.0
+ */
+@Evolving
+public interface Reducer<I, O> {
+  O reduce(I arg);
+}
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
new file mode 100644
index 000000000000..ef1a14e50cda
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * Base class for user-defined functions that can be 'reduced' on another 
function.
+ *
+ * A function f_source(x) is 'reducible' on another function f_target(x) if
+ * <ul>
+ *   <li> There exists a reducer function r(x) such that r(f_source(x)) = 
f_target(x)
+ *        for all input x, or </li>
+ *   <li> More generally, there exists reducer functions r1(x) and r2(x) such 
that
+ *        r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
+ * </ul>
+ * <p>
+ * Examples:
+ * <ul>
+ *    <li>Bucket functions where one side has reducer
+ *    <ul>
+ *        <li>f_source(x) = bucket(4, x)</li>
+ *        <li>f_target(x) = bucket(2, x)</li>
+ *        <li>r(x) = x % 2</li>
+ *    </ul>
+ *
+ *    <li>Bucket functions where both sides have reducer
+ *    <ul>
+ *        <li>f_source(x) = bucket(16, x)</li>
+ *        <li>f_target(x) = bucket(12, x)</li>
+ *        <li>r1(x) = x % 4</li>
+ *        <li>r2(x) = x % 4</li>
+ *    </ul>
+ *
+ *    <li>Date functions
+ *    <ul>
+ *        <li>f_source(x) = days(x)</li>
+ *        <li>f_target(x) = hours(x)</li>
+ *        <li>r(x) = x / 24</li>
+ *     </ul>
+ * </ul>
+ * @param <I> reducer function input type
+ * @param <O> reducer function output type
+ * @since 4.0.0
+ */
+@Evolving
+public interface ReducibleFunction<I, O> {
+
+  /**
+   * This method is for the bucket function.
+   *
+   * If this bucket function is 'reducible' on another bucket function,
+   * return the {@link Reducer} function.
+   * <p>
+   * For example, to return reducer for reducing f_source = bucket(4, x) on 
f_target = bucket(2, x)
+   * <ul>
+   *     <li>thisBucketFunction = bucket</li>
+   *     <li>thisNumBuckets = 4</li>
+   *     <li>otherBucketFunction = bucket</li>
+   *     <li>otherNumBuckets = 2</li>
+   * </ul>
+   *
+   * @param thisNumBuckets parameter for this function
+   * @param otherBucketFunction the other parameterized function
+   * @param otherNumBuckets parameter for the other function
+   * @return a reduction function if it is reducible, null if not
+   */
+  default Reducer<I, O> reducer(
+      int thisNumBuckets,
+      ReducibleFunction<?, ?> otherBucketFunction,
+      int otherNumBuckets) {
+    throw new UnsupportedOperationException();
+  }
+
+  /**
+   * This method is for all other functions.
+   *
+   * If this function is 'reducible' on another function, return the {@link 
Reducer} function.
+   * <p>
+   * Example of reducing f_source = days(x) on f_target = hours(x)
+   * <ul>
+   *     <li>thisFunction = days</li>
+   *     <li>otherFunction = hours</li>
+   * </ul>
+   *
+   * @param otherFunction the other function
+   * @return a reduction function if it is reducible, null if not.
+   */
+  default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
+    throw new UnsupportedOperationException();
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
index 8412de554b71..d37c9d9f6452 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.connector.catalog.functions.BoundFunction
+import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, 
Reducer, ReducibleFunction}
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -54,6 +54,61 @@ case class TransformExpression(
       false
   }
 
+  /**
+   * Whether this [[TransformExpression]]'s function is compatible with the 
`other`
+   * [[TransformExpression]]'s function.
+   *
+   * This is true if both are instances of [[ReducibleFunction]] and there 
exists a [[Reducer]] r(x)
+   * such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x.
+   *
+   * @param other the transform expression to compare to
+   * @return true if compatible, false if not
+   */
+  def isCompatible(other: TransformExpression): Boolean = {
+    if (isSameFunction(other)) {
+      true
+    } else {
+      (function, other.function) match {
+        case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
+          val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt)
+          val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt)
+          thisReducer.isDefined || otherReducer.isDefined
+        case _ => false
+      }
+    }
+  }
+
+  /**
+   * Return a [[Reducer]] for this transform expression on another
+   * on the transform expression.
+   * <p>
+   * A [[Reducer]] exists for a transform expression function if it is
+   * 'reducible' on the other expression function.
+   * <p>
+   * @return reducer function or None if not reducible on the other transform 
expression
+   */
+  def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
+    (function, other.function) match {
+      case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
+        reducer(e1, numBucketsOpt, e2, other.numBucketsOpt)
+      case _ => None
+    }
+  }
+
+  // Return a Reducer for a reducible function on another reducible function
+  private def reducer(
+      thisFunction: ReducibleFunction[_, _],
+      thisNumBucketsOpt: Option[Int],
+      otherFunction: ReducibleFunction[_, _],
+      otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = {
+    val res = (thisNumBucketsOpt, otherNumBucketsOpt) match {
+      case (Some(numBuckets), Some(otherNumBuckets)) =>
+        thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets)
+      case _ => thisFunction.reducer(otherFunction)
+    }
+    Option(res)
+  }
+
   override def dataType: DataType = function.resultType()
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Expression =
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index c98a2a92a3ab..2364130f79e4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, 
SparkUnsupportedOperationException}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
+import org.apache.spark.sql.connector.catalog.functions.Reducer
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, IntegerType}
 
@@ -833,10 +834,42 @@ case class KeyGroupedShuffleSpec(
     (left, right) match {
       case (_: LeafExpression, _: LeafExpression) => true
       case (left: TransformExpression, right: TransformExpression) =>
-        left.isSameFunction(right)
+        if (SQLConf.get.v2BucketingPushPartValuesEnabled &&
+          !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
+          SQLConf.get.v2BucketingAllowCompatibleTransforms) {
+          left.isCompatible(right)
+        } else {
+          left.isSameFunction(right)
+        }
       case _ => false
     }
 
+  /**
+   * Return a set of [[Reducer]] for the partition expressions of this shuffle 
spec,
+   * on the partition expressions of another shuffle spec.
+   * <p>
+   * A [[Reducer]] exists for a partition expression function of this shuffle 
spec if it is
+   * 'reducible' on the corresponding partition expression function of the 
other shuffle spec.
+   * <p>
+   * If a value is returned, there must be one [[Reducer]] per partition 
expression.
+   * A None value in the set indicates that the particular partition 
expression is not reducible
+   * on the corresponding expression on the other shuffle spec.
+   * <p>
+   * Returning none also indicates that none of the partition expressions can 
be reduced on the
+   * corresponding expression on the other shuffle spec.
+   *
+   * @param other other key-grouped shuffle spec
+   */
+  def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, 
_]]]] = {
+     val results = 
partitioning.expressions.zip(other.partitioning.expressions).map {
+       case (e1: TransformExpression, e2: TransformExpression) => 
e1.reducers(e2)
+       case (_, _) => None
+     }
+
+    // optimize to not return a value, if none of the partition expressions 
are reducible
+    if (results.forall(p => p.isEmpty)) None else Some(results)
+  }
+
   override def canCreatePartitioning: Boolean = 
SQLConf.get.v2BucketingShuffleEnabled &&
     // Only support partition expressions are AttributeReference for now
     partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
@@ -846,6 +879,21 @@ case class KeyGroupedShuffleSpec(
   }
 }
 
+object KeyGroupedShuffleSpec {
+  def reducePartitionValue(
+      row: InternalRow,
+      expressions: Seq[Expression],
+      reducers: Seq[Option[Reducer[_, _]]]):
+    InternalRowComparableWrapper = {
+    val partitionVals = row.toSeq(expressions.map(_.dataType))
+    val reducedRow = partitionVals.zip(reducers).map{
+      case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
+      case (v, _) => v
+    }.toArray
+    InternalRowComparableWrapper(new GenericInternalRow(reducedRow), 
expressions)
+  }
+}
+
 case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
   override def isCompatibleWith(other: ShuffleSpec): Boolean = {
     specs.exists(_.isCompatibleWith(other))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 9f07722528e8..73cb4fba8637 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1558,6 +1558,18 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS =
+    
buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled")
+      .doc("Whether to allow storage-partition join in the case where the 
partition transforms " +
+        "are compatible but not identical.  This config requires both " +
+        s"${V2_BUCKETING_ENABLED.key} and 
${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
+        s"enabled and 
${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
+        "to be disabled."
+      )
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val BUCKETING_MAX_BUCKETS = 
buildConf("spark.sql.sources.bucketing.maxBuckets")
     .doc("The maximum number of buckets allowed.")
     .version("2.4.0")
@@ -5323,6 +5335,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
     getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)
 
+  def v2BucketingAllowCompatibleTransforms: Boolean =
+    getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)
+
   def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
     getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index 7cce59904018..f949dbf71a37 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, 
Partitioning, SinglePartition}
+import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, 
KeyGroupedShuffleSpec, Partitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.util.{truncatedString, 
InternalRowComparableWrapper}
 import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.catalog.functions.Reducer
 import org.apache.spark.sql.connector.read._
 import org.apache.spark.util.ArrayImplicits._
 
@@ -164,6 +165,18 @@ case class BatchScanExec(
               (groupedParts, expressions)
           }
 
+          // Also re-group the partitions if we are reducing compatible 
partition expressions
+          val finalGroupedPartitions = spjParams.reducers match {
+            case Some(reducers) =>
+              val result = groupedPartitions.groupBy { case (row, _) =>
+                KeyGroupedShuffleSpec.reducePartitionValue(row, 
partExpressions, reducers)
+              }.map { case (wrapper, splits) => (wrapper.row, 
splits.flatMap(_._2)) }.toSeq
+              val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
+                partExpressions.map(_.dataType))
+              result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
+            case _ => groupedPartitions
+          }
+
           // When partially clustered, the input partitions are not grouped by 
partition
           // values. Here we'll need to check `commonPartitionValues` and 
decide how to group
           // and replicate splits within a partition.
@@ -174,7 +187,7 @@ case class BatchScanExec(
                 .get
                 .map(t => (InternalRowComparableWrapper(t._1, 
partExpressions), t._2))
                 .toMap
-            val nestGroupedPartitions = groupedPartitions.map { case 
(partValue, splits) =>
+            val nestGroupedPartitions = finalGroupedPartitions.map { case 
(partValue, splits) =>
               // `commonPartValuesMap` should contain the part value since 
it's the super set.
               val numSplits = commonPartValuesMap
                   .get(InternalRowComparableWrapper(partValue, 
partExpressions))
@@ -207,7 +220,7 @@ case class BatchScanExec(
           } else {
             // either `commonPartitionValues` is not defined, or it is defined 
but
             // `applyPartialClustering` is false.
-            val partitionMapping = groupedPartitions.map { case (partValue, 
splits) =>
+            val partitionMapping = finalGroupedPartitions.map { case 
(partValue, splits) =>
               InternalRowComparableWrapper(partValue, partExpressions) -> 
splits
             }.toMap
 
@@ -259,6 +272,7 @@ case class StoragePartitionJoinParams(
     keyGroupedPartitioning: Option[Seq[Expression]] = None,
     joinKeyPositions: Option[Seq[Int]] = None,
     commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
+    reducers: Option[Seq[Option[Reducer[_, _]]]] = None,
     applyPartialClustering: Boolean = false,
     replicatePartitions: Boolean = false) {
   override def equals(other: Any): Boolean = other match {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 2a7c1206bb41..a0f74ef6c3d0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
+import org.apache.spark.sql.connector.catalog.functions.Reducer
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
@@ -505,11 +506,28 @@ case class EnsureRequirements(
           }
         }
 
-        // Now we need to push-down the common partition key to the scan in 
each child
-        newLeft = populatePartitionValues(left, mergedPartValues, 
leftSpec.joinKeyPositions,
-          applyPartialClustering, replicateLeftSide)
-        newRight = populatePartitionValues(right, mergedPartValues, 
rightSpec.joinKeyPositions,
-          applyPartialClustering, replicateRightSide)
+        // in case of compatible but not identical partition expressions, we 
apply 'reduce'
+        // transforms to group one side's partitions as well as the common 
partition values
+        val leftReducers = leftSpec.reducers(rightSpec)
+        val rightReducers = rightSpec.reducers(leftSpec)
+
+        if (leftReducers.isDefined || rightReducers.isDefined) {
+          mergedPartValues = reduceCommonPartValues(mergedPartValues,
+            leftSpec.partitioning.expressions,
+            leftReducers)
+          mergedPartValues = reduceCommonPartValues(mergedPartValues,
+            rightSpec.partitioning.expressions,
+            rightReducers)
+          val rowOrdering = RowOrdering
+            .createNaturalAscendingOrdering(partitionExprs.map(_.dataType))
+          mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: 
(InternalRow, _)) => t._1))
+        }
+
+        // Now we need to push-down the common partition information to the 
scan in each child
+        newLeft = populateCommonPartitionInfo(left, mergedPartValues, 
leftSpec.joinKeyPositions,
+          leftReducers, applyPartialClustering, replicateLeftSide)
+        newRight = populateCommonPartitionInfo(right, mergedPartValues, 
rightSpec.joinKeyPositions,
+          rightReducers, applyPartialClustering, replicateRightSide)
       }
     }
 
@@ -527,11 +545,12 @@ case class EnsureRequirements(
         joinType == LeftAnti || joinType == LeftOuter
   }
 
-  // Populate the common partition values down to the scan nodes
-  private def populatePartitionValues(
+  // Populate the common partition information down to the scan nodes
+  private def populateCommonPartitionInfo(
       plan: SparkPlan,
       values: Seq[(InternalRow, Int)],
       joinKeyPositions: Option[Seq[Int]],
+      reducers: Option[Seq[Option[Reducer[_, _]]]],
       applyPartialClustering: Boolean,
       replicatePartitions: Boolean): SparkPlan = plan match {
     case scan: BatchScanExec =>
@@ -539,13 +558,26 @@ case class EnsureRequirements(
         spjParams = scan.spjParams.copy(
           commonPartitionValues = Some(values),
           joinKeyPositions = joinKeyPositions,
+          reducers = reducers,
           applyPartialClustering = applyPartialClustering,
           replicatePartitions = replicatePartitions
         )
       )
     case node =>
-      node.mapChildren(child => populatePartitionValues(
-        child, values, joinKeyPositions, applyPartialClustering, 
replicatePartitions))
+      node.mapChildren(child => populateCommonPartitionInfo(
+        child, values, joinKeyPositions, reducers, applyPartialClustering, 
replicatePartitions))
+  }
+
+  private def reduceCommonPartValues(
+      commonPartValues: Seq[(InternalRow, Int)],
+      expressions: Seq[Expression],
+      reducers: Option[Seq[Option[Reducer[_, _]]]]) = {
+    reducers match {
+      case Some(reducers) => commonPartValues.groupBy { case (row, _) =>
+        KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
+      }.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) 
}.toSeq
+      case _ => commonPartValues
+    }
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 7fdc703007c2..ec275fe101fd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -63,11 +63,17 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     Collections.emptyMap[String, String]
   }
   private val table: String = "tbl"
+
   private val columns: Array[Column] = Array(
     Column.create("id", IntegerType),
     Column.create("data", StringType),
     Column.create("ts", TimestampType))
 
+  private val columns2: Array[Column] = Array(
+      Column.create("store_id", IntegerType),
+      Column.create("dept_id", IntegerType),
+      Column.create("data", StringType))
+
   test("clustered distribution: output partitioning should be 
KeyGroupedPartitioning") {
     val partitions: Array[Transform] = Array(Expressions.years("ts"))
 
@@ -1309,6 +1315,474 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     }
   }
 
+  test("SPARK-47094: Support compatible buckets") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    Seq(
+      ((2, 4), (4, 2)),
+      ((4, 2), (2, 4)),
+      ((2, 2), (4, 6)),
+      ((6, 2), (2, 2))).foreach {
+      case ((table1buckets1, table1buckets2), (table2buckets1, 
table2buckets2)) =>
+        catalog.clearTables()
+
+        val partition1 = Array(bucket(table1buckets1, "store_id"),
+          bucket(table1buckets2, "dept_id"))
+        val partition2 = Array(bucket(table2buckets1, "store_id"),
+          bucket(table2buckets2, "dept_id"))
+
+        Seq((table1, partition1), (table2, partition2)).foreach { case (tab, 
part) =>
+          createTable(tab, columns2, part)
+          val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
+            "(0, 0, 'aa'), " +
+            "(0, 0, 'ab'), " + // duplicate partition key
+            "(0, 1, 'ac'), " +
+            "(0, 2, 'ad'), " +
+            "(0, 3, 'ae'), " +
+            "(0, 4, 'af'), " +
+            "(0, 5, 'ag'), " +
+            "(1, 0, 'ah'), " +
+            "(1, 0, 'ai'), " + // duplicate partition key
+            "(1, 1, 'aj'), " +
+            "(1, 2, 'ak'), " +
+            "(1, 3, 'al'), " +
+            "(1, 4, 'am'), " +
+            "(1, 5, 'an'), " +
+            "(2, 0, 'ao'), " +
+            "(2, 0, 'ap'), " + // duplicate partition key
+            "(2, 1, 'aq'), " +
+            "(2, 2, 'ar'), " +
+            "(2, 3, 'as'), " +
+            "(2, 4, 'at'), " +
+            "(2, 5, 'au'), " +
+            "(3, 0, 'av'), " +
+            "(3, 0, 'aw'), " + // duplicate partition key
+            "(3, 1, 'ax'), " +
+            "(3, 2, 'ay'), " +
+            "(3, 3, 'az'), " +
+            "(3, 4, 'ba'), " +
+            "(3, 5, 'bb'), " +
+            "(4, 0, 'bc'), " +
+            "(4, 0, 'bd'), " + // duplicate partition key
+            "(4, 1, 'be'), " +
+            "(4, 2, 'bf'), " +
+            "(4, 3, 'bg'), " +
+            "(4, 4, 'bh'), " +
+            "(4, 5, 'bi'), " +
+            "(5, 0, 'bj'), " +
+            "(5, 0, 'bk'), " + // duplicate partition key
+            "(5, 1, 'bl'), " +
+            "(5, 2, 'bm'), " +
+            "(5, 3, 'bn'), " +
+            "(5, 4, 'bo'), " +
+            "(5, 5, 'bp')"
+
+            // additional unmatched partitions to test push down
+            val finalStr = if (tab == table1) {
+              insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
+            } else {
+              insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
+            }
+
+            sql(finalStr)
+        }
+
+        Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key 
-> "false",
+            SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key 
->
+              allowJoinKeysSubsetOfPartitionKeys.toString,
+            SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+            val df = sql(
+              s"""
+                 |${selectWithMergeJoinHint("t1", "t2")}
+                 |t1.store_id, t1.dept_id, t1.data, t2.data
+                 |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+                 |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
+                 |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+                 |""".stripMargin)
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            assert(shuffles.isEmpty, "SPJ should be triggered")
+
+            val scans = 
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+              partitions.length)
+            val expectedBuckets = Math.min(table1buckets1, table2buckets1) *
+              Math.min(table1buckets2, table2buckets2)
+            assert(scans == Seq(expectedBuckets, expectedBuckets))
+
+            checkAnswer(df, Seq(
+              Row(0, 0, "aa", "aa"),
+              Row(0, 0, "aa", "ab"),
+              Row(0, 0, "ab", "aa"),
+              Row(0, 0, "ab", "ab"),
+              Row(0, 1, "ac", "ac"),
+              Row(0, 2, "ad", "ad"),
+              Row(0, 3, "ae", "ae"),
+              Row(0, 4, "af", "af"),
+              Row(0, 5, "ag", "ag"),
+              Row(1, 0, "ah", "ah"),
+              Row(1, 0, "ah", "ai"),
+              Row(1, 0, "ai", "ah"),
+              Row(1, 0, "ai", "ai"),
+              Row(1, 1, "aj", "aj"),
+              Row(1, 2, "ak", "ak"),
+              Row(1, 3, "al", "al"),
+              Row(1, 4, "am", "am"),
+              Row(1, 5, "an", "an"),
+              Row(2, 0, "ao", "ao"),
+              Row(2, 0, "ao", "ap"),
+              Row(2, 0, "ap", "ao"),
+              Row(2, 0, "ap", "ap"),
+              Row(2, 1, "aq", "aq"),
+              Row(2, 2, "ar", "ar"),
+              Row(2, 3, "as", "as"),
+              Row(2, 4, "at", "at"),
+              Row(2, 5, "au", "au"),
+              Row(3, 0, "av", "av"),
+              Row(3, 0, "av", "aw"),
+              Row(3, 0, "aw", "av"),
+              Row(3, 0, "aw", "aw"),
+              Row(3, 1, "ax", "ax"),
+              Row(3, 2, "ay", "ay"),
+              Row(3, 3, "az", "az"),
+              Row(3, 4, "ba", "ba"),
+              Row(3, 5, "bb", "bb"),
+              Row(4, 0, "bc", "bc"),
+              Row(4, 0, "bc", "bd"),
+              Row(4, 0, "bd", "bc"),
+              Row(4, 0, "bd", "bd"),
+              Row(4, 1, "be", "be"),
+              Row(4, 2, "bf", "bf"),
+              Row(4, 3, "bg", "bg"),
+              Row(4, 4, "bh", "bh"),
+              Row(4, 5, "bi", "bi"),
+              Row(5, 0, "bj", "bj"),
+              Row(5, 0, "bj", "bk"),
+              Row(5, 0, "bk", "bj"),
+              Row(5, 0, "bk", "bk"),
+              Row(5, 1, "bl", "bl"),
+              Row(5, 2, "bm", "bm"),
+              Row(5, 3, "bn", "bn"),
+              Row(5, 4, "bo", "bo"),
+              Row(5, 5, "bp", "bp")
+            ))
+          }
+        }
+    }
+  }
+
+  test("SPARK-47094: Support compatible buckets with common divisor") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    Seq(
+      ((6, 4), (4, 6)),
+      ((6, 6), (4, 4)),
+      ((4, 4), (6, 6)),
+      ((4, 6), (6, 4))).foreach {
+      case ((table1buckets1, table1buckets2), (table2buckets1, 
table2buckets2)) =>
+        catalog.clearTables()
+
+        val partition1 = Array(bucket(table1buckets1, "store_id"),
+          bucket(table1buckets2, "dept_id"))
+        val partition2 = Array(bucket(table2buckets1, "store_id"),
+          bucket(table2buckets2, "dept_id"))
+
+        Seq((table1, partition1), (table2, partition2)).foreach { case (tab, 
part) =>
+          createTable(tab, columns2, part)
+          val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
+            "(0, 0, 'aa'), " +
+            "(0, 0, 'ab'), " + // duplicate partition key
+            "(0, 1, 'ac'), " +
+            "(0, 2, 'ad'), " +
+            "(0, 3, 'ae'), " +
+            "(0, 4, 'af'), " +
+            "(0, 5, 'ag'), " +
+            "(1, 0, 'ah'), " +
+            "(1, 0, 'ai'), " + // duplicate partition key
+            "(1, 1, 'aj'), " +
+            "(1, 2, 'ak'), " +
+            "(1, 3, 'al'), " +
+            "(1, 4, 'am'), " +
+            "(1, 5, 'an'), " +
+            "(2, 0, 'ao'), " +
+            "(2, 0, 'ap'), " + // duplicate partition key
+            "(2, 1, 'aq'), " +
+            "(2, 2, 'ar'), " +
+            "(2, 3, 'as'), " +
+            "(2, 4, 'at'), " +
+            "(2, 5, 'au'), " +
+            "(3, 0, 'av'), " +
+            "(3, 0, 'aw'), " + // duplicate partition key
+            "(3, 1, 'ax'), " +
+            "(3, 2, 'ay'), " +
+            "(3, 3, 'az'), " +
+            "(3, 4, 'ba'), " +
+            "(3, 5, 'bb'), " +
+            "(4, 0, 'bc'), " +
+            "(4, 0, 'bd'), " + // duplicate partition key
+            "(4, 1, 'be'), " +
+            "(4, 2, 'bf'), " +
+            "(4, 3, 'bg'), " +
+            "(4, 4, 'bh'), " +
+            "(4, 5, 'bi'), " +
+            "(5, 0, 'bj'), " +
+            "(5, 0, 'bk'), " + // duplicate partition key
+            "(5, 1, 'bl'), " +
+            "(5, 2, 'bm'), " +
+            "(5, 3, 'bn'), " +
+            "(5, 4, 'bo'), " +
+            "(5, 5, 'bp')"
+
+            // additional unmatched partitions to test push down
+            val finalStr = if (tab == table1) {
+              insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
+            } else {
+              insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
+            }
+
+            sql(finalStr)
+        }
+
+        Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key 
-> "false",
+            SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key 
->
+              allowJoinKeysSubsetOfPartitionKeys.toString,
+            SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+            val df = sql(
+              s"""
+                 |${selectWithMergeJoinHint("t1", "t2")}
+                 |t1.store_id, t1.dept_id, t1.data, t2.data
+                 |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+                 |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
+                 |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+                 |""".stripMargin)
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            assert(shuffles.isEmpty, "SPJ should be triggered")
+
+            val scans = 
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+              partitions.length)
+
+            def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt
+            val expectedBuckets = gcd(table1buckets1, table2buckets1) *
+              gcd(table1buckets2, table2buckets2)
+            assert(scans == Seq(expectedBuckets, expectedBuckets))
+
+            checkAnswer(df, Seq(
+              Row(0, 0, "aa", "aa"),
+              Row(0, 0, "aa", "ab"),
+              Row(0, 0, "ab", "aa"),
+              Row(0, 0, "ab", "ab"),
+              Row(0, 1, "ac", "ac"),
+              Row(0, 2, "ad", "ad"),
+              Row(0, 3, "ae", "ae"),
+              Row(0, 4, "af", "af"),
+              Row(0, 5, "ag", "ag"),
+              Row(1, 0, "ah", "ah"),
+              Row(1, 0, "ah", "ai"),
+              Row(1, 0, "ai", "ah"),
+              Row(1, 0, "ai", "ai"),
+              Row(1, 1, "aj", "aj"),
+              Row(1, 2, "ak", "ak"),
+              Row(1, 3, "al", "al"),
+              Row(1, 4, "am", "am"),
+              Row(1, 5, "an", "an"),
+              Row(2, 0, "ao", "ao"),
+              Row(2, 0, "ao", "ap"),
+              Row(2, 0, "ap", "ao"),
+              Row(2, 0, "ap", "ap"),
+              Row(2, 1, "aq", "aq"),
+              Row(2, 2, "ar", "ar"),
+              Row(2, 3, "as", "as"),
+              Row(2, 4, "at", "at"),
+              Row(2, 5, "au", "au"),
+              Row(3, 0, "av", "av"),
+              Row(3, 0, "av", "aw"),
+              Row(3, 0, "aw", "av"),
+              Row(3, 0, "aw", "aw"),
+              Row(3, 1, "ax", "ax"),
+              Row(3, 2, "ay", "ay"),
+              Row(3, 3, "az", "az"),
+              Row(3, 4, "ba", "ba"),
+              Row(3, 5, "bb", "bb"),
+              Row(4, 0, "bc", "bc"),
+              Row(4, 0, "bc", "bd"),
+              Row(4, 0, "bd", "bc"),
+              Row(4, 0, "bd", "bd"),
+              Row(4, 1, "be", "be"),
+              Row(4, 2, "bf", "bf"),
+              Row(4, 3, "bg", "bg"),
+              Row(4, 4, "bh", "bh"),
+              Row(4, 5, "bi", "bi"),
+              Row(5, 0, "bj", "bj"),
+              Row(5, 0, "bj", "bk"),
+              Row(5, 0, "bk", "bj"),
+              Row(5, 0, "bk", "bk"),
+              Row(5, 1, "bl", "bl"),
+              Row(5, 2, "bm", "bm"),
+              Row(5, 3, "bn", "bn"),
+              Row(5, 4, "bo", "bo"),
+              Row(5, 5, "bp", "bp")
+            ))
+          }
+        }
+    }
+  }
+
+  test("SPARK-47094: Support compatible buckets with less join keys than 
partition keys") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    Seq((2, 4), (4, 2), (2, 6), (6, 2)).foreach {
+      case (table1buckets, table2buckets) =>
+        catalog.clearTables()
+
+        val partition1 = Array(identity("data"),
+          bucket(table1buckets, "dept_id"))
+        val partition2 = Array(bucket(3, "store_id"),
+          bucket(table2buckets, "dept_id"))
+
+        createTable(table1, columns2, partition1)
+        sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+          "(0, 0, 'aa'), " +
+          "(1, 0, 'ab'), " +
+          "(2, 1, 'ac'), " +
+          "(3, 2, 'ad'), " +
+          "(4, 3, 'ae'), " +
+          "(5, 4, 'af'), " +
+          "(6, 5, 'ag'), " +
+
+          // value without other side match
+          "(6, 6, 'xx')"
+        )
+
+        createTable(table2, columns2, partition2)
+        sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+          "(6, 0, '01'), " +
+          "(5, 1, '02'), " + // duplicate partition key
+          "(5, 1, '03'), " +
+          "(4, 2, '04'), " +
+          "(3, 3, '05'), " +
+          "(2, 4, '06'), " +
+          "(1, 5, '07'), " +
+
+          // value without other side match
+          "(7, 7, '99')"
+        )
+
+
+        withSQLConf(
+          SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+          SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true",
+          SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+          val df = sql(
+            s"""
+               |${selectWithMergeJoinHint("t1", "t2")}
+               |t1.store_id, t2.store_id, t1.dept_id, t2.dept_id, t1.data, 
t2.data
+               |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+               |ON t1.dept_id = t2.dept_id
+               |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+               |""".stripMargin)
+
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.isEmpty, "SPJ should be triggered")
+
+          val scans = 
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+            partitions.length)
+
+          val expectedBuckets = Math.min(table1buckets, table2buckets)
+
+          assert(scans == Seq(expectedBuckets, expectedBuckets))
+
+          checkAnswer(df, Seq(
+            Row(0, 6, 0, 0, "aa", "01"),
+            Row(1, 6, 0, 0, "ab", "01"),
+            Row(2, 5, 1, 1, "ac", "02"),
+            Row(2, 5, 1, 1, "ac", "03"),
+            Row(3, 4, 2, 2, "ad", "04"),
+            Row(4, 3, 3, 3, "ae", "05"),
+            Row(5, 2, 4, 4, "af", "06"),
+            Row(6, 1, 5, 5, "ag", "07")
+          ))
+        }
+      }
+  }
+
+  test("SPARK-47094: Compatible buckets does not support SPJ with " +
+    "push-down values or partially-clustered") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    val partition1 = Array(bucket(4, "store_id"),
+      bucket(2, "dept_id"))
+    val partition2 = Array(bucket(2, "store_id"),
+      bucket(2, "dept_id"))
+
+    createTable(table1, columns2, partition1)
+    sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+          "(0, 0, 'aa'), " +
+          "(1, 1, 'bb'), " +
+          "(2, 2, 'cc')"
+        )
+
+    createTable(table2, columns2, partition2)
+    sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+          "(0, 0, 'aa'), " +
+          "(1, 1, 'bb'), " +
+          "(2, 2, 'cc')"
+        )
+
+    Seq(true, false).foreach{ allowPushDown =>
+      Seq(true, false).foreach{ partiallyClustered =>
+        withSQLConf(
+          SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
allowPushDown.toString,
+          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+            partiallyClustered.toString,
+          SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true",
+          SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+          val df = sql(
+                s"""
+                   |${selectWithMergeJoinHint("t1", "t2")}
+                   |t1.store_id, t1.store_id, t1.dept_id, t2.dept_id, t1.data, 
t2.data
+                   |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+                   |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
+                   |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+                   |""".stripMargin)
+
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          val scans = 
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+            partitions.length)
+
+          (allowPushDown, partiallyClustered) match {
+            case (true, false) =>
+              assert(shuffles.isEmpty, "SPJ should be triggered")
+              assert(scans == Seq(2, 2))
+            case (_, _) =>
+              assert(shuffles.nonEmpty, "SPJ should not be triggered")
+              assert(scans == Seq(3, 2))
+          }
+
+          checkAnswer(df, Seq(
+              Row(0, 0, 0, 0, "aa", "aa"),
+              Row(1, 1, 1, 1, "bb", "bb"),
+              Row(2, 2, 2, 2, "cc", "cc")
+            ))
+          }
+      }
+    }
+  }
+
   test("SPARK-44647: test join key is the second cluster key") {
     val table1 = "tab1e1"
     val table2 = "table2"
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index 61895d49c4a2..5cdb90090105 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -76,7 +76,7 @@ object UnboundBucketFunction extends UnboundFunction {
   override def name(): String = "bucket"
 }
 
-object BucketFunction extends ScalarFunction[Int] {
+object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, 
Int] {
   override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
   override def resultType(): DataType = IntegerType
   override def name(): String = "bucket"
@@ -85,6 +85,26 @@ object BucketFunction extends ScalarFunction[Int] {
   override def produceResult(input: InternalRow): Int = {
     (input.getLong(1) % input.getInt(0)).toInt
   }
+
+  override def reducer(
+      thisNumBuckets: Int,
+      otherFunc: ReducibleFunction[_, _],
+      otherNumBuckets: Int): Reducer[Int, Int] = {
+
+    if (otherFunc == BucketFunction) {
+      val gcd = this.gcd(thisNumBuckets, otherNumBuckets)
+      if (gcd != thisNumBuckets) {
+        return BucketReducer(thisNumBuckets, gcd)
+      }
+    }
+    null
+  }
+
+  private def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt
+}
+
+case class BucketReducer(thisNumBuckets: Int, divisor: Int) extends 
Reducer[Int, Int] {
+  override def reduce(bucket: Int): Int = bucket % divisor
 }
 
 object UnboundStringSelfFunction extends UnboundFunction {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to