Repository: spark
Updated Branches:
  refs/heads/master 608353c8e -> 98d6d9c7a


[SPARK-9549][SQL] fix bugs in expressions

JIRA: https://issues.apache.org/jira/browse/SPARK-9549

This PR fix the following bugs:
1.  `UnaryMinus`'s codegen version would fail to compile when the input is 
`Long.MinValue`
2.  `BinaryComparison` would fail to compile in codegen mode when comparing 
Boolean types.
3.  `AddMonth` would fail if passed a huge negative month, which would lead 
accessing negative index of `monthDays` array.
4.  `Nanvl` with different type operands.

Author: Yijie Shen <henry.yijies...@gmail.com>

Closes #7882 from yjshen/minor_bug_fix and squashes the following commits:

41bbd2c [Yijie Shen] fix bug in Nanvl type coercion
3dee204 [Yijie Shen] address comments
4fa5de0 [Yijie Shen] fix bugs in expressions


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/98d6d9c7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/98d6d9c7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/98d6d9c7

Branch: refs/heads/master
Commit: 98d6d9c7a996f5456eb2653bb96985a1a05f4ce1
Parents: 608353c
Author: Yijie Shen <henry.yijies...@gmail.com>
Authored: Mon Aug 3 00:15:24 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Aug 3 00:15:24 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    |  5 ++
 .../sql/catalyst/expressions/arithmetic.scala   |  9 ++-
 .../sql/catalyst/expressions/predicates.scala   |  1 +
 .../spark/sql/catalyst/util/DateTimeUtils.scala |  7 ++-
 .../analysis/HiveTypeCoercionSuite.scala        | 12 ++++
 .../expressions/ArithmeticExpressionSuite.scala |  6 +-
 .../expressions/DateExpressionsSuite.scala      |  2 +
 .../catalyst/expressions/PredicateSuite.scala   | 62 ++++++++++----------
 .../spark/sql/ColumnExpressionSuite.scala       | 18 +++---
 9 files changed, 79 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 603afc4..422d423 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -562,6 +562,11 @@ object HiveTypeCoercion {
           case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
           case None => c
         }
+
+      case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType 
=>
+        NaNvl(l, Cast(r, DoubleType))
+      case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType 
=>
+        NaNvl(Cast(l, DoubleType), r)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 6f8f4dd..0891b55 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -36,7 +36,14 @@ case class UnaryMinus(child: Expression) extends 
UnaryExpression with ExpectsInp
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = dataType match {
     case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
-    case dt: NumericType => defineCodeGen(ctx, ev, c => 
s"(${ctx.javaType(dt)})(-($c))")
+    case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
+      val originValue = ctx.freshName("origin")
+      // codegen would fail to compile if we just write (-($c))
+      // for example, we could not write --9223372036854775808L in code
+      s"""
+        ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
+        ${ev.primitive} = (${ctx.javaType(dt)})(-($originValue));
+      """})
     case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => 
s"$c.negate()")
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index ab7d3af..b69bbab 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -227,6 +227,7 @@ abstract class BinaryComparison extends BinaryOperator with 
Predicate {
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     if (ctx.isPrimitiveType(left.dataType)
+        && left.dataType != BooleanType // java boolean doesn't support > or < 
operator
         && left.dataType != FloatType
         && left.dataType != DoubleType) {
       // faster version

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 6a98f4d..f645eb5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -614,8 +614,9 @@ object DateTimeUtils {
    */
   def dateAddMonths(days: Int, months: Int): Int = {
     val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + 
months
-    val currentMonthInYear = absoluteMonth % 12
-    val currentYear = absoluteMonth / 12
+    val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0
+    val currentMonthInYear = nonNegativeMonth % 12
+    val currentYear = nonNegativeMonth / 12
     val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + 
YearZero)) 1 else 0
     val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay
 
@@ -626,7 +627,7 @@ object DateTimeUtils {
     } else {
       dayOfMonth
     }
-    firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1
+    firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 7060877..cbdf453 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -251,6 +251,18 @@ class HiveTypeCoercionSuite extends PlanTest {
         :: Nil))
   }
 
+  test("nanvl casts") {
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
+      NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), 
Literal.create(1.0, DoubleType)))
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
+      NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, 
FloatType), DoubleType)))
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
+  }
+
   test("type coercion for If") {
     val rule = HiveTypeCoercion.IfCoercion
     ruleTest(rule,

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index d03b0fb..0bae8fe 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.Decimal
+import org.apache.spark.sql.types._
 
 class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
 
@@ -56,6 +56,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
       checkEvaluation(UnaryMinus(input), convert(-1))
       checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
     }
+    checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue)
+    checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue)
+    checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue)
+    checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
   }
 
   test("- (Minus)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 3bff8e0..e6e8790 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -280,6 +280,8 @@ class DateExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), 
null)
     checkEvaluation(AddMonths(Literal.create(null, DateType), 
Literal.create(null, IntegerType)),
       null)
+    checkEvaluation(
+      AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), 
-7293498)
   }
 
   test("months_between") {

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 0bc2812..d7eb13c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -136,60 +136,60 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
   }
 
-  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 
0d).map(Literal(_))
+  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, 
false).map(Literal(_))
   private val largeValues =
-    Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, 
Double.NaN).map(Literal(_))
+    Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, 
true).map(Literal(_))
 
   private val equalValues1 =
-    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, 
Double.NaN).map(Literal(_))
+    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, 
true).map(Literal(_))
   private val equalValues2 =
-    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, 
Double.NaN).map(Literal(_))
+    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, 
true).map(Literal(_))
 
