This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new e8549007ef7f [SPARK-56903][SQL] Spread NULL outer join keys across
shuffle partitions
e8549007ef7f is described below
commit e8549007ef7f17a95d7689971861142c66635d0a
Author: Chao Sun <[email protected]>
AuthorDate: Fri May 22 09:40:05 2026 -0700
[SPARK-56903][SQL] Spread NULL outer join keys across shuffle partitions
### What changes were proposed in this pull request?
This PR reduces shuffle skew for null-heavy shuffled outer equi-joins.
For `LEFT OUTER`, `RIGHT OUTER`, and `FULL OUTER` joins, preserved rows
with a `NULL`
shuffle key may not need to stay concentrated on one reducer. Today those
rows can all
collapse into the same shuffle partition, which creates avoidable skew on
NULL-heavy inputs.
This change adds a feature-flagged null-aware shuffle partitioning mode for
shuffled outer
joins:
- Non-NULL shuffle keys keep the existing hash partitioning behavior.
- Rows with any `NULL` shuffle key are spread across reducers instead of
collapsing into one
partition.
- The behavior is disabled by default behind
`spark.sql.shuffle.spreadNullJoinKeys.enabled`.
- The optimization is considered only for `LEFT OUTER`, `RIGHT OUTER`, and
`FULL OUTER`
equi-joins whose preserved side has nullable join keys.
Spreading remains result-safe for null-safe equality (`<=>`) outer joins:
- For ordinary extracted `<=>` join keys, Spark rewrites them into non-null
shuffle-key
expressions using `coalesce(...)` and `isnull(...)`, so there are no
`NULL` shuffle keys for
this feature to redistribute.
- The only remaining corner is `NullType`, where the shuffle key can still
be `NULL`. In that
case, shuffled join execution already treats the row as unmatched, so
redistributing those
rows does not change query results.
The implementation wires this through the planner and runtime pieces that
need to understand
the new partitioning contract:
- `ClusteredDistribution` can opt into null-aware spreading.
- New null-aware partitioning and shuffle-spec variants preserve
compatibility checks without
pretending to satisfy ordinary clustered distributions.
- Shuffle execution spreads unmatched `NULL` keys while preserving retry
safety.
- AQE/coalesced shuffle reads preserve the new partitioning shape.
When the feature flag is enabled, the null-aware join output partitioning
intentionally does not
satisfy a strict `ClusteredDistribution`. That can require an extra
downstream shuffle for
grouping, windowing, or another equi-join on the same key. Also, if one
side is already hash
partitioned, only the other side may be reshuffled into the null-aware
layout, so the
pre-shuffled side can keep its NULL skew.
This PR intentionally stays scoped to outer joins. Left anti joins may also
have skewed
preserved-side `NULL` rows for ordinary `=` predicates and are worth
evaluating separately, but
they need their own correctness and planning review rather than being
folded into this patch.
### Why are the changes needed?
Outer joins can preserve large numbers of unmatched rows from the outer
side. When many of those
rows have `NULL` shuffle keys, sending them all to one reducer creates skew
even though they do
not require one shared reducer for correctness.
Example:
```sql
SELECT *
FROM fact f
LEFT OUTER JOIN dim d
ON f.k = d.k
```
If `fact.k` contains many `NULL` values, those rows must remain in the
result as unmatched
left-side rows, but they do not need to be grouped together for
correctness. Spreading them
reduces needless reducer concentration while leaving normal key matching
unchanged.
### Does this PR introduce _any_ user-facing change?
Yes, in execution behavior only. Query results are unchanged, but when the
feature flag is
enabled, shuffle partitioning for eligible NULL-heavy outer equi-joins
becomes less skewed.
### How was this patch tested?
- Added and updated unit tests covering outer-join planning, FULL OUTER
JOIN result correctness
with `NULL` keys, null-safe outer-join behavior, shuffle-level `NULL`
spreading, retry
determinism, shuffle-spec compatibility, and AQE preservation of
null-aware coalesced reads.
- Ran focused plan-stability verification for the affected TPC-DS cases
locally.
- Ran `git diff --check`.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Codex GPT-5
Closes #55927 from sunchao/dev/chao/codex/null-aware-outer-join-apache.
Authored-by: Chao Sun <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
---
.../sql/catalyst/plans/physical/partitioning.scala | 216 +++++++++++++-
.../org/apache/spark/sql/internal/SQLConf.scala | 14 +
.../spark/sql/catalyst/ShuffleSpecSuite.scala | 60 ++++
.../execution/adaptive/AQEShuffleReadExec.scala | 9 +-
.../execution/exchange/ShuffleExchangeExec.scala | 47 ++-
.../spark/sql/execution/joins/ShuffledJoin.scala | 22 ++
.../DistributionAndOrderingSuiteBase.scala | 2 +-
.../connector/KeyGroupedPartitioningSuite.scala | 2 +-
.../apache/spark/sql/execution/ExchangeSuite.scala | 37 ++-
.../adaptive/AdaptiveQueryExecSuite.scala | 151 +++++++---
.../spark/sql/execution/joins/OuterJoinSuite.scala | 318 ++++++++++++++++++++-
11 files changed, 806 insertions(+), 72 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index f331cd124759..e92bb0f7c0d6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -81,12 +81,17 @@ case object AllTuples extends Distribution {
*
* @param requireAllClusterKeys When true, `Partitioning` which satisfies this
distribution,
* must match all `clustering` expressions in the
same ordering.
+ * @param allowNullKeySpreading When true, the default partitioning may spread
rows whose
+ * clustering keys contain NULL values. This is a
permission for
+ * consumers that do not require NULL-key
co-location; ordinary
+ * [[HashPartitioning]] can still satisfy this
distribution.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
requireAllClusterKeys: Boolean = SQLConf.get.getConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION),
- requiredNumPartitions: Option[Int] = None) extends Distribution {
+ requiredNumPartitions: Option[Int] = None,
+ allowNullKeySpreading: Boolean = false) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil.
" +
@@ -97,7 +102,11 @@ case class ClusteredDistribution(
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get ==
numPartitions,
s"This ClusteredDistribution requires ${requiredNumPartitions.get}
partitions, but " +
s"the actual number of partitions is $numPartitions.")
- HashPartitioning(clustering, numPartitions)
+ if (allowNullKeySpreading) {
+ NullAwareHashPartitioning(clustering, numPartitions)
+ } else {
+ HashPartitioning(clustering, numPartitions)
+ }
}
/**
@@ -282,7 +291,7 @@ trait HashPartitioningLike extends Expression with
Partitioning with Unevaluable
expressions.length == h.expressions.length &&
expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
- case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _) =>
+ case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _, _) =>
if (requireAllClusterKeys) {
// Checks `HashPartitioning` is partitioned on exactly same
clustering keys of
// `ClusteredDistribution`.
@@ -324,6 +333,45 @@ case class HashPartitioning(expressions: Seq[Expression],
numPartitions: Int)
newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions
= newChildren)
}
+/**
+ * Represents a hash partitioning for equi-join inputs where rows with a NULL
join key do not need
+ * to be co-located. Non-NULL join keys preserve the same partitioning
contract as
+ * [[HashPartitioning]], while rows with any NULL join key may be spread
across partitions. As a
+ * result, this partitioning intentionally does not satisfy a strict
[[ClusteredDistribution]].
+ */
+case class NullAwareHashPartitioning(expressions: Seq[Expression],
numPartitions: Int)
+ extends HashPartitioningLike {
+
+ override def satisfies0(required: Distribution): Boolean = {
+ (required match {
+ case UnspecifiedDistribution => true
+ case AllTuples => numPartitions == 1
+ case _ => false
+ }) || {
+ // Stateful operators require strict NULL-key co-location and therefore
cannot consume
+ // null-aware hash partitioning as a compatible clustered layout.
+ required match {
+ case c @ ClusteredDistribution(
+ requiredClustering, requireAllClusterKeys, _,
allowNullKeySpreading)
+ if allowNullKeySpreading =>
+ if (requireAllClusterKeys) {
+ c.areAllClusterKeysMatched(expressions)
+ } else {
+ expressions.forall(x =>
requiredClustering.exists(_.semanticEquals(x)))
+ }
+ case _ => false
+ }
+ }
+ }
+
+ override def createShuffleSpec(distribution: ClusteredDistribution):
ShuffleSpec =
+ NullAwareHashShuffleSpec(this, distribution)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): NullAwareHashPartitioning =
+ copy(expressions = newChildren)
+}
+
case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int)
/**
@@ -345,6 +393,47 @@ case class CoalescedHashPartitioning(from:
HashPartitioning, partitions: Seq[Coa
copy(from = from.copy(expressions = newChildren))
}
+/**
+ * Represents a null-aware hash partitioning whose reducer ranges have been
coalesced into fewer
+ * partitions. It preserves the same relaxed NULL-key co-location contract as
+ * [[NullAwareHashPartitioning]].
+ */
+case class CoalescedNullAwareHashPartitioning(
+ from: NullAwareHashPartitioning,
+ partitions: Seq[CoalescedBoundary]) extends HashPartitioningLike {
+
+ override def expressions: Seq[Expression] = from.expressions
+
+ override def satisfies0(required: Distribution): Boolean = {
+ (required match {
+ case UnspecifiedDistribution => true
+ case AllTuples => numPartitions == 1
+ case _ => false
+ }) || {
+ required match {
+ case c @ ClusteredDistribution(
+ requiredClustering, requireAllClusterKeys, _,
allowNullKeySpreading)
+ if allowNullKeySpreading =>
+ if (requireAllClusterKeys) {
+ c.areAllClusterKeysMatched(expressions)
+ } else {
+ expressions.forall(x =>
requiredClustering.exists(_.semanticEquals(x)))
+ }
+ case _ => false
+ }
+ }
+ }
+
+ override def createShuffleSpec(distribution: ClusteredDistribution):
ShuffleSpec =
+ CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions)
+
+ override val numPartitions: Int = partitions.length
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): CoalescedNullAwareHashPartitioning
=
+ copy(from = from.copy(expressions = newChildren))
+}
+
/**
* Represents a partitioning where rows are split across partitions based on
transforms defined by
* `expressions`.
@@ -482,7 +571,7 @@ case class KeyedPartitioning(
def groupedSatisfies(required: Distribution): Boolean = {
required match {
- case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _) =>
+ case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _, _) =>
if (requireAllClusterKeys) {
// Checks whether this partitioning is partitioned on exactly same
clustering keys of
// `ClusteredDistribution`.
@@ -657,7 +746,7 @@ case class RangePartitioning(ordering: Seq[SortOrder],
numPartitions: Int)
// `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a,
b)`.
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
- case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _) =>
+ case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _, _) =>
val expressions = ordering.map(_.child)
if (requireAllClusterKeys) {
// Checks `RangePartitioning` is partitioned on exactly same
clustering keys of
@@ -838,7 +927,7 @@ case class ShufflePartitionIdPassThrough(
super.satisfies0(required) || {
required match {
// TODO(SPARK-53428): Support Direct Passthrough Partitioning in the
Streaming Joins
- case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _) =>
+ case c @ ClusteredDistribution(requiredClustering,
requireAllClusterKeys, _, _) =>
val partitioningExpressions = expr.child :: Nil
if (requireAllClusterKeys) {
c.areAllClusterKeysMatched(partitioningExpressions)
@@ -919,6 +1008,25 @@ case class RangeShuffleSpec(
}
}
+private object HashShuffleSpecCompatibility {
+ def isCompatible(
+ leftDistribution: ClusteredDistribution,
+ leftNumPartitions: Int,
+ leftExpressions: Seq[Expression],
+ leftHashKeyPositions: Seq[mutable.BitSet],
+ rightDistribution: ClusteredDistribution,
+ rightNumPartitions: Int,
+ rightExpressions: Seq[Expression],
+ rightHashKeyPositions: Seq[mutable.BitSet]): Boolean = {
+ leftDistribution.clustering.length == rightDistribution.clustering.length
&&
+ leftNumPartitions == rightNumPartitions &&
+ leftExpressions.length == rightExpressions.length &&
+ leftHashKeyPositions.zip(rightHashKeyPositions).forall { case (left,
right) =>
+ left.intersect(right).nonEmpty
+ }
+ }
+}
+
case class HashShuffleSpec(
partitioning: HashPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {
@@ -951,14 +1059,26 @@ case class HashShuffleSpec(
// 3. both partitioning have the same number of expressions
// 4. each pair of partitioning expression from both sides has
overlapping positions in their
// corresponding distributions.
- distribution.clustering.length == otherDistribution.clustering.length &&
- partitioning.numPartitions == otherPartitioning.numPartitions &&
- partitioning.expressions.length == otherPartitioning.expressions.length
&& {
- val otherHashKeyPositions = otherHashSpec.hashKeyPositions
- hashKeyPositions.zip(otherHashKeyPositions).forall { case (left,
right) =>
- left.intersect(right).nonEmpty
- }
- }
+ HashShuffleSpecCompatibility.isCompatible(
+ distribution,
+ partitioning.numPartitions,
+ partitioning.expressions,
+ hashKeyPositions,
+ otherDistribution,
+ otherPartitioning.numPartitions,
+ otherPartitioning.expressions,
+ otherHashSpec.hashKeyPositions)
+ case otherNullAwareSpec @ NullAwareHashShuffleSpec(otherPartitioning,
otherDistribution)
+ if distribution.allowNullKeySpreading &&
otherDistribution.allowNullKeySpreading =>
+ HashShuffleSpecCompatibility.isCompatible(
+ distribution,
+ partitioning.numPartitions,
+ partitioning.expressions,
+ hashKeyPositions,
+ otherDistribution,
+ otherPartitioning.numPartitions,
+ otherPartitioning.expressions,
+ otherNullAwareSpec.hashKeyPositions)
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ =>
@@ -979,7 +1099,73 @@ case class HashShuffleSpec(
override def createPartitioning(clustering: Seq[Expression]): Partitioning =
{
val exprs = hashKeyPositions.map(v => clustering(v.head))
- HashPartitioning(exprs, partitioning.numPartitions)
+ if (distribution.allowNullKeySpreading) {
+ NullAwareHashPartitioning(exprs, partitioning.numPartitions)
+ } else {
+ HashPartitioning(exprs, partitioning.numPartitions)
+ }
+ }
+
+ override def numPartitions: Int = partitioning.numPartitions
+}
+
+/**
+ * Shuffle specification for [[NullAwareHashPartitioning]]. It is compatible
only with shuffle
+ * layouts whose distributions explicitly allow NULL-key spreading.
+ */
+case class NullAwareHashShuffleSpec(
+ partitioning: NullAwareHashPartitioning,
+ distribution: ClusteredDistribution) extends ShuffleSpec {
+
+ lazy val hashKeyPositions: Seq[mutable.BitSet] = {
+ val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
+ distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos)
=>
+ distKeyToPos.getOrElseUpdate(distKey.canonicalized,
mutable.BitSet.empty).add(distKeyPos)
+ }
+ partitioning.expressions.map(k => distKeyToPos.getOrElse(k.canonicalized,
mutable.BitSet.empty))
+ }
+
+ override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
+ case SinglePartitionShuffleSpec =>
+ partitioning.numPartitions == 1
+ case otherSpec @ NullAwareHashShuffleSpec(otherPartitioning,
otherDistribution) =>
+ HashShuffleSpecCompatibility.isCompatible(
+ distribution,
+ partitioning.numPartitions,
+ partitioning.expressions,
+ hashKeyPositions,
+ otherDistribution,
+ otherPartitioning.numPartitions,
+ otherPartitioning.expressions,
+ otherSpec.hashKeyPositions)
+ case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherDistribution)
+ if distribution.allowNullKeySpreading &&
otherDistribution.allowNullKeySpreading =>
+ HashShuffleSpecCompatibility.isCompatible(
+ distribution,
+ partitioning.numPartitions,
+ partitioning.expressions,
+ hashKeyPositions,
+ otherDistribution,
+ otherPartitioning.numPartitions,
+ otherPartitioning.expressions,
+ otherHashSpec.hashKeyPositions)
+ case ShuffleSpecCollection(specs) =>
+ specs.exists(isCompatibleWith)
+ case _ =>
+ false
+ }
+
+ override def canCreatePartitioning: Boolean = {
+ if
(SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
+ distribution.areAllClusterKeysMatched(partitioning.expressions)
+ } else {
+ true
+ }
+ }
+
+ override def createPartitioning(clustering: Seq[Expression]): Partitioning =
{
+ val exprs = hashKeyPositions.map(v => clustering(v.head))
+ NullAwareHashPartitioning(exprs, partitioning.numPartitions)
}
override def numPartitions: Int = partitioning.numPartitions
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index c34b52e15dbc..8ab725350448 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -967,6 +967,20 @@ object SQLConf {
.checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be
positive")
.createWithDefault(200)
+ val SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED =
+ buildConf("spark.sql.shuffle.spreadNullJoinKeys.enabled")
+ .doc("When true, Spark may spread rows with NULL equi-join keys across
shuffle partitions " +
+ "for shuffled LEFT, RIGHT, and FULL OUTER equi-joins on nullable keys
to reduce " +
+ "shuffle skew. Null-aware join output partitioning does not satisfy a
strict " +
+ "ClusteredDistribution, so downstream grouping, windowing, or
equi-joins may require " +
+ "an extra shuffle. If one input is already hash partitioned, only the
other input may " +
+ "be reshuffled into the null-aware layout, so the pre-shuffled input
can keep its NULL " +
+ "skew.")
+ .version("4.1.0")
+ .withBindingPolicy(ConfigBindingPolicy.SESSION)
+ .booleanConf
+ .createWithDefault(false)
+
val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED =
buildConf("spark.sql.shuffle.orderIndependentChecksum.enabled")
.doc("Whether to calculate order independent checksum for the shuffle
data or not. If " +
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
index 85d285aa76c0..cb5d77d44512 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
@@ -453,6 +453,66 @@ class ShuffleSpecSuite extends SparkFunSuite with
SQLHelper {
)
}
+ test("compatibility: NullAwareHashShuffleSpec") {
+ val spreadAB = ClusteredDistribution(Seq($"a", $"b"),
allowNullKeySpreading = true)
+ val spreadCD = ClusteredDistribution(Seq($"c", $"d"),
allowNullKeySpreading = true)
+ val regularAB = ClusteredDistribution(Seq($"a", $"b"))
+
+ val nullAwareAB = NullAwareHashShuffleSpec(
+ NullAwareHashPartitioning(Seq($"a", $"b"), 10), spreadAB)
+ val nullAwareCD = NullAwareHashShuffleSpec(
+ NullAwareHashPartitioning(Seq($"c", $"d"), 10), spreadCD)
+ val regularABSpec = HashShuffleSpec(
+ HashPartitioning(Seq($"a", $"b"), 10), regularAB)
+ val spreadABHashSpec = HashShuffleSpec(
+ HashPartitioning(Seq($"a", $"b"), 10), spreadAB)
+
+ checkCompatible(nullAwareAB, nullAwareCD, expected = true)
+ checkCompatible(nullAwareAB, SinglePartitionShuffleSpec, expected = false)
+ checkCompatible(
+ NullAwareHashShuffleSpec(NullAwareHashPartitioning(Seq($"a", $"b"), 1),
spreadAB),
+ SinglePartitionShuffleSpec,
+ expected = true)
+ checkCompatible(nullAwareAB, regularABSpec, expected = false)
+ checkCompatible(nullAwareAB, spreadABHashSpec, expected = true)
+ checkCompatible(spreadABHashSpec, nullAwareAB, expected = true)
+ }
+
+ test("canCreatePartitioning: NullAwareHashShuffleSpec") {
+ val spreadDistribution =
+ ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)
+ val partialSpec = NullAwareHashShuffleSpec(
+ NullAwareHashPartitioning(Seq($"a"), 10), spreadDistribution)
+ val fullSpec = NullAwareHashShuffleSpec(
+ NullAwareHashPartitioning(Seq($"a", $"b"), 10), spreadDistribution)
+
+ withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key ->
"false") {
+ assert(partialSpec.canCreatePartitioning)
+ }
+ withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key ->
"true") {
+ assert(!partialSpec.canCreatePartitioning)
+ assert(fullSpec.canCreatePartitioning)
+ }
+ }
+
+ test("createPartitioning: NullAwareHashShuffleSpec") {
+ checkCreatePartitioning(
+ NullAwareHashShuffleSpec(
+ NullAwareHashPartitioning(Seq($"a"), 10),
+ ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)),
+ ClusteredDistribution(Seq($"c", $"d"), allowNullKeySpreading = true),
+ NullAwareHashPartitioning(Seq($"c"), 10)
+ )
+
+ checkCreatePartitioning(
+ HashShuffleSpec(
+ HashPartitioning(Seq($"a"), 10),
+ ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)),
+ ClusteredDistribution(Seq($"c", $"d"), allowNullKeySpreading = true),
+ NullAwareHashPartitioning(Seq($"c"), 10)
+ )
+ }
+
test("createPartitioning: other specs") {
val distribution = ClusteredDistribution(Seq($"a", $"b"))
checkCreatePartitioning(SinglePartitionShuffleSpec,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
index eba0346a94bd..bff86983961c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary,
CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning,
RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary,
CoalescedHashPartitioning, CoalescedNullAwareHashPartitioning,
HashPartitioning, NullAwareHashPartitioning, Partitioning, RangePartitioning,
RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec,
ShuffleExchangeLike}
@@ -83,6 +83,13 @@ case class AQEShuffleReadExec private(
throw SparkException.internalError(s"Unexpected
ShufflePartitionSpec: $unexpected")
}
CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h,
partitions))
+ case h: NullAwareHashPartitioning =>
+ val partitions = partitionSpecs.map {
+ case CoalescedPartitionSpec(start, end, _) =>
CoalescedBoundary(start, end)
+ case unexpected =>
+ throw SparkException.internalError(s"Unexpected
ShufflePartitionSpec: $unexpected")
+ }
+
CurrentOrigin.withOrigin(h.origin)(CoalescedNullAwareHashPartitioning(h,
partitions))
case r: RangePartitioning =>
CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions =
partitionSpecs.length))
// This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 744438422916..114f221c52f6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -30,7 +30,9 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter,
ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
UnsafeProjection, UnsafeRow, UnsafeRowChecksum}
+import org.apache.spark.sql.catalyst.expressions.{
+ Attribute, BoundReference, CollationAwareMurmur3Hash, Literal, Pmod,
UnsafeProjection,
+ UnsafeRow, UnsafeRowChecksum}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
@@ -349,6 +351,10 @@ object ShuffleExchangeExec {
// For HashPartitioning, the partitioning key is already a valid
partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning
key.
new PartitionIdPassthrough(n)
+ case NullAwareHashPartitioning(_, n) =>
+ // The null-aware extractor below produces partition IDs directly:
+ // Pmod(hash, n) for non-NULL keys, and a round-robin counter for NULL
keys.
+ new PartitionIdPassthrough(n)
case ShufflePartitionIdPassThrough(_, n) =>
// For ShufflePartitionIdPassThrough, the DirectShufflePartitionID
expression directly
// produces partition IDs, so we use PartitionIdPassthrough to pass
them through directly.
@@ -403,6 +409,32 @@ object ShuffleExchangeExec {
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression ::
Nil, outputAttributes)
row => projection(row).getInt(0)
+ case h: NullAwareHashPartitioning =>
+ // Non-NULL keys must produce the same partition id as
+ // HashPartitioning.partitionIdExpression so opted-in HashShuffleSpec
and
+ // NullAwareHashShuffleSpec inputs stay aligned.
+ val joinKeyProjection = UnsafeProjection.create(h.expressions,
outputAttributes)
+ val boundJoinKeys = h.expressions.zipWithIndex.map { case (expr,
index) =>
+ BoundReference(index, expr.dataType, expr.nullable)
+ }
+ val partitionIdExpression = Pmod(
+ new CollationAwareMurmur3Hash(boundJoinKeys),
+ Literal(h.numPartitions))
+ val partitionIdProjection =
UnsafeProjection.create(partitionIdExpression :: Nil)
+ var nullKeyPartition =
+ new
XORShiftRandom(TaskContext.get().partitionId()).nextInt(h.numPartitions)
+ row => {
+ val joinKeys = joinKeyProjection(row)
+ if (joinKeys.anyNull()) {
+ // NULL join keys cannot match under ordinary equi-join semantics.
Spread them
+ // round-robin within each map task so identical rows do not
collapse to one reducer.
+ val partition = nullKeyPartition
+ nullKeyPartition = (nullKeyPartition + 1) % h.numPartitions
+ partition
+ } else {
+ partitionIdProjection(joinKeys).getInt(0)
+ }
+ }
case RangePartitioning(sortingExpressions, _) =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
@@ -419,9 +451,14 @@ object ShuffleExchangeExec {
val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1
+ val isNullAwareHashPartitioning =
+ newPartitioning.isInstanceOf[NullAwareHashPartitioning] &&
+ newPartitioning.numPartitions > 1
+ val needsDeterministicLocalSort =
+ (isRoundRobin || isNullAwareHashPartitioning) &&
SQLConf.get.sortBeforeRepartition
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
- // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning
is deterministic,
+ // [SPARK-23207] Have to make sure stateful row-to-partition assignment
is deterministic,
// otherwise a retry task may output different rows and thus lead to
data loss.
//
// Currently we following the most straight-forward way that perform a
local sort before
@@ -429,7 +466,7 @@ object ShuffleExchangeExec {
//
// Note that we don't perform local sort if the new partitioning has
only 1 partition, under
// that case all output rows go to the same partition.
- val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
+ val newRdd = if (needsDeterministicLocalSort) {
rdd.mapPartitionsInternal { iter =>
val recordComparatorSupplier = new Supplier[RecordComparator] {
override def get: RecordComparator = new RecordBinaryComparator()
@@ -468,7 +505,9 @@ object ShuffleExchangeExec {
}
// round-robin function is order sensitive if we don't sort the input.
- val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
+ // Stateful partition assignment is order-sensitive when it depends on
row visitation order.
+ val isOrderSensitive =
+ (isRoundRobin || isNullAwareHashPartitioning) &&
!SQLConf.get.sortBeforeRepartition
if (needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsWithIndexInternal((_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
index 3fb968bfea7a..179f88c99af6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter,
InnerLike, LeftExistence, LeftOuter, LeftSingle, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution,
Distribution, Partitioning, PartitioningCollection, UnknownPartitioning,
UnspecifiedDistribution}
+import org.apache.spark.sql.internal.SQLConf
/**
* Holds common logic for join operators by shuffling two child relations
@@ -28,6 +29,24 @@ import
org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Dist
trait ShuffledJoin extends JoinCodegenSupport {
def isSkewJoin: Boolean
+ private lazy val canSpreadNullJoinKeys: Boolean = {
+ // Only NULL keys on the preserved side can create this skew: they must be
emitted, but
+ // cannot satisfy ordinary equi-join predicates. Non-preserved NULL-keyed
rows are filtered
+ // out by `=` and never emitted, so their reducer placement does not
matter here.
+ //
+ // Null-safe equality usually rewrites to non-null shuffle keys. The
NullType corner can still
+ // produce NULL shuffle keys, but shuffled join execution already treats
those rows as
+ // unmatched, so spreading them does not change the result.
+ val preservedSideHasNullableKeys = joinType match {
+ case LeftOuter => leftKeys.exists(_.nullable)
+ case RightOuter => rightKeys.exists(_.nullable)
+ case FullOuter => leftKeys.exists(_.nullable) ||
rightKeys.exists(_.nullable)
+ case _ => false
+ }
+ conf.getConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED) &&
+ preservedSideHasNullableKeys
+ }
+
override def nodeName: String = {
if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName
}
@@ -39,6 +58,9 @@ trait ShuffledJoin extends JoinCodegenSupport {
// We re-arrange the shuffle partitions to deal with skew join, and the
new children
// partitioning doesn't satisfy `ClusteredDistribution`.
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+ } else if (canSpreadNullJoinKeys) {
+ ClusteredDistribution(leftKeys, allowNullKeySpreading = true) ::
+ ClusteredDistribution(rightKeys, allowNullKeySpreading = true) :: Nil
} else {
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) ::
Nil
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
index 0273a5d6dd49..c1741cac8ad3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
@@ -67,7 +67,7 @@ abstract class DistributionAndOrderingSuiteBase
protected def resolveDistribution[T <: QueryPlan[T]](
distribution: physical.Distribution,
plan: QueryPlan[T]): physical.Distribution = distribution match {
- case physical.ClusteredDistribution(clustering, numPartitions, _) =>
+ case physical.ClusteredDistribution(clustering, numPartitions, _, _) =>
physical.ClusteredDistribution(clustering.map(resolveAttrs(_, plan)),
numPartitions)
case physical.OrderedDistribution(ordering) =>
physical.OrderedDistribution(ordering.map(resolveAttrs(_,
plan).asInstanceOf[SortOrder]))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 2a0ab52c3693..711f6dbdcdb1 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -233,7 +233,7 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase with
}.head
resolveDistribution(distribution, relation) match {
- case physical.ClusteredDistribution(clustering, _, _) =>
+ case physical.ClusteredDistribution(clustering, _, _, _) =>
assert(relation.keyGroupedPartitioning.isDefined &&
relation.keyGroupedPartitioning.get == clustering)
case _ =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 554cf5111bea..b7798b0bde5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.execution
import scala.util.Random
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
IdentityBroadcastMode, SinglePartition}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
IdentityBroadcastMode, NullAwareHashPartitioning, SinglePartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
Exchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.internal.SQLConf
@@ -59,6 +59,39 @@ class ExchangeSuite extends SharedSparkSession {
)
}
+ test("null-aware hash shuffle spreads identical NULL keys from one mapper") {
+ val input =
Seq.fill(64)(Tuple1(null.asInstanceOf[Integer])).toDF("k").coalesce(1)
+ val plan = input.queryExecution.executedPlan
+ val exchange = ShuffleExchangeExec(NullAwareHashPartitioning(plan.output,
4), plan)
+ val partitionSizes = exchange.execute().collectPartitions().map(_.length)
+
+ assert(partitionSizes.sorted === Array(16, 16, 16, 16))
+ }
+
+ test("null-aware hash shuffle preserves retry determinism with local
sorting") {
+ withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") {
+ val input = spark.range(64).repartition(4).selectExpr("CAST(NULL AS INT)
AS k")
+ val plan = input.queryExecution.executedPlan
+ val exchange =
ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 4), plan)
+
+ assert(plan.execute().outputDeterministicLevel ==
DeterministicLevel.UNORDERED)
+ assert(exchange.shuffleDependency.rdd.outputDeterministicLevel !=
+ DeterministicLevel.INDETERMINATE)
+ }
+ }
+
+ test("null-aware hash shuffle marks unsorted repartitioning as
order-sensitive") {
+ withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "false") {
+ val input = spark.range(64).repartition(4).selectExpr("CAST(NULL AS INT)
AS k")
+ val plan = input.queryExecution.executedPlan
+ val exchange =
ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 4), plan)
+
+ assert(plan.execute().outputDeterministicLevel ==
DeterministicLevel.UNORDERED)
+ assert(exchange.shuffleDependency.rdd.outputDeterministicLevel ==
+ DeterministicLevel.INDETERMINATE)
+ }
+ }
+
test("BroadcastMode.canonicalized") {
val mode1 = IdentityBroadcastMode
val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 50322905f29f..0e7ba599e0fb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference,
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, JoinHint,
LocalRelation, LogicalPlan}
+import
org.apache.spark.sql.catalyst.plans.physical.CoalescedNullAwareHashPartitioning
import org.apache.spark.sql.classic.Strategy
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
@@ -2089,55 +2090,80 @@ class AdaptiveQueryExecSuite
|ON CAST(value AS INT) = b
""".stripMargin)
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
- // Repartition with no partition num specified.
- checkBHJ(df.repartition($"b"),
- // The top shuffle from repartition is optimized out.
- optimizeOutRepartition = true, probeSideLocalRead = false,
probeSideCoalescedRead = true)
-
- // Repartition with default partition num (5 in test env) specified.
- checkBHJ(df.repartition(5, $"b"),
- // The top shuffle from repartition is optimized out
- // The final plan must have 5 partitions, no optimization can be
made to the probe side.
- optimizeOutRepartition = true, probeSideLocalRead = false,
probeSideCoalescedRead = false)
-
- // Repartition with non-default partition num specified.
- checkBHJ(df.repartition(4, $"b"),
- // The top shuffle from repartition is not optimized out
- optimizeOutRepartition = false, probeSideLocalRead = true,
probeSideCoalescedRead = true)
+ def checkRepartitionOptimization(
+ df: Dataset[Row],
+ useNullAwarePartitioning: Boolean): Unit = {
+ val optimizeDefaultRepartition = !useNullAwarePartitioning
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ // Repartition with no partition num specified.
+ checkBHJ(df.repartition($"b"),
+ optimizeOutRepartition = optimizeDefaultRepartition,
+ probeSideLocalRead = useNullAwarePartitioning,
+ probeSideCoalescedRead = !useNullAwarePartitioning)
+
+ // Repartition with default partition num (5 in test env) specified.
+ checkBHJ(df.repartition(5, $"b"),
+ optimizeOutRepartition = optimizeDefaultRepartition,
+ probeSideLocalRead = useNullAwarePartitioning,
+ probeSideCoalescedRead = false)
+
+ // Repartition with non-default partition num specified.
+ checkBHJ(df.repartition(4, $"b"),
+ optimizeOutRepartition = false,
+ probeSideLocalRead = true,
+ probeSideCoalescedRead = true)
+
+ // Repartition by col and project away the partition cols
+ checkBHJ(df.repartition($"b").select($"key"),
+ optimizeOutRepartition = false,
+ probeSideLocalRead = true,
+ probeSideCoalescedRead = true)
+ }
- // Repartition by col and project away the partition cols
- checkBHJ(df.repartition($"b").select($"key"),
- // The top shuffle from repartition is not optimized out
- optimizeOutRepartition = false, probeSideLocalRead = true,
probeSideCoalescedRead = true)
+ // Force skew join
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_ENABLED.key -> "true",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
+ // Repartition with no partition num specified.
+ checkSMJ(df.repartition($"b"),
+ optimizeOutRepartition = optimizeDefaultRepartition,
+ optimizeSkewJoin = useNullAwarePartitioning,
+ coalescedRead = !useNullAwarePartitioning)
+
+ // Repartition with default partition num (5 in test env) specified.
+ checkSMJ(df.repartition(5, $"b"),
+ optimizeOutRepartition = optimizeDefaultRepartition,
+ optimizeSkewJoin = useNullAwarePartitioning,
+ coalescedRead = false)
+
+ // Repartition with non-default partition num specified.
+ checkSMJ(df.repartition(4, $"b"),
+ optimizeOutRepartition = false, optimizeSkewJoin = true,
coalescedRead = false)
+
+ // Repartition by col and project away the partition cols
+ checkSMJ(df.repartition($"b").select($"key"),
+ optimizeOutRepartition = false, optimizeSkewJoin = true,
coalescedRead = false)
+ }
}
- // Force skew join
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
- SQLConf.SKEW_JOIN_ENABLED.key -> "true",
- SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1",
- SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
- SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
- // Repartition with no partition num specified.
- checkSMJ(df.repartition($"b"),
- // The top shuffle from repartition is optimized out.
- optimizeOutRepartition = true, optimizeSkewJoin = false,
coalescedRead = true)
-
- // Repartition with default partition num (5 in test env) specified.
- checkSMJ(df.repartition(5, $"b"),
- // The top shuffle from repartition is optimized out.
- // The final plan must have 5 partitions, can't do coalesced read.
- optimizeOutRepartition = true, optimizeSkewJoin = false,
coalescedRead = false)
-
- // Repartition with non-default partition num specified.
- checkSMJ(df.repartition(4, $"b"),
- // The top shuffle from repartition is not optimized out.
- optimizeOutRepartition = false, optimizeSkewJoin = true,
coalescedRead = false)
-
- // Repartition by col and project away the partition cols
- checkSMJ(df.repartition($"b").select($"key"),
- // The top shuffle from repartition is not optimized out.
- optimizeOutRepartition = false, optimizeSkewJoin = true,
coalescedRead = false)
+ checkRepartitionOptimization(df, useNullAwarePartitioning = false)
+ withSQLConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true")
{
+ // Null-aware join output partitioning is not equivalent to ordinary
hash repartitioning.
+ val nullablePreservedSideDf = sql(
+ """
+ |SELECT * FROM (
+ | SELECT * FROM testData WHERE key = 1
+ |)
+ |RIGHT OUTER JOIN (
+ | SELECT a, b FROM testData2
+ | UNION ALL
+ | SELECT CAST(NULL AS INT) AS a, CAST(NULL AS INT) AS b
+ |)
+ |ON CAST(value AS INT) = b
+ """.stripMargin)
+ checkRepartitionOptimization(nullablePreservedSideDf,
useNullAwarePartitioning = true)
}
}
}
@@ -2604,6 +2630,39 @@ class AdaptiveQueryExecSuite
}
}
+ test("AQE preserves coalesced null-aware partitioning for outer equi-join") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "8",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null-1"),
+ (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+ val df = nullableLeft.join(
+ nullableRight, nullableLeft("k") === nullableRight("k"), "left_outer")
+
+ checkAnswer(df, Seq(
+ Row(1, "left-1", 1, "right-1"),
+ Row(null, "left-null-1", null, null),
+ Row(null, "left-null-2", null, null)))
+
+ val coalescedNullAwareReads = collect(df.queryExecution.executedPlan) {
+ case read: AQEShuffleReadExec
+ if read.hasCoalescedPartition &&
+
read.outputPartitioning.isInstanceOf[CoalescedNullAwareHashPartitioning] =>
+ read
+ }
+ assert(coalescedNullAwareReads.nonEmpty)
+ }
+ }
+
test("SPARK-35794: Allow custom plugin for cost evaluator") {
CostEvaluator.instantiate(
classOf[SimpleShuffleSortCostEvaluator].getCanonicalName,
spark.sparkContext.getConf)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 73e739e261b7..2deb452c3a09 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -17,15 +17,16 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression,
LessThan}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
NullAwareHashPartitioning}
import org.apache.spark.sql.classic.DataFrame
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.exchange.EnsureRequirements
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements,
ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestData}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
@@ -36,6 +37,18 @@ class OuterJoinSuite extends SharedSparkSession with
SQLTestData {
private val EnsureRequirements = new EnsureRequirements()
+ private def extractJoinParts(
+ left: DataFrame,
+ right: DataFrame,
+ condition: Column): ExtractEquiJoinKeys.ReturnType = {
+ val analyzedJoin = left.join(right, condition, "inner")
+ .queryExecution.analyzed
+ .collectFirst { case join: Join => join }
+ .getOrElse(fail("Failed to build analyzed equi-join"))
+ ExtractEquiJoinKeys.unapply(analyzedJoin)
+ .getOrElse(fail("Failed to extract equi-join keys"))
+ }
+
private lazy val left = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, 2.0),
@@ -345,4 +358,305 @@ class OuterJoinSuite extends SharedSparkSession with
SQLTestData {
val df2 = join("SHUFFLE_MERGE(t1)")
checkAnswer(df1, identity, df2.collect().toSeq)
}
+
+ test("ordinary outer equi-join spreads NULL keys in shuffle partitioning") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null-1"),
+ (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") ===
nullableRight("k"))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+ nullableLeft.queryExecution.sparkPlan,
nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+ checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right:
SparkPlan) =>
+ EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
left, right)),
+ Seq(
+ Row(1, "left-1", 1, "right-1"),
+ Row(null, "left-null-1", null, null),
+ Row(null, "left-null-2", null, null)),
+ sortAnswers = true)
+ }
+ }
+
+ test("ordinary outer equi-join keeps hash partitioning when null-aware
shuffle is disabled") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") ===
nullableRight("k"))
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "4") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+ nullableLeft.queryExecution.sparkPlan,
nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[HashPartitioning]))
+ }
+ }
+
+ test("ordinary outer equi-join keeps hash partitioning for non-nullable join
keys") {
+ val nonNullableLeft = spark.range(3).toDF("k")
+ val nonNullableRight = spark.range(3).toDF("k")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(
+ nonNullableLeft,
+ nonNullableRight,
+ nonNullableLeft("k") === nonNullableRight("k"))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+ nonNullableLeft.queryExecution.sparkPlan,
nonNullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[HashPartitioning]))
+ }
+ }
+
+ test("ordinary right outer equi-join spreads NULL keys in shuffle
partitioning") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null-1"),
+ (null.asInstanceOf[Integer], "right-null-2")).toDF("k", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") ===
nullableRight("k"))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, RightOuter, boundCondition,
+ nullableLeft.queryExecution.sparkPlan,
nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+ checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right:
SparkPlan) =>
+ EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, RightOuter, boundCondition,
left, right)),
+ Seq(
+ Row(1, "left-1", 1, "right-1"),
+ Row(null, null, null, "right-null-1"),
+ Row(null, null, null, "right-null-2")),
+ sortAnswers = true)
+ }
+ }
+
+ test("ordinary full outer equi-join keeps NULL keys unmatched while
spreading them") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null-1"),
+ (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null-1"),
+ (null.asInstanceOf[Integer], "right-null-2")).toDF("k", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") ===
nullableRight("k"))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, FullOuter, boundCondition,
+ nullableLeft.queryExecution.sparkPlan,
nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+ checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right:
SparkPlan) =>
+ EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, FullOuter, boundCondition,
left, right)),
+ Seq(
+ Row(1, "left-1", 1, "right-1"),
+ Row(null, "left-null-1", null, null),
+ Row(null, "left-null-2", null, null),
+ Row(null, null, null, "right-null-1"),
+ Row(null, null, null, "right-null-2")),
+ sortAnswers = true)
+ }
+ }
+
+ test("ordinary outer equi-join preserves null-aware shuffle beside existing
hash partitioning") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") ===
nullableRight("k"))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val existingLeftShuffle = ShuffleExchangeExec(
+ HashPartitioning(leftKeys, 4),
+ nullableLeft.queryExecution.sparkPlan)
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+ existingLeftShuffle, nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+
+ assert(partitionings.size == 2)
+ assert(partitionings.count(_.isInstanceOf[HashPartitioning]) == 1)
+ assert(partitionings.count(_.isInstanceOf[NullAwareHashPartitioning]) ==
1)
+ }
+ }
+
+ test("mixed ordinary and null-safe outer equi-join can use null-aware
shuffle partitioning") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), null.asInstanceOf[Integer], "left-match"),
+ (Integer.valueOf(2), null.asInstanceOf[Integer], "left-no-match"))
+ .toDF("k1", "k2", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), null.asInstanceOf[Integer], "right-match"),
+ (Integer.valueOf(2), Integer.valueOf(3), "right-no-match"))
+ .toDF("k1", "k2", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(
+ nullableLeft,
+ nullableRight,
+ nullableLeft("k1") === nullableRight("k1") &&
+ nullableLeft("k2").eqNullSafe(nullableRight("k2")))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+ nullableLeft.queryExecution.sparkPlan,
nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+ checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right:
SparkPlan) =>
+ EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
left, right)),
+ Seq(
+ Row(1, null, "left-match", 1, null, "right-match"),
+ Row(2, null, "left-no-match", null, null, null)),
+ sortAnswers = true)
+ }
+ }
+
+ test("null-safe outer equi-join keeps hash partitioning for non-null shuffle
keys") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null"))
+ .toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(
+ nullableLeft,
+ nullableRight,
+ nullableLeft("k").eqNullSafe(nullableRight("k")))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+ nullableLeft.queryExecution.sparkPlan,
nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[HashPartitioning]))
+ }
+ }
+
+ test("ordinary outer equi-join spreads NULL keys for shuffled hash join") {
+ val nullableLeft = Seq(
+ (Integer.valueOf(1), "left-1"),
+ (null.asInstanceOf[Integer], "left-null-1"),
+ (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+ val nullableRight = Seq(
+ (Integer.valueOf(1), "right-1"),
+ (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") ===
nullableRight("k"))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ ShuffledHashJoinExec(leftKeys, rightKeys, LeftOuter, BuildRight,
boundCondition,
+ nullableLeft.queryExecution.sparkPlan,
nullableRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+ checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right:
SparkPlan) =>
+ EnsureRequirements.apply(
+ ShuffledHashJoinExec(
+ leftKeys, rightKeys, LeftOuter, BuildRight, boundCondition, left,
right)),
+ Seq(
+ Row(1, "left-1", 1, "right-1"),
+ Row(null, "left-null-1", null, null),
+ Row(null, "left-null-2", null, null)),
+ sortAnswers = true)
+ }
+ }
+
+ test("NullType null-safe outer equi-join remains result-safe with null-aware
shuffle") {
+ val nullTypeLeft = spark.range(2).selectExpr("NULL AS k", "id AS lv")
+ val nullTypeRight = spark.range(1).selectExpr("NULL AS k", "id AS rv")
+ val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+ extractJoinParts(
+ nullTypeLeft,
+ nullTypeRight,
+ nullTypeLeft("k").eqNullSafe(nullTypeRight("k")))
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+ SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+ val plan = EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+ nullTypeLeft.queryExecution.sparkPlan,
nullTypeRight.queryExecution.sparkPlan))
+ val partitionings = plan.collect {
+ case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+ }
+ assert(partitionings.size == 2)
+ assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+ checkAnswer2(nullTypeLeft, nullTypeRight, (left: SparkPlan, right:
SparkPlan) =>
+ EnsureRequirements.apply(
+ SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
left, right)),
+ Seq(
+ Row(null, 0L, null, null),
+ Row(null, 1L, null, null)),
+ sortAnswers = true)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]