This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 2dd054775bda [SPARK-54852][SQL] `NOT IN` subquery returns incorrect
results with a collated table
2dd054775bda is described below
commit 2dd054775bdac3fbbc505eccf0841946eae02f6d
Author: ilicmarkodb <[email protected]>
AuthorDate: Wed Jan 7 16:11:23 2026 +0800
[SPARK-54852][SQL] `NOT IN` subquery returns incorrect results with a
collated table
### What changes were proposed in this pull request?
```
create or replace table t1 (c1 string collate utf8_lcase_rtrim);
create or replace table t2 (c1 string collate utf8_lcase_rtrim);
insert into t1 values ('a');
insert into t2 values ('A ');
select * from t1 where c1 not in (select * from t2);
-- should return no data, but it returns one row
```
When performing a hash join on collated columns, we first wrap the column
with `CollationKey` during analysis. This is because the hash of `CollationKey`
is collation-aware. The problem with this query is that there is no join during
the analysis phase (we have `NOT IN`), and the join is added during the
optimization phase. As a result, the join operates on raw columns, which are
not collation-aware.
This PR fixes the issue by rewriting the join keys in `HashJoin` trait.
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53622 from ilicmarkodb/fix_not_in.
Authored-by: ilicmarkodb <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 57b2a065c41868d442ea4f7357f13268a9938302)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/CollationKey.scala | 61 +++++++++++++++-
.../execution/joins/BroadcastHashJoinExec.scala | 28 +++++++-
.../spark/sql/execution/joins/HashJoin.scala | 16 +++++
.../sql/execution/joins/ShuffledHashJoinExec.scala | 26 ++++++-
.../spark/sql/collation/CollationSuite.scala | 81 ++++++++++++++++++++++
.../sql/execution/joins/BroadcastJoinSuite.scala | 4 +-
6 files changed, 210 insertions(+), 6 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
index 5d2fd14eee29..9a0aaea75f81 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
@@ -18,10 +18,11 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
-import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.catalyst.util.{CollationFactory, UnsafeRowUtils}
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ArrayImplicits.SparkArrayOps
case class CollationKey(expr: Expression) extends UnaryExpression with
ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
@@ -46,3 +47,61 @@ case class CollationKey(expr: Expression) extends
UnaryExpression with ExpectsIn
override def child: Expression = expr
}
+
+object CollationKey {
+ /**
+ * Recursively process the expression in order to recursively replace
non-binary collated strings
+ * with their associated collation key.
+ */
+ def injectCollationKey(expr: Expression): Expression = {
+ injectCollationKey(expr, expr.dataType)
+ }
+
+ private def injectCollationKey(expr: Expression, dt: DataType): Expression =
{
+ dt match {
+ // For binary stable expressions, no special handling is needed.
+ case _ if UnsafeRowUtils.isBinaryStable(dt) =>
+ expr
+
+ // Inject CollationKey for non-binary collated strings.
+ case _: StringType =>
+ CollationKey(expr)
+
+ // Recursively process struct fields for non-binary structs.
+ case StructType(fields) =>
+ val transformed = fields.zipWithIndex.map { case (f, i) =>
+ val originalField = GetStructField(expr, i, Some(f.name))
+ val injected = injectCollationKey(originalField, f.dataType)
+ (f, injected, injected.fastEquals(originalField))
+ }
+ val anyChanged = transformed.exists { case (_, _, same) => !same }
+ if (!anyChanged) {
+ expr
+ } else {
+ val struct = CreateNamedStruct(
+ transformed.flatMap { case (f, injected, _) =>
+ Seq(Literal(f.name), injected)
+ }.toImmutableArraySeq)
+ if (expr.nullable) {
+ If(IsNull(expr), Literal(null, struct.dataType), struct)
+ } else {
+ struct
+ }
+ }
+
+ // Recursively process array elements for non-binary arrays.
+ case ArrayType(et, containsNull) =>
+ val param: NamedExpression = NamedLambdaVariable("a", et, containsNull)
+ val funcBody: Expression = injectCollationKey(param, et)
+ if (!funcBody.fastEquals(param)) {
+ ArrayTransform(expr, LambdaFunction(funcBody, Seq(param)))
+ } else {
+ expr
+ }
+
+ // Joins are not supported on maps, so there's no special handling for
MapType.
+ case _ =>
+ expr
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b62d8f0798b6..944ee3b05909 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution,
Distribution, HashPartitioningLike, Partitioning, PartitioningCollection,
UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
* broadcast relation. This data is then placed in a Spark broadcast
variable. The streamed
* relation is not shuffled.
*/
-case class BroadcastHashJoinExec(
+case class BroadcastHashJoinExec private(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
@@ -245,3 +245,27 @@ case class BroadcastHashJoinExec(
newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExec =
copy(left = newLeft, right = newRight)
}
+
+object BroadcastHashJoinExec extends JoinSelectionHelper {
+ def apply(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isNullAwareAntiJoin: Boolean = false): BroadcastHashJoinExec = {
+ val (normalizedLeftKeys, normalizedRightKeys) =
HashJoin.normalizeJoinKeys(leftKeys, rightKeys)
+
+ new BroadcastHashJoinExec(
+ normalizedLeftKeys,
+ normalizedRightKeys,
+ joinType,
+ buildSide,
+ condition,
+ left,
+ right,
+ isNullAwareAntiJoin)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index a1abb64e262d..fab14dba444d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils,
RowIterator}
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -41,6 +42,9 @@ private[joins] case class HashedRelationInfo(
isEmpty: Boolean)
trait HashJoin extends JoinCodegenSupport {
+ assert(leftKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType)))
+ assert(rightKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType)))
+
def buildSide: BuildSide
override def simpleStringWithNodeId(): String = {
@@ -724,6 +728,18 @@ trait HashJoin extends JoinCodegenSupport {
object HashJoin extends CastSupport with SQLConfHelper {
+ /**
+ * Normalize join keys by injecting `CollationKey` when the keys are
collated.
+ */
+ def normalizeJoinKeys(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+ (
+ leftKeys.map(CollationKey.injectCollationKey),
+ rightKeys.map(CollationKey.injectCollationKey)
+ )
+ }
+
private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
// TODO: support BooleanType, DateType and TimestampType
keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 97ca74aee30c..0f90f443ad41 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet}
/**
* Performs a hash join of two child relations by first shuffling the data
using the join keys.
*/
-case class ShuffledHashJoinExec(
+case class ShuffledHashJoinExec private (
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
@@ -659,3 +659,27 @@ case class ShuffledHashJoinExec(
newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec =
copy(left = newLeft, right = newRight)
}
+
+object ShuffledHashJoinExec {
+ def apply(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isSkewJoin: Boolean = false): ShuffledHashJoinExec = {
+ val (normalizedLeftKeys, normalizedRightKeys) =
HashJoin.normalizeJoinKeys(leftKeys, rightKeys)
+
+ new ShuffledHashJoinExec(
+ normalizedLeftKeys,
+ normalizedRightKeys,
+ joinType,
+ buildSide,
+ condition,
+ left,
+ right,
+ isSkewJoin)
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
index 6cdf681d65ca..c84647066f25 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
@@ -2114,4 +2114,85 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
sql(s"CREATE TABLE t (c STRING COLLATE system.builtin.UTF8_LCASE)")
}
}
+
+ test("null aware anti join from NOT IN with collated columns") {
+ val expectedAnswer = Seq()
+ val (tableName1, tableName2) = ("t1", "t2")
+ withTable(tableName1, tableName2) {
+ sql(s"CREATE TABLE $tableName1 (C1 STRING COLLATE UTF8_LCASE_RTRIM)")
+ sql(s"CREATE TABLE $tableName2 (C1 STRING COLLATE UTF8_LCASE_RTRIM)")
+ sql(s"INSERT INTO $tableName1 VALUES ('a')")
+ sql(s"INSERT INTO $tableName2 VALUES ('A ')")
+
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ expectedAnswer)
+
+ sql(s"INSERT INTO $tableName1 VALUES (NULL)")
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ expectedAnswer)
+
+ sql(s"INSERT INTO $tableName1 VALUES ('b')")
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ expectedAnswer ++ Seq(Row("b")))
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)" +
+ s" AND C1 = 'B '"), Row("b"))
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)" +
+ s" AND C1 > 'b'"), Seq())
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)" +
+ s" AND C1 = 'c'"), Seq())
+
+ // This case results in empty output due to NULL in the t2.
+ sql(s"INSERT INTO $tableName2 VALUES (NULL)")
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ Seq())
+ }
+ }
+
+ test("null aware anti join from NOT IN with collated columns in array type")
{
+ val expectedAnswer = Seq()
+ val (tableName1, tableName2) = ("t1", "t2")
+ withTable(tableName1, tableName2) {
+ sql(s"CREATE TABLE $tableName1 (C1 ARRAY<STRING COLLATE
UTF8_LCASE_RTRIM>)")
+ sql(s"CREATE TABLE $tableName2 (C1 ARRAY<STRING COLLATE
UTF8_LCASE_RTRIM>)")
+ sql(s"INSERT INTO $tableName1 VALUES (ARRAY('a ', 'Aa '))")
+ sql(s"INSERT INTO $tableName2 VALUES (ARRAY('A', 'aa'))")
+
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ expectedAnswer)
+
+ sql(s"INSERT INTO $tableName1 VALUES (NULL)")
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ expectedAnswer)
+
+ // This case results in empty output due to NULL in the t2.
+ sql(s"INSERT INTO $tableName2 VALUES (NULL)")
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ Seq())
+ }
+ }
+
+ test("null aware anti join from NOT IN with collated columns in struct
type") {
+ val expectedAnswer = Seq()
+ val (tableName1, tableName2) = ("t1", "t2")
+ withTable(tableName1, tableName2) {
+ sql(s"CREATE TABLE $tableName1 (C1 STRUCT<x: STRING COLLATE
UTF8_LCASE_RTRIM," +
+ s" y: STRING COLLATE UTF8_LCASE_RTRIM>)")
+ sql(s"CREATE TABLE $tableName2 (C1 STRUCT<x: STRING COLLATE
UTF8_LCASE_RTRIM," +
+ s" y: STRING COLLATE UTF8_LCASE_RTRIM>)")
+ sql(s"INSERT INTO $tableName1 VALUES (named_struct('x', 'a ', 'y', 'Aa
'))")
+ sql(s"INSERT INTO $tableName2 VALUES (named_struct('x', 'A', 'y',
'aa'))")
+
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ expectedAnswer)
+
+ sql(s"INSERT INTO $tableName1 VALUES (NULL)")
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ expectedAnswer)
+
+ // This case results in empty output due to NULL in the t2.
+ sql(s"INSERT INTO $tableName2 VALUES (NULL)")
+ checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT *
FROM $tableName2)"),
+ Seq())
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 69dd04e07d55..9bd858608cb9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -397,8 +397,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest
with SQLTestUtils
}
}
- private val bh = BroadcastHashJoinExec.toString
- private val bl = BroadcastNestedLoopJoinExec.toString
+ private val bh = classOf[BroadcastHashJoinExec].getSimpleName
+ private val bl = classOf[BroadcastNestedLoopJoinExec].getSimpleName
private def assertJoinBuildSide(sqlStr: String, joinMethod: String,
buildSide: BuildSide): Any = {
val executedPlan = stripAQEPlan(sql(sqlStr).queryExecution.executedPlan)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]