This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch comet-parquet-exec
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/comet-parquet-exec by this 
push:
     new 1cca8d6f feat: Hook DataFusion Parquet native scan with Comet 
execution (#1094)
1cca8d6f is described below

commit 1cca8d6f7bd2dfb8e1996bdd55ebe09d08eb8221
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Nov 19 12:26:59 2024 -0800

    feat: Hook DataFusion Parquet native scan with Comet execution (#1094)
    
    * init
    
    * more
    
    * fix
    
    * more
    
    * more
    
    * fix
---
 native/core/src/execution/datafusion/planner.rs    |  18 +-
 .../apache/comet/CometSparkSessionExtensions.scala |  14 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |   6 +-
 .../org/apache/spark/sql/comet/CometExecRDD.scala  |  56 +++
 .../spark/sql/comet/CometNativeScanExec.scala      | 447 ++-------------------
 .../org/apache/spark/sql/comet/operators.scala     |  62 ++-
 .../org/apache/comet/exec/CometExecSuite.scala     |  26 +-
 7 files changed, 166 insertions(+), 463 deletions(-)

diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index cd79e8e0..8cd30161 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -1023,17 +1023,15 @@ impl PhysicalPlanner {
                         .with_file_groups(file_groups);
 
                 // Check for projection, if so generate the vector and add to 
FileScanConfig.
-                if !required_schema_arrow.fields.is_empty() {
-                    let mut projection_vector: Vec<usize> =
-                        Vec::with_capacity(required_schema_arrow.fields.len());
-                    // TODO: could be faster with a hashmap rather than 
iterating over data_schema_arrow with index_of.
-                    required_schema_arrow.fields.iter().for_each(|field| {
-                        
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
-                    });
+                let mut projection_vector: Vec<usize> =
+                    Vec::with_capacity(required_schema_arrow.fields.len());
+                // TODO: could be faster with a hashmap rather than iterating 
over data_schema_arrow with index_of.
+                required_schema_arrow.fields.iter().for_each(|field| {
+                    
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
+                });
 
-                    assert_eq!(projection_vector.len(), 
required_schema_arrow.fields.len());
-                    file_scan_config = 
file_scan_config.with_projection(Some(projection_vector));
-                }
+                assert_eq!(projection_vector.len(), 
required_schema_arrow.fields.len());
+                file_scan_config = 
file_scan_config.with_projection(Some(projection_vector));
 
                 let mut table_parquet_options = TableParquetOptions::new();
                 // TODO: Maybe these are configs?
diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 6026fcff..d88f129a 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -202,8 +202,9 @@ class CometSparkSessionExtensions
               if CometNativeScanExec.isSchemaSupported(requiredSchema)
                 && CometNativeScanExec.isSchemaSupported(partitionSchema)
                 && COMET_FULL_NATIVE_SCAN_ENABLED.get =>
-            logInfo("Comet extension enabled for v1 Scan")
-            CometNativeScanExec(scanExec, session)
+            logInfo("Comet extension enabled for v1 full native Scan")
+            CometScanExec(scanExec, session)
+
           // data source V1
           case scanExec @ FileSourceScanExec(
                 HadoopFsRelation(_, partitionSchema, _, _, _: 
ParquetFileFormat, _),
@@ -365,6 +366,12 @@ class CometSparkSessionExtensions
       }
 
       plan.transformUp {
+        // Fully native scan for V1
+        case scan: CometScanExec if COMET_FULL_NATIVE_SCAN_ENABLED.get =>
+          val nativeOp = QueryPlanSerde.operator2Proto(scan).get
+          CometNativeScanExec(nativeOp, scan.wrapped, scan.session)
+
+        // Comet JVM + native scan for V1 and V2
         case op if isCometScan(op) =>
           val nativeOp = QueryPlanSerde.operator2Proto(op).get
           CometScanWrapper(nativeOp, op)
@@ -1221,8 +1228,7 @@ object CometSparkSessionExtensions extends Logging {
   }
 
   def isCometScan(op: SparkPlan): Boolean = {
-    op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec] ||
-    op.isInstanceOf[CometNativeScanExec]
+    op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
   }
 
   private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): 
Boolean = {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index f9a25466..b8a780e6 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, 
BuildRight, Normalize
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometNativeScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, 
DecimalPrecision}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometNativeScanExec, CometScanExec, CometSinkPlaceHolder, 
CometSparkToColumnarExec, DecimalPrecision}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution._
@@ -2481,7 +2481,9 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
     childOp.foreach(result.addChildren)
 
     op match {
-      case scan: CometNativeScanExec =>
+
+      // Fully native scan for V1
+      case scan: CometScanExec if CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.get 
=>
         val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
         nativeScanBuilder.setSource(op.simpleStringWithNodeId())
 
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
new file mode 100644
index 00000000..952515d5
--- /dev/null
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.rdd.{RDD, RDDOperationScope}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * A RDD that executes Spark SQL query in Comet native execution to generate 
ColumnarBatch.
+ */
+private[spark] class CometExecRDD(
+    sc: SparkContext,
+    partitionNum: Int,
+    var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch])
+    extends RDD[ColumnarBatch](sc, Nil) {
+
+  override def compute(s: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
+    f(Seq.empty)
+  }
+
+  override protected def getPartitions: Array[Partition] = {
+    Array.tabulate(partitionNum)(i =>
+      new Partition {
+        override def index: Int = i
+      })
+  }
+}
+
+object CometExecRDD {
+  def apply(sc: SparkContext, partitionNum: Int)(
+      f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): 
RDD[ColumnarBatch] =
+    withScope(sc) {
+      new CometExecRDD(sc, partitionNum, f)
+    }
+
+  private[spark] def withScope[U](sc: SparkContext)(body: => U): U =
+    RDDOperationScope.withScope[U](sc)(body)
+}
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala
index ccd7de0d..fd5afdb8 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala
@@ -19,38 +19,28 @@
 
 package org.apache.spark.sql.comet
 
-import scala.collection.mutable.HashMap
-import scala.concurrent.duration.NANOSECONDS
 import scala.reflect.ClassTag
 
-import org.apache.hadoop.fs.Path
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.comet.shims.ShimCometScanExec
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
-import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
-import org.apache.spark.sql.execution.metric._
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.ColumnarBatch
-import org.apache.spark.util.SerializableConfiguration
 import org.apache.spark.util.collection._
 
-import org.apache.comet.{CometConf, DataTypeSupport, MetricsSupport}
-import org.apache.comet.parquet.{CometParquetFileFormat, 
CometParquetPartitionReaderFactory}
+import com.google.common.base.Objects
+
+import org.apache.comet.DataTypeSupport
+import org.apache.comet.parquet.CometParquetFileFormat
+import org.apache.comet.serde.OperatorOuterClass.Operator
 
 /**
- * Comet physical scan node for DataSource V1. Most of the code here follow 
Spark's
- * [[FileSourceScanExec]],
+ * Comet fully native scan node for DataSource V1.
  */
 case class CometNativeScanExec(
+    override val nativeOp: Operator,
     @transient relation: HadoopFsRelation,
     override val output: Seq[Attribute],
     requiredSchema: StructType,
@@ -60,415 +50,34 @@ case class CometNativeScanExec(
     dataFilters: Seq[Expression],
     tableIdentifier: Option[TableIdentifier],
     disableBucketedScan: Boolean = false,
-    originalPlan: FileSourceScanExec)
-    extends CometPlan
-    with DataSourceScanExec
-    with ShimCometScanExec {
-
-  def wrapped: FileSourceScanExec = originalPlan
-
-  // FIXME: ideally we should reuse wrapped.supportsColumnar, however that 
fails many tests
-  override lazy val supportsColumnar: Boolean =
-    relation.fileFormat.supportBatch(relation.sparkSession, schema)
-
-  override def vectorTypes: Option[Seq[String]] = originalPlan.vectorTypes
-
-  private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty
-
-  /**
-   * Send the driver-side metrics. Before calling this function, 
selectedPartitions has been
-   * initialized. See SPARK-26327 for more details.
-   */
-  private def sendDriverMetrics(): Unit = {
-    driverMetrics.foreach(e => metrics(e._1).add(e._2))
-    val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    SQLMetrics.postDriverMetricUpdates(
-      sparkContext,
-      executionId,
-      metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq)
-  }
-
-  private def isDynamicPruningFilter(e: Expression): Boolean =
-    e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
-
-  @transient lazy val selectedPartitions: Array[PartitionDirectory] = {
-    val optimizerMetadataTimeNs = 
relation.location.metadataOpsTimeNs.getOrElse(0L)
-    val startTime = System.nanoTime()
-    val ret =
-      
relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), 
dataFilters)
-    setFilesNumAndSizeMetric(ret, true)
-    val timeTakenMs =
-      NANOSECONDS.toMillis((System.nanoTime() - startTime) + 
optimizerMetadataTimeNs)
-    driverMetrics("metadataTime") = timeTakenMs
-    ret
-  }.toArray
-
-  // We can only determine the actual partitions at runtime when a dynamic 
partition filter is
-  // present. This is because such a filter relies on information that is only 
available at run
-  // time (for instance the keys used in the other side of a join).
-  @transient private lazy val dynamicallySelectedPartitions: 
Array[PartitionDirectory] = {
-    val dynamicPartitionFilters = 
partitionFilters.filter(isDynamicPruningFilter)
-
-    if (dynamicPartitionFilters.nonEmpty) {
-      val startTime = System.nanoTime()
-      // call the file index for the files matching all filters except dynamic 
partition filters
-      val predicate = dynamicPartitionFilters.reduce(And)
-      val partitionColumns = relation.partitionSchema
-      val boundPredicate = Predicate.create(
-        predicate.transform { case a: AttributeReference =>
-          val index = partitionColumns.indexWhere(a.name == _.name)
-          BoundReference(index, partitionColumns(index).dataType, nullable = 
true)
-        },
-        Nil)
-      val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
-      setFilesNumAndSizeMetric(ret, false)
-      val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000
-      driverMetrics("pruningTime") = timeTakenMs
-      ret
-    } else {
-      selectedPartitions
-    }
-  }
-
-  // exposed for testing
-  lazy val bucketedScan: Boolean = originalPlan.bucketedScan
-
-  override lazy val (outputPartitioning, outputOrdering): (Partitioning, 
Seq[SortOrder]) =
-    (originalPlan.outputPartitioning, originalPlan.outputOrdering)
-
-  @transient
-  private lazy val pushedDownFilters = getPushedDownFilters(relation, 
dataFilters)
-
-  override lazy val metadata: Map[String, String] =
-    if (originalPlan == null) Map.empty else originalPlan.metadata
-
-  override def verboseStringWithOperatorId(): String = {
-    val metadataStr = metadata.toSeq.sorted
-      .filterNot {
-        case (_, value) if (value.isEmpty || value.equals("[]")) => true
-        case (key, _) if (key.equals("DataFilters") || key.equals("Format")) 
=> true
-        case (_, _) => false
-      }
-      .map {
-        case (key, _) if (key.equals("Location")) =>
-          val location = relation.location
-          val numPaths = location.rootPaths.length
-          val abbreviatedLocation = if (numPaths <= 1) {
-            location.rootPaths.mkString("[", ", ", "]")
-          } else {
-            "[" + location.rootPaths.head + s", ... ${numPaths - 1} entries]"
-          }
-          s"$key: ${location.getClass.getSimpleName} 
${redact(abbreviatedLocation)}"
-        case (key, value) => s"$key: ${redact(value)}"
-      }
-
-    s"""
-       |$formattedNodeName
-       |${ExplainUtils.generateFieldString("Output", output)}
-       |${metadataStr.mkString("\n")}
-       |""".stripMargin
-  }
-
-  lazy val inputRDD: RDD[InternalRow] = {
-    val options = relation.options +
-      (FileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString)
-    val readFile: (PartitionedFile) => Iterator[InternalRow] =
-      relation.fileFormat.buildReaderWithPartitionValues(
-        sparkSession = relation.sparkSession,
-        dataSchema = relation.dataSchema,
-        partitionSchema = relation.partitionSchema,
-        requiredSchema = requiredSchema,
-        filters = pushedDownFilters,
-        options = options,
-        hadoopConf =
-          
relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
-
-    val readRDD = if (bucketedScan) {
-      createBucketedReadRDD(
-        relation.bucketSpec.get,
-        readFile,
-        dynamicallySelectedPartitions,
-        relation)
-    } else {
-      createReadRDD(readFile, dynamicallySelectedPartitions, relation)
-    }
-    sendDriverMetrics()
-    readRDD
-  }
-
-  override def inputRDDs(): Seq[RDD[InternalRow]] = {
-    inputRDD :: Nil
-  }
-
-  /** Helper for computing total number and size of files in selected 
partitions. */
-  private def setFilesNumAndSizeMetric(
-      partitions: Seq[PartitionDirectory],
-      static: Boolean): Unit = {
-    val filesNum = partitions.map(_.files.size.toLong).sum
-    val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
-    if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
-      driverMetrics("numFiles") = filesNum
-      driverMetrics("filesSize") = filesSize
-    } else {
-      driverMetrics("staticFilesNum") = filesNum
-      driverMetrics("staticFilesSize") = filesSize
-    }
-    if (relation.partitionSchema.nonEmpty) {
-      driverMetrics("numPartitions") = partitions.length
-    }
-  }
-
-  override lazy val metrics: Map[String, SQLMetric] = originalPlan.metrics ++ {
-    // Tracking scan time has overhead, we can't afford to do it for each row, 
and can only do
-    // it for each batch.
-    if (supportsColumnar) {
-      Map(
-        "scanTime" -> SQLMetrics.createNanoTimingMetric(
-          sparkContext,
-          "scan time")) ++ CometMetricNode.scanMetrics(sparkContext)
-    } else {
-      Map.empty
-    }
-  } ++ {
-    relation.fileFormat match {
-      case f: MetricsSupport => f.initMetrics(sparkContext)
-      case _ => Map.empty
-    }
-  }
-
-  override def doExecute(): RDD[InternalRow] = {
-    ColumnarToRowExec(this).doExecute()
-  }
-
-  override def doExecuteColumnar(): RDD[ColumnarBatch] = {
-    val numOutputRows = longMetric("numOutputRows")
-    val scanTime = longMetric("scanTime")
-    inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches 
=>
-      new Iterator[ColumnarBatch] {
-
-        override def hasNext: Boolean = {
-          // The `FileScanRDD` returns an iterator which scans the file during 
the `hasNext` call.
-          val startNs = System.nanoTime()
-          val res = batches.hasNext
-          scanTime += System.nanoTime() - startNs
-          res
-        }
-
-        override def next(): ColumnarBatch = {
-          val batch = batches.next()
-          numOutputRows += batch.numRows()
-          batch
-        }
-      }
-    }
-  }
-
-  override def executeCollect(): Array[InternalRow] = {
-    ColumnarToRowExec(this).executeCollect()
-  }
-
-  override val nodeName: String =
-    s"CometNativeScan $relation 
${tableIdentifier.map(_.unquotedString).getOrElse("")}"
-
-  /**
-   * Create an RDD for bucketed reads. The non-bucketed variant of this 
function is
-   * [[createReadRDD]].
-   *
-   * The algorithm is pretty simple: each RDD partition being returned should 
include all the
-   * files with the same bucket id from all the given Hive partitions.
-   *
-   * @param bucketSpec
-   *   the bucketing spec.
-   * @param readFile
-   *   a function to read each (part of a) file.
-   * @param selectedPartitions
-   *   Hive-style partition that are part of the read.
-   * @param fsRelation
-   *   [[HadoopFsRelation]] associated with the read.
-   */
-  private def createBucketedReadRDD(
-      bucketSpec: BucketSpec,
-      readFile: (PartitionedFile) => Iterator[InternalRow],
-      selectedPartitions: Array[PartitionDirectory],
-      fsRelation: HadoopFsRelation): RDD[InternalRow] = {
-    logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
-    val filesGroupedToBuckets =
-      selectedPartitions
-        .flatMap { p =>
-          p.files.map { f =>
-            getPartitionedFile(f, p)
-          }
-        }
-        .groupBy { f =>
-          BucketingUtils
-            .getBucketId(new Path(f.filePath.toString()).getName)
-            .getOrElse(throw invalidBucketFile(f.filePath.toString(), 
sparkContext.version))
-        }
-
-    val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
-      val bucketSet = optionalBucketSet.get
-      filesGroupedToBuckets.filter { f =>
-        bucketSet.get(f._1)
-      }
-    } else {
-      filesGroupedToBuckets
-    }
+    originalPlan: FileSourceScanExec,
+    override val serializedPlanOpt: SerializedPlan)
+    extends CometLeafExec {
 
-    val filePartitions = optionalNumCoalescedBuckets
-      .map { numCoalescedBuckets =>
-        logInfo(s"Coalescing to ${numCoalescedBuckets} buckets")
-        val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % 
numCoalescedBuckets)
-        Seq.tabulate(numCoalescedBuckets) { bucketId =>
-          val partitionedFiles = coalescedBuckets
-            .get(bucketId)
-            .map {
-              _.values.flatten.toArray
-            }
-            .getOrElse(Array.empty)
-          FilePartition(bucketId, partitionedFiles)
-        }
-      }
-      .getOrElse {
-        Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
-          FilePartition(bucketId, 
prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
-        }
-      }
+  override def outputPartitioning: Partitioning =
+    UnknownPartitioning(originalPlan.inputRDD.getNumPartitions)
+  override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering
 
