Repository: spark
Updated Branches:
  refs/heads/master 9073a426e -> 53c16b92a


[SPARK-8362] [SQL] Add unit tests for +, -, *, /, %

Added unit tests for all supported data types for:
- Add
- Subtract
- Multiply
- Divide
- UnaryMinus
- Remainder

Fixed bugs caught by the unit tests.

Author: Reynold Xin <r...@databricks.com>

Closes #6813 from rxin/SPARK-8362 and squashes the following commits:

fb3fe62 [Reynold Xin] Added Remainder.
3b266ba [Reynold Xin] [SPARK-8362] Add unit tests for +, -, *, /.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/53c16b92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/53c16b92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/53c16b92

Branch: refs/heads/master
Commit: 53c16b92a537c392a7c3ebc3ef24c1e86cb1a7a4
Parents: 9073a42
Author: Reynold Xin <r...@databricks.com>
Authored: Sun Jun 14 11:23:23 2015 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Sun Jun 14 11:23:23 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   |  31 ++--
 .../expressions/ArithmeticExpressionSuite.scala | 173 ++++++++++---------
 2 files changed, 99 insertions(+), 105 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/53c16b92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 18ddac1..9d1e965 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, 
GeneratedExpressionCode}
 import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -52,8 +51,8 @@ case class UnaryMinus(child: Expression) extends 
UnaryArithmetic {
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = dataType match {
-    case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
-    case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
+    case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
+    case dt: NumericType => defineCodeGen(ctx, ev, c => 
s"(${ctx.javaType(dt)})(-($c))")
   }
 
   protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
@@ -144,8 +143,8 @@ abstract class BinaryArithmetic extends BinaryExpression {
       defineCodeGen(ctx, ev, (eval1, eval2) => 
s"$eval1.$decimalMethod($eval2)")
     // byte and short are casted into int when add, minus, times or divide
     case ByteType | ShortType =>
-      defineCodeGen(ctx, ev, (eval1, eval2) =>
-        s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+      defineCodeGen(ctx, ev,
+        (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol 
$eval2)")
     case _ =>
       defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
   }
@@ -205,7 +204,7 @@ case class Multiply(left: Expression, right: Expression) 
extends BinaryArithmeti
 
 case class Divide(left: Expression, right: Expression) extends 
BinaryArithmetic {
   override def symbol: String = "/"
-  override def decimalMethod: String = "$divide"
+  override def decimalMethod: String = "$div"
 
   override def nullable: Boolean = true
 
@@ -245,11 +244,8 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
     } else {
       s"${eval2.primitive} == 0"
     }
-    val method = if (left.dataType.isInstanceOf[DecimalType]) {
-      s".$decimalMethod"
-    } else {
-      s"$symbol"
-    }
+    val method = if (left.dataType.isInstanceOf[DecimalType]) 
s".$decimalMethod" else s" $symbol "
+    val javaType = ctx.javaType(left.dataType)
     eval1.code + eval2.code +
       s"""
       boolean ${ev.isNull} = false;
@@ -257,7 +253,7 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
       if (${eval1.isNull} || ${eval2.isNull} || $test) {
         ${ev.isNull} = true;
       } else {
-        ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
+        ${ev.primitive} = ($javaType) 
(${eval1.primitive}$method(${eval2.primitive}));
       }
       """
   }
@@ -265,7 +261,7 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
 
 case class Remainder(left: Expression, right: Expression) extends 
BinaryArithmetic {
   override def symbol: String = "%"
-  override def decimalMethod: String = "reminder"
+  override def decimalMethod: String = "remainder"
 
   override def nullable: Boolean = true
 
@@ -305,11 +301,8 @@ case class Remainder(left: Expression, right: Expression) 
extends BinaryArithmet
     } else {
       s"${eval2.primitive} == 0"
     }
-    val method = if (left.dataType.isInstanceOf[DecimalType]) {
-      s".$decimalMethod"
-    } else {
-      s"$symbol"
-    }
+    val method = if (left.dataType.isInstanceOf[DecimalType]) 
s".$decimalMethod" else s" $symbol "
+    val javaType = ctx.javaType(left.dataType)
     eval1.code + eval2.code +
       s"""
       boolean ${ev.isNull} = false;
@@ -317,7 +310,7 @@ case class Remainder(left: Expression, right: Expression) 
extends BinaryArithmet
       if (${eval1.isNull} || ${eval2.isNull} || $test) {
         ${ev.isNull} = true;
       } else {
-        ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
+        ${ev.primitive} = ($javaType) 
(${eval1.primitive}$method(${eval2.primitive}));
       }
       """
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/53c16b92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 5ff1bca..3f48432 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.scalatest.Matchers._
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
@@ -26,100 +24,103 @@ import org.apache.spark.sql.types.{Decimal, DoubleType, 
IntegerType}
 
 class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
 
-  test("arithmetic") {
-    val row = create_row(1, 2, 3, null)
-    val c1 = 'a.int.at(0)
-    val c2 = 'a.int.at(1)
-    val c3 = 'a.int.at(2)
-    val c4 = 'a.int.at(3)
-
-    checkEvaluation(UnaryMinus(c1), -1, row)
-    checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100)
-
-    checkEvaluation(Add(c1, c4), null, row)
-    checkEvaluation(Add(c1, c2), 3, row)
-    checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row)
-    checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row)
-    checkEvaluation(
-      Add(Literal.create(null, IntegerType), Literal.create(null, 
IntegerType)), null, row)
-
-    checkEvaluation(-c1, -1, row)
-    checkEvaluation(c1 + c2, 3, row)
-    checkEvaluation(c1 - c2, -1, row)
-    checkEvaluation(c1 * c2, 2, row)
-    checkEvaluation(c1 / c2, 0, row)
-    checkEvaluation(c1 % c2, 1, row)
+  /**
+   * Runs through the testFunc for all numeric data types.
+   *
+   * @param testFunc a test function that accepts a conversion function to 
convert an integer
+   *                 into another data type.
+   */
+  private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = {
+    testFunc(_.toByte)
+    testFunc(_.toShort)
+    testFunc(identity)
+    testFunc(_.toLong)
+    testFunc(_.toFloat)
+    testFunc(_.toDouble)
+    testFunc(Decimal(_))
   }
 
