Repository: spark
Updated Branches:
  refs/heads/master b9ef7ac98 -> fba3f5ba8


[SPARK-9169][SQL] Improve unit test coverage for null expressions.

Author: Reynold Xin <r...@databricks.com>

Closes #7490 from rxin/unit-test-null-funcs and squashes the following commits:

7b276f0 [Reynold Xin] Move isNaN.
8307287 [Reynold Xin] [SPARK-9169][SQL] Improve unit test coverage for null 
expressions.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fba3f5ba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fba3f5ba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fba3f5ba

Branch: refs/heads/master
Commit: fba3f5ba85673336c0556ef8731dcbcd175c7418
Parents: b9ef7ac
Author: Reynold Xin <r...@databricks.com>
Authored: Sat Jul 18 11:06:46 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Sat Jul 18 11:06:46 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/nullFunctions.scala    | 81 ++++++++++++++++++--
 .../sql/catalyst/expressions/predicates.scala   | 51 ------------
 .../expressions/NullFunctionsSuite.scala        | 78 ++++++++++---------
 .../catalyst/expressions/PredicateSuite.scala   | 12 +--
 4 files changed, 119 insertions(+), 103 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fba3f5ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 1522bca..98c6708 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, 
GeneratedExpressionCode}
 import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
 
+
+/**
+ * An expression that is evaluated to the first non-null input.
+ *
+ * {{{
+ *   coalesce(1, 2) => 1
+ *   coalesce(null, 1, 2) => 1
+ *   coalesce(null, null, 2) => 2
+ *   coalesce(null, null, null) => null
+ * }}}
+ */
 case class Coalesce(children: Seq[Expression]) extends Expression {
 
   /** Coalesce is nullable if all of its children are nullable, or if it has 
no children. */
@@ -70,6 +81,62 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
   }
 }
 
+
+/**
+ * Evaluates to `true` if it's NaN or null
+ */
+case class IsNaN(child: Expression) extends UnaryExpression
+  with Predicate with ImplicitCastInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(DoubleType, FloatType))
+
+  override def nullable: Boolean = false
+
+  override def eval(input: InternalRow): Any = {
+    val value = child.eval(input)
+    if (value == null) {
+      true
+    } else {
+      child.dataType match {
+        case DoubleType => value.asInstanceOf[Double].isNaN
+        case FloatType => value.asInstanceOf[Float].isNaN
+      }
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    val eval = child.gen(ctx)
+    child.dataType match {
+      case FloatType =>
+        s"""
+          ${eval.code}
+          boolean ${ev.isNull} = false;
+          ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
+          if (${eval.isNull}) {
+            ${ev.primitive} = true;
+          } else {
+            ${ev.primitive} = Float.isNaN(${eval.primitive});
+          }
+        """
+      case DoubleType =>
+        s"""
+          ${eval.code}
+          boolean ${ev.isNull} = false;
+          ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
+          if (${eval.isNull}) {
+            ${ev.primitive} = true;
+          } else {
+            ${ev.primitive} = Double.isNaN(${eval.primitive});
+          }
+        """
+    }
+  }
+}
+
+
+/**
+ * An expression that is evaluated to true if the input is null.
+ */
 case class IsNull(child: Expression) extends UnaryExpression with Predicate {
   override def nullable: Boolean = false
 
@@ -83,13 +150,14 @@ case class IsNull(child: Expression) extends 
UnaryExpression with Predicate {
     ev.primitive = eval.isNull
     eval.code
   }
-
-  override def toString: String = s"IS NULL $child"
 }
 
+
+/**
+ * An expression that is evaluated to true if the input is not null.
+ */
 case class IsNotNull(child: Expression) extends UnaryExpression with Predicate 
{
   override def nullable: Boolean = false
-  override def toString: String = s"IS NOT NULL $child"
 
   override def eval(input: InternalRow): Any = {
     child.eval(input) != null
@@ -103,12 +171,13 @@ case class IsNotNull(child: Expression) extends 
UnaryExpression with Predicate {
   }
 }
 
+
 /**
- * A predicate that is evaluated to be true if there are at least `n` non-null 
values.
+ * A predicate that is evaluated to be true if there are at least `n` non-null 
and non-NaN values.
  */
 case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends 
Predicate {
   override def nullable: Boolean = false
-  override def foldable: Boolean = false
+  override def foldable: Boolean = children.forall(_.foldable)
   override def toString: String = s"AtLeastNNulls(n, 
${children.mkString(",")})"
 
   private[this] val childrenArray = children.toArray

http://git-wip-us.apache.org/repos/asf/spark/blob/fba3f5ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2751c8e..bddd2a9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, 
CodeGenContext}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -120,56 +119,6 @@ case class InSet(child: Expression, hset: Set[Any])
   }
 }
 