-    prepareRDD(fsRelation, readFile, filePartitions)
-  }
+  override def stringArgs: Iterator[Any] = Iterator(output)
 
-  /**
-   * Create an RDD for non-bucketed reads. The bucketed variant of this 
function is
-   * [[createBucketedReadRDD]].
-   *
-   * @param readFile
-   *   a function to read each (part of a) file.
-   * @param selectedPartitions
-   *   Hive-style partition that are part of the read.
-   * @param fsRelation
-   *   [[HadoopFsRelation]] associated with the read.
-   */
-  private def createReadRDD(
-      readFile: (PartitionedFile) => Iterator[InternalRow],
-      selectedPartitions: Array[PartitionDirectory],
-      fsRelation: HadoopFsRelation): RDD[InternalRow] = {
-    val openCostInBytes = 
fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
-    val maxSplitBytes =
-      FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions)
-    logInfo(
-      s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
-        s"open cost is considered as scanning $openCostInBytes bytes.")
-
-    // Filter files with bucket pruning if possible
-    val bucketingEnabled = 
fsRelation.sparkSession.sessionState.conf.bucketingEnabled
-    val shouldProcess: Path => Boolean = optionalBucketSet match {
-      case Some(bucketSet) if bucketingEnabled =>
-        // Do not prune the file if bucket file name is invalid
-        filePath => 
BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get)
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case other: CometNativeScanExec =>
+        this.output == other.output &&
+        this.serializedPlanOpt == other.serializedPlanOpt
       case _ =>
