This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 13edafab9f4 [SPARK-38530][SQL] Fix a bug that GeneratorNestedColumnAliasing can be incorrectly applied to some expressions 13edafab9f4 is described below commit 13edafab9f45cc80aee41e2f82475367d88357ec Author: minyyy <min.y...@databricks.com> AuthorDate: Wed Apr 13 14:01:27 2022 +0800 [SPARK-38530][SQL] Fix a bug that GeneratorNestedColumnAliasing can be incorrectly applied to some expressions ### What changes were proposed in this pull request? This PR makes GeneratorNestedColumnAliasing only be able to apply to GetStructField*(_: AttributeReference), here GetStructField* means nested GetStructField. The current way to collect expressions is a top-down way and it actually only checks 2 levels which is wrong. The rule is simple - If we see expressions other than GetStructField, we are done. When an expression E is pushed down into an Explode, the thing happens is: E(x) is now pushed down to apply to E(array(x)). So only expressions that can operate on both x and array(x) can be pushed. GetStructField is special since we have GetArrayStructFields and when GetStructField is pushed down, it becomes GetArrayStructFields. Any other expressions are not applicable. We also do not even need to check the child type is Array(Array()) or whether the rewritten expression has the pattern GetArrayStructFields(GetArrayStructFields()). 1. When the child input type is Array(Array()), the ExtractValues expressions we get will always start from an innermost GetArrayStructFields, it does not align with GetStructField*(x). 2. When we see GetArrayStructFields(GetArrayStructFields()) in the rewritten generator, we must have seen a GetArrayStructFields in the expressions before pushdown. ### Why are the changes needed? It fixes some correctness issues. See the above section for more details. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Closes #35866 from minyyy/gnca_wrong_expr. Lead-authored-by: minyyy <min.y...@databricks.com> Co-authored-by: minyyy <98760575+min...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/optimizer/NestedColumnAliasing.scala | 50 ++++++++++++++-------- .../optimizer/NestedColumnAliasingSuite.scala | 40 ++++++++++++++++- 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 9cf2925cdd2..45f84c21b7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -240,12 +240,14 @@ object NestedColumnAliasing { */ def getAttributeToExtractValues( exprList: Seq[Expression], - exclusiveAttrs: Seq[Attribute]): Map[Attribute, Seq[ExtractValue]] = { + exclusiveAttrs: Seq[Attribute], + extractor: (Expression) => Seq[Expression] = collectRootReferenceAndExtractValue) + : Map[Attribute, Seq[ExtractValue]] = { val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]() val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]() exprList.foreach { e => - collectRootReferenceAndExtractValue(e).foreach { + extractor(e).foreach { // we can not alias the attr from lambda variable whose expr id is not available case ev: ExtractValue if !ev.exists(_.isInstanceOf[NamedLambdaVariable]) => if (ev.references.size == 1) { @@ -350,23 +352,44 @@ object GeneratorNestedColumnAliasing { return None } val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput) - val (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) = + var (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) = attrToExtractValues.partition { case (attr, _) => attr.references.subsetOf(generatorOutputSet) } val pushedThrough = NestedColumnAliasing.rewritePlanWithAliases( plan, attrToExtractValuesNotOnGenerator) - // If the generator output is `ArrayType`, we cannot push through the extractor. - // It is because we don't allow field extractor on two-level array, - // i.e., attr.field when attr is a ArrayType(ArrayType(...)). - // Similarily, we also cannot push through if the child of generator is `MapType`. + // We cannot push through if the child of generator is `MapType`. g.generator.children.head.dataType match { case _: MapType => return Some(pushedThrough) case ArrayType(_: ArrayType, _) => return Some(pushedThrough) case _ => } + // This function collects all GetStructField*(attribute) from the passed in expression. + // GetStructField* means arbitrary levels of nesting. + def collectNestedGetStructFields(e: Expression): Seq[Expression] = { + // The helper function returns a tuple of + // (nested GetStructField including the current level, all other nested GetStructField) + def helper(e: Expression): (Seq[Expression], Seq[Expression]) = e match { + case _: AttributeReference => (Seq(e), Seq.empty) + case gsf: GetStructField => + val child_res = helper(gsf.child) + (child_res._1.map(p => gsf.withNewChildren(Seq(p))), child_res._2) + case other => + val child_res = other.children.map(helper) + val child_res_combined = (child_res.flatMap(_._1), child_res.flatMap(_._2)) + (Seq.empty, child_res_combined._1 ++ child_res_combined._2) + } + + val res = helper(e) + (res._1 ++ res._2).filterNot(_.isInstanceOf[Attribute]) + } + + attrToExtractValuesOnGenerator = NestedColumnAliasing.getAttributeToExtractValues( + attrToExtractValuesOnGenerator.flatMap(_._2).toSeq, Seq.empty, + collectNestedGetStructFields) + // Pruning on `Generator`'s output. We only process single field case. // For multiple field case, we cannot directly move field extractor into // the generator expression. A workaround is to re-construct array of struct @@ -391,17 +414,6 @@ object GeneratorNestedColumnAliasing { e.withNewChildren(Seq(extractor)) } - // If after replacing generator expression with nested extractor, there - // is invalid extractor pattern like - // `GetArrayStructFields(GetArrayStructFields(...), ...), we cannot do - // pruning but fallback to original query plan. - val invalidExtractor = rewrittenG.generator.children.head.collect { - case GetArrayStructFields(_: GetArrayStructFields, _, _, _, _) => true - } - if (invalidExtractor.nonEmpty) { - return Some(pushedThrough) - } - // As we change the child of the generator, its output data type must be updated. val updatedGeneratorOutput = rewrittenG.generatorOutput .zip(rewrittenG.generator.elementSchema.toAttributes) @@ -416,7 +428,7 @@ object GeneratorNestedColumnAliasing { // Replace nested column accessor with generator output. val attrExprIdsOnGenerator = attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet val updatedProject = p.withNewChildren(Seq(updatedGenerate)).transformExpressions { - case f: ExtractValue if nestedFieldsOnGenerator.contains(f) => + case f: GetStructField if nestedFieldsOnGenerator.contains(f) => updatedGenerate.output .find(a => attrExprIdsOnGenerator.contains(a.exprId)) .getOrElse(f) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index 93ae98ecb31..1475b3dc1b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.SchemaPruningTest -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} class NestedColumnAliasingSuite extends SchemaPruningTest { @@ -813,6 +813,42 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val expected3 = contact.select($"name").rebalance($"name").select($"name.first").analyze comparePlans(optimized3, expected3) } + + test("SPARK-38530: Do not push down nested ExtractValues with other expressions") { + val inputType = StructType.fromDDL( + "a int, b struct<c: array<int>, c2: int>") + val simpleStruct = StructType.fromDDL( + "b struct<c: struct<d: int, e: int>, c2 int>" + ) + val input = LocalRelation( + 'id.int, + 'col1.array(ArrayType(inputType))) + + val query = input + .generate(Explode('col1)) + .select( + UnresolvedExtractValue( + UnresolvedExtractValue( + CaseWhen(Seq(('col.getField("a") === 1, + Literal.default(simpleStruct)))), + Literal("b")), + Literal("c")).as("result")) + .analyze + val optimized = Optimize.execute(query) + + val aliases = collectGeneratedAliases(optimized) + + // Only the inner-most col.a should be pushed down. + val expected = input + .select('col1.getField("a").as(aliases(0))) + .generate(Explode($"${aliases(0)}"), unrequiredChildIndex = Seq(0)) + .select(UnresolvedExtractValue(UnresolvedExtractValue( + CaseWhen(Seq(('col === 1, + Literal.default(simpleStruct)))), Literal("b")), Literal("c")).as("result")) + .analyze + + comparePlans(optimized, expected) + } } object NestedColumnAliasingSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org