This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c6a0271e minor: Move shuffle logic from `CometExecRule` to 
`CometShuffleExchangeExec` serde implementation (#2853)
5c6a0271e is described below

commit 5c6a0271e55acc647b984f9feead993fec83d616
Author: Andy Grove <[email protected]>
AuthorDate: Fri Dec 5 16:57:04 2025 -0700

    minor: Move shuffle logic from `CometExecRule` to 
`CometShuffleExchangeExec` serde implementation (#2853)
---
 .../apache/comet/CometSparkSessionExtensions.scala |  16 -
 .../org/apache/comet/rules/CometExecRule.scala     | 313 +-------------------
 .../shuffle/CometShuffleExchangeExec.scala         | 326 ++++++++++++++++++++-
 3 files changed, 328 insertions(+), 327 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 71ce8d311..01a11bf0d 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -157,22 +157,6 @@ object CometSparkSessionExtensions extends Logging {
     COMET_EXEC_ENABLED.get(conf)
   }
 
-  private[comet] def isCometNativeShuffleMode(conf: SQLConf): Boolean = {
-    COMET_SHUFFLE_MODE.get(conf) match {
-      case "native" => true
-      case "auto" => true
-      case _ => false
-    }
-  }
-
-  private[comet] def isCometJVMShuffleMode(conf: SQLConf): Boolean = {
-    COMET_SHUFFLE_MODE.get(conf) match {
-      case "jvm" => true
-      case "auto" => true
-      case _ => false
-    }
-  }
-
   def isCometScan(op: SparkPlan): Boolean = {
     op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
   }
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala 
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 0ef10ec09..6879fba4e 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -22,11 +22,10 @@ package org.apache.comet.rules
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, 
EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, 
GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, 
NamedExpression, Remainder}
 import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
RangePartitioning, RoundRobinPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.sideBySide
 import org.apache.spark.sql.comet._
-import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
+import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle, CometShuffleExchangeExec}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec}
@@ -35,14 +34,12 @@ import 
org.apache.spark.sql.execution.datasources.v2.V2CommandExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 import org.apache.comet.{CometConf, ExtendedExplainInfo}
-import org.apache.comet.CometConf.COMET_EXEC_SHUFFLE_ENABLED
 import org.apache.comet.CometSparkSessionExtensions._
 import org.apache.comet.rules.CometExecRule.allExecs
-import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, 
OperatorOuterClass, QueryPlanSerde, Unsupported}
+import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, 
OperatorOuterClass, Unsupported}
 import org.apache.comet.serde.OperatorOuterClass.Operator
 import org.apache.comet.serde.operator._
 import org.apache.comet.serde.operator.CometDataWritingCommand
@@ -92,21 +89,19 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
 
   private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
     plan.transformUp {
-      case s: ShuffleExchangeExec if nativeShuffleSupported(s) =>
+      case s: ShuffleExchangeExec if 
CometShuffleExchangeExec.nativeShuffleSupported(s) =>
         // Switch to use Decimal128 regardless of precision, since Arrow 
native execution
         // doesn't support Decimal32 and Decimal64 yet.
         conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
         CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
 
-      case s: ShuffleExchangeExec if columnarShuffleSupported(s) =>
+      case s: ShuffleExchangeExec if 
CometShuffleExchangeExec.columnarShuffleSupported(s) =>
         // Columnar shuffle for regular Spark operators (not Comet) and Comet 
operators
         // (if configured)
         CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
     }
   }
 
-  private def isCometPlan(op: SparkPlan): Boolean = op.isInstanceOf[CometPlan]
-
   private def isCometNative(op: SparkPlan): Boolean = 
op.isInstanceOf[CometNativeExec]
 
   // spotless:off
@@ -249,9 +244,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
         convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
 
       case s: ShuffleExchangeExec =>
-        // try native shuffle first, then columnar shuffle, then fall back to 
Spark
-        // if neither are supported
-        tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
+        convertToComet(s, CometShuffleExchangeExec).getOrElse(s)
 
       case op =>
         val handler = allExecs
@@ -288,39 +281,6 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     }
   }
 