-        _ => true
+        false
     }
-
-    val splitFiles = selectedPartitions
-      .flatMap { partition =>
-        partition.files.flatMap { file =>
-          // getPath() is very expensive so we only want to call it once in 
this block:
-          val filePath = file.getPath
-
-          if (shouldProcess(filePath)) {
-            val isSplitable = relation.fileFormat.isSplitable(
-              relation.sparkSession,
-              relation.options,
-              filePath) &&
-              // SPARK-39634: Allow file splitting in combination with row 
index generation once
-              // the fix for PARQUET-2161 is available.
-              !isNeededForSchema(requiredSchema)
-            super.splitFiles(
-              sparkSession = relation.sparkSession,
-              file = file,
-              filePath = filePath,
-              isSplitable = isSplitable,
-              maxSplitBytes = maxSplitBytes,
-              partitionValues = partition.values)
-          } else {
-            Seq.empty
-          }
-        }
-      }
-      .sortBy(_.length)(implicitly[Ordering[Long]].reverse)
-
-    prepareRDD(
-      fsRelation,
-      readFile,
-      FilePartition.getFilePartitions(relation.sparkSession, splitFiles, 
maxSplitBytes))
-  }
-
-  private def prepareRDD(
-      fsRelation: HadoopFsRelation,
-      readFile: (PartitionedFile) => Iterator[InternalRow],
-      partitions: Seq[FilePartition]): RDD[InternalRow] = {
-    val hadoopConf = 
relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)
-    val prefetchEnabled = hadoopConf.getBoolean(
-      CometConf.COMET_SCAN_PREFETCH_ENABLED.key,
-      CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get)
-
-    val sqlConf = fsRelation.sparkSession.sessionState.conf
-    if (prefetchEnabled) {
-      CometParquetFileFormat.populateConf(sqlConf, hadoopConf)
-      val broadcastedConf =
-        fsRelation.sparkSession.sparkContext.broadcast(new 
SerializableConfiguration(hadoopConf))
-      val partitionReaderFactory = CometParquetPartitionReaderFactory(
-        sqlConf,
-        broadcastedConf,
-        requiredSchema,
-        relation.partitionSchema,
-        pushedDownFilters.toArray,
-        new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf),
-        metrics)
-
-      new DataSourceRDD(
-        fsRelation.sparkSession.sparkContext,
-        partitions.map(Seq(_)),
-        partitionReaderFactory,
-        true,
-        Map.empty)
-    } else {
-      newFileScanRDD(
-        fsRelation,
-        readFile,
-        partitions,
-        new StructType(requiredSchema.fields ++ 
fsRelation.partitionSchema.fields),
-        new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf))
-    }
-  }
-
-  // Filters unused DynamicPruningExpression expressions - one which has been 
replaced
-  // with DynamicPruningExpression(Literal.TrueLiteral) during Physical 
Planning
-  private def filterUnusedDynamicPruningExpressions(
-      predicates: Seq[Expression]): Seq[Expression] = {
-    predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral))
-  }
-
-  override def doCanonicalize(): CometNativeScanExec = {
-    CometNativeScanExec(
-      relation,
-      output.map(QueryPlan.normalizeExpressions(_, output)),
-      requiredSchema,
-      QueryPlan.normalizePredicates(
-        filterUnusedDynamicPruningExpressions(partitionFilters),
-        output),
-      optionalBucketSet,
-      optionalNumCoalescedBuckets,
-      QueryPlan.normalizePredicates(dataFilters, output),
-      None,
-      disableBucketedScan,
-      null)
   }
 
