This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 96c8b4f47c2 [SPARK-38855][SQL] DS V2 supports push down math functions
96c8b4f47c2 is described below

commit 96c8b4f47c2d0df249efb088882b248b5c230188
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Wed Apr 13 14:41:47 2022 +0800

    [SPARK-38855][SQL] DS V2 supports push down math functions
    
    ### What changes were proposed in this pull request?
    Currently, Spark have some math functions of ANSI standard. Please refer 
https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L388
    These functions show below:
    `LN`,
    `EXP`,
    `POWER`,
    `SQRT`,
    `FLOOR`,
    `CEIL`,
    `WIDTH_BUCKET`
    
    The mainstream databases support these functions show below.
    
    |  函数   | PostgreSQL  | ClickHouse  | H2  | MySQL  | Oracle  | Redshift  | 
Presto  | Teradata  | Snowflake  | DB2  | Vertica  | Exasol  | SqlServer  | 
Yellowbrick  | Impala  | Mariadb | Druid | Pig | SQLite | Influxdata | 
Singlestore | ElasticSearch |
    |  ----  | ----  | ----  | ----  | ----  | ----  | ----  | ----  | ----  | 
----  | ----  | ----  | ----  | ----  | ----  | ----  | ----  | ----  | ----  | 
----  | ----  | ----  | ----  |
    | `LN` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | 
Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
    | `EXP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | 
Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
    | `POWER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes 
| Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes |
    | `SQRT` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes 
| Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
    | `FLOOR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes 
| Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
    | `CEIL` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes 
| Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
    | `WIDTH_BUCKET` | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | 
Yes | No | No | No | Yes | No | No | No | No | No | No | No |
    
    DS V2 should supports push down these math functions.
    
    ### Why are the changes needed?
    DS V2 supports push down math functions
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New tests.
    
    Closes #36140 from beliefer/SPARK-38855.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit bf75b495e18ed87d0c118bfd5f1ceb52d720cad9)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/GeneralScalarExpression.java       | 54 ++++++++++++++++++++++
 .../sql/connector/util/V2ExpressionSQLBuilder.java |  7 +++
 .../spark/sql/errors/QueryCompilationErrors.scala  |  4 ++
 .../sql/catalyst/util/V2ExpressionBuilder.scala    | 28 ++++++++++-
 .../org/apache/spark/sql/jdbc/H2Dialect.scala      | 26 +++++++++++
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 28 ++++++++++-
 6 files changed, 145 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index 8952761f9ef..58082d5ee09 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -94,6 +94,60 @@ import 
org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
  *    <li>Since version: 3.3.0</li>
  *   </ul>
  *  </li>
+ *  <li>Name: <code>ABS</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>ABS(expr)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>COALESCE</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>COALESCE(expr1, expr2)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>LN</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>LN(expr)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>EXP</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>EXP(expr)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>POWER</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>POWER(expr, number)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>SQRT</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>SQRT(expr)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>FLOOR</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>FLOOR(expr)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>CEIL</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>CEIL(expr)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
+ *  <li>Name: <code>WIDTH_BUCKET</code>
+ *   <ul>
+ *    <li>SQL semantic: <code>WIDTH_BUCKET(expr)</code></li>
+ *    <li>Since version: 3.3.0</li>
+ *   </ul>
+ *  </li>
  * </ol>
  * Note: SQL semantic conforms ANSI standard, so some expressions are not 
supported when ANSI off,
  * including: add, subtract, multiply, divide, remainder, pmod.
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index a7d1ed7f85e..c9dfa2003e3 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -95,6 +95,13 @@ public class V2ExpressionSQLBuilder {
           return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
         case "ABS":
         case "COALESCE":
+        case "LN":
+        case "EXP":
+        case "POWER":
+        case "SQRT":
+        case "FLOOR":
+        case "CEIL":
+        case "WIDTH_BUCKET":
           return visitSQLFunction(name,
             Arrays.stream(e.children()).map(c -> 
build(c)).toArray(String[]::new));
         case "CASE_WHEN": {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 0532a953ef4..f1357f91f9d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2392,4 +2392,8 @@ object QueryCompilationErrors {
     new AnalysisException(
       "Sinks cannot request distribution and ordering in continuous execution 
mode")
   }
+
+  def noSuchFunctionError(database: String, funcInfo: String): Throwable = {
+    new AnalysisException(s"$database does not support function: $funcInfo")
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 37db499470a..487b809d48a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.util
 
-import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, 
BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, 
BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, 
Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, 
Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
+import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, 
BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, 
BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, 
EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, 
Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, 
StringPredicate, Subtract, UnaryMinus, WidthBucket}
 import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression 
=> V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
 import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, 
AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
 import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -104,6 +104,32 @@ class V2ExpressionBuilder(
       } else {
         None
       }
+    case Log(child) => generateExpression(child)
+      .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v)))
+    case Exp(child) => generateExpression(child)
+      .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v)))
+    case Pow(left, right) =>
+      val l = generateExpression(left)
+      val r = generateExpression(right)
+      if (l.isDefined && r.isDefined) {
+        Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, 
r.get)))
+      } else {
+        None
+      }
+    case Sqrt(child) => generateExpression(child)
+      .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v)))
+    case Floor(child) => generateExpression(child)
+      .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v)))
+    case Ceil(child) => generateExpression(child)
+      .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v)))
+    case wb: WidthBucket =>
+      val childrenExpressions = wb.children.flatMap(generateExpression(_))
+      if (childrenExpressions.length == wb.children.length) {
+        Some(new GeneralScalarExpression("WIDTH_BUCKET",
+          childrenExpressions.toArray[V2Expression]))
+      } else {
+        None
+      }
     case and: And =>
       // AND expects predicate
       val l = generateExpression(and.left, true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 643376cdb12..0aa971c0d3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc
 import java.sql.SQLException
 import java.util.Locale
 
+import scala.util.control.NonFatal
+
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, 
NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.expressions.Expression
 import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
GeneralAggregateFunc}
+import org.apache.spark.sql.errors.QueryCompilationErrors
 
 private object H2Dialect extends JdbcDialect {
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
 
+  class H2SQLBuilder extends JDBCSQLBuilder {
+    override def visitSQLFunction(funcName: String, inputs: Array[String]): 
String = {
+      funcName match {
+        case "WIDTH_BUCKET" =>
+          val functionInfo = super.visitSQLFunction(funcName, inputs)
+          throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo)
+        case _ => super.visitSQLFunction(funcName, inputs)
+      }
+    }
+  }
+
+  override def compileExpression(expr: Expression): Option[String] = {
+    val h2SQLBuilder = new H2SQLBuilder()
+    try {
+      Some(h2SQLBuilder.build(expr))
+    } catch {
+      case NonFatal(e) =>
+        logWarning("Error occurs while compiling V2 expression", e)
+        None
+    }
+  }
+
   override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
     super.compileAggregate(aggFunction).orElse(
       aggFunction match {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 858781f2cde..e28d9ba9ba8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
 import 
org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, 
V1ScanWrapper}
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{abs, avg, coalesce, count, 
count_distinct, lit, not, sum, udf, when}
+import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, 
count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.util.Utils
@@ -440,6 +440,32 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
         checkPushedInfo(df5, expectedPlanFragment5)
         checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
           Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, 
true)))
+
+        val df6 = spark.table("h2.test.employee")
+          .filter(ln($"dept") > 1)
+          .filter(exp($"salary") > 2000)
+          .filter(pow($"dept", 2) > 4)
+          .filter(sqrt($"salary") > 100)
+          .filter(floor($"dept") > 1)
+          .filter(ceil($"dept") > 1)
+        checkFiltersRemoved(df6, ansiMode)
+        val expectedPlanFragment6 = if (ansiMode) {
+          "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " +
+            "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...,"
+        } else {
+          "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]"
+        }
+        checkPushedInfo(df6, expectedPlanFragment6)
+        checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true)))
+
+        // H2 does not support width_bucket
+        val df7 = sql("""
+                        |SELECT * FROM h2.test.employee
+                        |WHERE width_bucket(dept, 1, 6, 3) > 1
+                        |""".stripMargin)
+        checkFiltersRemoved(df7, false)
+        checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]")
+        checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true)))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to