wForget commented on code in PR #3731:
URL: https://github.com/apache/datafusion-comet/pull/3731#discussion_r2964182143
##########
spark/src/main/scala/org/apache/spark/sql/comet/operators.scala:
##########
@@ -606,6 +613,108 @@ abstract class CometNativeExec extends CometExec {
}
}
+ /**
+ * Walk the serialized protobuf plan depth-first to find which input indices
correspond to
+ * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes
one input in order.
+ */
+ private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = {
+ val plan = OperatorOuterClass.Operator.parseFrom(planBytes)
+ var scanIndex = 0
+ val indices = mutable.Set.empty[Int]
+ def walk(op: OperatorOuterClass.Operator): Unit = {
+ if (op.hasShuffleScan) {
+ indices += scanIndex
+ scanIndex += 1
+ } else if (op.hasScan) {
+ scanIndex += 1
+ } else {
+ op.getChildrenList.asScala.foreach(walk)
+ }
+ }
+ walk(plan)
+ indices.toSet
+ }
+
+ /**
+ * Build factory functions that produce CometShuffleBlockIterator for each
input index that is a
+ * ShuffleScan. Maps from input index to a factory that, given TaskContext
and Partition,
+ * creates the iterator.
+ */
+ private def buildShuffleBlockIteratorFactories(
+ sparkPlans: ArrayBuffer[SparkPlan],
+ inputs: ArrayBuffer[RDD[ColumnarBatch]],
+ shuffleScanIndices: Set[Int])
+ : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = {
+ if (shuffleScanIndices.isEmpty) return Map.empty
+
+ val factories = mutable.Map.empty[Int, (TaskContext, Partition) =>
CometShuffleBlockIterator]
+
+ shuffleScanIndices.foreach { scanIdx =>
+ if (scanIdx < inputs.length) {
+ inputs(scanIdx) match {
+ case rdd: CometShuffledBatchRDD =>
+ val dep = rdd.dependency
+ val rddMetrics = rdd.metrics
+ factories(scanIdx) = (context, part) => {
Review Comment:
In that case, we no longer need to `buildShuffleBlockIteratorFactories`; we
can compute it in CometExecRDD like this:
```
class CometExecRDD {
override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometExecPartition]
val inputs = inputRDDs.zip(partition.inputPartitions).zipWithIndex.map {
case ((rdd: CometShuffledBatchRDD, part), idx) if
shuffleScanIndices.contains(idx) =>
rdd.computeAsShuffleBlockIterator(part, context)
case ((rdd, part), _) => rdd.iterator(part, context)
}
...
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]