+  override def hashCode(): Int = Objects.hashCode(output)
 }
 
 object CometNativeScanExec extends DataTypeSupport {
-  def apply(scanExec: FileSourceScanExec, session: SparkSession): 
CometNativeScanExec = {
+  def apply(
+      nativeOp: Operator,
+      scanExec: FileSourceScanExec,
+      session: SparkSession): CometNativeScanExec = {
     // TreeNode.mapProductIterator is protected method.
     def mapProductIterator[B: ClassTag](product: Product, f: Any => B): 
Array[B] = {
       val arr = Array.ofDim[B](product.productArity)
@@ -493,6 +102,7 @@ object CometNativeScanExec extends DataTypeSupport {
     val newArgs = mapProductIterator(scanExec, transform(_))
     val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec]
     val batchScanExec = CometNativeScanExec(
+      nativeOp,
       wrapped.relation,
       wrapped.output,
       wrapped.requiredSchema,
@@ -502,7 +112,8 @@ object CometNativeScanExec extends DataTypeSupport {
       wrapped.dataFilters,
       wrapped.tableIdentifier,
       wrapped.disableBucketedScan,
-      wrapped)
+      wrapped,
+      SerializedPlan(None))
     scanExec.logicalLink.foreach(batchScanExec.setLogicalLink)
     batchScanExec
   }
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 293bc35f..8b50ad19 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -249,53 +249,77 @@ abstract class CometNativeExec extends CometExec {
           case _ => true
         }
 
+        val containsBroadcastInput = sparkPlans.exists {
+          case _: CometBroadcastExchangeExec => true
+          case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => 
true
+          case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true
+          case _ => false
+        }
+
         // If the first non broadcast plan is not found, it means all the 
plans are broadcast plans.
         // This is not expected, so throw an exception.
-        if (firstNonBroadcastPlan.isEmpty) {
+        if (containsBroadcastInput && firstNonBroadcastPlan.isEmpty) {
           throw new CometRuntimeException(s"Cannot find the first non 
broadcast plan: $this")
         }
 
         // If the first non broadcast plan is found, we need to adjust the 
partition number of
         // the broadcast plans to make sure they have the same partition 
number as the first non
         // broadcast plan.
-        val firstNonBroadcastPlanRDD = 
firstNonBroadcastPlan.get._1.executeColumnar()
-        val firstNonBroadcastPlanNumPartitions = 
firstNonBroadcastPlanRDD.getNumPartitions
+        val firstNonBroadcastPlanNumPartitions =
+          firstNonBroadcastPlan.map(_._1.outputPartitioning.numPartitions)
 
         // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule 
Broadcast RDDs with
         // same partition number. But for Comet, we need to zip them so we 
need to adjust the
         // partition number of Broadcast RDDs to make sure they have the same 
partition number.
         sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
           plan match {
-            case c: CometBroadcastExchangeExec =>
-              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
-            case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) 
=>
-              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
-            case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
-              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
+            case c: CometBroadcastExchangeExec if 
firstNonBroadcastPlanNumPartitions.nonEmpty =>
+              inputs += c
+                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
+                .executeColumnar()
+            case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _)
+                if firstNonBroadcastPlanNumPartitions.nonEmpty =>
+              inputs += c
+                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
+                .executeColumnar()
+            case ReusedExchangeExec(_, c: CometBroadcastExchangeExec)
+                if firstNonBroadcastPlanNumPartitions.nonEmpty =>
+              inputs += c
+                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
+                .executeColumnar()
             case BroadcastQueryStageExec(
                   _,
                   ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
-                  _) =>
-              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
-            case _ if idx == firstNonBroadcastPlan.get._2 =>
-              inputs += firstNonBroadcastPlanRDD
-            case _ =>
+                  _) if firstNonBroadcastPlanNumPartitions.nonEmpty =>
+              inputs += c
+                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
+                .executeColumnar()
+            case _: CometNativeExec =>
+            // no-op
+            case _ if firstNonBroadcastPlanNumPartitions.nonEmpty =>
               val rdd = plan.executeColumnar()
-              if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) {
+              if (plan.outputPartitioning.numPartitions != 
firstNonBroadcastPlanNumPartitions.get) {
                 throw new CometRuntimeException(
                   s"Partition number mismatch: ${rdd.getNumPartitions} != " +
-                    s"$firstNonBroadcastPlanNumPartitions")
+                    s"${firstNonBroadcastPlanNumPartitions.get}")
               } else {
                 inputs += rdd
               }
+            case _ =>
+              throw new CometRuntimeException(s"Unexpected plan: $plan")
           }
         }
 
