comphead commented on code in PR #3731:
URL: https://github.com/apache/datafusion-comet/pull/3731#discussion_r2968192142


##########
spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala:
##########
@@ -86,15 +90,67 @@ abstract class CometSink[T <: SparkPlan] extends 
CometOperatorSerde[T] {
 
 object CometExchangeSink extends CometSink[SparkPlan] {
 
-  /**
-   * Exchange data is FFI safe because there is no use of mutable buffers 
involved.
-   *
-   * Source of broadcast exchange batches is ArrowStreamReader.
-   *
-   * Source of shuffle exchange batches is NativeBatchDecoderIterator.
-   */
   override def isFfiSafe: Boolean = true
 
+  override def convert(
+      op: SparkPlan,
+      builder: Operator.Builder,
+      childOp: OperatorOuterClass.Operator*): 
Option[OperatorOuterClass.Operator] = {
+    if (shouldUseShuffleScan(op)) {
+      convertToShuffleScan(op, builder)
+    } else {
+      super.convert(op, builder, childOp: _*)
+    }
+  }
+
+  private def shouldUseShuffleScan(op: SparkPlan): Boolean = {
+    if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false
+
+    // Extract the CometShuffleExchangeExec from the wrapper
+    val shuffleExec = op match {
+      case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s)
+      case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: 
CometShuffleExchangeExec), _) =>
+        Some(s)
+      case s: CometShuffleExchangeExec => Some(s)
+      case _ => None
+    }
+
+    shuffleExec.isDefined
+  }
+
+  private def convertToShuffleScan(
+      op: SparkPlan,
+      builder: Operator.Builder): Option[OperatorOuterClass.Operator] = {
+    val supportedTypes =
+      op.output.forall(a => supportedDataType(a.dataType, allowComplex = true))
+
+    if (!supportedTypes) {
+      withInfo(op, "Unsupported data type for shuffle direct read")
+      return None
+    }
+
+    val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder()
+    val source = op.simpleStringWithNodeId()
+    if (source.isEmpty) {
+      scanBuilder.setSource(op.getClass.getSimpleName)
+    } else {
+      scanBuilder.setSource(source)
+    }
+
+    val scanTypes = op.output.flatMap { attr =>
+      serializeDataType(attr.dataType)
+    }
+
+    if (scanTypes.length == op.output.length) {
+      scanBuilder.addAllFields(scanTypes.asJava)
+      builder.clearChildren()
+      Some(builder.setShuffleScan(scanBuilder).build())
+    } else {
+      withInfo(op, "unsupported data types for shuffle direct read")

Review Comment:
   this error message not really conforms to the condition IMO



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to