This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 57b2a065c418 [SPARK-54852][SQL] `NOT IN` subquery returns incorrect 
results with a collated table
57b2a065c418 is described below

commit 57b2a065c41868d442ea4f7357f13268a9938302
Author: ilicmarkodb <[email protected]>
AuthorDate: Wed Jan 7 16:11:23 2026 +0800

    [SPARK-54852][SQL] `NOT IN` subquery returns incorrect results with a 
collated table
    
    ### What changes were proposed in this pull request?
    ```
    create or replace table t1 (c1 string collate utf8_lcase_rtrim);
    create or replace table t2 (c1 string collate utf8_lcase_rtrim);
    insert into t1 values ('a');
    insert into t2 values ('A ');
    
    select * from t1 where c1 not in (select * from t2);
    -- should return no data, but it returns one row
    ```
    
    When performing a hash join on collated columns, we first wrap the column 
with `CollationKey` during analysis. This is because the hash of `CollationKey` 
is collation-aware. The problem with this query is that there is no join during 
the analysis phase (we have `NOT IN`), and the join is added during the 
optimization phase. As a result, the join operates on raw columns, which are 
not collation-aware.
    
    This PR fixes the issue by rewriting the join keys in `HashJoin` trait.
    
    ### Why are the changes needed?
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #53622 from ilicmarkodb/fix_not_in.
    
    Authored-by: ilicmarkodb <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/expressions/CollationKey.scala    | 61 +++++++++++++++-
 .../execution/joins/BroadcastHashJoinExec.scala    | 28 +++++++-
 .../spark/sql/execution/joins/HashJoin.scala       | 16 +++++
 .../sql/execution/joins/ShuffledHashJoinExec.scala | 26 ++++++-
 .../spark/sql/collation/CollationSuite.scala       | 81 ++++++++++++++++++++++
 .../sql/execution/joins/BroadcastJoinSuite.scala   |  4 +-
 6 files changed, 210 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
index 5d2fd14eee29..9a0aaea75f81 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
@@ -18,10 +18,11 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
-import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.catalyst.util.{CollationFactory, UnsafeRowUtils}
 import org.apache.spark.sql.internal.types.StringTypeWithCollation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ArrayImplicits.SparkArrayOps
 
 case class CollationKey(expr: Expression) extends UnaryExpression with 
ExpectsInputTypes {
   override def inputTypes: Seq[AbstractDataType] =
@@ -46,3 +47,61 @@ case class CollationKey(expr: Expression) extends 
UnaryExpression with ExpectsIn
 
   override def child: Expression = expr
 }
+
+object CollationKey {
+  /**
+   * Recursively process the expression in order to recursively replace 
non-binary collated strings
+   * with their associated collation key.
+   */
+  def injectCollationKey(expr: Expression): Expression = {
+    injectCollationKey(expr, expr.dataType)
+  }
+
+  private def injectCollationKey(expr: Expression, dt: DataType): Expression = 
{
+    dt match {
+      // For binary stable expressions, no special handling is needed.
+      case _ if UnsafeRowUtils.isBinaryStable(dt) =>
+        expr
+
+      // Inject CollationKey for non-binary collated strings.
+      case _: StringType =>
+        CollationKey(expr)
+
+      // Recursively process struct fields for non-binary structs.
+      case StructType(fields) =>
+        val transformed = fields.zipWithIndex.map { case (f, i) =>
+          val originalField = GetStructField(expr, i, Some(f.name))
+          val injected = injectCollationKey(originalField, f.dataType)
+          (f, injected, injected.fastEquals(originalField))
+        }
+        val anyChanged = transformed.exists { case (_, _, same) => !same }
+        if (!anyChanged) {
+          expr
+        } else {
+          val struct = CreateNamedStruct(
+            transformed.flatMap { case (f, injected, _) =>
+              Seq(Literal(f.name), injected)
+            }.toImmutableArraySeq)
+          if (expr.nullable) {
+            If(IsNull(expr), Literal(null, struct.dataType), struct)
+          } else {
+            struct
+          }
+        }
+
+      // Recursively process array elements for non-binary arrays.
+      case ArrayType(et, containsNull) =>
+        val param: NamedExpression = NamedLambdaVariable("a", et, containsNull)
+        val funcBody: Expression = injectCollationKey(param, et)
+        if (!funcBody.fastEquals(param)) {
+          ArrayTransform(expr, LambdaFunction(funcBody, Seq(param)))
+        } else {
+          expr
+        }
+
+      // Joins are not supported on maps, so there's no special handling for 
MapType.
+      case _ =>
+        expr
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b62d8f0798b6..944ee3b05909 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide}
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide, JoinSelectionHelper}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, 
Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, 
UnspecifiedDistribution}
 import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
  * broadcast relation.  This data is then placed in a Spark broadcast 