-        if (inputs.isEmpty) {
+        if (inputs.isEmpty && 
!sparkPlans.forall(_.isInstanceOf[CometNativeExec])) {
           throw new CometRuntimeException(s"No input for CometNativeExec:\n 
$this")
         }
 
-        ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
+        if (inputs.nonEmpty) {
+          ZippedPartitionsRDD(sparkContext, 
inputs.toSeq)(createCometExecIter(_))
+        } else {
+          val partitionNum = firstNonBroadcastPlanNumPartitions.get
+          CometExecRDD(sparkContext, partitionNum)(createCometExecIter(_))
+        }
     }
   }
 
@@ -402,6 +426,8 @@ abstract class CometNativeExec extends CometExec {
   }
 }
 
+abstract class CometLeafExec extends CometNativeExec with LeafExecNode
+
 abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode
 
 abstract class CometBinaryExec extends CometNativeExec with BinaryExecNode
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 73ccbbd6..a54b70ea 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -119,17 +119,19 @@ class CometExecSuite extends CometTestBase {
   }
 
   test("ShuffleQueryStageExec could be direct child node of 
CometBroadcastExchangeExec") {
-    val table = "src"
-    withTable(table) {
-      withView("lv_noalias") {
-        sql(s"CREATE TABLE $table (key INT, value STRING) USING PARQUET")
-        sql(s"INSERT INTO $table VALUES(238, 'val_238')")
+    withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") {
+      val table = "src"
+      withTable(table) {
+        withView("lv_noalias") {
+          sql(s"CREATE TABLE $table (key INT, value STRING) USING PARQUET")
+          sql(s"INSERT INTO $table VALUES(238, 'val_238')")
 
-        sql(
-          "CREATE VIEW lv_noalias AS SELECT myTab.* FROM src " +
-            "LATERAL VIEW explode(map('key1', 100, 'key2', 200)) myTab LIMIT 
2")
-        val df = sql("SELECT * FROM lv_noalias a JOIN lv_noalias b ON 
a.key=b.key");
-        checkSparkAnswer(df)
+          sql(
+            "CREATE VIEW lv_noalias AS SELECT myTab.* FROM src " +
+              "LATERAL VIEW explode(map('key1', 100, 'key2', 200)) myTab LIMIT 
2")
+          val df = sql("SELECT * FROM lv_noalias a JOIN lv_noalias b ON 
a.key=b.key")
+          checkSparkAnswer(df)
+        }
       }
     }
   }
@@ -551,7 +553,9 @@ class CometExecSuite extends CometTestBase {
   }
 
   test("Comet native metrics: scan") {
-    withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "true") {
+    withSQLConf(
+      CometConf.COMET_EXEC_ENABLED.key -> "true",
+      CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "false") {
       withTempDir { dir =>
         val path = new Path(dir.toURI.toString, "native-scan.parquet")
         makeParquetFileAllTypes(path, dictionaryEnabled = true, 10000)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to