This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 793ba608181 [SPARK-38918][SQL] Nested column pruning should filter out 
attributes that do not belong to the current relation
793ba608181 is described below

commit 793ba608181b3eba8f1f57fcdd12dcd3fe035362
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Tue Apr 26 22:39:44 2022 -0700

    [SPARK-38918][SQL] Nested column pruning should filter out attributes that 
do not belong to the current relation
    
    ### What changes were proposed in this pull request?
    This PR updates `ProjectionOverSchema`  to use the outputs of the data 
source relation to filter the attributes in the nested schema pruning. This is 
needed because the attributes in the schema do not necessarily belong to the 
current data source relation. For example, if a filter contains a correlated 
subquery, then the subquery's children can contain attributes from both the 
inner query and the outer query. Since the `RewriteSubquery` batch happens 
after early scan pushdown rules, n [...]
    
    ### Why are the changes needed?
    
    To fix a bug in `SchemaPruning`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    Closes #36216 from allisonwang-db/spark-38918-nested-column-pruning.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
    (cherry picked from commit 150434b5d7909dcf8248ffa5ec3d937ea3da09fd)
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../expressions/ProjectionOverSchema.scala         |  8 +++-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  1 +
 .../spark/sql/catalyst/optimizer/objects.scala     |  2 +-
 .../sql/execution/datasources/SchemaPruning.scala  |  4 +-
 .../datasources/v2/V2ScanRelationPushDown.scala    |  6 +--
 .../execution/datasources/SchemaPruningSuite.scala | 45 +++++++++++++++++++++-
 6 files changed, 57 insertions(+), 9 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index a6be98c8a3a..69d30dd5048 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -24,15 +24,19 @@ import org.apache.spark.sql.types._
  * field indexes and field counts of complex type extractors and attributes
  * are adjusted to fit the schema. All other expressions are left as-is. This
  * class is motivated by columnar nested schema pruning.
+ *
+ * @param schema nested column schema
+ * @param output output attributes of the data source relation. They are used 
to filter out
+ *               attributes in the schema that do not belong to the current 
relation.
  */