-  test("fractional arithmetic") {
-    val row = create_row(1.1, 2.0, 3.1, null)
-    val c1 = 'a.double.at(0)
-    val c2 = 'a.double.at(1)
-    val c3 = 'a.double.at(2)
-    val c4 = 'a.double.at(3)
-
-    checkEvaluation(UnaryMinus(c1), -1.1, row)
-    checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0)
-    checkEvaluation(Add(c1, c4), null, row)
-    checkEvaluation(Add(c1, c2), 3.1, row)
-    checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row)
-    checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row)
-    checkEvaluation(
-      Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), 
null, row)
-
-    checkEvaluation(-c1, -1.1, row)
-    checkEvaluation(c1 + c2, 3.1, row)
-    checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row)
-    checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row)
-    checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row)
-    checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row)
+  test("+ (Add)") {
+    testNumericDataTypes { convert =>
+      val left = Literal(convert(1))
+      val right = Literal(convert(2))
+      checkEvaluation(Add(left, right), convert(3))
+      checkEvaluation(Add(Literal.create(null, left.dataType), right), null)
+      checkEvaluation(Add(left, Literal.create(null, right.dataType)), null)
+    }
   }
 
-  test("Abs") {
-    def testAbs(convert: (Int) => Any): Unit = {
-      checkEvaluation(Abs(Literal(convert(0))), convert(0))
-      checkEvaluation(Abs(Literal(convert(1))), convert(1))
-      checkEvaluation(Abs(Literal(convert(-1))), convert(1))
+  test("- (UnaryMinus)") {
+    testNumericDataTypes { convert =>
+      val input = Literal(convert(1))
+      val dataType = input.dataType
+      checkEvaluation(UnaryMinus(input), convert(-1))
+      checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
     }
-    testAbs(_.toByte)
-    testAbs(_.toShort)
-    testAbs(identity)
-    testAbs(_.toLong)
-    testAbs(_.toFloat)
-    testAbs(_.toDouble)
-    testAbs(Decimal(_))
   }
 
-  test("Divide") {
-    checkEvaluation(Divide(Literal(2), Literal(1)), 2)
-    checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
+  test("- (Minus)") {
+    testNumericDataTypes { convert =>
+      val left = Literal(convert(1))
+      val right = Literal(convert(2))
+      checkEvaluation(Subtract(left, right), convert(-1))
+      checkEvaluation(Subtract(Literal.create(null, left.dataType), right), 
null)
+      checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), 
null)
+    }
+  }
+
+  test("* (Multiply)") {
+    testNumericDataTypes { convert =>
+      val left = Literal(convert(1))
+      val right = Literal(convert(2))
+      checkEvaluation(Multiply(left, right), convert(2))
+      checkEvaluation(Multiply(Literal.create(null, left.dataType), right), 
null)
+      checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), 
null)
+    }
+  }
+
+  test("/ (Divide) basic") {
+    testNumericDataTypes { convert =>
+      val left = Literal(convert(2))
+      val right = Literal(convert(1))
+      val dataType = left.dataType
+      checkEvaluation(Divide(left, right), convert(2))
+      checkEvaluation(Divide(Literal.create(null, dataType), right), null)
+      checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null)
+      checkEvaluation(Divide(left, Literal(convert(0))), null)  // divide by 
zero
+    }
+  }
+
+  test("/ (Divide) for integral type") {
+    checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
+    checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
     checkEvaluation(Divide(Literal(1), Literal(2)), 0)
-    checkEvaluation(Divide(Literal(1), Literal(0)), null)
-    checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null)
-    checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null)
-    checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), 
null)
-    checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), 
null)
-    checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), 
null)
-    checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), 
null)
-    checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), 
null)
-    checkEvaluation(Divide(Literal.create(null, IntegerType), 
Literal.create(null, IntegerType)),
-      null)
+    checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong)
   }
 
