This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new b3cd07b236f [SPARK-38761][SQL] DS V2 supports push down misc 
non-aggregate functions
b3cd07b236f is described below

commit b3cd07b236f46e8c402b06820d6f3a25fe608593
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Mon Apr 11 13:50:57 2022 +0800

    [SPARK-38761][SQL] DS V2 supports push down misc non-aggregate functions
    
    ### What changes were proposed in this pull request?
    Currently, Spark have some misc non-aggregate functions of ANSI standard. 
Please refer 
https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L362.
    These functions show below:
    `abs`,
    `coalesce`,
    `nullif`,
    `CASE WHEN`
    DS V2 should supports push down these misc non-aggregate functions.
    Because DS V2 already support push down `CASE WHEN`, so this PR no need do 
the job again.
    Because `nullif` extends `RuntimeReplaceable`, so this PR no need do the 
job too.
    
    ### Why are the changes needed?
    DS V2 supports push down misc non-aggregate functions
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New tests.
    
    Closes #36039 from beliefer/SPARK-38761.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 9ce4ba02d3f67116a4a9786af453d869596fb3ec)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/connector/util/V2ExpressionSQLBuilder.java |  8 ++++
 .../sql/catalyst/util/V2ExpressionBuilder.scala    | 11 ++++-
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 50 +++++++++++-----------
 3 files changed, 44 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index c8d924db75a..a7d1ed7f85e 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -93,6 +93,10 @@ public class V2ExpressionSQLBuilder {
           return visitNot(build(e.children()[0]));
         case "~":
           return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
+        case "ABS":
+        case "COALESCE":
+          return visitSQLFunction(name,
+            Arrays.stream(e.children()).map(c -> 
build(c)).toArray(String[]::new));
         case "CASE_WHEN": {
           List<String> children =
             Arrays.stream(e.children()).map(c -> 
build(c)).collect(Collectors.toList());
@@ -210,6 +214,10 @@ public class V2ExpressionSQLBuilder {
     return sb.toString();
   }
 
+  protected String visitSQLFunction(String funcName, String[] inputs) {
+    return funcName + "(" + 
Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
+  }
+
   protected String visitUnexpectedExpr(Expression expr) throws 
IllegalArgumentException {
     throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 5fd01ac5636..37db499470a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.util
 
-import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, 
BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, 
Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, 
Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, 
Subtract, UnaryMinus}
+import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, 
BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, 
BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, 
Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, 
Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
 import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression 
=> V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
 import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, 
AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
 import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -95,6 +95,15 @@ class V2ExpressionBuilder(
       }
     case Cast(child, dataType, _, true) =>
       generateExpression(child).map(v => new V2Cast(v, dataType))
+    case Abs(child, true) => generateExpression(child)
+      .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v)))
+    case Coalesce(children) =>
+      val childrenExpressions = children.flatMap(generateExpression(_))
+      if (children.length == childrenExpressions.length) {
+        Some(new GeneralScalarExpression("COALESCE", 
childrenExpressions.toArray[V2Expression]))
+      } else {
+        None
+      }
     case and: And =>
       // AND expects predicate
       val l = generateExpression(and.left, true)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 23138b8899b..858781f2cde 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
 import 
org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, 
V1ScanWrapper}
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, 
sum, udf, when}
+import org.apache.spark.sql.functions.{abs, avg, coalesce, count, 
count_distinct, lit, not, sum, udf, when}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.util.Utils
@@ -381,19 +381,13 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
         checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2)))
 
         val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 
1)
-
         checkFiltersRemoved(df2, ansiMode)
-
-        df2.queryExecution.optimizedPlan.collect {
-          case _: DataSourceV2ScanRelation =>
-            val expected_plan_fragment = if (ansiMode) {
-              "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], "
-            } else {
-              "PushedFilters: [ID IS NOT NULL], "
-            }
-            checkKeywordsExistsInExplain(df2, expected_plan_fragment)
+        val expectedPlanFragment2 = if (ansiMode) {
+          "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], "
+        } else {
+          "PushedFilters: [ID IS NOT NULL], "
         }
-
+        checkPushedInfo(df2, expectedPlanFragment2)
         if (ansiMode) {
           val e = intercept[SparkException] {
             checkAnswer(df2, Seq.empty)
@@ -422,22 +416,30 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
 
         val df4 = spark.table("h2.test.employee")
           .filter(($"salary" > 1000d).and($"salary" < 12000d))
-
         checkFiltersRemoved(df4, ansiMode)
-
-        df4.queryExecution.optimizedPlan.collect {
-          case _: DataSourceV2ScanRelation =>
-            val expected_plan_fragment = if (ansiMode) {
-              "PushedFilters: [SALARY IS NOT NULL, " +
-                "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 
12000.0], "
-            } else {
-              "PushedFilters: [SALARY IS NOT NULL], "
-            }
-            checkKeywordsExistsInExplain(df4, expected_plan_fragment)
+        val expectedPlanFragment4 = if (ansiMode) {
+          "PushedFilters: [SALARY IS NOT NULL, " +
+            "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 
12000.0], "
+        } else {
+          "PushedFilters: [SALARY IS NOT NULL], "
         }
-
+        checkPushedInfo(df4, expectedPlanFragment4)
         checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true),
           Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, 
true)))
+
+        val df5 = spark.table("h2.test.employee")
+          .filter(abs($"dept" - 3) > 1)
+          .filter(coalesce($"salary", $"bonus") > 2000)
+        checkFiltersRemoved(df5, ansiMode)
+        val expectedPlanFragment5 = if (ansiMode) {
+          "PushedFilters: [DEPT IS NOT NULL, ABS(DEPT - 3) > 1, " +
+            "(COALESCE(CAST(SALARY AS double), BONUS)) > 2000.0]"
+        } else {
+          "PushedFilters: [DEPT IS NOT NULL]"
+        }
+        checkPushedInfo(df5, expectedPlanFragment5)
+        checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
+          Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, 
true)))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to