This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 13edafab9f4 [SPARK-38530][SQL] Fix a bug that 
GeneratorNestedColumnAliasing can be incorrectly applied to some expressions
13edafab9f4 is described below

commit 13edafab9f45cc80aee41e2f82475367d88357ec
Author: minyyy <min.y...@databricks.com>
AuthorDate: Wed Apr 13 14:01:27 2022 +0800

    [SPARK-38530][SQL] Fix a bug that GeneratorNestedColumnAliasing can be 
incorrectly applied to some expressions
    
    ### What changes were proposed in this pull request?
    
    This PR makes GeneratorNestedColumnAliasing only be able to apply to 
GetStructField*(_: AttributeReference), here GetStructField* means nested 
GetStructField. The current way to collect expressions is a top-down way and it 
actually only checks 2 levels which is wrong. The rule is simple - If we see 
expressions other than GetStructField, we are done. When an expression E is 
pushed down into an Explode, the thing happens is:
    E(x) is now pushed down to apply to E(array(x)).
    So only expressions that can operate on both x and array(x) can be pushed. 
GetStructField is special since we have GetArrayStructFields and when 
GetStructField is pushed down, it becomes GetArrayStructFields. Any other 
expressions are not applicable.
    We also do not even need to check the child type is Array(Array()) or 
whether the rewritten expression has the pattern 
GetArrayStructFields(GetArrayStructFields()).
    1. When the child input type is Array(Array()), the ExtractValues 
expressions we get will always start from an innermost GetArrayStructFields, it 
does not align with GetStructField*(x).
    2. When we see GetArrayStructFields(GetArrayStructFields()) in the 
