This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4567ed99a52 [SPARK-39340][SQL] DS v2 agg pushdown should allow dots in the name of top-level columns 4567ed99a52 is described below commit 4567ed99a52d0274312ba78024c548f91659a12a Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Wed Jun 22 22:59:16 2022 +0800 [SPARK-39340][SQL] DS v2 agg pushdown should allow dots in the name of top-level columns ### What changes were proposed in this pull request? It turns out that I was wrong in https://github.com/apache/spark/pull/36727 . We still have the limitation (column name cannot contain dot) in master and 3.3 braches, in a very implicit way: The `V2ExpressionBuilder` has a boolean flag `nestedPredicatePushdownEnabled` whose default value is false. When it's false, it uses `PushableColumnWithoutNestedColumn` to match columns, which doesn't support dot in names. `V2ExpressionBuilder` is only used in 2 places: 1. `PushableExpression`. This is a pattern match that is only used in v2 agg pushdown 2. `PushablePredicate`. This is a pattern match that is used in various places, but all the caller sides set `nestedPredicatePushdownEnabled` to true. This PR removes the `nestedPredicatePushdownEnabled` flag from `V2ExpressionBuilder`, and makes it always support nested fields. `PushablePredicate` is also updated accordingly to remove the boolean flag, as it's always true. ### Why are the changes needed? Fix a mistake to eliminate an unexpected limitation in DS v2 pushdown. ### Does this PR introduce _any_ user-facing change? No for end users. For data source developers, they can trigger agg pushdowm more often. ### How was this patch tested? a new test Closes #36945 from cloud-fan/dsv2. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/util/V2ExpressionBuilder.scala | 25 +++++++------- .../datasources/v2/DataSourceV2Strategy.scala | 38 ++++++++-------------- .../execution/datasources/v2/PushDownUtils.scala | 2 +- .../datasources/v2/DataSourceV2StrategySuite.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 31 ++++++++++++------ 5 files changed, 49 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 120b2044135..81d0b7dfeb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,19 +17,15 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Lower, Multiply, Not, Or, Overlay, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, StringTranslate, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, UnaryMinus, Upper, WidthBucket} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} -import org.apache.spark.sql.execution.datasources.PushableColumn import org.apache.spark.sql.types.BooleanType /** * The builder to generate V2 expressions from catalyst expressions. */ -class V2ExpressionBuilder( - e: Expression, nestedPredicatePushdownEnabled: Boolean = false, isPredicate: Boolean = false) { - - val pushableColumn = PushableColumn(nestedPredicatePushdownEnabled) +class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { def build(): Option[V2Expression] = generateExpression(e, isPredicate) @@ -49,12 +45,8 @@ class V2ExpressionBuilder( case Literal(true, BooleanType) => Some(new AlwaysTrue()) case Literal(false, BooleanType) => Some(new AlwaysFalse()) case Literal(value, dataType) => Some(LiteralValue(value, dataType)) - case col @ pushableColumn(name) => - val ref = if (nestedPredicatePushdownEnabled) { - FieldReference(name) - } else { - FieldReference.column(name) - } + case col @ ColumnOrField(nameParts) => + val ref = FieldReference(nameParts) if (isPredicate && col.dataType.isInstanceOf[BooleanType]) { Some(new V2Predicate("=", Array(ref, LiteralValue(true, BooleanType)))) } else { @@ -266,3 +258,12 @@ class V2ExpressionBuilder( case _ => None } } + +object ColumnOrField { + def unapply(e: Expression): Option[Seq[String]] = e match { + case a: Attribute => Some(Seq(a.name)) + case s: GetStructField => + unapply(s.child).map(_ :+ s.childSchema(s.ordinal).name) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 28dbf8b13f2..c0fa3e2ba65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -491,12 +491,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat private[sql] object DataSourceV2Strategy { - private def translateLeafNodeFilterV2( - predicate: Expression, - supportNestedPredicatePushdown: Boolean): Option[Predicate] = { - val pushablePredicate = PushablePredicate(supportNestedPredicatePushdown) + private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = { predicate match { - case pushablePredicate(expr) => Some(expr) + case PushablePredicate(expr) => Some(expr) case _ => None } } @@ -506,10 +503,8 @@ private[sql] object DataSourceV2Strategy { * * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ - protected[sql] def translateFilterV2( - predicate: Expression, - supportNestedPredicatePushdown: Boolean): Option[Predicate] = { - translateFilterV2WithMapping(predicate, None, supportNestedPredicatePushdown) + protected[sql] def translateFilterV2(predicate: Expression): Option[Predicate] = { + translateFilterV2WithMapping(predicate, None) } /** @@ -523,8 +518,7 @@ private[sql] object DataSourceV2Strategy { */ protected[sql] def translateFilterV2WithMapping( predicate: Expression, - translatedFilterToExpr: Option[mutable.HashMap[Predicate, Expression]], - nestedPredicatePushdownEnabled: Boolean) + translatedFilterToExpr: Option[mutable.HashMap[Predicate, Expression]]) : Option[Predicate] = { predicate match { case And(left, right) => @@ -538,26 +532,21 @@ private[sql] object DataSourceV2Strategy { // Pushing one leg of AND down is only safe to do at the top level. // You can see ParquetFilters' createFilter for more details. for { - leftFilter <- translateFilterV2WithMapping( - left, translatedFilterToExpr, nestedPredicatePushdownEnabled) - rightFilter <- translateFilterV2WithMapping( - right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + leftFilter <- translateFilterV2WithMapping(left, translatedFilterToExpr) + rightFilter <- translateFilterV2WithMapping(right, translatedFilterToExpr) } yield new V2And(leftFilter, rightFilter) case Or(left, right) => for { - leftFilter <- translateFilterV2WithMapping( - left, translatedFilterToExpr, nestedPredicatePushdownEnabled) - rightFilter <- translateFilterV2WithMapping( - right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + leftFilter <- translateFilterV2WithMapping(left, translatedFilterToExpr) + rightFilter <- translateFilterV2WithMapping(right, translatedFilterToExpr) } yield new V2Or(leftFilter, rightFilter) case Not(child) => - translateFilterV2WithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) - .map(new V2Not(_)) + translateFilterV2WithMapping(child, translatedFilterToExpr).map(new V2Not(_)) case other => - val filter = translateLeafNodeFilterV2(other, nestedPredicatePushdownEnabled) + val filter = translateLeafNodeFilterV2(other) if (filter.isDefined && translatedFilterToExpr.isDefined) { translatedFilterToExpr.get(filter.get) = predicate } @@ -589,10 +578,9 @@ private[sql] object DataSourceV2Strategy { /** * Get the expression of DS V2 to represent catalyst predicate that can be pushed down. */ -case class PushablePredicate(nestedPredicatePushdownEnabled: Boolean) { - +object PushablePredicate { def unapply(e: Expression): Option[Predicate] = - new V2ExpressionBuilder(e, nestedPredicatePushdownEnabled, true).build().map { v => + new V2ExpressionBuilder(e, true).build().map { v => assert(v.isInstanceOf[Predicate]) v.asInstanceOf[Predicate] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 492db45626a..60371d6bf43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -80,7 +80,7 @@ object PushDownUtils extends PredicateHelper { for (filterExpr <- filters) { val translated = DataSourceV2Strategy.translateFilterV2WithMapping( - filterExpr, Some(translatedFilterToExpr), nestedPredicatePushdownEnabled = true) + filterExpr, Some(translatedFilterToExpr)) if (translated.isEmpty) { untranslatableExprs += filterExpr } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index d149dfbb510..66dc65cf681 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -37,7 +37,7 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { */ def testTranslateFilter(catalystFilter: Expression, result: Option[Predicate]): Unit = { assertResult(result) { - DataSourceV2Strategy.translateFilterV2(catalystFilter, true) + DataSourceV2Strategy.translateFilterV2(catalystFilter) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 6c73ee09741..a6073566813 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -82,9 +82,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel conn.prepareStatement( "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate() conn.prepareStatement( - "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate() + "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL, \"dept.id\" INTEGER)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1, 1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2, 1)").executeUpdate() // scalastyle:off conn.prepareStatement( @@ -120,10 +121,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2))) } - private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String): Unit = { + private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = { df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - checkKeywordsExistsInExplain(df, expectedPlanFragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment: _*) } } @@ -1284,11 +1285,21 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("column name with composite field") { - checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2))) - val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept") - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [COUNT(`dept id`)]") - checkAnswer(df, Seq(Row(2))) + checkAnswer(sql("SELECT `dept id`, `dept.id` FROM h2.test.dept"), Seq(Row(1, 1), Row(2, 1))) + + val df1 = sql("SELECT COUNT(`dept id`) FROM h2.test.dept") + checkPushedInfo(df1, "PushedAggregates: [COUNT(`dept id`)]") + checkAnswer(df1, Seq(Row(2))) + + val df2 = sql("SELECT `dept.id`, COUNT(`dept id`) FROM h2.test.dept GROUP BY `dept.id`") + checkPushedInfo(df2, + "PushedGroupByExpressions: [`dept.id`]", "PushedAggregates: [COUNT(`dept id`)]") + checkAnswer(df2, Seq(Row(1, 2))) + + val df3 = sql("SELECT `dept id`, COUNT(`dept.id`) FROM h2.test.dept GROUP BY `dept id`") + checkPushedInfo(df3, + "PushedGroupByExpressions: [`dept id`]", "PushedAggregates: [COUNT(`dept.id`)]") + checkAnswer(df3, Seq(Row(1, 1), Row(2, 1))) } test("column name with non-ascii") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org