-  private def tryNativeShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
-    Some(s)
-      .filter(nativeShuffleSupported)
-      .filter(_.children.forall(_.isInstanceOf[CometNativeExec]))
-      .flatMap(_ => operator2Proto(s))
-      .map { nativeOp =>
-        // Switch to use Decimal128 regardless of precision, since Arrow 
native execution
-        // doesn't support Decimal32 and Decimal64 yet.
-        conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
-        val cometOp = CometShuffleExchangeExec(s, shuffleType = 
CometNativeShuffle)
-        CometSinkPlaceHolder(nativeOp, s, cometOp)
-      }
-  }
-
-  private def tryColumnarShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
-    // Columnar shuffle for regular Spark operators (not Comet) and Comet 
operators
-    // (if configured).
-    // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we 
should not
-    // convert it to CometColumnarShuffle,
-    Some(s)
-      .filter(columnarShuffleSupported)
-      .flatMap(_ => operator2Proto(s))
-      .flatMap { nativeOp =>
-        s.child match {
-          case n if n.isInstanceOf[CometNativeExec] || !n.supportsColumnar =>
-            val cometOp = CometShuffleExchangeExec(s, shuffleType = 
CometColumnarShuffle)
-            Some(CometSinkPlaceHolder(nativeOp, s, cometOp))
-          case _ =>
-            None
-        }
-      }
-  }
-
   private def normalizePlan(plan: SparkPlan): SparkPlan = {
     plan.transformUp {
       case p: ProjectExec =>
@@ -497,269 +457,6 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     }
   }
 