rewritten generator, we must have seen a GetArrayStructFields in the 
expressions before pushdown.
    
    ### Why are the changes needed?
    It fixes some correctness issues. See the above section for more details.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #35866 from minyyy/gnca_wrong_expr.
    
    Lead-authored-by: minyyy <min.y...@databricks.com>
    Co-authored-by: minyyy <98760575+min...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/optimizer/NestedColumnAliasing.scala  | 50 ++++++++++++++--------
 .../optimizer/NestedColumnAliasingSuite.scala      | 40 ++++++++++++++++-
 2 files changed, 69 insertions(+), 21 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 9cf2925cdd2..45f84c21b7d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -240,12 +240,14 @@ object NestedColumnAliasing {
    */
   def getAttributeToExtractValues(
       exprList: Seq[Expression],
-      exclusiveAttrs: Seq[Attribute]): Map[Attribute, Seq[ExtractValue]] = {
+      exclusiveAttrs: Seq[Attribute],
+      extractor: (Expression) => Seq[Expression] = 
collectRootReferenceAndExtractValue)
+    : Map[Attribute, Seq[ExtractValue]] = {
 
     val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]()
     val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]()
     exprList.foreach { e =>
-      collectRootReferenceAndExtractValue(e).foreach {
+      extractor(e).foreach {
         // we can not alias the attr from lambda variable whose expr id is not 
available
         case ev: ExtractValue if 
!ev.exists(_.isInstanceOf[NamedLambdaVariable]) =>
           if (ev.references.size == 1) {
@@ -350,23 +352,44 @@ object GeneratorNestedColumnAliasing {
         return None
       }
       val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
-      val (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) =
+      var (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) =
         attrToExtractValues.partition { case (attr, _) =>
           attr.references.subsetOf(generatorOutputSet) }
 
       val pushedThrough = NestedColumnAliasing.rewritePlanWithAliases(
         plan, attrToExtractValuesNotOnGenerator)
 
-      // If the generator output is `ArrayType`, we cannot push through the 
extractor.
-      // It is because we don't allow field extractor on two-level array,
-      // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
-      // Similarily, we also cannot push through if the child of generator is 
`MapType`.
+      // We cannot push through if the child of generator is `MapType`.
       g.generator.children.head.dataType match {
         case _: MapType => return Some(pushedThrough)
         case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
         case _ =>
       }
 
+      // This function collects all GetStructField*(attribute) from the passed 
in expression.
+      // GetStructField* means arbitrary levels of nesting.
+      def collectNestedGetStructFields(e: Expression): Seq[Expression] = {
+        // The helper function returns a tuple of
+        // (nested GetStructField including the current level, all other 
nested GetStructField)
+        def helper(e: Expression): (Seq[Expression], Seq[Expression]) = e 
match {
+          case _: AttributeReference => (Seq(e), Seq.empty)
+          case gsf: GetStructField =>
+            val child_res = helper(gsf.child)
+            (child_res._1.map(p => gsf.withNewChildren(Seq(p))), child_res._2)
+          case other =>
+            val child_res = other.children.map(helper)
+            val child_res_combined = (child_res.flatMap(_._1), 
child_res.flatMap(_._2))
+            (Seq.empty, child_res_combined._1 ++ child_res_combined._2)
+        }
+
+        val res = helper(e)
+        (res._1 ++ res._2).filterNot(_.isInstanceOf[Attribute])
+      }
+
+      attrToExtractValuesOnGenerator = 
NestedColumnAliasing.getAttributeToExtractValues(
+        attrToExtractValuesOnGenerator.flatMap(_._2).toSeq, Seq.empty,
+        collectNestedGetStructFields)
+
       // Pruning on `Generator`'s output. We only process single field case.
       // For multiple field case, we cannot directly move field extractor into
       // the generator expression. A workaround is to re-construct array of 
struct
@@ -391,17 +414,6 @@ object GeneratorNestedColumnAliasing {
                 e.withNewChildren(Seq(extractor))
             }
 
-            // If after replacing generator expression with nested extractor, 
there
-            // is invalid extractor pattern like
-            // `GetArrayStructFields(GetArrayStructFields(...), ...), we 
cannot do
-            // pruning but fallback to original query plan.
-            val invalidExtractor = rewrittenG.generator.children.head.collect {
-              case GetArrayStructFields(_: GetArrayStructFields, _, _, _, _) 
=> true
-            }
-            if (invalidExtractor.nonEmpty) {
-              return Some(pushedThrough)
-            }
-
             // As we change the child of the generator, its output data type 
must be updated.
             val updatedGeneratorOutput = rewrittenG.generatorOutput
               .zip(rewrittenG.generator.elementSchema.toAttributes)
@@ -416,7 +428,7 @@ object GeneratorNestedColumnAliasing {
             // Replace nested column accessor with generator output.
             val attrExprIdsOnGenerator = 
attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet
             val updatedProject = 
p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
-              case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
+              case f: GetStructField if nestedFieldsOnGenerator.contains(f) =>
                 updatedGenerate.output
                   .find(a => attrExprIdsOnGenerator.contains(a.exprId))
                   .getOrElse(f)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index 93ae98ecb31..1475b3dc1b8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.SchemaPruningTest
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, 
UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Cross
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
+import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, 
StructField, StructType}
 
 class NestedColumnAliasingSuite extends SchemaPruningTest {
 
@@ -813,6 +813,42 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
     val expected3 = 
contact.select($"name").rebalance($"name").select($"name.first").analyze
     comparePlans(optimized3, expected3)
   }
+
+  test("SPARK-38530: Do not push down nested ExtractValues with other 
expressions") {
+    val inputType = StructType.fromDDL(
+      "a int, b struct<c: array<int>, c2: int>")
+    val simpleStruct = StructType.fromDDL(
+      "b struct<c: struct<d: int, e: int>, c2 int>"
+    )
+    val input = LocalRelation(
+      'id.int,
+      'col1.array(ArrayType(inputType)))
+
+    val query = input
+      .generate(Explode('col1))
+      .select(
+        UnresolvedExtractValue(
+          UnresolvedExtractValue(
+            CaseWhen(Seq(('col.getField("a") === 1,
+              Literal.default(simpleStruct)))),
+            Literal("b")),
+          Literal("c")).as("result"))
+      .analyze
+    val optimized = Optimize.execute(query)
+
+    val aliases = collectGeneratedAliases(optimized)
+
+    // Only the inner-most col.a should be pushed down.
+    val expected = input
+      .select('col1.getField("a").as(aliases(0)))
+      .generate(Explode($"${aliases(0)}"), unrequiredChildIndex = Seq(0))
+      .select(UnresolvedExtractValue(UnresolvedExtractValue(
+        CaseWhen(Seq(('col === 1,
+          Literal.default(simpleStruct)))), Literal("b")), 
Literal("c")).as("result"))
+      .analyze
+
+    comparePlans(optimized, expected)
+  }
 }
 
 object NestedColumnAliasingSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to