-/**
- * Evaluates to `true` if it's NaN or null
- */
-case class IsNaN(child: Expression) extends UnaryExpression
-    with Predicate with ImplicitCastInputTypes {
-
-  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(DoubleType, FloatType))
-
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val value = child.eval(input)
-    if (value == null) {
-      true
-    } else {
-      child.dataType match {
-        case DoubleType => value.asInstanceOf[Double].isNaN
-        case FloatType => value.asInstanceOf[Float].isNaN
-      }
-    }
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    val eval = child.gen(ctx)
-    child.dataType match {
-      case FloatType =>
-        s"""
-          ${eval.code}
-          boolean ${ev.isNull} = false;
-          ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
-          if (${eval.isNull}) {
-            ${ev.primitive} = true;
-          } else {
-            ${ev.primitive} = Float.isNaN(${eval.primitive});
-          }
-        """
-      case DoubleType =>
-        s"""
-          ${eval.code}
-          boolean ${ev.isNull} = false;
-          ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
-          if (${eval.isNull}) {
-            ${ev.primitive} = true;
-          } else {
-            ${ev.primitive} = Double.isNaN(${eval.primitive});
-          }
-        """
-    }
-  }
-}
 
 case class And(left: Expression, right: Expression) extends BinaryOperator 
with Predicate {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fba3f5ba/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index ccdada8..765cc7a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -18,48 +18,52 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{BooleanType, StringType, ShortType}
+import org.apache.spark.sql.types._
 
 class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
-  test("null checking") {
-    val row = create_row("^Ba*n", null, true, null)
-    val c1 = 'a.string.at(0)
-    val c2 = 'a.string.at(1)
-    val c3 = 'a.boolean.at(2)
-    val c4 = 'a.boolean.at(3)
-
-    checkEvaluation(c1.isNull, false, row)
-    checkEvaluation(c1.isNotNull, true, row)
-
-    checkEvaluation(c2.isNull, true, row)
-    checkEvaluation(c2.isNotNull, false, row)
-
-    checkEvaluation(Literal.create(1, ShortType).isNull, false)
-    checkEvaluation(Literal.create(1, ShortType).isNotNull, true)
-
-    checkEvaluation(Literal.create(null, ShortType).isNull, true)
-    checkEvaluation(Literal.create(null, ShortType).isNotNull, false)
+  def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
+    testFunc(false, BooleanType)
+    testFunc(1.toByte, ByteType)
+    testFunc(1.toShort, ShortType)
+    testFunc(1, IntegerType)
+    testFunc(1L, LongType)
+    testFunc(1.0F, FloatType)
+    testFunc(1.0, DoubleType)
+    testFunc(Decimal(1.5), DecimalType.Unlimited)
+    testFunc(new java.sql.Date(10), DateType)
+    testFunc(new java.sql.Timestamp(10), TimestampType)
+    testFunc("abcd", StringType)
+  }
 
-    checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
-    checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, 
row)
-    checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: 
Nil), "^Ba*n", row)
+  test("isnull and isnotnull") {
+    testAllTypes { (value: Any, tpe: DataType) =>
+      checkEvaluation(IsNull(Literal.create(value, tpe)), false)
+      checkEvaluation(IsNotNull(Literal.create(value, tpe)), true)
+      checkEvaluation(IsNull(Literal.create(null, tpe)), true)
+      checkEvaluation(IsNotNull(Literal.create(null, tpe)), false)
+    }
+  }
 
-    checkEvaluation(
-      If(c3, Literal.create("a", StringType), Literal.create("b", 
StringType)), "a", row)
-    checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
-    checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
-    checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", 
row)
-    checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", 
row)
-    checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", 
row)
-    checkEvaluation(If(Literal.create(false, BooleanType),
-      Literal.create("a", StringType), Literal.create("b", StringType)), "b", 
row)
+  test("IsNaN") {
+    checkEvaluation(IsNaN(Literal(Double.NaN)), true)
+    checkEvaluation(IsNaN(Literal(Float.NaN)), true)
+    checkEvaluation(IsNaN(Literal(math.log(-3))), true)
+    checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
+    checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
+    checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
+    checkEvaluation(IsNaN(Literal(5.5f)), false)
+  }
 
-    checkEvaluation(c1 in (c1, c2), true, row)
-    checkEvaluation(
-      Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", 
StringType)), true, row)
-    checkEvaluation(
-      Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", 
StringType), c2), true, row)
+  test("coalesce") {
+    testAllTypes { (value: Any, tpe: DataType) =>
+      val lit = Literal.create(value, tpe)
+      val nullLit = Literal.create(null, tpe)
+      checkEvaluation(Coalesce(Seq(nullLit)), null)
+      checkEvaluation(Coalesce(Seq(lit)), value)
+      checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
+      checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
+      checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fba3f5ba/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 052abc5..2173a0c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(
       And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), 
Seq(Literal(1), Literal(2)))),
       true)
-  }
 
-  test("IsNaN") {
-    checkEvaluation(IsNaN(Literal(Double.NaN)), true)
-    checkEvaluation(IsNaN(Literal(Float.NaN)), true)
-    checkEvaluation(IsNaN(Literal(math.log(-3))), true)
-    checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
-    checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
-    checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
-    checkEvaluation(IsNaN(Literal(5.5f)), false)
+    checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
+    checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), 
Literal("^Ba*n"))), true)
+    checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), 
false)
   }
 
   test("INSET") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to