Repository: spark
Updated Branches:
  refs/heads/branch-2.3 77cccc5e1 -> f9c913263


[SPARK-23315][SQL] failed to get output from canonicalized data source v2 
related plans

## What changes were proposed in this pull request?

`DataSourceV2Relation`  keeps a `fullOutput` and resolves the real output on 
demand by column name lookup. i.e.
```
lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name =>
  fullOutput.find(_.name == name).get
}
```

This will be broken after we canonicalize the plan, because all attribute names 
become "None", see 
https://github.com/apache/spark/blob/v2.3.0-rc1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala#L42

To fix this, `DataSourceV2Relation` should just keep `output`, and update the 
`output` when doing column pruning.

## How was this patch tested?

a new test case

Author: Wenchen Fan <wenc...@databricks.com>

Closes #20485 from cloud-fan/canonicalize.

(cherry picked from commit b96a083b1c6ff0d2c588be9499b456e1adce97dc)
Signed-off-by: gatorsmile <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f9c91326
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f9c91326
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f9c91326

Branch: refs/heads/branch-2.3
Commit: f9c913263219f5e8a375542994142645dd0f6c6a
Parents: 77cccc5
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Feb 6 12:43:45 2018 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Feb 6 12:43:53 2018 -0800

----------------------------------------------------------------------
 .../datasources/v2/DataSourceReaderHolder.scala | 12 +++-----
 .../datasources/v2/DataSourceV2Relation.scala   |  8 +++---
 .../datasources/v2/DataSourceV2ScanExec.scala   |  4 +--
 .../v2/PushDownOperatorsToDataSource.scala      | 29 ++++++++++++++------
 .../sql/sources/v2/DataSourceV2Suite.scala      | 20 +++++++++++++-
 5 files changed, 48 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f9c91326/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala
index 6460c97..81219e9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import java.util.Objects
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.sources.v2.reader._
 
 /**
@@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._
 trait DataSourceReaderHolder {
 
   /**
-   * The full output of the data source reader, without column pruning.
+   * The output of the data source reader, w.r.t. column pruning.
    */
-  def fullOutput: Seq[AttributeReference]
+  def output: Seq[Attribute]
 
   /**
    * The held data source reader.
@@ -46,7 +46,7 @@ trait DataSourceReaderHolder {
       case s: SupportsPushDownFilters => s.pushedFilters().toSet
       case _ => Nil
     }
-    Seq(fullOutput, reader.getClass, reader.readSchema(), filters)
+    Seq(output, reader.getClass, filters)
   }
 
   def canEqual(other: Any): Boolean
@@ -61,8 +61,4 @@ trait DataSourceReaderHolder {
   override def hashCode(): Int = {
     metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
   }
-
-  lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name 
=>
-    fullOutput.find(_.name == name).get
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f9c91326/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index eebfa29..38f6b15 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, 
Statistics}
 import org.apache.spark.sql.sources.v2.reader._
 
 case class DataSourceV2Relation(
-    fullOutput: Seq[AttributeReference],
+    output: Seq[AttributeReference],
     reader: DataSourceReader)
   extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {
 
@@ -37,7 +37,7 @@ case class DataSourceV2Relation(
   }
 
   override def newInstance(): DataSourceV2Relation = {
-    copy(fullOutput = fullOutput.map(_.newInstance()))
+    copy(output = output.map(_.newInstance()))
   }
 }
 
@@ -46,8 +46,8 @@ case class DataSourceV2Relation(
  * to the non-streaming relation.
  */
 class StreamingDataSourceV2Relation(
-    fullOutput: Seq[AttributeReference],
-    reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) 
{
+    output: Seq[AttributeReference],
+    reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
   override def isStreaming: Boolean = true
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f9c91326/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index df469af..7d9581b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType
  * Physical plan node for scanning data from a data source.
  */
 case class DataSourceV2ScanExec(
-    fullOutput: Seq[AttributeReference],
+    output: Seq[AttributeReference],
     @transient reader: DataSourceReader)
   extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {
 
   override def canEqual(other: Any): Boolean = 
other.isInstanceOf[DataSourceV2ScanExec]
 
-  override def producedAttributes: AttributeSet = AttributeSet(fullOutput)
-
   override def outputPartitioning: physical.Partitioning = reader match {
     case s: SupportsReportPartitioning =>
       new DataSourcePartitioning(

http://git-wip-us.apache.org/repos/asf/spark/blob/f9c91326/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
index 566a483..1ca6cbf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
@@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends 
Rule[LogicalPlan] with PredicateHel
 
     // TODO: add more push down rules.
 
-    pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
+    val columnPruned = pushDownRequiredColumns(filterPushed, 
filterPushed.outputSet)
     // After column pruning, we may have redundant PROJECT nodes in the query 
plan, remove them.
-    RemoveRedundantProject(filterPushed)
+    RemoveRedundantProject(columnPruned)
   }
 
   // TODO: nested fields pruning
-  private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: 
AttributeSet): Unit = {
+  private def pushDownRequiredColumns(
+      plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = {
     plan match {
-      case Project(projectList, child) =>
+      case p @ Project(projectList, child) =>
         val required = projectList.flatMap(_.references)
-        pushDownRequiredColumns(child, AttributeSet(required))
+        p.copy(child = pushDownRequiredColumns(child, AttributeSet(required)))
 
-      case Filter(condition, child) =>
+      case f @ Filter(condition, child) =>
         val required = requiredByParent ++ condition.references
-        pushDownRequiredColumns(child, required)
+        f.copy(child = pushDownRequiredColumns(child, required))
 
       case relation: DataSourceV2Relation => relation.reader match {
         case reader: SupportsPushDownRequiredColumns =>
+          // TODO: Enable the below assert after we make 
`DataSourceV2Relation` immutable. Fow now
+          // it's possible that the mutable reader being updated by someone 
else, and we need to
+          // always call `reader.pruneColumns` here to correct it.
+          // assert(relation.output.toStructType == reader.readSchema(),
+          //  "Schema of data source reader does not match the relation plan.")
+
           val requiredColumns = 
relation.output.filter(requiredByParent.contains)
           reader.pruneColumns(requiredColumns.toStructType)
 
-        case _ =>
+          val nameToAttr = 
relation.output.map(_.name).zip(relation.output).toMap
+          val newOutput = reader.readSchema().map(_.name).map(nameToAttr)
+          relation.copy(output = newOutput)
+
+        case _ => relation
       }
 
       // TODO: there may be more operators that can be used to calculate the 
required columns. We
       // can add more and more in the future.
-      case _ => plan.children.foreach(child => pushDownRequiredColumns(child, 
child.outputSet))
+      case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f9c91326/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index e0e034d..6ad0e5f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
DataSourceV2ScanExec}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.functions._
@@ -297,6 +297,24 @@ class DataSourceV2Suite extends QueryTest with 
SharedSQLContext {
     val reader4 = getReader(q4)
     assert(reader4.requiredSchema.fieldNames === Seq("i"))
   }
+
+  test("SPARK-23315: get output from canonicalized data source v2 related 
plans") {
+    def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = {
+      val logical = df.queryExecution.optimizedPlan.collect {
+        case d: DataSourceV2Relation => d
+      }.head
+      assert(logical.canonicalized.output.length == numOutput)
+
+      val physical = df.queryExecution.executedPlan.collect {
+        case d: DataSourceV2ScanExec => d
+      }.head
+      assert(physical.canonicalized.output.length == numOutput)
+    }
+
+    val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+    checkCanonicalizedOutput(df, 2)
+    checkCanonicalizedOutput(df.select('i), 1)
+  }
 }
 
 class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to