This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 85b504d64701 [SPARK-46442][SQL] DS V2 supports push down PERCENTILE_CONT and PERCENTILE_DISC 85b504d64701 is described below commit 85b504d64701ca470b946841ca5b2b4e129293c1 Author: Jiaan Geng <belie...@163.com> AuthorDate: Wed Jan 10 12:24:24 2024 +0800 [SPARK-46442][SQL] DS V2 supports push down PERCENTILE_CONT and PERCENTILE_DISC ### What changes were proposed in this pull request? This PR will translate the aggregate function `PERCENTILE_CONT` and `PERCENTILE_DISC` for pushdown. - This PR adds `Expression[] orderingWithinGroups` into `GeneralAggregateFunc`, so as DS V2 pushdown framework could compile the `WITHIN GROUP (ORDER BY ...)` easily. - This PR also split `visitInverseDistributionFunction` from `visitAggregateFunction`, so as DS V2 pushdown framework could generate the syntax `WITHIN GROUP (ORDER BY ...)` easily. - This PR also fix a bug that `JdbcUtils` can't treat the precision and scale of decimal returned from JDBC. ### Why are the changes needed? DS V2 supports push down `PERCENTILE_CONT` and `PERCENTILE_DISC`. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #44397 from beliefer/SPARK-46442. Lead-authored-by: Jiaan Geng <belie...@163.com> Co-authored-by: beliefer <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../aggregate/GeneralAggregateFunc.java | 21 +++++++- .../sql/connector/util/V2ExpressionSQLBuilder.java | 21 +++++++- .../sql/catalyst/util/V2ExpressionBuilder.scala | 20 +++++-- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 15 +----- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 17 +++++- .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 62 ++++++++++++++++++++-- 6 files changed, 132 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 4d787eaf9644..d287288ba33f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -21,6 +21,7 @@ import java.util.Arrays; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.SortValue; import org.apache.spark.sql.internal.connector.ExpressionWithToString; /** @@ -41,7 +42,9 @@ import org.apache.spark.sql.internal.connector.ExpressionWithToString; * <li><pre>REGR_R2(input1, input2)</pre> Since 3.4.0</li> * <li><pre>REGR_SLOPE(input1, input2)</pre> Since 3.4.0</li> * <li><pre>REGR_SXY(input1, input2)</pre> Since 3.4.0</li> - * <li><pre>MODE(input1[, inverse])</pre> Since 4.0.0</li> + * <li><pre>MODE() WITHIN (ORDER BY input1 [ASC|DESC])</pre> Since 4.0.0</li> + * <li><pre>PERCENTILE_CONT(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> Since 4.0.0</li> + * <li><pre>PERCENTILE_DISC(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> Since 4.0.0</li> * </ol> * * @since 3.3.0 @@ -51,11 +54,21 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement private final String name; private final boolean isDistinct; private final Expression[] children; + private final SortValue[] orderingWithinGroups; public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) { this.name = name; this.isDistinct = isDistinct; this.children = children; + this.orderingWithinGroups = new SortValue[]{}; + } + + public GeneralAggregateFunc( + String name, boolean isDistinct, Expression[] children, SortValue[] orderingWithinGroups) { + this.name = name; + this.isDistinct = isDistinct; + this.children = children; + this.orderingWithinGroups = orderingWithinGroups; } public String name() { return name; } @@ -64,6 +77,8 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement @Override public Expression[] children() { return children; } + public SortValue[] orderingWithinGroups() { return orderingWithinGroups; } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -73,7 +88,8 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement if (isDistinct != that.isDistinct) return false; if (!name.equals(that.name)) return false; - return Arrays.equals(children, that.children); + if (!Arrays.equals(children, that.children)) return false; + return Arrays.equals(orderingWithinGroups, that.orderingWithinGroups); } @Override @@ -81,6 +97,7 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement int result = name.hashCode(); result = 31 * result + (isDistinct ? 1 : 0); result = 31 * result + Arrays.hashCode(children); + result = 31 * result + Arrays.hashCode(orderingWithinGroups); return result; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index fb11de4fdedd..1035d2da0240 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -146,8 +146,16 @@ public class V2ExpressionSQLBuilder { return visitAggregateFunction("AVG", avg.isDistinct(), expressionsToStringArray(avg.children())); } else if (expr instanceof GeneralAggregateFunc f) { - return visitAggregateFunction(f.name(), f.isDistinct(), - expressionsToStringArray(f.children())); + if (f.orderingWithinGroups().length == 0) { + return visitAggregateFunction(f.name(), f.isDistinct(), + expressionsToStringArray(f.children())); + } else { + return visitInverseDistributionFunction( + f.name(), + f.isDistinct(), + expressionsToStringArray(f.children()), + expressionsToStringArray(f.orderingWithinGroups())); + } } else if (expr instanceof UserDefinedScalarFunc f) { return visitUserDefinedScalarFunction(f.name(), f.canonicalName(), expressionsToStringArray(f.children())); @@ -273,6 +281,15 @@ public class V2ExpressionSQLBuilder { } } + protected String visitInverseDistributionFunction( + String funcName, boolean isDistinct, String[] inputs, String[] orderingWithinGroups) { + assert(isDistinct == false); + String withinGroup = + joinArrayToString(orderingWithinGroups, ", ", "WITHIN GROUP (ORDER BY ", ")"); + String functionCall = joinArrayToString(inputs, ", ", funcName + "(", ")"); + return functionCall + " " + withinGroup; + } + protected String visitUserDefinedScalarFunction( String funcName, String canonicalName, String[] inputs) { throw new SparkUnsupportedOperationException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 2766bbaa8880..3942d193a328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.connector.catalog.functions.ScalarFunction -import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableExpression @@ -347,8 +347,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right))) // Translate Mode if it is deterministic or reverse is defined. case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) => - Some(new GeneralAggregateFunc("MODE", isDistinct, - Array(expr, LiteralValue(reverse, BooleanType)))) + Some(new GeneralAggregateFunc( + "MODE", isDistinct, Array.empty, Array(generateSortValue(expr, !reverse)))) + case aggregate.Percentile( + PushableExpression(left), PushableExpression(right), LongLiteral(1L), _, _, reverse) => + Some(new GeneralAggregateFunc("PERCENTILE_CONT", isDistinct, + Array(right), Array(generateSortValue(left, reverse)))) + case aggregate.PercentileDisc( + PushableExpression(left), PushableExpression(right), reverse, _, _, _) => + Some(new GeneralAggregateFunc("PERCENTILE_DISC", isDistinct, + Array(right), Array(generateSortValue(left, reverse)))) // TODO supports other aggregate functions case aggregate.V2Aggregator(aggrFunc, children, _, _) => val translatedExprs = children.flatMap(PushableExpression.unapply(_)) @@ -380,6 +388,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { None } } + + private def generateSortValue(expr: V2Expression, reverse: Boolean): SortValue = if (reverse) { + SortValue(expr, SortDirection.DESCENDING, NullOrdering.NULLS_LAST) + } else { + SortValue(expr, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + } } object ColumnOrField { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index fd20e495b10f..ae3a3addf7bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -43,7 +43,7 @@ private[sql] object H2Dialect extends JdbcDialect { private val distinctUnsupportedAggregateFunctions = Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY", - "MODE") + "MODE", "PERCENTILE_CONT", "PERCENTILE_DISC") private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions @@ -271,18 +271,7 @@ private[sql] object H2Dialect extends JdbcDialect { throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + s"support aggregate function: $funcName with DISTINCT") } else { - funcName match { - case "MODE" => - // Support Mode only if it is deterministic or reverse is defined. - assert(inputs.length == 2) - if (inputs.last == "true") { - s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})" - } else { - s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)" - } - case _ => - super.visitAggregateFunction(funcName, isDistinct, inputs) - } + super.visitAggregateFunction(funcName, isDistinct, inputs) } override def visitExtract(field: String, source: String): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 888ef4a20be3..bee870fcf7b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -336,7 +336,22 @@ abstract class JdbcDialect extends Serializable with Logging { super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs) } else { throw new UnsupportedOperationException( - s"${this.getClass.getSimpleName} does not support aggregate function: $funcName"); + s"${this.getClass.getSimpleName} does not support aggregate function: $funcName") + } + } + + override def visitInverseDistributionFunction( + funcName: String, + isDistinct: Boolean, + inputs: Array[String], + orderingWithinGroups: Array[String]): String = { + if (isSupportedFunction(funcName)) { + super.visitInverseDistributionFunction( + dialectFunctionName(funcName), isDistinct, inputs, orderingWithinGroups) + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support " + + s"inverse distribution function: $funcName") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 0a66680edd63..05b3787d0ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -2435,7 +2435,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df1) checkPushedInfo(df1, """ - |PushedAggregates: [MODE(SALARY, true)], + |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)], |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], |PushedGroupByExpressions: [DEPT], |""".stripMargin.replaceAll("\n", " ")) @@ -2465,7 +2465,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df3) checkPushedInfo(df3, """ - |PushedAggregates: [MODE(SALARY, true)], + |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)], |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], |PushedGroupByExpressions: [DEPT], |""".stripMargin.replaceAll("\n", " ")) @@ -2481,13 +2481,69 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df4) checkPushedInfo(df4, """ - |PushedAggregates: [MODE(SALARY, false)], + |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)], |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], |PushedGroupByExpressions: [DEPT], |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00))) } + test("scan with aggregate push-down: PERCENTILE & PERCENTILE_DISC with filter and group by") { + val df1 = sql( + """ + |SELECT + | dept, + | PERCENTILE(salary, 0.5) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, + """ + |PushedAggregates: [PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df1, Seq(Row(1, 9500.00), Row(2, 11000.00), Row(6, 12000.00))) + + val df2 = sql( + """ + |SELECT + | dept, + | PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY), + | PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df2) + checkAggregateRemoved(df2) + checkPushedInfo(df2, + """ + |PushedAggregates: [PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST), + |PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df2, + Seq(Row(1, 9300.0, 9700.0), Row(2, 10600.0, 11400.0), Row(6, 12000.0, 12000.0))) + + val df3 = sql( + """ + |SELECT + | dept, + | PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY), + | PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df3) + checkAggregateRemoved(df3) + checkPushedInfo(df3, + """ + |PushedAggregates: [PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST), + |PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df3, + Seq(Row(1, 9000.0, 10000.0), Row(2, 10000.0, 12000.0), Row(6, 12000.0, 12000.0))) + } + test("scan with aggregate push-down: aggregate over alias push down") { val cols = Seq("a", "b", "c", "d", "e") val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org