This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bf75b495e18 [SPARK-38855][SQL] DS V2 supports push down math functions bf75b495e18 is described below commit bf75b495e18ed87d0c118bfd5f1ceb52d720cad9 Author: Jiaan Geng <belie...@163.com> AuthorDate: Wed Apr 13 14:41:47 2022 +0800 [SPARK-38855][SQL] DS V2 supports push down math functions ### What changes were proposed in this pull request? Currently, Spark have some math functions of ANSI standard. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L388 These functions show below: `LN`, `EXP`, `POWER`, `SQRT`, `FLOOR`, `CEIL`, `WIDTH_BUCKET` The mainstream databases support these functions show below. | 函数 | PostgreSQL | ClickHouse | H2 | MySQL | Oracle | Redshift | Presto | Teradata | Snowflake | DB2 | Vertica | Exasol | SqlServer | Yellowbrick | Impala | Mariadb | Druid | Pig | SQLite | Influxdata | Singlestore | ElasticSearch | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | `LN` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `EXP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `POWER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | | `SQRT` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `FLOOR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `CEIL` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `WIDTH_BUCKET` | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No | No | Yes | No | No | No | No | No | No | No | DS V2 should supports push down these math functions. ### Why are the changes needed? DS V2 supports push down math functions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36140 from beliefer/SPARK-38855. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/GeneralScalarExpression.java | 54 ++++++++++++++++++++++ .../sql/connector/util/V2ExpressionSQLBuilder.java | 7 +++ .../spark/sql/errors/QueryCompilationErrors.scala | 4 ++ .../sql/catalyst/util/V2ExpressionBuilder.scala | 28 ++++++++++- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 26 +++++++++++ .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 28 ++++++++++- 6 files changed, 145 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 8952761f9ef..58082d5ee09 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -94,6 +94,60 @@ import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; * <li>Since version: 3.3.0</li> * </ul> * </li> + * <li>Name: <code>ABS</code> + * <ul> + * <li>SQL semantic: <code>ABS(expr)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>COALESCE</code> + * <ul> + * <li>SQL semantic: <code>COALESCE(expr1, expr2)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>LN</code> + * <ul> + * <li>SQL semantic: <code>LN(expr)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>EXP</code> + * <ul> + * <li>SQL semantic: <code>EXP(expr)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>POWER</code> + * <ul> + * <li>SQL semantic: <code>POWER(expr, number)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>SQRT</code> + * <ul> + * <li>SQL semantic: <code>SQRT(expr)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>FLOOR</code> + * <ul> + * <li>SQL semantic: <code>FLOOR(expr)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>CEIL</code> + * <ul> + * <li>SQL semantic: <code>CEIL(expr)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> + * <li>Name: <code>WIDTH_BUCKET</code> + * <ul> + * <li>SQL semantic: <code>WIDTH_BUCKET(expr)</code></li> + * <li>Since version: 3.3.0</li> + * </ul> + * </li> * </ol> * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index a7d1ed7f85e..c9dfa2003e3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -95,6 +95,13 @@ public class V2ExpressionSQLBuilder { return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); case "ABS": case "COALESCE": + case "LN": + case "EXP": + case "POWER": + case "SQRT": + case "FLOOR": + case "CEIL": + case "WIDTH_BUCKET": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 5e42c1012bf..f42f0dbe772 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2384,4 +2384,8 @@ object QueryCompilationErrors { new AnalysisException( "Sinks cannot request distribution and ordering in continuous execution mode") } + + def noSuchFunctionError(database: String, funcInfo: String): Throwable = { + new AnalysisException(s"$database does not support function: $funcInfo") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 37db499470a..487b809d48a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -104,6 +104,32 @@ class V2ExpressionBuilder( } else { None } + case Log(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) + case Exp(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) + case Pow(left, right) => + val l = generateExpression(left) + val r = generateExpression(right) + if (l.isDefined && r.isDefined) { + Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) + } else { + None + } + case Sqrt(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) + case Floor(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) + case Ceil(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) + case wb: WidthBucket => + val childrenExpressions = wb.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == wb.children.length) { + Some(new GeneralScalarExpression("WIDTH_BUCKET", + childrenExpressions.toArray[V2Expression])) + } else { + None + } case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 643376cdb12..0aa971c0d3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.errors.QueryCompilationErrors private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "WIDTH_BUCKET" => + val functionInfo = super.visitSQLFunction(funcName, inputs) + throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo) + case _ => super.visitSQLFunction(funcName, inputs) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val h2SQLBuilder = new H2SQLBuilder() + try { + Some(h2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index e60af877e9c..5cfa2f465a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when} +import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -464,6 +464,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df5, expectedPlanFragment5) checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true))) + + val df6 = spark.table("h2.test.employee") + .filter(ln($"dept") > 1) + .filter(exp($"salary") > 2000) + .filter(pow($"dept", 2) > 4) + .filter(sqrt($"salary") > 100) + .filter(floor($"dept") > 1) + .filter(ceil($"dept") > 1) + checkFiltersRemoved(df6, ansiMode) + val expectedPlanFragment6 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " + + "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...," + } else { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]" + } + checkPushedInfo(df6, expectedPlanFragment6) + checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true))) + + // H2 does not support width_bucket + val df7 = sql(""" + |SELECT * FROM h2.test.employee + |WHERE width_bucket(dept, 1, 6, 3) > 1 + |""".stripMargin) + checkFiltersRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]") + checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org