This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6926849 [SPARK-28395][SQL] Division operator support integral division 6926849 is described below commit 69268492471137dd7a3da54c218026c3b1fa7db3 Author: Yuming Wang <yumw...@ebay.com> AuthorDate: Tue Jul 16 15:43:15 2019 +0800 [SPARK-28395][SQL] Division operator support integral division ## What changes were proposed in this pull request? PostgreSQL, Teradata, SQL Server, DB2 and Presto perform integral division with the `/` operator. But Oracle, Vertica, Hive, MySQL and MariaDB perform fractional division with the `/` operator. This pr add a flag(`spark.sql.function.preferIntegralDivision`) to control whether to use integral division with the `/` operator. Examples: **PostgreSQL**: ```sql postgres=# select substr(version(), 0, 16), cast(10 as int) / cast(3 as int), cast(10.1 as float8) / cast(3 as int), cast(10 as int) / cast(3.1 as float8), cast(10.1 as float8)/cast(3.1 as float8); substr | ?column? | ?column? | ?column? | ?column? -----------------+----------+------------------+-----------------+------------------ PostgreSQL 11.3 | 3 | 3.36666666666667 | 3.2258064516129 | 3.25806451612903 (1 row) ``` **SQL Server**: ```sql 1> select cast(10 as int) / cast(3 as int), cast(10.1 as float) / cast(3 as int), cast(10 as int) / cast(3.1 as float), cast(10.1 as float)/cast(3.1 as float); 2> go ----------- ------------------------ ------------------------ ------------------------ 3 3.3666666666666667 3.225806451612903 3.258064516129032 (1 rows affected) ``` **DB2**: ```sql [db2inst12f3c821d36b7 ~]$ db2 "select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double) from table (sysproc.env_get_inst_info())" 1 2 3 4 ----------- ------------------------ ------------------------ ------------------------ 3 +3.36666666666667E+000 +3.22580645161290E+000 +3.25806451612903E+000 1 record(s) selected. ``` **Presto**: ```sql presto> select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double); _col0 | _col1 | _col2 | _col3 -------+--------------------+-------------------+------------------- 3 | 3.3666666666666667 | 3.225806451612903 | 3.258064516129032 (1 row) ``` **Teradata**: ![image](https://user-images.githubusercontent.com/5399861/61200701-e97d5380-a714-11e9-9a1d-57fd99d38c8d.png) **Oracle**: ```sql SQL> select 10 / 3 from dual; 10/3 ---------- 3.33333333 ``` **Vertica** ```sql dbadmin=> select version(), cast(10 as int) / cast(3 as int), cast(10.1 as float8) / cast(3 as int), cast(10 as int) / cast(3.1 as float8), cast(10.1 as float8)/cast(3.1 as float8); version | ?column? | ?column? | ?column? | ?column? ------------------------------------+----------------------+------------------+-----------------+------------------ Vertica Analytic Database v9.1.1-0 | 3.333333333333333333 | 3.36666666666667 | 3.2258064516129 | 3.25806451612903 (1 row) ``` **Hive**: ```sql hive> select cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double); OK 3.3333333333333335 3.3666666666666667 3.225806451612903 3.258064516129032 Time taken: 0.143 seconds, Fetched: 1 row(s) ``` **MariaDB**: ```sql MariaDB [(none)]> select version(), cast(10 as int) / cast(3 as int), cast(10.1 as double) / cast(3 as int), cast(10 as int) / cast(3.1 as double), cast(10.1 as double)/cast(3.1 as double); +--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+ | version() | cast(10 as int) / cast(3 as int) | cast(10.1 as double) / cast(3 as int) | cast(10 as int) / cast(3.1 as double) | cast(10.1 as double)/cast(3.1 as double) | +--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+ | 10.4.6-MariaDB-1:10.4.6+maria~bionic | 3.3333 | 3.3666666666666667 | 3.225806451612903 | 3.258064516129032 | +--------------------------------------+----------------------------------+---------------------------------------+---------------------------------------+------------------------------------------+ 1 row in set (0.000 sec) ``` **MySQL**: ```sql mysql> select version(), 10 / 3, 10 / 3.1, 10.1 / 3, 10.1 / 3.1; +-----------+--------+----------+----------+------------+ | version() | 10 / 3 | 10 / 3.1 | 10.1 / 3 | 10.1 / 3.1 | +-----------+--------+----------+----------+------------+ | 8.0.16 | 3.3333 | 3.2258 | 3.36667 | 3.25806 | +-----------+--------+----------+----------+------------+ 1 row in set (0.00 sec) ``` ## How was this patch tested? unit tests Closes #25158 from wangyum/SPARK-28395. Authored-by: Yuming Wang <yumw...@ebay.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 11 ++++++--- .../org/apache/spark/sql/internal/SQLConf.scala | 8 +++++++ .../sql/catalyst/analysis/TypeCoercionSuite.scala | 27 ++++++++++++++++++++-- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1fdec89..3125f8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { CaseWhenCoercion :: IfCoercion :: StackCoercion :: - Division :: + Division(conf) :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -666,7 +666,7 @@ object TypeCoercion { * Hive only performs integral division with the DIV operator. The arguments to / are always * converted to fractional types. */ - object Division extends TypeCoercionRule { + case class Division(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, @@ -677,7 +677,12 @@ object TypeCoercion { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) => - Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + (left.dataType, right.dataType) match { + case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision => + IntegralDivide(left, right) + case _ => + Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + } } private def isNumericOrNull(ex: Expression): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f76103e..57f5128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1524,6 +1524,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision") + .doc("When true, will perform integral division with the / operator " + + "if both sides are integral types.") + .booleanConf + .createWithDefault(false) + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation") .internal() @@ -2294,6 +2300,8 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION) + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index a725e4b..949bb30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1456,7 +1456,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion, Division) + val rules = Seq(FunctionArgumentConversion, Division(conf)) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1475,12 +1475,35 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division(conf), ImplicitTypeCasts) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) } + test("SPARK-28395 Division operator support integral division") { + val rules = Seq(FunctionArgumentConversion, Division(conf)) + Seq(true, false).foreach { preferIntegralDivision => + withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") { + val result1 = if (preferIntegralDivision) { + IntegralDivide(1L, 1L) + } else { + Divide(Cast(1L, DoubleType), Cast(1L, DoubleType)) + } + ruleTest(rules, Divide(1L, 1L), result1) + val result2 = if (preferIntegralDivision) { + IntegralDivide(1, Cast(1, ShortType)) + } else { + Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType)) + } + ruleTest(rules, Divide(1, Cast(1, ShortType)), result2) + + ruleTest(rules, Divide(1L, 1D), Divide(Cast(1L, DoubleType), Cast(1D, DoubleType))) + ruleTest(rules, Divide(Decimal(1.1), 1L), Divide(Decimal(1.1), 1L)) + } + } + } + test("binary comparison with string promotion") { val rule = TypeCoercion.PromoteStrings(conf) ruleTest(rule, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org