wForget commented on code in PR #3731:
URL: https://github.com/apache/datafusion-comet/pull/3731#discussion_r2964129024
##########
spark/src/main/scala/org/apache/spark/sql/comet/operators.scala:
##########
@@ -606,6 +613,108 @@ abstract class CometNativeExec extends CometExec {
}
}
+ /**
+ * Walk the serialized protobuf plan depth-first to find which input indices
correspond to
+ * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes
one input in order.
+ */
+ private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = {
+ val plan = OperatorOuterClass.Operator.parseFrom(planBytes)
+ var scanIndex = 0
+ val indices = mutable.Set.empty[Int]
+ def walk(op: OperatorOuterClass.Operator): Unit = {
+ if (op.hasShuffleScan) {
+ indices += scanIndex
+ scanIndex += 1
+ } else if (op.hasScan) {
+ scanIndex += 1
+ } else {
+ op.getChildrenList.asScala.foreach(walk)
+ }
+ }
+ walk(plan)
+ indices.toSet
+ }
+
+ /**
+ * Build factory functions that produce CometShuffleBlockIterator for each
input index that is a
+ * ShuffleScan. Maps from input index to a factory that, given TaskContext
and Partition,
+ * creates the iterator.
+ */
+ private def buildShuffleBlockIteratorFactories(
+ sparkPlans: ArrayBuffer[SparkPlan],
+ inputs: ArrayBuffer[RDD[ColumnarBatch]],
+ shuffleScanIndices: Set[Int])
+ : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = {
+ if (shuffleScanIndices.isEmpty) return Map.empty
+
+ val factories = mutable.Map.empty[Int, (TaskContext, Partition) =>
CometShuffleBlockIterator]
+
+ shuffleScanIndices.foreach { scanIdx =>
+ if (scanIdx < inputs.length) {
+ inputs(scanIdx) match {
+ case rdd: CometShuffledBatchRDD =>
+ val dep = rdd.dependency
+ val rddMetrics = rdd.metrics
+ factories(scanIdx) = (context, part) => {
Review Comment:
Much of the logic here duplicates `CometShuffledBatchRDD#compute`. Perhaps
we could add a `computeAsShuffleBlockIterator` method to
`CometShuffledBatchRDD` and reuse `createReader` logic. Like:
```
class CometShuffledBatchRDD {
def computeAsShuffleBlockIterator(context: TaskContext, split: Partition):
CometShuffleBlockIterator = {
...
}
}
factories(scanIdx) = rdd.computeAsShuffleBlockIterator
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]