This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 01191c83f8c [SPARK-44505][SQL] Provide override for columnar support 
in Scan for DSv2
01191c83f8c is described below

commit 01191c83f8c77f5dcc85b9017551023d81ed0d45
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Thu Jul 27 20:53:43 2023 +0800

    [SPARK-44505][SQL] Provide override for columnar support in Scan for DSv2
    
    ### What changes were proposed in this pull request?
    Previously, when a new DSv2 data source is implemented during planning, it 
will always call `BatchScanExec:supportsColumnar` which will in turn iterate 
over all input partitions to check if they support columnar or not.
    
    When the `planInputPartitions` method is expensive this can be problematic. 
This patch adds an option to the Scan interface that allows specifying a 
default value. For backward compatibility the default value provided by the 
Scan interface is partition defined, but a Scan can change it accordingly.
    
    To fully support the changes of this PR, the following additional changes 
had to be done:
    
    * `DataSourceV2ScanExecBase::outputPartitioning` removed the case for 
single partitions.
    * `lazyval DataSourceV2ScanExecBase::groupedPartitions` added a special 
check for empty key group partitioning so that the simple case does not trigger 
a materialization of the input partitions during planning.
    
    Additionally:
    * Fixes similar issues as https://github.com/apache/spark/pull/40004
    
    ### Why are the changes needed?
    Avoid costly operations during explain operations.
    
    ### Does this PR introduce _any_ user-facing change?
    Np
    
    ### How was this patch tested?
    Added new UT.
    
    Closes #42099 from grundprinzip/SPARK-44505.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/connector/read/Scan.java  | 24 ++++++++++++
 .../datasources/v2/DataSourceV2ScanExecBase.scala  | 44 +++++++++++++---------
 .../spark/sql/connector/DataSourceV2Suite.scala    | 37 ++++++++++++++++++
 .../connector/KeyGroupedPartitioningSuite.scala    |  5 ++-
 .../command/AlignAssignmentsSuiteBase.scala        |  4 +-
 5 files changed, 93 insertions(+), 21 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java
index 8f79c656210..969a47be707 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java
@@ -125,4 +125,28 @@ public interface Scan {
   default CustomTaskMetric[] reportDriverMetrics() {
     return new CustomTaskMetric[]{};
   }
+
+  /**
+   * This enum defines how the columnar support for the partitions of the data 
source
+   * should be determined. The default value is `PARTITION_DEFINED` which 
indicates that each
+   * partition can determine if it should be columnar or not. SUPPORTED and 
UNSUPPORTED provide
+   * default shortcuts to indicate support for columnar data or not.
+   *
+   * @since 3.5.0
+   */
+  enum ColumnarSupportMode {
+    PARTITION_DEFINED,
+    SUPPORTED,
+    UNSUPPORTED
+  }
+
+  /**
+   * Subclasses can implement this method to indicate if the support for 
columnar data should
+   * be determined by each partition or is set as a default for the whole scan.
+   *
+   * @since 3.5.0
+   */
+  default ColumnarSupportMode columnarSupportMode() {
+    return ColumnarSupportMode.PARTITION_DEFINED;
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
index e539b1c4ee3..f688d3514d9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering, 
SortOrder}
 import org.apache.spark.sql.catalyst.plans.physical
-import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, 
SinglePartition}
+import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
 import org.apache.spark.sql.catalyst.util.{truncatedString, 
InternalRowComparableWrapper}
 import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, 
PartitionReaderFactory, Scan}
 import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, 
SQLExecution}
@@ -91,22 +91,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
   }
 
   override def outputPartitioning: physical.Partitioning = {
-    if (partitions.length == 1) {
-      SinglePartition
-    } else {
-      keyGroupedPartitioning match {
-        case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) 
=>
-          groupedPartitions.map { partitionValues =>
+    keyGroupedPartitioning match {
+      case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) =>
+        groupedPartitions
+          .map { partitionValues =>
             KeyGroupedPartitioning(exprs, partitionValues.size, 
partitionValues.map(_._1))
-          }.getOrElse(super.outputPartitioning)
-        case _ =>
-          super.outputPartitioning
-      }
+          }
+          .getOrElse(super.outputPartitioning)
+      case _ =>
+        super.outputPartitioning
     }
   }
 