variable.  The streamed
  * relation is not shuffled.
  */
-case class BroadcastHashJoinExec(
+case class BroadcastHashJoinExec private(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     joinType: JoinType,
@@ -245,3 +245,27 @@ case class BroadcastHashJoinExec(
       newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExec =
     copy(left = newLeft, right = newRight)
 }
+
+object BroadcastHashJoinExec extends JoinSelectionHelper {
+  def apply(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      joinType: JoinType,
+      buildSide: BuildSide,
+      condition: Option[Expression],
+      left: SparkPlan,
+      right: SparkPlan,
+      isNullAwareAntiJoin: Boolean = false): BroadcastHashJoinExec = {
+    val (normalizedLeftKeys, normalizedRightKeys) = 
HashJoin.normalizeJoinKeys(leftKeys, rightKeys)
+
+    new BroadcastHashJoinExec(
+      normalizedLeftKeys,
+      normalizedRightKeys,
+      joinType,
+      buildSide,
+      condition,
+      left,
+      right,
+      isNullAwareAntiJoin)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index a1abb64e262d..fab14dba444d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, 
RowIterator}
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -41,6 +42,9 @@ private[joins] case class HashedRelationInfo(
     isEmpty: Boolean)
 
 trait HashJoin extends JoinCodegenSupport {
+  assert(leftKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType)))
+  assert(rightKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType)))
+
   def buildSide: BuildSide
 
   override def simpleStringWithNodeId(): String = {
@@ -724,6 +728,18 @@ trait HashJoin extends JoinCodegenSupport {
 
 object HashJoin extends CastSupport with SQLConfHelper {
 
+  /**
+   * Normalize join keys by injecting `CollationKey` when the keys are 
collated.
+   */
+  def normalizeJoinKeys(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+    (
+      leftKeys.map(CollationKey.injectCollationKey),
+      rightKeys.map(CollationKey.injectCollationKey)
+    )
+  }
+
   private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
     // TODO: support BooleanType, DateType and TimestampType
     keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 97ca74aee30c..0f90f443ad41 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet}
 /**
  * Performs a hash join of two child relations by first shuffling the data 
using the join keys.
  */
-case class ShuffledHashJoinExec(
+case class ShuffledHashJoinExec private (
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     joinType: JoinType,
@@ -659,3 +659,27 @@ case class ShuffledHashJoinExec(
       newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec =
     copy(left = newLeft, right = newRight)
 }
+
+object ShuffledHashJoinExec {
+  def apply(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      joinType: JoinType,
+      buildSide: BuildSide,
+      condition: Option[Expression],
+      left: SparkPlan,
+      right: SparkPlan,
+      isSkewJoin: Boolean = false): ShuffledHashJoinExec = {
+    val (normalizedLeftKeys, normalizedRightKeys) = 
HashJoin.normalizeJoinKeys(leftKeys, rightKeys)
+
+    new ShuffledHashJoinExec(
+      normalizedLeftKeys,
+      normalizedRightKeys,
+      joinType,
+      buildSide,
+      condition,
+      left,
+      right,
+      isSkewJoin)
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
index 6cdf681d65ca..c84647066f25 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
@@ -2114,4 +2114,85 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
       sql(s"CREATE TABLE t (c STRING COLLATE system.builtin.UTF8_LCASE)")
     }
   }
+
+  test("null aware anti join from NOT IN with collated columns") {
+    val expectedAnswer = Seq()
+    val (tableName1, tableName2) = ("t1", "t2")
+    withTable(tableName1, tableName2) {
+      sql(s"CREATE TABLE $tableName1 (C1 STRING COLLATE UTF8_LCASE_RTRIM)")
+      sql(s"CREATE TABLE $tableName2 (C1 STRING COLLATE UTF8_LCASE_RTRIM)")
+      sql(s"INSERT INTO $tableName1 VALUES ('a')")
+      sql(s"INSERT INTO $tableName2 VALUES ('A   ')")
+
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        expectedAnswer)
+
+      sql(s"INSERT INTO $tableName1 VALUES (NULL)")
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        expectedAnswer)
+
+      sql(s"INSERT INTO $tableName1 VALUES ('b')")
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        expectedAnswer ++ Seq(Row("b")))
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)" +
+        s" AND C1 = 'B  '"), Row("b"))
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)" +
+        s" AND C1 > 'b'"), Seq())
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)" +
+        s" AND C1 = 'c'"), Seq())
+
+      // This case results in empty output due to NULL in the t2.
+      sql(s"INSERT INTO $tableName2 VALUES (NULL)")
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        Seq())
+    }
+  }
+
+  test("null aware anti join from NOT IN with collated columns in array type") 
{
+    val expectedAnswer = Seq()
+    val (tableName1, tableName2) = ("t1", "t2")
+    withTable(tableName1, tableName2) {
+      sql(s"CREATE TABLE $tableName1 (C1 ARRAY<STRING COLLATE 
UTF8_LCASE_RTRIM>)")
+      sql(s"CREATE TABLE $tableName2 (C1 ARRAY<STRING COLLATE 
UTF8_LCASE_RTRIM>)")
+      sql(s"INSERT INTO $tableName1 VALUES (ARRAY('a  ', 'Aa '))")
+      sql(s"INSERT INTO $tableName2 VALUES (ARRAY('A', 'aa'))")
+
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        expectedAnswer)
+
+      sql(s"INSERT INTO $tableName1 VALUES (NULL)")
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        expectedAnswer)
+
+      // This case results in empty output due to NULL in the t2.
+      sql(s"INSERT INTO $tableName2 VALUES (NULL)")
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        Seq())
+    }
+  }
+
+  test("null aware anti join from NOT IN with collated columns in struct 
type") {
+    val expectedAnswer = Seq()
+    val (tableName1, tableName2) = ("t1", "t2")
+    withTable(tableName1, tableName2) {
+      sql(s"CREATE TABLE $tableName1 (C1 STRUCT<x: STRING COLLATE 
UTF8_LCASE_RTRIM," +
+        s" y: STRING COLLATE UTF8_LCASE_RTRIM>)")
+      sql(s"CREATE TABLE $tableName2 (C1 STRUCT<x: STRING COLLATE 
UTF8_LCASE_RTRIM," +
+        s" y: STRING COLLATE UTF8_LCASE_RTRIM>)")
+      sql(s"INSERT INTO $tableName1 VALUES (named_struct('x', 'a  ', 'y', 'Aa 
'))")
+      sql(s"INSERT INTO $tableName2 VALUES (named_struct('x', 'A', 'y', 
'aa'))")
+
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        expectedAnswer)
+
+      sql(s"INSERT INTO $tableName1 VALUES (NULL)")
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        expectedAnswer)
+
+      // This case results in empty output due to NULL in the t2.
+      sql(s"INSERT INTO $tableName2 VALUES (NULL)")
+      checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * 
FROM $tableName2)"),
+        Seq())
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 69dd04e07d55..9bd858608cb9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -397,8 +397,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest 
with SQLTestUtils
     }
   }
 
-  private val bh = BroadcastHashJoinExec.toString
-  private val bl = BroadcastNestedLoopJoinExec.toString
+  private val bh = classOf[BroadcastHashJoinExec].getSimpleName
+  private val bl = classOf[BroadcastNestedLoopJoinExec].getSimpleName
 
   private def assertJoinBuildSide(sqlStr: String, joinMethod: String, 
buildSide: BuildSide): Any = {
     val executedPlan = stripAQEPlan(sql(sqlStr).queryExecution.executedPlan)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to