-  test("BinaryComparison: <") {
+  test("BinaryComparison: lessThan") {
     for (i <- 0 until smallValues.length) {
-      checkEvaluation(smallValues(i) < largeValues(i), true)
-      checkEvaluation(equalValues1(i) < equalValues2(i), false)
-      checkEvaluation(largeValues(i) < smallValues(i), false)
+      checkEvaluation(LessThan(smallValues(i), largeValues(i)), true)
+      checkEvaluation(LessThan(equalValues1(i), equalValues2(i)), false)
+      checkEvaluation(LessThan(largeValues(i), smallValues(i)), false)
     }
   }
 
-  test("BinaryComparison: <=") {
+  test("BinaryComparison: LessThanOrEqual") {
     for (i <- 0 until smallValues.length) {
-      checkEvaluation(smallValues(i) <= largeValues(i), true)
-      checkEvaluation(equalValues1(i) <= equalValues2(i), true)
-      checkEvaluation(largeValues(i) <= smallValues(i), false)
+      checkEvaluation(LessThanOrEqual(smallValues(i), largeValues(i)), true)
+      checkEvaluation(LessThanOrEqual(equalValues1(i), equalValues2(i)), true)
+      checkEvaluation(LessThanOrEqual(largeValues(i), smallValues(i)), false)
     }
   }
 
-  test("BinaryComparison: >") {
+  test("BinaryComparison: GreaterThan") {
     for (i <- 0 until smallValues.length) {
-      checkEvaluation(smallValues(i) > largeValues(i), false)
-      checkEvaluation(equalValues1(i) > equalValues2(i), false)
-      checkEvaluation(largeValues(i) > smallValues(i), true)
+      checkEvaluation(GreaterThan(smallValues(i), largeValues(i)), false)
+      checkEvaluation(GreaterThan(equalValues1(i), equalValues2(i)), false)
+      checkEvaluation(GreaterThan(largeValues(i), smallValues(i)), true)
     }
   }
 
-  test("BinaryComparison: >=") {
+  test("BinaryComparison: GreaterThanOrEqual") {
     for (i <- 0 until smallValues.length) {
-      checkEvaluation(smallValues(i) >= largeValues(i), false)
-      checkEvaluation(equalValues1(i) >= equalValues2(i), true)
-      checkEvaluation(largeValues(i) >= smallValues(i), true)
+      checkEvaluation(GreaterThanOrEqual(smallValues(i), largeValues(i)), 
false)
+      checkEvaluation(GreaterThanOrEqual(equalValues1(i), equalValues2(i)), 
true)
+      checkEvaluation(GreaterThanOrEqual(largeValues(i), smallValues(i)), true)
     }
   }
 
-  test("BinaryComparison: ===") {
+  test("BinaryComparison: EqualTo") {
     for (i <- 0 until smallValues.length) {
-      checkEvaluation(smallValues(i) === largeValues(i), false)
-      checkEvaluation(equalValues1(i) === equalValues2(i), true)
-      checkEvaluation(largeValues(i) === smallValues(i), false)
+      checkEvaluation(EqualTo(smallValues(i), largeValues(i)), false)
+      checkEvaluation(EqualTo(equalValues1(i), equalValues2(i)), true)
+      checkEvaluation(EqualTo(largeValues(i), smallValues(i)), false)
     }
   }
 
-  test("BinaryComparison: <=>") {
+  test("BinaryComparison: EqualNullSafe") {
     for (i <- 0 until smallValues.length) {
-      checkEvaluation(smallValues(i) <=> largeValues(i), false)
-      checkEvaluation(equalValues1(i) <=> equalValues2(i), true)
-      checkEvaluation(largeValues(i) <=> smallValues(i), false)
+      checkEvaluation(EqualNullSafe(smallValues(i), largeValues(i)), false)
+      checkEvaluation(EqualNullSafe(equalValues1(i), equalValues2(i)), true)
+      checkEvaluation(EqualNullSafe(largeValues(i), smallValues(i)), false)
     }
   }
 
@@ -209,8 +209,8 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     nullTest(GreaterThanOrEqual)
     nullTest(EqualTo)
 
-    checkEvaluation(normalInt <=> nullInt, false)
-    checkEvaluation(nullInt <=> normalInt, false)
-    checkEvaluation(nullInt <=> nullInt, true)
+    checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
+    checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
+    checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/98d6d9c7/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index eb64684..35ca0b4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -227,20 +227,24 @@ class ColumnExpressionSuite extends QueryTest with 
SQLTestUtils {
 
   test("nanvl") {
     val testData = ctx.createDataFrame(ctx.sparkContext.parallelize(
-      Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil),
+      Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil),
       StructType(Seq(StructField("a", DoubleType), StructField("b", 
DoubleType),
-        StructField("c", DoubleType), StructField("d", DoubleType))))
+        StructField("c", DoubleType), StructField("d", DoubleType),
+        StructField("e", FloatType), StructField("f", IntegerType))))
 
     checkAnswer(
       testData.select(
-        nanvl($"a", lit(5)), nanvl($"b", lit(10)),
-        nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))),
-      Row(null, 3.0, null, Double.PositiveInfinity)
+        nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"),
+        nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)),
+        nanvl($"b", $"e"), nanvl($"e", $"f")),
+      Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
     )
     testData.registerTempTable("t")
     checkAnswer(
-      ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) 
from t"),
-      Row(null, 3.0, null, Double.PositiveInfinity)
+      ctx.sql(
+        "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), 
nanvl(d, 10), " +
+          " nanvl(b, e), nanvl(e, f) from t"),
+      Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to