-  @transient lazy val groupedPartitions: Option[Seq[(InternalRow, 
Seq[InputPartition])]] =
-    groupPartitions(inputPartitions)
+  @transient lazy val groupedPartitions: Option[Seq[(InternalRow, 
Seq[InputPartition])]] = {
+    // Early check if we actually need to materialize the input partitions.
+    keyGroupedPartitioning match {
+      case Some(_) => groupPartitions(inputPartitions)
+      case _ => None
+    }
+  }
 
   /**
    * Group partition values for all the input partitions. This returns `Some` 
iff:
@@ -170,11 +173,16 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
   }
 
   override def supportsColumnar: Boolean = {
-    require(inputPartitions.forall(readerFactory.supportColumnarReads) ||
-      !inputPartitions.exists(readerFactory.supportColumnarReads),
-      "Cannot mix row-based and columnar input partitions.")
-
-    inputPartitions.exists(readerFactory.supportColumnarReads)
+    scan.columnarSupportMode() match {
+      case Scan.ColumnarSupportMode.PARTITION_DEFINED =>
+        require(
+          inputPartitions.forall(readerFactory.supportColumnarReads) ||
+            !inputPartitions.exists(readerFactory.supportColumnarReads),
+          "Cannot mix row-based and columnar input partitions.")
+        inputPartitions.exists(readerFactory.supportColumnarReads)
+      case Scan.ColumnarSupportMode.SUPPORTED => true
+      case Scan.ColumnarSupportMode.UNSUPPORTED => false
+    }
   }
 
   def inputRDD: RDD[InternalRow]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 236b4c702d1..52d0151ee46 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -29,6 +29,7 @@ import 
org.apache.spark.sql.connector.catalog.TableCapability._
 import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, 
Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.connector.read.Scan.ColumnarSupportMode
 import 
org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, 
Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.execution.SortExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -388,6 +389,23 @@ class DataSourceV2Suite extends QueryTest with 
SharedSparkSession with AdaptiveS
     assert(df.queryExecution.executedPlan.collect { case e: Exchange => e 
}.isEmpty)
   }
 
+  test("SPARK-44505: should not call planInputPartitions() on explain") {
+    val df = spark.read.format(classOf[ScanDefinedColumnarSupport].getName)
+      .option("columnar", "PARTITION_DEFINED").load()
+    // Default mode will throw an exception on explain.
+    var ex = intercept[IllegalArgumentException](df.explain())
+    assert(ex.getMessage == "planInputPartitions must not be called")
+
+    Seq("SUPPORTED", "UNSUPPORTED").foreach { o =>
+      val dfScan = 
spark.read.format(classOf[ScanDefinedColumnarSupport].getName)
+        .option("columnar", o).load()
+      dfScan.explain()
+      //  Will fail during regular execution.
+      ex = intercept[IllegalArgumentException](dfScan.count())
+      assert(ex.getMessage == "planInputPartitions must not be called")
+    }
+  }
+
   test("simple writable data source") {
     Seq(classOf[SimpleWritableDataSource], 
classOf[JavaSimpleWritableDataSource]).foreach { cls =>
       withTempPath { file =>
@@ -686,6 +704,25 @@ class SimpleSinglePartitionSource extends TestingV2Source {
   }
 }
 
+class ScanDefinedColumnarSupport extends TestingV2Source {
+
+  class MyScanBuilder(st: ColumnarSupportMode) extends SimpleScanBuilder {
+    override def planInputPartitions(): Array[InputPartition] = {
+      throw new IllegalArgumentException("planInputPartitions must not be 
called")
+    }
+
+    override def columnarSupportMode() : ColumnarSupportMode = st
+
+  }
+
+  override def getTable(options: CaseInsensitiveStringMap): Table = new 
SimpleBatchTable {
+    override def newScanBuilder(options: CaseInsensitiveStringMap): 
ScanBuilder = {
+      new 
MyScanBuilder(Scan.ColumnarSupportMode.valueOf(options.get("columnar")))
+    }
+  }
+
+}
+
 
 // This class is used by pyspark tests. If this class is modified/moved, make 
sure pyspark
 // tests still pass.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index be5e1b524e5..8be3c6d9e13 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -128,7 +128,10 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     val distribution = physical.ClusteredDistribution(
       Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
 
-    checkQueryPlan(df, distribution, physical.SinglePartition)
+    // Has exactly one partition.
+    val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
+    checkQueryPlan(df, distribution,
+      physical.KeyGroupedPartitioning(distribution.clustering, 1, 
partitionValues))
   }
 
   test("non-clustered distribution: no V2 catalog") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
index 66a986da936..a2f3d872a68 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
@@ -147,7 +147,7 @@ abstract class AlignAssignmentsSuiteBase extends 
AnalysisTest {
   private val v2Catalog = {
     val newCatalog = mock(classOf[TableCatalog])
     when(newCatalog.loadTable(any())).thenAnswer((invocation: 
InvocationOnMock) => {
-      val ident = invocation.getArgument[Identifier](0)
+      val ident = invocation.getArguments()(0).asInstanceOf[Identifier]
       ident.name match {
         case "primitive_table" => primitiveTable
         case "primitive_table_src" => primitiveTableSource
@@ -172,7 +172,7 @@ abstract class AlignAssignmentsSuiteBase extends 
AnalysisTest {
   private val catalogManager = {
     val manager = mock(classOf[CatalogManager])
     when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => {
-      invocation.getArgument[String](0) match {
+      invocation.getArguments()(0).asInstanceOf[String] match {
         case "testcat" => v2Catalog
         case CatalogManager.SESSION_CATALOG_NAME => v2SessionCatalog
         case name => throw new CatalogNotFoundException(s"No such catalog: 
$name")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to