This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8235f1d56bf2 [SPARK-46219][SQL] Unwrap cast in join predicates
8235f1d56bf2 is described below

commit 8235f1d56bf232bb713fe24ff6f2ffdaf49d2fcc
Author: Yuming Wang <yumw...@ebay.com>
AuthorDate: Tue Dec 5 08:37:34 2023 -0800

    [SPARK-46219][SQL] Unwrap cast in join predicates
    
    ### What changes were proposed in this pull request?
    
    In a large data platform, it is very common to join different data types. 
Similar to 
[`reorderJoinPredicates`](https://github.com/apache/spark/blob/b03afa7bde5a050eb95284b275eae0aac2257f63/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala#L321-L338).
 This PR adds a function in `EnsureRequirements` to unwrap cast in join 
predicates to reduce shuffle if they are integral types.
    
    The key idea here is that casting to either of these two types will not 
affect the result of join for integral types join keys. For example: `a.intCol 
= try_cast(b.bigIntCol AS int)`, if the value of `bigIntCol` exceeds the range 
of int, the result of `try_cast(b.bigIntCol AS int)` is `null`, and the result 
of  `a.intCol = try_cast(b.bigIntCol AS int)` in the join condition is `false`. 
The result is consistent with `cast(a.intCol AS bigint) = b.bigIntCol`.
    
    ### Why are the changes needed?
    
    Reduce shuffle to improve query performance.
    Case 1: Shuffle before join
    ```sql
    CREATE TABLE t1(id int) USING parquet;
    CREATE TABLE t2(id int) USING parquet;
    CREATE TABLE t3(id bigint) USING parquet;
    SET spark.sql.autoBroadcastJoinThreshold=-1;
    explain SELECT * FROM t1 JOIN t2 ON t1.id = t2.id JOIN t3 ON t1.id = t3.id;
    explain SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY id ORDER 
BY id) AS rn FROM t1) t JOIN t2 ON t.id = t2.id WHERE rn = 1;
    ```
    The plan differences after this PR:
    ```diff
     == Physical Plan ==
     AdaptiveSparkPlan isFinalPlan=false
    -+- SortMergeJoin [cast(id#10 as bigint)], [id#12L], Inner
    -   :- Sort [cast(id#10 as bigint) ASC NULLS FIRST], false, 0
    -   :  +- Exchange hashpartitioning(cast(id#10 as bigint), 5), 
ENSURE_REQUIREMENTS, [plan_id=54]
    -   :     +- SortMergeJoin [id#10], [id#11], Inner
    -   :        :- Sort [id#10 ASC NULLS FIRST], false, 0
    -   :        :  +- Exchange hashpartitioning(id#10, 5), 
ENSURE_REQUIREMENTS, [plan_id=47]
    -   :        :     +- Filter isnotnull(id#10)
    -   :        :        +- FileScan parquet spark_catalog.default.t1[id#10]
    -   :        +- Sort [id#11 ASC NULLS FIRST], false, 0
    -   :           +- Exchange hashpartitioning(id#11, 5), 
ENSURE_REQUIREMENTS, [plan_id=48]
    -   :              +- Filter isnotnull(id#11)
    -   :                 +- FileScan parquet spark_catalog.default.t2[id#11]
    -   +- Sort [id#12L ASC NULLS FIRST], false, 0
    -      +- Exchange hashpartitioning(id#12L, 5), ENSURE_REQUIREMENTS, 
[plan_id=55]
    -         +- Filter isnotnull(id#12L)
    -            +- FileScan parquet spark_catalog.default.t3[id#12L]
    ++- SortMergeJoin [id#20], [try_cast(id#22L as int)], Inner
    +   :- SortMergeJoin [id#20], [id#21], Inner
    +   :  :- Sort [id#20 ASC NULLS FIRST], false, 0
    +   :  :  +- Exchange hashpartitioning(id#20, 5), ENSURE_REQUIREMENTS, 
[plan_id=50]
    +   :  :     +- Filter isnotnull(id#20)
    +   :  :        +- FileScan parquet spark_catalog.default.t1[id#20]
    +   :  +- Sort [id#21 ASC NULLS FIRST], false, 0
    +   :     +- Exchange hashpartitioning(id#21, 5), ENSURE_REQUIREMENTS, 
[plan_id=51]
    +   :        +- Filter isnotnull(id#21)
    +   :           +- FileScan parquet spark_catalog.default.t2[id#21]
    +   +- Sort [try_cast(id#22L as int) ASC NULLS FIRST], false, 0
    +      +- Exchange hashpartitioning(try_cast(id#22L as int), 5), 
ENSURE_REQUIREMENTS, [plan_id=58]
    +         +- Filter isnotnull(id#22L)
    +            +- FileScan parquet spark_catalog.default.t3[id#22L]
    ```
    
    ```diff
     == Physical Plan ==
     AdaptiveSparkPlan isFinalPlan=false
    -+- SortMergeJoin [cast(id#22 as bigint)], [id#23L], Inner
    -   :- Sort [cast(id#22 as bigint) ASC NULLS FIRST], false, 0
    -   :  +- Exchange hashpartitioning(cast(id#22 as bigint), 5), 
ENSURE_REQUIREMENTS, [plan_id=62]
    -   :     +- Filter (rn#20 = 1)
    -   :        +- Window [row_number() windowspecdefinition(id#22, id#22 ASC 
NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), 
currentrow$())) AS rn#20], [id#22], [id#22 ASC NULLS FIRST]
    -   :           +- WindowGroupLimit [id#22], [id#22 ASC NULLS FIRST], 
row_number(), 1, Final
    -   :              +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS FIRST], 
false, 0
    -   :                 +- Exchange hashpartitioning(id#22, 5), 
ENSURE_REQUIREMENTS, [plan_id=55]
    -   :                    +- WindowGroupLimit [id#22], [id#22 ASC NULLS 
FIRST], row_number(), 1, Partial
    -   :                       +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS 
FIRST], false, 0
    -   :                          +- Filter isnotnull(id#22)
    -   :                             +- FileScan parquet 
spark_catalog.default.t1[id#22]
    -   +- Sort [id#23L ASC NULLS FIRST], false, 0
    -      +- Exchange hashpartitioning(id#23L, 5), ENSURE_REQUIREMENTS, 
[plan_id=63]
    ++- SortMergeJoin [id#22], [try_cast(id#23L as int)], Inner
    +   :- Filter (rn#20 = 1)
    +   :  +- Window [row_number() windowspecdefinition(id#22, id#22 ASC NULLS 
FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS 
rn#20], [id#22], [id#22 ASC NULLS FIRST]
    +   :     +- WindowGroupLimit [id#22], [id#22 ASC NULLS FIRST], 
row_number(), 1, Final
    +   :        +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS FIRST], false, 0
    +   :           +- Exchange hashpartitioning(id#22, 5), 
ENSURE_REQUIREMENTS, [plan_id=55]
    +   :              +- WindowGroupLimit [id#22], [id#22 ASC NULLS FIRST], 
row_number(), 1, Partial
    +   :                 +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS 
FIRST], false, 0
    +   :                    +- Filter isnotnull(id#22)
    +   :                       +- FileScan parquet 
spark_catalog.default.t1[id#22]
    +   +- Sort [try_cast(id#23L as int) ASC NULLS FIRST], false, 0
    +      +- Exchange hashpartitioning(try_cast(id#23L as int), 5), 
ENSURE_REQUIREMENTS, [plan_id=63]
              +- Filter isnotnull(id#23L)
                 +- FileScan parquet spark_catalog.default.t2[id#23L]
    ```
    
    Case 2: Bucket table
    ```sql
    CREATE TABLE t1(id bigint) USING parquet CLUSTERED BY (id) INTO 200 buckets;
    CREATE TABLE t2(id decimal(18, 0)) USING parquet CLUSTERED BY (id) INTO 200 
buckets;
    SET spark.sql.autoBroadcastJoinThreshold=-1;
    explain SELECT * FROM t1 JOIN t2 ON t1.id = t2.id;
    ```
    The plan differences after this PR:
    ```diff
     == Physical Plan ==
     AdaptiveSparkPlan isFinalPlan=false
    -+- SortMergeJoin [cast(id#10L as decimal(20,0))], [cast(id#11 as 
decimal(20,0))], Inner
    -   :- Sort [cast(id#10L as decimal(20,0)) ASC NULLS FIRST], false, 0
    -   :  +- Exchange hashpartitioning(cast(id#10L as decimal(20,0)), 5), 
ENSURE_REQUIREMENTS, [plan_id=38]
    -   :     +- Filter isnotnull(id#10L)
    -   :        +- FileScan parquet spark_catalog.default.t1[id#10L]
    -   +- Sort [cast(id#11 as decimal(20,0)) ASC NULLS FIRST], false, 0
    -      +- Exchange hashpartitioning(cast(id#11 as decimal(20,0)), 5), 
ENSURE_REQUIREMENTS, [plan_id=42]
    -         +- Filter isnotnull(id#11)
    -            +- FileScan parquet spark_catalog.default.t2[id#11]
    ++- SortMergeJoin [id#20L], [try_cast(id#21 as bigint)], Inner
    +   :- Sort [id#20L ASC NULLS FIRST], false, 0
    +   :  +- Filter isnotnull(id#20L)
    +   :     +- FileScan parquet spark_catalog.default.t1[id#20L] Bucketed: 
true, SelectedBucketsCount: 200 out of 200
    +   +- Sort [try_cast(id#21 as bigint) ASC NULLS FIRST], false, 0
    +      +- Exchange hashpartitioning(try_cast(id#21 as bigint), 200), 
ENSURE_REQUIREMENTS, [plan_id=42]
    +         +- Filter isnotnull(id#21)
    +            +- FileScan parquet spark_catalog.default.t2[id#21] Bucketed: 
false (disabled by query planner)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44133 from wangyum/SPARK-46219.
    
    Authored-by: Yuming Wang <yumw...@ebay.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  10 ++
 .../bucketing/CoalesceBucketsInJoin.scala          |  22 +---
 .../execution/exchange/EnsureRequirements.scala    |  25 ++++-
 ...ractJoinWithUnwrappedCastInJoinPredicates.scala | 114 +++++++++++++++++++++
 .../spark/sql/execution/joins/ShuffledJoin.scala   |  14 ++-
 .../apache/spark/sql/execution/PlannerSuite.scala  |  74 +++++++++++++
 .../spark/sql/sources/BucketedReadSuite.scala      |  65 ++++++++++++
 7 files changed, 301 insertions(+), 23 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 080928baf8a9..9918d583d49e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -564,6 +564,14 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED =
+    buildConf("spark.sql.unwrapCastInJoinCondition.enabled")
+      .doc("When true, unwrap the cast in the join condition to reduce shuffle 
if they are " +
+        "integral types.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val MAX_SINGLE_PARTITION_BYTES = 
buildConf("spark.sql.maxSinglePartitionBytes")
     .doc("The maximum number of bytes allowed for a single partition. 
Otherwise, The planner " +
       "will introduce shuffle to improve parallelism.")
@@ -5043,6 +5051,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
 
+  def unwrapCastInJoinConditionEnabled: Boolean = 
getConf(UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED)
+
   def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED)
 
   def isParquetSchemaMergingEnabled: Boolean = 
getConf(PARQUET_SCHEMA_MERGING_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
index d1464b4ac4ee..ab0eaa044dea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
@@ -20,9 +20,7 @@ package org.apache.spark.sql.execution.bucketing
 import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, PartitioningCollection}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec, 
ProjectExec, SparkPlan}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, 
SortMergeJoinExec}
@@ -131,27 +129,11 @@ object ExtractJoinWithBuckets {
     }
   }
 
-  /**
-   * The join keys should match with expressions for output partitioning. Note 
that
-   * the ordering does not matter because it will be handled in 
`EnsureRequirements`.
-   */
-  private def satisfiesOutputPartitioning(
-      keys: Seq[Expression],
-      partitioning: Partitioning): Boolean = {
-    partitioning match {
-      case HashPartitioning(exprs, _) if exprs.length == keys.length =>
-        exprs.forall(e => keys.exists(_.semanticEquals(e)))
-      case PartitioningCollection(partitionings) =>
-        partitionings.exists(satisfiesOutputPartitioning(keys, _))
-      case _ => false
-    }
-  }
-
   private def isApplicable(j: ShuffledJoin): Boolean = {
     hasScanOperation(j.left) &&
       hasScanOperation(j.right) &&
-      satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
-      satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning)
+      j.satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
+      j.satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning)
   }
 
   private def isDivisible(numBuckets1: Int, numBuckets2: Int): Boolean = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 2a7c1206bb41..38a8b5db2695 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -337,6 +337,28 @@ case class EnsureRequirements(
     }
   }
 
+  /**
+   * Unwrap the cast in join predicates to reduce shuffle.
+   */
+  private def unwrapCastInJoinPredicates(plan: SparkPlan): SparkPlan = {
+    if (conf.unwrapCastInJoinConditionEnabled) {
+      plan match {
+        case ExtractJoinWithUnwrappedCastInJoinPredicates(join, joinKeys) =>
+          val (leftKeys, rightKeys) = joinKeys.unzip
+          join match {
+            case j: SortMergeJoinExec =>
+              j.copy(leftKeys = leftKeys, rightKeys = rightKeys)
+            case j: ShuffledHashJoinExec =>
+              j.copy(leftKeys = leftKeys, rightKeys = rightKeys)
+            case other => other
+          }
+        case _ => plan
+      }
+    } else {
+      plan
+    }
+  }
+
   /**
    * Checks whether two children, `left` and `right`, of a join operator have 
compatible
    * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
@@ -605,7 +627,8 @@ case class EnsureRequirements(
         }
 
       case operator: SparkPlan =>
-        val reordered = reorderJoinPredicates(operator)
+        val unwrapped = unwrapCastInJoinPredicates(operator)
+        val reordered = reorderJoinPredicates(unwrapped)
         val newChildren = ensureDistributionAndOrdering(
           Some(reordered),
           reordered.children,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
new file mode 100644
index 000000000000..5d46fac90985
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import 
org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion.findWiderTypeForTwo
+import org.apache.spark.sql.catalyst.expressions.{Cast, EvalMode, Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
PartitioningCollection}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.joins.ShuffledJoin
+import org.apache.spark.sql.types.{DataType, DecimalType, IntegralType}
+
+/**
+ * An extractor that extracts `SortMergeJoinExec` and `ShuffledHashJoin`,
+ * where one sides can do bucketed read after unwrap cast in join keys.
+ */
+object ExtractJoinWithUnwrappedCastInJoinPredicates {
+  private def isIntegralType(dt: DataType): Boolean = dt match {
+    case _: IntegralType => true
+    case DecimalType.Fixed(_, 0) => true
+    case _ => false
+  }
+
+  private def unwrapCastInJoinKeys(joinKeys: Seq[Expression]): Seq[Expression] 
= {
+    joinKeys.map {
+      case c: Cast if isIntegralType(c.child.dataType) => c.child
+      case e => e
+    }
+  }
+
+  // Casts the left or right side of join keys to the same data type.
+  private def coerceJoinKeyType(
+      unwrapLeftKeys: Seq[Expression],
+      unwrapRightKeys: Seq[Expression],
+      isAddCastToLeftSide: Boolean): Seq[(Expression, Expression)] = {
+    unwrapLeftKeys.zip(unwrapRightKeys).map {
+      case (l, r) if l.dataType != r.dataType =>
+        // Use TRY mode to avoid runtime exception in ANSI mode or data issue 
in non-ANSI mode.
+        if (isAddCastToLeftSide) {
+          Cast(l, r.dataType, evalMode = EvalMode.TRY) -> r
+        } else {
+          l -> Cast(r, l.dataType, evalMode = EvalMode.TRY)
+        }
+      case (l, r) => l -> r
+    }
+  }
+
+  private def unwrapCastInJoinPredicates(j: ShuffledJoin): 
Option[Seq[(Expression, Expression)]] = {
+    val leftKeys = unwrapCastInJoinKeys(j.leftKeys)
+    val rightKeys = unwrapCastInJoinKeys(j.rightKeys)
+    // Make sure cast to wider type.
+    // For example, we do not support: cast(longCol as int) = cast(decimalCol 
as int).
+    val isCastToWiderType = leftKeys.zip(rightKeys).zipWithIndex.forall {
+      case ((e1, e2), i) =>
+        findWiderTypeForTwo(e1.dataType, 
e2.dataType).contains(j.leftKeys(i).dataType)
+    }
+    if (isCastToWiderType) {
+      val leftSatisfies = j.satisfiesOutputPartitioning(leftKeys, 
j.left.outputPartitioning)
+      val rightSatisfies = j.satisfiesOutputPartitioning(rightKeys, 
j.right.outputPartitioning)
+      if (leftSatisfies && rightSatisfies) {
+        // If there is a bucketed read, their number of partitions may be 
inconsistent.
+        // If the number of partitions on the left side is less than the 
number of partitions
+        // on the right side, cast the left side keys to the data type of the 
right side keys.
+        // Otherwise, cast the right side keys to the data type of the left 
side keys.
+        Some(coerceJoinKeyType(leftKeys, rightKeys,
+          j.left.outputPartitioning.numPartitions < 
j.right.outputPartitioning.numPartitions))
+      } else if (leftSatisfies) {
+        Some(coerceJoinKeyType(leftKeys, rightKeys, false))
+      } else if (rightSatisfies) {
+        Some(coerceJoinKeyType(leftKeys, rightKeys, true))
+      } else {
+        None
+      }
+    } else {
+      None
+    }
+  }
+
+  private def isTryToUnwrapCastInJoinPredicates(j: ShuffledJoin): Boolean = {
+    (j.leftKeys.exists(_.isInstanceOf[Cast]) || 
j.rightKeys.exists(_.isInstanceOf[Cast])) &&
+      !j.satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
+      !j.satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning) 
&&
+      j.children.map(_.outputPartitioning).exists { _ match {
+        case _: PartitioningCollection => true
+        case _: HashPartitioning => true
+        case _ => false
+      }}
+  }
+
+  def unapply(plan: SparkPlan): Option[(ShuffledJoin, Seq[(Expression, 
Expression)])] = {
+    plan match {
+      case j: ShuffledJoin if isTryToUnwrapCastInJoinPredicates(j) =>
+        unwrapCastInJoinPredicates(j) match {
+          case Some(joinKeys) => Some(j, joinKeys)
+          case _ => None
+        }
+      case _ => None
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
index 7c4628c8576c..9591218b099b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
@@ -17,9 +17,9 @@
 
 package org.apache.spark.sql.execution.joins
 
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, 
InnerLike, LeftExistence, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, 
UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, HashPartitioning, Partitioning, PartitioningCollection, 
UnknownPartitioning, UnspecifiedDistribution}
 
 /**
  * Holds common logic for join operators by shuffling two child relations
@@ -56,6 +56,16 @@ trait ShuffledJoin extends JoinCodegenSupport {
         s"ShuffledJoin should not take $x as the JoinType")
   }
 
+  def satisfiesOutputPartitioning(keys: Seq[Expression], partitioning: 
Partitioning): Boolean = {
+    partitioning match {
+      case HashPartitioning(exprs, _) if exprs.length == keys.length =>
+        exprs.forall(e => keys.exists(_.semanticEquals(e)))
+      case PartitioningCollection(partitionings) =>
+        partitionings.exists(satisfiesOutputPartitioning(keys, _))
+      case _ => false
+    }
+  }
+
   override def output: Seq[Attribute] = {
     joinType match {
       case _: InnerLike =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index c5b1e68fb912..8565e06ba9fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -1372,6 +1372,80 @@ class PlannerSuite extends SharedSparkSession with 
AdaptiveSparkPlanHelper {
       assert(numOutputPartitioning.size == 8)
     }
   }
+
+  test("SPARK-46219: Unwrap cast in join condition") {
+    val intExpr = Literal(1)
+    val longExpr = Literal(1L)
+    val smjExec = SortMergeJoinExec(
+      leftKeys = Cast(intExpr, LongType) :: Nil,
+      rightKeys = longExpr :: Nil,
+      joinType = Inner,
+      condition = None,
+      left = DummySparkPlan(outputPartitioning = HashPartitioning(intExpr:: 
Nil, 5)),
+      right = DummySparkPlan())
+
+    Seq(true, false).foreach { unwrapCast =>
+      withSQLConf(SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key -> 
unwrapCast.toString) {
+        val outputPlan = EnsureRequirements.apply(smjExec)
+        if (unwrapCast) {
+          outputPlan match {
+            case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+              SortExec(_, _, _: DummySparkPlan, _),
+              SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), 
_), _) =>
+              assert(leftKeys === Seq(intExpr))
+              assert(rightKeys === Seq(Cast(longExpr, IntegerType, evalMode = 
EvalMode.TRY)))
+            case _ => fail(outputPlan.toString)
+          }
+        } else {
+          outputPlan match {
+            case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+              SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), 
_),
+              SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), 
_), _) =>
+              assert(leftKeys === smjExec.leftKeys)
+              assert(rightKeys === smjExec.rightKeys)
+            case _ => fail(outputPlan.toString)
+          }
+        }
+      }
+    }
+  }
+
+  test("SPARK-46219: Number of partitions may be inconsistent") {
+    val longExpr = Literal(1L)
+    val decimalExpr = Literal(Decimal(1L, 18, 0))
+    val smjExec = SortMergeJoinExec(
+      leftKeys = Cast(longExpr, DecimalType(20, 0)) :: Nil,
+      rightKeys = Cast(decimalExpr, DecimalType(20, 0)) :: Nil,
+      joinType = Inner,
+      condition = None,
+      left = DummySparkPlan(outputPartitioning = HashPartitioning(longExpr :: 
Nil, 10)),
+      right = DummySparkPlan(outputPartitioning = HashPartitioning(decimalExpr 
:: Nil, 5)))
+
+    Seq(true, false).foreach { unwrapCast =>
+      withSQLConf(SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key -> 
unwrapCast.toString) {
+        val outputPlan = EnsureRequirements.apply(smjExec)
+        if (unwrapCast) {
+          outputPlan match {
+            case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+              SortExec(_, _, _: DummySparkPlan, _),
+              SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), 
_), _) =>
+              assert(leftKeys === Seq(longExpr))
+              assert(rightKeys === Seq(Cast(decimalExpr, LongType, evalMode = 
EvalMode.TRY)))
+            case _ => fail(outputPlan.toString)
+          }
+        } else {
+          outputPlan match {
+            case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+              SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), 
_),
+              SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), 
_), _) =>
+              assert(leftKeys === smjExec.leftKeys)
+              assert(rightKeys === smjExec.rightKeys)
+            case _ => fail(outputPlan.toString)
+          }
+        }
+      }
+    }
+  }
 }
 
 // Used for unit-testing EnsureRequirements
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 3573bafe482c..52a316e63a81 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
+import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, 
LongType}
 import org.apache.spark.tags.SlowSQLTest
 import org.apache.spark.util.collection.BitSet
 
@@ -1088,4 +1089,68 @@ abstract class BucketedReadSuite extends QueryTest with 
SQLTestUtils with Adapti
       }
     }
   }
+
+  test("SPARK-46219: Unwrap cast in join condition") {
+    def verify(
+        query: String,
+        expectedNumShuffles: Int,
+        numPartitions: Option[Int] = None,
+        partitioningKeyTypes: Option[Seq[DataType]] = None): Unit = {
+      Seq(true, false).foreach { ansiEnabled =>
+        Seq(true, false).foreach { aqeEnabled =>
+          withSQLConf(
+            SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+            SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
+            val df = sql(query)
+            val plan = df.queryExecution.executedPlan
+            val shuffles = collect(plan) {
+              case s: ShuffleExchangeExec => s
+            }
+            assert(shuffles.size === expectedNumShuffles)
+            if (shuffles.size == 1) {
+              val outputPartitioning = shuffles.head.outputPartitioning
+              assert(outputPartitioning.numPartitions === numPartitions.get)
+              assert(outputPartitioning.asInstanceOf[HashPartitioning]
+                .expressions.map(_.dataType) === partitioningKeyTypes.get)
+
+              collect(plan) { case s: SortMergeJoinExec => s 
}.flatMap(_.expressions).foreach {
+                case c: Cast => assert(c.evalMode === EvalMode.TRY) // The 
EvalMode should be try.
+                case _ =>
+              }
+
+              checkAnswer(df, Row(1, 1) :: Nil)
+            }
+          }
+        }
+      }
+    }
+
+    withTable("t1", "t2", "t3", "t4") {
+      sql(
+        s"""
+           |CREATE TABLE t1 USING parquet CLUSTERED BY (i) INTO 8 buckets AS
+           |SELECT CAST(v AS int) AS i FROM values(1), (${Int.MaxValue}) AS 
data(v)
+           |""".stripMargin)
+      sql(
+        s"""
+           |CREATE TABLE t2 USING parquet CLUSTERED BY (i) INTO 8 buckets AS
+           |SELECT CAST(v AS bigint) AS i FROM values(1), (${Long.MaxValue}) 
AS data(v)
+           |""".stripMargin)
+      sql(
+        s"""
+           |CREATE TABLE t3 USING parquet CLUSTERED BY (i) INTO 4 buckets AS
+           |SELECT CAST(v AS decimal(18, 0)) AS i FROM values(1), (${"9" * 
18}) AS data(v)
+           |""".stripMargin)
+      spark.table("t2").write.saveAsTable("t4")
+
+      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
+        SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key -> "true") {
+        verify("SELECT * FROM t2 JOIN t3 ON t2.i = t3.i", 1, Some(8), 
Some(Seq(LongType)))
+        verify("SELECT * FROM t1 JOIN t4 ON t1.i = t4.i", 1, Some(8), 
Some(Seq(IntegerType)))
+        verify("SELECT * FROM t3 JOIN t4 ON t3.i = t4.i", 1, Some(4), 
Some(Seq(DecimalType(18, 0))))
+        // Do not unwrap cast if it is added by user.
+        verify("SELECT * FROM t2 JOIN t3 ON cast(t2.i as int) = cast(t3.i as 
int)", 2)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to