This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 85b504d64701 [SPARK-46442][SQL] DS V2 supports push down 
PERCENTILE_CONT and PERCENTILE_DISC
85b504d64701 is described below

commit 85b504d64701ca470b946841ca5b2b4e129293c1
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Wed Jan 10 12:24:24 2024 +0800

    [SPARK-46442][SQL] DS V2 supports push down PERCENTILE_CONT and 
PERCENTILE_DISC
    
    ### What changes were proposed in this pull request?
    This PR will translate the aggregate function `PERCENTILE_CONT` and 
`PERCENTILE_DISC` for pushdown.
    
    - This PR adds `Expression[] orderingWithinGroups` into 
`GeneralAggregateFunc`, so as DS V2 pushdown framework could compile the 
`WITHIN GROUP (ORDER BY ...)` easily.
    
    - This PR also split `visitInverseDistributionFunction` from 
`visitAggregateFunction`, so as DS V2 pushdown framework could generate the 
syntax `WITHIN GROUP (ORDER BY ...)` easily.
    
    - This PR also fix a bug that `JdbcUtils` can't treat the precision and 
scale of decimal returned from JDBC.
    
    ### Why are the changes needed?
    DS V2 supports push down `PERCENTILE_CONT` and `PERCENTILE_DISC`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    'No'.
    
    Closes #44397 from beliefer/SPARK-46442.
    
    Lead-authored-by: Jiaan Geng <belie...@163.com>
    Co-authored-by: beliefer <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../aggregate/GeneralAggregateFunc.java            | 21 +++++++-
 .../sql/connector/util/V2ExpressionSQLBuilder.java | 21 +++++++-
 .../sql/catalyst/util/V2ExpressionBuilder.scala    | 20 +++++--
 .../org/apache/spark/sql/jdbc/H2Dialect.scala      | 15 +-----
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   | 17 +++++-
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 62 ++++++++++++++++++++--
 6 files changed, 132 insertions(+), 24 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index 4d787eaf9644..d287288ba33f 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -21,6 +21,7 @@ import java.util.Arrays;
 
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.connector.expressions.SortValue;
 import org.apache.spark.sql.internal.connector.ExpressionWithToString;
 
 /**
@@ -41,7 +42,9 @@ import 
org.apache.spark.sql.internal.connector.ExpressionWithToString;
  *  <li><pre>REGR_R2(input1, input2)</pre> Since 3.4.0</li>
  *  <li><pre>REGR_SLOPE(input1, input2)</pre> Since 3.4.0</li>
  *  <li><pre>REGR_SXY(input1, input2)</pre> Since 3.4.0</li>
- *  <li><pre>MODE(input1[, inverse])</pre> Since 4.0.0</li>
+ *  <li><pre>MODE() WITHIN (ORDER BY input1 [ASC|DESC])</pre> Since 4.0.0</li>
+ *  <li><pre>PERCENTILE_CONT(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> 
Since 4.0.0</li>
+ *  <li><pre>PERCENTILE_DISC(input1) WITHIN (ORDER BY input2 [ASC|DESC])</pre> 
Since 4.0.0</li>
  * </ol>
  *
  * @since 3.3.0
@@ -51,11 +54,21 @@ public final class GeneralAggregateFunc extends 
ExpressionWithToString implement
   private final String name;
   private final boolean isDistinct;
   private final Expression[] children;
+  private final SortValue[] orderingWithinGroups;
 
   public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] 
children) {
     this.name = name;
     this.isDistinct = isDistinct;
     this.children = children;
+    this.orderingWithinGroups = new SortValue[]{};
+  }
+
+  public GeneralAggregateFunc(
+      String name, boolean isDistinct, Expression[] children, SortValue[] 
orderingWithinGroups) {
+    this.name = name;
+    this.isDistinct = isDistinct;
+    this.children = children;
+    this.orderingWithinGroups = orderingWithinGroups;
   }
 
   public String name() { return name; }
@@ -64,6 +77,8 @@ public final class GeneralAggregateFunc extends 
ExpressionWithToString implement
   @Override
   public Expression[] children() { return children; }
 
+  public SortValue[] orderingWithinGroups() { return orderingWithinGroups; }
+
   @Override
   public boolean equals(Object o) {
     if (this == o) return true;
@@ -73,7 +88,8 @@ public final class GeneralAggregateFunc extends 
ExpressionWithToString implement
 
     if (isDistinct != that.isDistinct) return false;
     if (!name.equals(that.name)) return false;
-    return Arrays.equals(children, that.children);
+    if (!Arrays.equals(children, that.children)) return false;
+    return Arrays.equals(orderingWithinGroups, that.orderingWithinGroups);
   }
 
   @Override
@@ -81,6 +97,7 @@ public final class GeneralAggregateFunc extends 
ExpressionWithToString implement
     int result = name.hashCode();
     result = 31 * result + (isDistinct ? 1 : 0);
     result = 31 * result + Arrays.hashCode(children);
+    result = 31 * result + Arrays.hashCode(orderingWithinGroups);
     return result;
   }
 }
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index fb11de4fdedd..1035d2da0240 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -146,8 +146,16 @@ public class V2ExpressionSQLBuilder {
       return visitAggregateFunction("AVG", avg.isDistinct(),
         expressionsToStringArray(avg.children()));
     } else if (expr instanceof GeneralAggregateFunc f) {
-      return visitAggregateFunction(f.name(), f.isDistinct(),
-        expressionsToStringArray(f.children()));
+      if (f.orderingWithinGroups().length == 0) {
+        return visitAggregateFunction(f.name(), f.isDistinct(),
+          expressionsToStringArray(f.children()));
+      } else {
+        return visitInverseDistributionFunction(
+          f.name(),
+          f.isDistinct(),
+          expressionsToStringArray(f.children()),
+          expressionsToStringArray(f.orderingWithinGroups()));
+      }
     } else if (expr instanceof UserDefinedScalarFunc f) {
       return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
         expressionsToStringArray(f.children()));
@@ -273,6 +281,15 @@ public class V2ExpressionSQLBuilder {
     }
   }
 
+  protected String visitInverseDistributionFunction(
+      String funcName, boolean isDistinct, String[] inputs, String[] 
orderingWithinGroups) {
+    assert(isDistinct == false);
+    String withinGroup =
+      joinArrayToString(orderingWithinGroups, ", ", "WITHIN GROUP (ORDER BY ", 
")");
+    String functionCall = joinArrayToString(inputs, ", ", funcName + "(", ")");
+    return functionCall + " " + withinGroup;
+  }
+
   protected String visitUserDefinedScalarFunction(
       String funcName, String canonicalName, String[] inputs) {
     throw new SparkUnsupportedOperationException(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 2766bbaa8880..3942d193a328 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateFunction, Complete}
 import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
 import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
-import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression 
=> V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, 
LiteralValue, UserDefinedScalarFunc}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression 
=> V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, 
LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
 import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, 
UserDefinedAggregateFunc}
 import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, 
AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
 import org.apache.spark.sql.execution.datasources.PushableExpression
@@ -347,8 +347,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, 
right)))
     // Translate Mode if it is deterministic or reverse is defined.
     case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) =>
-      Some(new GeneralAggregateFunc("MODE", isDistinct,
-        Array(expr, LiteralValue(reverse, BooleanType))))
+      Some(new GeneralAggregateFunc(
+        "MODE", isDistinct, Array.empty, Array(generateSortValue(expr, 
!reverse))))
+    case aggregate.Percentile(
+      PushableExpression(left), PushableExpression(right), LongLiteral(1L), _, 
_, reverse) =>
+      Some(new GeneralAggregateFunc("PERCENTILE_CONT", isDistinct,
+        Array(right), Array(generateSortValue(left, reverse))))
+    case aggregate.PercentileDisc(
+      PushableExpression(left), PushableExpression(right), reverse, _, _, _) =>
+      Some(new GeneralAggregateFunc("PERCENTILE_DISC", isDistinct,
+        Array(right), Array(generateSortValue(left, reverse))))
     // TODO supports other aggregate functions
     case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
       val translatedExprs = children.flatMap(PushableExpression.unapply(_))
@@ -380,6 +388,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       None
     }
   }
+
+  private def generateSortValue(expr: V2Expression, reverse: Boolean): 
SortValue = if (reverse) {
+    SortValue(expr, SortDirection.DESCENDING, NullOrdering.NULLS_LAST)
+  } else {
+    SortValue(expr, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)
+  }
 }
 
 object ColumnOrField {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index fd20e495b10f..ae3a3addf7bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -43,7 +43,7 @@ private[sql] object H2Dialect extends JdbcDialect {
 
   private val distinctUnsupportedAggregateFunctions =
     Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", 
"REGR_SLOPE", "REGR_SXY",
-      "MODE")
+      "MODE", "PERCENTILE_CONT", "PERCENTILE_DISC")
 
   private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", 
"AVG",
     "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ 
distinctUnsupportedAggregateFunctions
@@ -271,18 +271,7 @@ private[sql] object H2Dialect extends JdbcDialect {
         throw new 
UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
           s"support aggregate function: $funcName with DISTINCT")
       } else {
-        funcName match {
-          case "MODE" =>
-            // Support Mode only if it is deterministic or reverse is defined.
-            assert(inputs.length == 2)
-            if (inputs.last == "true") {
-              s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})"
-            } else {
-              s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)"
-            }
-          case _ =>
-            super.visitAggregateFunction(funcName, isDistinct, inputs)
-        }
+        super.visitAggregateFunction(funcName, isDistinct, inputs)
       }
 
     override def visitExtract(field: String, source: String): String = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 888ef4a20be3..bee870fcf7b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -336,7 +336,22 @@ abstract class JdbcDialect extends Serializable with 
Logging {
         super.visitAggregateFunction(dialectFunctionName(funcName), 
isDistinct, inputs)
       } else {
         throw new UnsupportedOperationException(
-          s"${this.getClass.getSimpleName} does not support aggregate 
function: $funcName");
+          s"${this.getClass.getSimpleName} does not support aggregate 
function: $funcName")
+      }
+    }
+
+    override def visitInverseDistributionFunction(
+        funcName: String,
+        isDistinct: Boolean,
+        inputs: Array[String],
+        orderingWithinGroups: Array[String]): String = {
+      if (isSupportedFunction(funcName)) {
+        super.visitInverseDistributionFunction(
+          dialectFunctionName(funcName), isDistinct, inputs, 
orderingWithinGroups)
+      } else {
+        throw new UnsupportedOperationException(
+          s"${this.getClass.getSimpleName} does not support " +
+            s"inverse distribution function: $funcName")
       }
     }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 0a66680edd63..05b3787d0ff2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -2435,7 +2435,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkAggregateRemoved(df1)
     checkPushedInfo(df1,
       """
-        |PushedAggregates: [MODE(SALARY, true)],
+        |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS 
FIRST)],
         |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
         |PushedGroupByExpressions: [DEPT],
         |""".stripMargin.replaceAll("\n", " "))
@@ -2465,7 +2465,7 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkAggregateRemoved(df3)
     checkPushedInfo(df3,
       """
-        |PushedAggregates: [MODE(SALARY, true)],
+        |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS 
FIRST)],
         |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
         |PushedGroupByExpressions: [DEPT],
         |""".stripMargin.replaceAll("\n", " "))
@@ -2481,13 +2481,69 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkAggregateRemoved(df4)
     checkPushedInfo(df4,
       """
-        |PushedAggregates: [MODE(SALARY, false)],
+        |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY DESC NULLS 
LAST)],
         |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
         |PushedGroupByExpressions: [DEPT],
         |""".stripMargin.replaceAll("\n", " "))
     checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00)))
   }
 
+  test("scan with aggregate push-down: PERCENTILE & PERCENTILE_DISC with 
filter and group by") {
+    val df1 = sql(
+      """
+        |SELECT
+        |  dept,
+        |  PERCENTILE(salary, 0.5)
+        |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
+    checkFiltersRemoved(df1)
+    checkAggregateRemoved(df1)
+    checkPushedInfo(df1,
+      """
+        |PushedAggregates: [PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY SALARY 
ASC NULLS FIRST)],
+        |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
+        |PushedGroupByExpressions: [DEPT],
+        |""".stripMargin.replaceAll("\n", " "))
+    checkAnswer(df1, Seq(Row(1, 9500.00), Row(2, 11000.00), Row(6, 12000.00)))
+
+    val df2 = sql(
+      """
+        |SELECT
+        |  dept,
+        |  PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY),
+        |  PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
+        |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
+    checkFiltersRemoved(df2)
+    checkAggregateRemoved(df2)
+    checkPushedInfo(df2,
+      """
+        |PushedAggregates: [PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY 
ASC NULLS FIRST),
+        |PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
+        |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
+        |PushedGroupByExpressions: [DEPT],
+        |""".stripMargin.replaceAll("\n", " "))
+    checkAnswer(df2,
+      Seq(Row(1, 9300.0, 9700.0), Row(2, 10600.0, 11400.0), Row(6, 12000.0, 
12000.0)))
+
+    val df3 = sql(
+      """
+        |SELECT
+        |  dept,
+        |  PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY),
+        |  PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
+        |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
+    checkFiltersRemoved(df3)
+    checkAggregateRemoved(df3)
+    checkPushedInfo(df3,
+      """
+        |PushedAggregates: [PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY 
ASC NULLS FIRST),
+        |PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
+        |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
+        |PushedGroupByExpressions: [DEPT],
+        |""".stripMargin.replaceAll("\n", " "))
+    checkAnswer(df3,
+      Seq(Row(1, 9000.0, 10000.0), Row(2, 10000.0, 12000.0), Row(6, 12000.0, 
12000.0)))
+  }
+
   test("scan with aggregate push-down: aggregate over alias push down") {
     val cols = Seq("a", "b", "c", "d", "e")
     val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to