This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f212c61 [SPARK-34868][SQL] Support divide an year-month interval by a numeric f212c61 is described below commit f212c61c435f74cf021e4e780ef9a20ff6ab8c90 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Fri Mar 26 05:56:56 2021 +0000 [SPARK-34868][SQL] Support divide an year-month interval by a numeric ### What changes were proposed in this pull request? 1. Add new expression `DivideYMInterval` which multiplies a `YearMonthIntervalType` expression by a `NumericType` expression including ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType. 2. Extend binary arithmetic rules to support `year-month interval / numeric`. ### Why are the changes needed? To conform the ANSI SQL standard which requires such operation over year-month intervals: <img width="656" alt="Screenshot 2021-03-25 at 18 44 58" src="https://user-images.githubusercontent.com/1580697/112501559-68f07080-8d9a-11eb-8781-66e6631bb7ef.png"> ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By running new tests: ``` $ build/sbt "test:testOnly *IntervalExpressionsSuite" $ build/sbt "test:testOnly *ColumnExpressionSuite" ``` Closes #31961 from MaxGekk/div-ym-interval-by-num. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 1 + .../catalyst/expressions/intervalExpressions.scala | 54 +++++++++++++++++++++- .../expressions/IntervalExpressionsSuite.scala | 33 +++++++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 34 ++++++++++++++ 4 files changed, 120 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 66546f8..fedf9ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -380,6 +380,7 @@ class Analyzer(override val catalogManager: CatalogManager) } case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, r.dataType) match { case (CalendarIntervalType, _) => DivideInterval(l, r, f) + case (YearMonthIntervalType, _) => DivideYMInterval(l, r) case _ => d } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 8c64d23..78b3871 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import java.math.RoundingMode import java.util.Locale -import com.google.common.math.DoubleMath +import com.google.common.math.{DoubleMath, IntMath, LongMath} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.internal.SQLConf @@ -341,3 +341,53 @@ case class MultiplyDTInterval( override def toString: String = s"($left * $right)" } + +// Divide an year-month interval by a numeric +case class DivideYMInterval( + interval: Expression, + num: Expression) + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + override def left: Expression = interval + override def right: Expression = num + + override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType) + override def dataType: DataType = YearMonthIntervalType + + @transient + private lazy val evalFunc: (Int, Any) => Any = right.dataType match { + case LongType => (months: Int, num) => + LongMath.divide(months, num.asInstanceOf[Long], RoundingMode.HALF_UP).toInt + case _: IntegralType => (months: Int, num) => + IntMath.divide(months, num.asInstanceOf[Number].intValue(), RoundingMode.HALF_UP) + case _: DecimalType => (months: Int, num) => + val decimalRes = ((new Decimal).set(months) / num.asInstanceOf[Decimal]).toJavaBigDecimal + decimalRes.setScale(0, java.math.RoundingMode.HALF_UP).intValueExact() + case _: FractionalType => (months: Int, num) => + DoubleMath.roundToInt(months / num.asInstanceOf[Number].doubleValue(), RoundingMode.HALF_UP) + } + + override def nullSafeEval(interval: Any, num: Any): Any = { + evalFunc(interval.asInstanceOf[Int], num) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = right.dataType match { + case LongType => + val math = classOf[LongMath].getName + val javaType = CodeGenerator.javaType(dataType) + defineCodeGen(ctx, ev, (m, n) => + s"($javaType)($math.divide($m, $n, java.math.RoundingMode.HALF_UP))") + case _: IntegralType => + val math = classOf[IntMath].getName + defineCodeGen(ctx, ev, (m, n) => s"$math.divide($m, $n, java.math.RoundingMode.HALF_UP)") + case _: DecimalType => + defineCodeGen(ctx, ev, (m, n) => + s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" + + ".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()") + case _: FractionalType => + val math = classOf[DoubleMath].getName + defineCodeGen(ctx, ev, (m, n) => + s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)") + } + + override def toString: String = s"($left / $right)" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index bc9a50f..6971b08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -344,4 +344,37 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DayTimeIntervalType, numType) } } + + test("SPARK-34868: divide year-month interval by numeric") { + Seq( + (Period.ofYears(-123), Literal(null, DecimalType.USER_DEFAULT)) -> null, + (Period.ofMonths(0), 10) -> Period.ofMonths(0), + (Period.ofMonths(200), Double.PositiveInfinity) -> Period.ofMonths(0), + (Period.ofMonths(-200), Float.NegativeInfinity) -> Period.ofMonths(0), + (Period.ofYears(100), -1.toByte) -> Period.ofYears(-100), + (Period.ofYears(1), 2.toShort) -> Period.ofMonths(6), + (Period.ofYears(-1), -3) -> Period.ofMonths(4), + (Period.ofMonths(-1000), 0.5f) -> Period.ofMonths(-2000), + (Period.ofYears(1000), 100d) -> Period.ofYears(10), + (Period.ofMonths(2), BigDecimal(0.1)) -> Period.ofMonths(20) + ).foreach { case ((period, num), expected) => + checkEvaluation(DivideYMInterval(Literal(period), Literal(num)), expected) + } + + Seq( + (Period.ofMonths(1), 0) -> "/ by zero", + (Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN", + (Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN" + ).foreach { case ((period, num), expectedErrMsg) => + checkExceptionInExpression[ArithmeticException]( + DivideYMInterval(Literal(period), Literal(num)), + expectedErrMsg) + } + + numericTypes.foreach { numType => + checkConsistencyBetweenInterpretedAndCodegenAllowingException( + (interval: Expression, num: Expression) => DivideYMInterval(interval, num), + YearMonthIntervalType, numType) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 60044ad..8c57b7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -2647,4 +2647,38 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(e.isInstanceOf[ArithmeticException]) assert(e.getMessage.contains("overflow")) } + + test("SPARK-34868: divide year-month interval by numeric") { + checkAnswer( + Seq((Period.ofYears(0), 10.toByte)).toDF("i", "n").select($"i" / $"n"), + Row(Period.ofYears(0))) + checkAnswer( + Seq((Period.ofYears(10), 3.toShort)).toDF("i", "n").select($"i" / $"n"), + Row(Period.ofYears(3).plusMonths(4))) + checkAnswer( + Seq((Period.ofYears(1000), "2")).toDF("i", "n").select($"i" / $"n"), + Row(Period.ofYears(500))) + checkAnswer( + Seq((Period.ofMonths(1).multipliedBy(Int.MaxValue), Int.MaxValue)) + .toDF("i", "n").select($"i" / $"n"), + Row(Period.ofMonths(1))) + checkAnswer( + Seq((Period.ofYears(-1), 12L)).toDF("i", "n").select($"i" / $"n"), + Row(Period.ofMonths(-1))) + checkAnswer( + Seq((Period.ofMonths(-1), 0.499f)).toDF("i", "n").select($"i" / $"n"), + Row(Period.ofMonths(-2))) + checkAnswer( + Seq((Period.ofMonths(10000000), 10000000d)).toDF("i", "n").select($"i" / $"n"), + Row(Period.ofMonths(1))) + checkAnswer( + Seq((Period.ofMonths(-1), BigDecimal(0.5))).toDF("i", "n").select($"i" / $"n"), + Row(Period.ofMonths(-2))) + + val e = intercept[SparkException] { + Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect() + }.getCause + assert(e.isInstanceOf[ArithmeticException]) + assert(e.getMessage.contains("/ by zero")) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org