-  test("Remainder") {
-    checkEvaluation(Remainder(Literal(2), Literal(1)), 0)
-    checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0)
-    checkEvaluation(Remainder(Literal(1), Literal(2)), 1)
-    checkEvaluation(Remainder(Literal(1), Literal(0)), null)
-    checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null)
-    checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null)
-    checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), 
null)
-    checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), 
null)
-    checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), 
null)
-    checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), 
null)
-    checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), 
null)
-    checkEvaluation(Remainder(Literal.create(null, IntegerType), 
Literal.create(null, IntegerType)),
-      null)
+  test("/ (Divide) for floating point") {
+    checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
+    checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
+    checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), 
Decimal(0.5))
+  }
+
+  test("% (Remainder)") {
+    testNumericDataTypes { convert =>
+      val left = Literal(convert(1))
+      val right = Literal(convert(2))
+      checkEvaluation(Remainder(left, right), convert(1))
+      checkEvaluation(Remainder(Literal.create(null, left.dataType), right), 
null)
+      checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), 
null)
+      checkEvaluation(Remainder(left, Literal(convert(0))), null)  // mod by 0
+    }
+  }
+
+  test("Abs") {
+    testNumericDataTypes { convert =>
+      checkEvaluation(Abs(Literal(convert(0))), convert(0))
+      checkEvaluation(Abs(Literal(convert(1))), convert(1))
+      checkEvaluation(Abs(Literal(convert(-1))), convert(1))
+    }
   }
 
   test("MaxOf") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to