-case class ProjectionOverSchema(schema: StructType) {
+case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
   private val fieldNames = schema.fieldNames.toSet
 
   def unapply(expr: Expression): Option[Expression] = getProjection(expr)
 
   private def getProjection(expr: Expression): Option[Expression] =
     expr match {
-      case a: AttributeReference if fieldNames.contains(a.name) =>
+      case a: AttributeReference if fieldNames.contains(a.name) && 
output.contains(a) =>
         Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
       case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
         getProjection(child).map {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index dc3e4c3da34..753d81e4003 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -60,6 +60,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
   override protected val excludedOnceBatches: Set[String] =
     Set(
       "PartitionPruning",
+      "RewriteSubquery",
       "Extract Python UDFs")
 
   protected def fixedPoint =
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 82aef32c5a2..3387bb20077 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
         }
 
         // Builds new projection.
-        val projectionOverSchema = ProjectionOverSchema(prunedSchema)
+        val projectionOverSchema = ProjectionOverSchema(prunedSchema, 
AttributeSet(s.output))
         val newProjects = p.projectList.map(_.transformDown {
           case projectionOverSchema(expr) => expr
         }).map { case expr: NamedExpression => expr }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
index a49c10c852b..26d5d92fecb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
@@ -91,8 +91,8 @@ object SchemaPruning extends Rule[LogicalPlan] {
       if (countLeaves(hadoopFsRelation.dataSchema) > 
countLeaves(prunedDataSchema) ||
         countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) {
         val prunedRelation = leafNodeBuilder(prunedDataSchema, 
prunedMetadataSchema)
-        val projectionOverSchema =
-          ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema))
+        val projectionOverSchema = ProjectionOverSchema(
+          prunedDataSchema.merge(prunedMetadataSchema), 
AttributeSet(relation.output))
         Some(buildNewProjection(projects, normalizedProjects, 
normalizedFilters,
           prunedRelation, projectionOverSchema))
       } else {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 5371829271d..92859f94888 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, 
Attribute, AttributeReference, Cast, Expression, IntegerLiteral, 
NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, 
SubqueryExpression}
-import org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, 
AliasHelper, And, Attribute, AttributeReference, AttributeSet, Cast, 
Expression, IntegerLiteral, NamedExpression, PredicateHelper, 
ProjectionOverSchema, SortOrder, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.optimizer.CollapseProject
 import org.apache.spark.sql.catalyst.planning.ScanOperation
@@ -320,7 +319,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
 
       val scanRelation = DataSourceV2ScanRelation(sHolder.relation, 
wrappedScan, output)
 
-      val projectionOverSchema = ProjectionOverSchema(output.toStructType)
+      val projectionOverSchema =
+        ProjectionOverSchema(output.toStructType, AttributeSet(output))
       val projectionFunc = (expr: Expression) => expr transformDown {
         case projectionOverSchema(newExpr) => newExpr
       }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 1d62713331f..d66a9aba9a2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -61,11 +61,15 @@ abstract class SchemaPruningSuite
   override protected def sparkConf: SparkConf =
     super.sparkConf.set(SQLConf.ANSI_STRICT_INDEX_OPERATOR.key, "false")
 
+  case class Employee(id: Int, name: FullName, employer: Company)
+
   val janeDoe = FullName("Jane", "X.", "Doe")
   val johnDoe = FullName("John", "Y.", "Doe")
   val susanSmith = FullName("Susan", "Z.", "Smith")
 
-  val employer = Employer(0, Company("abc", "123 Business Street"))
+  val company = Company("abc", "123 Business Street")
+
+  val employer = Employer(0, company)
   val employerWithNullCompany = Employer(1, null)
   val employerWithNullCompany2 = Employer(2, null)
 
@@ -81,6 +85,8 @@ abstract class SchemaPruningSuite
     Department(1, "Marketing", 1, employerWithNullCompany) ::
     Department(2, "Operation", 4, employerWithNullCompany2) :: Nil
 
+  val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, 
company) :: Nil
+
   case class Name(first: String, last: String)
   case class BriefContact(id: Int, name: Name, address: String)
 
@@ -621,6 +627,26 @@ abstract class SchemaPruningSuite
     }
   }
 
+  testSchemaPruning("SPARK-38918: nested schema pruning with correlated 
subqueries") {
+    withContacts {
+      withEmployees {
+        val query = sql(
+          """
+            |select count(*)
+            |from contacts c
+            |where not exists (select null from employees e where e.name.first 
= c.name.first
+            |  and e.employer.name = c.employer.company.name)
+            |""".stripMargin)
+        checkScan(query,
+          "struct<name:struct<first:string,middle:string,last:string>," +
+            
"employer:struct<id:int,company:struct<name:string,address:string>>>",
+          "struct<name:struct<first:string,middle:string,last:string>," +
+            "employer:struct<name:string,address:string>>")
+        checkAnswer(query, Row(3))
+      }
+    }
+  }
+
   protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit 
= {
     test(s"Spark vectorized reader - without partition data column - 
$testName") {
       withSQLConf(vectorizedReaderEnabledKey -> "true") {
@@ -701,6 +727,23 @@ abstract class SchemaPruningSuite
     }
   }
 
+  private def withEmployees(testThunk: => Unit): Unit = {
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+
+      makeDataSourceFile(employees, new File(path + "/employees"))
+
+      // Providing user specified schema. Inferred schema from different data 
sources might
+      // be different.
+      val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, 
`last`: STRING>, " +
+        "`employer` STRUCT<`name`: STRING, `address`: STRING>"
+      spark.read.format(dataSourceName).schema(schema).load(path + 
"/employees")
+        .createOrReplaceTempView("employees")
+
+      testThunk
+    }
+  }
+
   case class MixedCaseColumn(a: String, B: Int)
   case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to