This is an automated email from the ASF dual-hosted git repository.
marin-ma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 2db6253b3c [VL] Simplify the logic in
AppendBatchResizeForShuffleInputAndOutput (#12087)
2db6253b3c is described below
commit 2db6253b3c5ba421fd7123aadafb88e5238d3f68
Author: Rong Ma <[email protected]>
AuthorDate: Fri May 29 13:54:55 2026 +0100
[VL] Simplify the logic in AppendBatchResizeForShuffleInputAndOutput
(#12087)
---
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 2 +-
...AppendBatchResizeForShuffleInputAndOutput.scala | 108 ++++++++++-----------
2 files changed, 52 insertions(+), 58 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 9ce400de56..d63928527d 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -108,7 +108,7 @@ object VeloxRuleApi {
offloads))
// Legacy: Post-transform rules.
- injector.injectPostTransform(_ =>
AppendBatchResizeForShuffleInputAndOutput())
+ injector.injectPostTransform(c =>
AppendBatchResizeForShuffleInputAndOutput(c.caller.isAqe()))
injector.injectPostTransform(_ =>
GpuBufferBatchResizeForShuffleInputOutput())
injector.injectPostTransform(_ => UnionTransformerRule())
injector.injectPostTransform(_ => PartialFallbackRules())
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
index fcce64d652..c21335da19 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala
@@ -22,17 +22,18 @@ import org.apache.gluten.execution.VeloxResizeBatchesExec
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
/**
* Try to append [[VeloxResizeBatchesExec]] for shuffle input and output to
make the batch sizes in
* good shape.
*/
-case class AppendBatchResizeForShuffleInputAndOutput() extends Rule[SparkPlan]
{
+case class AppendBatchResizeForShuffleInputAndOutput(isAdaptiveContext:
Boolean)
+ extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
if (VeloxConfig.get.enableColumnarCudf) {
return plan
}
+
val resizeBatchesShuffleInputEnabled =
VeloxConfig.get.veloxResizeBatchesShuffleInput
val resizeBatchesShuffleOutputEnabled =
VeloxConfig.get.veloxResizeBatchesShuffleOutput
if (!resizeBatchesShuffleInputEnabled &&
!resizeBatchesShuffleOutputEnabled) {
@@ -41,65 +42,58 @@ case class AppendBatchResizeForShuffleInputAndOutput()
extends Rule[SparkPlan] {
val range = VeloxConfig.get.veloxResizeBatchesShuffleInputOutputRange
val preferredBatchBytes = VeloxConfig.get.veloxPreferredBatchBytes
+
+ val newPlan = if (resizeBatchesShuffleInputEnabled) {
+ addResizeBatchesForShuffleInput(plan, range.min, range.max,
preferredBatchBytes)
+ } else {
+ plan
+ }
+
+ val resultPlan = if (isAdaptiveContext &&
resizeBatchesShuffleOutputEnabled) {
+ addResizeBatchesForShuffleOutput(newPlan, range.min, range.max,
preferredBatchBytes)
+ } else {
+ newPlan
+ }
+
+ resultPlan
+ }
+
+ private def addResizeBatchesForShuffleInput(
+ plan: SparkPlan,
+ min: Int,
+ max: Int,
+ preferredBatchBytes: Long): SparkPlan = {
plan.transformUp {
case shuffle: ColumnarShuffleExchangeExec
- if resizeBatchesShuffleInputEnabled &&
- shuffle.shuffleWriterType.requiresResizingShuffleInput =>
+ if shuffle.shuffleWriterType.requiresResizingShuffleInput =>
val appendBatches =
- VeloxResizeBatchesExec(shuffle.child, range.min, range.max,
preferredBatchBytes)
+ VeloxResizeBatchesExec(shuffle.child, min, max, preferredBatchBytes)
shuffle.withNewChildren(Seq(appendBatches))
- case a @ AQEShuffleReadExec(
- ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _),
- _)
- if resizeBatchesShuffleOutputEnabled &&
- shuffle.shuffleWriterType.requiresResizingShuffleOutput =>
- VeloxResizeBatchesExec(a, range.min, range.max, preferredBatchBytes)
- case a @ AQEShuffleReadExec(
- ShuffleQueryStageExec(
- _,
- ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec),
- _),
- _)
- if resizeBatchesShuffleOutputEnabled &&
- shuffle.shuffleWriterType.requiresResizingShuffleOutput =>
- VeloxResizeBatchesExec(a, range.min, range.max, preferredBatchBytes)
- // Since it's transformed in a bottom to up order, so we may first
encounter
- // ShuffeQueryStageExec, which is transformed to
VeloxResizeBatchesExec(ShuffeQueryStageExec),
- // then we see AQEShuffleReadExec
- case a @ AQEShuffleReadExec(
- VeloxResizeBatchesExec(
- s @ ShuffleQueryStageExec(_, shuffle:
ColumnarShuffleExchangeExec, _),
- _,
- _,
- _),
- _)
- if resizeBatchesShuffleOutputEnabled &&
- shuffle.shuffleWriterType.requiresResizingShuffleOutput =>
- VeloxResizeBatchesExec(a.copy(child = s), range.min, range.max,
preferredBatchBytes)
- case a @ AQEShuffleReadExec(
- VeloxResizeBatchesExec(
- s @ ShuffleQueryStageExec(
- _,
- ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec),
- _),
- _,
- _,
- _),
- _)
- if resizeBatchesShuffleOutputEnabled &&
- shuffle.shuffleWriterType.requiresResizingShuffleOutput =>
- VeloxResizeBatchesExec(a.copy(child = s), range.min, range.max,
preferredBatchBytes)
- case s @ ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec,
_)
- if resizeBatchesShuffleOutputEnabled &&
- shuffle.shuffleWriterType.requiresResizingShuffleOutput =>
- VeloxResizeBatchesExec(s, range.min, range.max, preferredBatchBytes)
- case s @ ShuffleQueryStageExec(
- _,
- ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec),
- _)
- if resizeBatchesShuffleOutputEnabled &&
- shuffle.shuffleWriterType.requiresResizingShuffleOutput =>
- VeloxResizeBatchesExec(s, range.min, range.max, preferredBatchBytes)
+ }
+ }
+
+ private def addResizeBatchesForShuffleOutput(
+ plan: SparkPlan,
+ min: Int,
+ max: Int,
+ preferredBatchBytes: Long): SparkPlan = {
+ plan match {
+ case s: ShuffleQueryStageExec if requiresResizingShuffleOutput(s) =>
+ VeloxResizeBatchesExec(s, min, max, preferredBatchBytes)
+ case a @ AQEShuffleReadExec(s @ ShuffleQueryStageExec(_, _, _), _)
+ if requiresResizingShuffleOutput(s) =>
+ VeloxResizeBatchesExec(a, min, max, preferredBatchBytes)
+ case other =>
+ other.mapChildren(addResizeBatchesForShuffleOutput(_, min, max,
preferredBatchBytes))
+ }
+ }
+
+ private def requiresResizingShuffleOutput(s: ShuffleQueryStageExec): Boolean
= {
+ s.shuffle match {
+ case c: ColumnarShuffleExchangeExec
+ if c.shuffleWriterType.requiresResizingShuffleOutput =>
+ true
+ case _ => false
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]