-  /**
-   * Returns true if a given spark plan is Comet shuffle operator.
-   */
-  private def isShuffleOperator(op: SparkPlan): Boolean = {
-    op match {
-      case op: ShuffleQueryStageExec if 
op.plan.isInstanceOf[CometShuffleExchangeExec] => true
-      case _: CometShuffleExchangeExec => true
-      case op: CometSinkPlaceHolder => isShuffleOperator(op.child)
-      case _ => false
-    }
-  }
-
-  def isCometShuffleEnabledWithInfo(op: SparkPlan): Boolean = {
-    if (!COMET_EXEC_SHUFFLE_ENABLED.get(op.conf)) {
-      withInfo(
-        op,
-        s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is 
not enabled")
-      false
-    } else if (!isCometShuffleManagerEnabled(op.conf)) {
-      withInfo(op, s"spark.shuffle.manager is not set to 
${classOf[CometShuffleManager].getName}")
-      false
-    } else {
-      true
-    }
-  }
-
-  /**
-   * Whether the given Spark partitioning is supported by Comet native shuffle.
-   */
-  private def nativeShuffleSupported(s: ShuffleExchangeExec): Boolean = {
-
-    /**
-     * Determine which data types are supported as partition columns in native 
shuffle.
-     *
-     * For HashPartitioning this defines the key that determines how data 
should be collocated for
-     * operations like `groupByKey`, `reduceByKey`, or `join`. Native code 
does not support
-     * hashing complex types, see hash_funcs/utils.rs
-     */
-    def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match {
-      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
-          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
-          _: TimestampNTZType | _: DecimalType | _: DateType =>
-        true
-      case _ =>
-        false
-    }
-
-    /**
-     * Determine which data types are supported as partition columns in native 
shuffle.
-     *
-     * For RangePartitioning this defines the key that determines how data 
should be collocated
-     * for operations like `orderBy`, `repartitionByRange`. Native code does 
not support sorting
-     * complex types.
-     */
-    def supportedRangePartitioningDataType(dt: DataType): Boolean = dt match {
-      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
-          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
-          _: TimestampNTZType | _: DecimalType | _: DateType =>
-        true
-      case _ =>
-        false
-    }
-
-    /**
-     * Determine which data types are supported as data columns in native 
shuffle.
-     *
-     * Native shuffle relies on the Arrow IPC writer to serialize batches to 
disk, so it should
-     * support all types that Comet supports.
-     */
-    def supportedSerializableDataType(dt: DataType): Boolean = dt match {
-      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
-          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
-          _: TimestampNTZType | _: DecimalType | _: DateType =>
-        true
-      case StructType(fields) =>
-        fields.nonEmpty && fields.forall(f => 
supportedSerializableDataType(f.dataType))
-      case ArrayType(elementType, _) =>
-        supportedSerializableDataType(elementType)
-      case MapType(keyType, valueType, _) =>
-        supportedSerializableDataType(keyType) && 
supportedSerializableDataType(valueType)
-      case _ =>
-        false
-    }
-
-    if (!isCometShuffleEnabledWithInfo(s)) {
-      return false
-    }
-
-    if (!isCometNativeShuffleMode(s.conf)) {
-      withInfo(s, "Comet native shuffle not enabled")
-      return false
-    }
-
-    if (!isCometPlan(s.child)) {
-      // we do not need to report a fallback reason if the child plan is not a 
Comet plan
-      return false
-    }
-
-    val inputs = s.child.output
-
-    for (input <- inputs) {
-      if (!supportedSerializableDataType(input.dataType)) {
-        withInfo(s, s"unsupported shuffle data type ${input.dataType} for 
input $input")
-        return false
-      }
-    }
-
-    val partitioning = s.outputPartitioning
-    val conf = SQLConf.get
-    partitioning match {
-      case HashPartitioning(expressions, _) =>
-        var supported = true
-        if 
(!CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)) {
-          withInfo(
-            s,
-            
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is 
disabled")
-          supported = false
-        }
-        for (expr <- expressions) {
-          if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
-            withInfo(s, s"unsupported hash partitioning expression: $expr")
-            supported = false
-            // We don't short-circuit in case there is more than one 
unsupported expression
-            // to provide info for.
-          }
-        }
-        for (dt <- expressions.map(_.dataType).distinct) {
-          if (!supportedHashPartitioningDataType(dt)) {
-            withInfo(s, s"unsupported hash partitioning data type for native 
shuffle: $dt")
-            supported = false
-          }
-        }
-        supported
-      case SinglePartition =>
-        // we already checked that the input types are supported
-        true
-      case RangePartitioning(orderings, _) =>
-        if 
(!CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)) {
-          withInfo(
-            s,
-            
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key} is 
disabled")
-          return false
-        }
-        var supported = true
-        for (o <- orderings) {
-          if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
-            withInfo(s, s"unsupported range partitioning sort order: $o", o)
-            supported = false
-            // We don't short-circuit in case there is more than one 
unsupported expression
-            // to provide info for.
-          }
-        }
-        for (dt <- orderings.map(_.dataType).distinct) {
-          if (!supportedRangePartitioningDataType(dt)) {
-            withInfo(s, s"unsupported range partitioning data type for native 
shuffle: $dt")
-            supported = false
-          }
-        }
-        supported
-      case _ =>
-        withInfo(
-          s,
-          s"unsupported Spark partitioning for native shuffle: 
${partitioning.getClass.getName}")
-        false
-    }
-  }
-
-  /**
-   * Check if the datatypes of shuffle input are supported. This is used for 
Columnar shuffle
-   * which supports struct/array.
-   */
-  private def columnarShuffleSupported(s: ShuffleExchangeExec): Boolean = {
-
-    /**
-     * Determine which data types are supported as data columns in columnar 
shuffle.
-     *
-     * Comet columnar shuffle used native code to convert Spark unsafe rows to 
Arrow batches, see
-     * shuffle/row.rs
-     */
-    def supportedSerializableDataType(dt: DataType): Boolean = dt match {
-      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
-          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
-          _: TimestampNTZType | _: DecimalType | _: DateType =>
-        true
-      case StructType(fields) =>
-        fields.nonEmpty && fields.forall(f => 
supportedSerializableDataType(f.dataType)) &&
-        // Java Arrow stream reader cannot work on duplicate field name
-        fields.map(f => f.name).distinct.length == fields.length &&
-        fields.nonEmpty
-      case ArrayType(elementType, _) =>
-        supportedSerializableDataType(elementType)
-      case MapType(keyType, valueType, _) =>
-        supportedSerializableDataType(keyType) && 
supportedSerializableDataType(valueType)
-      case _ =>
-        false
-    }
-
-    if (!isCometShuffleEnabledWithInfo(s)) {
-      return false
-    }
-
-    if (!isCometJVMShuffleMode(s.conf)) {
-      withInfo(s, "Comet columnar shuffle not enabled")
-      return false
-    }
-
-    if (isShuffleOperator(s.child)) {
-      withInfo(s, s"Child ${s.child.getClass.getName} is a shuffle operator")
-      return false
-    }
-
-    if (!(!s.child.supportsColumnar || isCometPlan(s.child))) {
-      withInfo(s, s"Child ${s.child.getClass.getName} is a neither row-based 
or a Comet operator")
-      return false
-    }
-
-    val inputs = s.child.output
-
-    for (input <- inputs) {
-      if (!supportedSerializableDataType(input.dataType)) {
-        withInfo(s, s"unsupported shuffle data type ${input.dataType} for 
input $input")
-        return false
-      }
-    }
-
-    val partitioning = s.outputPartitioning
-    partitioning match {
-      case HashPartitioning(expressions, _) =>
-        var supported = true
-        for (expr <- expressions) {
-          if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
-            withInfo(s, s"unsupported hash partitioning expression: $expr")
-            supported = false
-            // We don't short-circuit in case there is more than one 
unsupported expression
-            // to provide info for.
-          }
-        }
-        supported
-      case SinglePartition =>
-        // we already checked that the input types are supported
-        true
-      case RoundRobinPartitioning(_) =>
-        // we already checked that the input types are supported
-        true
-      case RangePartitioning(orderings, _) =>
-        var supported = true
-        for (o <- orderings) {
-          if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
-            withInfo(s, s"unsupported range partitioning sort order: $o")
-            supported = false
-            // We don't short-circuit in case there is more than one 
unsupported expression
-            // to provide info for.
-          }
-        }
-        supported
-      case _ =>
-        withInfo(
-          s,
-          s"unsupported Spark partitioning for columnar shuffle: 
${partitioning.getClass.getName}")
-        false
-    }
-  }
-
   /**
    * Fallback for handling sinks that have not been handled explicitly. This 
method should
    * eventually be removed once CometExecRule fully uses the operator serde 
framework.
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 1f7d37a10..2e6ab9aff 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -29,16 +29,18 @@ import org.apache.spark.internal.config
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
UnsafeProjection, UnsafeRow}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.comet.{CometMetricNode, CometPlan}
+import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, 
CometPlan, CometSinkPlaceHolder}
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
 import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, 
ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, 
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, 
ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, 
LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, 
TimestampType}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.MutablePair
 import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, 
RecordComparator}
@@ -47,6 +49,10 @@ import org.apache.spark.util.random.XORShiftRandom
 import com.google.common.base.Objects
 
 import org.apache.comet.CometConf
+import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, 
COMET_SHUFFLE_MODE}
+import 
org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, 
withInfo}
+import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, 
SupportLevel, Unsupported}
+import org.apache.comet.serde.operator.CometSink
 import org.apache.comet.shims.ShimCometShuffleExchangeExec
 
 /**
@@ -210,7 +216,321 @@ case class CometShuffleExchangeExec(
     Iterator(outputPartitioning, shuffleOrigin, shuffleType, child) ++ 
Iterator(s"[plan_id=$id]")
 }
 
-object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
+object CometShuffleExchangeExec
+    extends CometSink[ShuffleExchangeExec]
+    with ShimCometShuffleExchangeExec
+    with SQLConfHelper {
+
+  override def getSupportLevel(op: ShuffleExchangeExec): SupportLevel = {
+    if (nativeShuffleSupported(op) || columnarShuffleSupported(op)) {
+      Compatible()
+    } else {
+      Unsupported()
+    }
+  }
+
+  override def createExec(
+      nativeOp: OperatorOuterClass.Operator,
+      op: ShuffleExchangeExec): CometNativeExec = {
+    if (nativeShuffleSupported(op) && 
op.children.forall(_.isInstanceOf[CometNativeExec])) {
+      // Switch to use Decimal128 regardless of precision, since Arrow native 
execution
+      // doesn't support Decimal32 and Decimal64 yet.
+      conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
+      CometSinkPlaceHolder(
+        nativeOp,
+        op,
+        CometShuffleExchangeExec(op, shuffleType = CometNativeShuffle))
+
+    } else if (columnarShuffleSupported(op)) {
+      CometSinkPlaceHolder(
+        nativeOp,
+        op,
+        CometShuffleExchangeExec(op, shuffleType = CometColumnarShuffle))
+    } else {
+      throw new IllegalStateException()
+    }
+  }
+
+  /**
+   * Whether the given Spark partitioning is supported by Comet native shuffle.
+   */
+  def nativeShuffleSupported(s: ShuffleExchangeExec): Boolean = {
+
+    /**
+     * Determine which data types are supported as partition columns in native 
shuffle.
+     *
+     * For HashPartitioning this defines the key that determines how data 
should be collocated for
+     * operations like `groupByKey`, `reduceByKey`, or `join`. Native code 
does not support
+     * hashing complex types, see hash_funcs/utils.rs
+     */
+    def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match {
+      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
+          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
+          _: TimestampNTZType | _: DecimalType | _: DateType =>
+        true
+      case _ =>
+        false
+    }
+
+    /**
+     * Determine which data types are supported as partition columns in native 
shuffle.
+     *
+     * For RangePartitioning this defines the key that determines how data 
should be collocated
+     * for operations like `orderBy`, `repartitionByRange`. Native code does 
not support sorting
+     * complex types.
+     */
+    def supportedRangePartitioningDataType(dt: DataType): Boolean = dt match {
+      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
+          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
+          _: TimestampNTZType | _: DecimalType | _: DateType =>
+        true
+      case _ =>
+        false
+    }
+
+    /**
+     * Determine which data types are supported as data columns in native 
shuffle.
+     *
+     * Native shuffle relies on the Arrow IPC writer to serialize batches to 
disk, so it should
+     * support all types that Comet supports.
+     */
+    def supportedSerializableDataType(dt: DataType): Boolean = dt match {
+      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
+          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
+          _: TimestampNTZType | _: DecimalType | _: DateType =>
+        true
+      case StructType(fields) =>
+        fields.nonEmpty && fields.forall(f => 
supportedSerializableDataType(f.dataType))
+      case ArrayType(elementType, _) =>
+        supportedSerializableDataType(elementType)
+      case MapType(keyType, valueType, _) =>
+        supportedSerializableDataType(keyType) && 
supportedSerializableDataType(valueType)
+      case _ =>
+        false
+    }
+
+    if (!isCometShuffleEnabledWithInfo(s)) {
+      return false
+    }
+
+    if (!isCometNativeShuffleMode(s.conf)) {
+      withInfo(s, "Comet native shuffle not enabled")
+      return false
+    }
+
+    if (!isCometPlan(s.child)) {
+      // we do not need to report a fallback reason if the child plan is not a 
Comet plan
+      return false
+    }
+
+    val inputs = s.child.output
+
+    for (input <- inputs) {
+      if (!supportedSerializableDataType(input.dataType)) {
+        withInfo(s, s"unsupported shuffle data type ${input.dataType} for 
input $input")
+        return false
+      }
+    }
+
+    val partitioning = s.outputPartitioning
+    val conf = SQLConf.get
+    partitioning match {
+      case HashPartitioning(expressions, _) =>
+        var supported = true
+        if 
(!CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)) {
+          withInfo(
+            s,
+            
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is 
disabled")
+          supported = false
+        }
+        for (expr <- expressions) {
+          if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
+            withInfo(s, s"unsupported hash partitioning expression: $expr")
+            supported = false
+            // We don't short-circuit in case there is more than one 
unsupported expression
+            // to provide info for.
+          }
+        }
+        for (dt <- expressions.map(_.dataType).distinct) {
+          if (!supportedHashPartitioningDataType(dt)) {
+            withInfo(s, s"unsupported hash partitioning data type for native 
shuffle: $dt")
+            supported = false
+          }
+        }
+        supported
+      case SinglePartition =>
+        // we already checked that the input types are supported
+        true
+      case RangePartitioning(orderings, _) =>
+        if 
(!CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)) {
+          withInfo(
+            s,
+            
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key} is 
disabled")
+          return false
+        }
+        var supported = true
+        for (o <- orderings) {
+          if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
+            withInfo(s, s"unsupported range partitioning sort order: $o", o)
+            supported = false
+            // We don't short-circuit in case there is more than one 
unsupported expression
+            // to provide info for.
+          }
+        }
+        for (dt <- orderings.map(_.dataType).distinct) {
+          if (!supportedRangePartitioningDataType(dt)) {
+            withInfo(s, s"unsupported range partitioning data type for native 
shuffle: $dt")
+            supported = false
+          }
+        }
+        supported
+      case _ =>
+        withInfo(
+          s,
+          s"unsupported Spark partitioning for native shuffle: 
${partitioning.getClass.getName}")
+        false
+    }
+  }
+
+  /**
+   * Check if the datatypes of shuffle input are supported. This is used for 
Columnar shuffle
+   * which supports struct/array.
+   */
+  def columnarShuffleSupported(s: ShuffleExchangeExec): Boolean = {
+
+    /**
+     * Determine which data types are supported as data columns in columnar 
shuffle.
+     *
+     * Comet columnar shuffle used native code to convert Spark unsafe rows to 
Arrow batches, see
+     * shuffle/row.rs
+     */
+    def supportedSerializableDataType(dt: DataType): Boolean = dt match {
+      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
+          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
+          _: TimestampNTZType | _: DecimalType | _: DateType =>
+        true
+      case StructType(fields) =>
+        fields.nonEmpty && fields.forall(f => 
supportedSerializableDataType(f.dataType)) &&
+        // Java Arrow stream reader cannot work on duplicate field name
+        fields.map(f => f.name).distinct.length == fields.length &&
+        fields.nonEmpty
+      case ArrayType(elementType, _) =>
+        supportedSerializableDataType(elementType)
+      case MapType(keyType, valueType, _) =>
+        supportedSerializableDataType(keyType) && 
supportedSerializableDataType(valueType)
+      case _ =>
+        false
+    }
+
+    if (!isCometShuffleEnabledWithInfo(s)) {
+      return false
+    }
+
+    if (!isCometJVMShuffleMode(s.conf)) {
+      withInfo(s, "Comet columnar shuffle not enabled")
+      return false
+    }
+
+    if (isShuffleOperator(s.child)) {
+      withInfo(s, s"Child ${s.child.getClass.getName} is a shuffle operator")
+      return false
+    }
+
+    if (!(!s.child.supportsColumnar || isCometPlan(s.child))) {
+      withInfo(s, s"Child ${s.child.getClass.getName} is a neither row-based 
or a Comet operator")
+      return false
+    }
+
+    val inputs = s.child.output
+
+    for (input <- inputs) {
+      if (!supportedSerializableDataType(input.dataType)) {
+        withInfo(s, s"unsupported shuffle data type ${input.dataType} for 
input $input")
+        return false
+      }
+    }
+
+    val partitioning = s.outputPartitioning
+    partitioning match {
+      case HashPartitioning(expressions, _) =>
+        var supported = true
+        for (expr <- expressions) {
+          if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
+            withInfo(s, s"unsupported hash partitioning expression: $expr")
+            supported = false
+            // We don't short-circuit in case there is more than one 
unsupported expression
+            // to provide info for.
+          }
+        }
+        supported
+      case SinglePartition =>
+        // we already checked that the input types are supported
+        true
+      case RoundRobinPartitioning(_) =>
+        // we already checked that the input types are supported
+        true
+      case RangePartitioning(orderings, _) =>
+        var supported = true
+        for (o <- orderings) {
+          if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
+            withInfo(s, s"unsupported range partitioning sort order: $o")
+            supported = false
+            // We don't short-circuit in case there is more than one 
unsupported expression
+            // to provide info for.
+          }
+        }
+        supported
+      case _ =>
+        withInfo(
+          s,
+          s"unsupported Spark partitioning for columnar shuffle: 
${partitioning.getClass.getName}")
+        false
+    }
+  }
+
+  private def isCometNativeShuffleMode(conf: SQLConf): Boolean = {
+    COMET_SHUFFLE_MODE.get(conf) match {
+      case "native" => true
+      case "auto" => true
+      case _ => false
+    }
+  }
+
+  private def isCometJVMShuffleMode(conf: SQLConf): Boolean = {
+    COMET_SHUFFLE_MODE.get(conf) match {
+      case "jvm" => true
+      case "auto" => true
+      case _ => false
+    }
+  }
+
+  private def isCometPlan(op: SparkPlan): Boolean = op.isInstanceOf[CometPlan]
+
+  /**
+   * Returns true if a given spark plan is Comet shuffle operator.
+   */
+  private def isShuffleOperator(op: SparkPlan): Boolean = {
+    op match {
+      case op: ShuffleQueryStageExec if 
op.plan.isInstanceOf[CometShuffleExchangeExec] => true
+      case _: CometShuffleExchangeExec => true
+      case op: CometSinkPlaceHolder => isShuffleOperator(op.child)
+      case _ => false
+    }
+  }
+
+  def isCometShuffleEnabledWithInfo(op: SparkPlan): Boolean = {
+    if (!COMET_EXEC_SHUFFLE_ENABLED.get(op.conf)) {
+      withInfo(
+        op,
+        s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is 
not enabled")
+      false
+    } else if (!isCometShuffleManagerEnabled(op.conf)) {
+      withInfo(op, s"spark.shuffle.manager is not set to 
${classOf[CometShuffleManager].getName}")
+      false
+    } else {
+      true
+    }
+  }
 
   def prepareShuffleDependency(
       rdd: RDD[ColumnarBatch],


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to