Repository: spark
Updated Branches:
  refs/heads/master 6718c1eb6 -> c46aaf47f


[SPARK-8759][SQL] add default eval to binary and unary expression according to 
default behavior of nullable

We have `nullSafeCodeGen` to provide default code generation for binary and 
unary expression, and we can do the same thing for `eval`.

Author: Wenchen Fan <cloud0...@outlook.com>

Closes #7157 from cloud-fan/refactor and squashes the following commits:

f3987c6 [Wenchen Fan] refactor Expression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c46aaf47
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c46aaf47
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c46aaf47

Branch: refs/heads/master
Commit: c46aaf47f38163e9c7be671d7b8398512df34e62
Parents: 6718c1e
Author: Wenchen Fan <cloud0...@outlook.com>
Authored: Mon Jul 6 22:13:50 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Jul 6 22:13:50 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   |   7 +-
 .../sql/catalyst/expressions/Expression.scala   |  69 ++++++-
 .../sql/catalyst/expressions/ExtractValue.scala |  51 ++---
 .../sql/catalyst/expressions/arithmetic.scala   |  97 ++++------
 .../sql/catalyst/expressions/bitwise.scala      |   8 +-
 .../catalyst/expressions/decimalFunctions.scala |  33 +---
 .../spark/sql/catalyst/expressions/math.scala   | 186 +++++--------------
 .../spark/sql/catalyst/expressions/misc.scala   | 153 ++++++---------
 .../sql/catalyst/expressions/predicates.scala   |  84 ++++-----
 .../spark/sql/catalyst/expressions/sets.scala   |   9 +-
 .../catalyst/expressions/stringOperations.scala | 138 ++++----------
 11 files changed, 292 insertions(+), 543 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 2d99d1a..4f73ba4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -114,8 +114,6 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
     }
   }
 
-  override def foldable: Boolean = child.foldable
-
   override def nullable: Boolean = Cast.forceNullable(child.dataType, 
dataType) || child.nullable
 
   override def toString: String = s"CAST($child, $dataType)"
@@ -426,10 +424,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
 
   private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
 
-  override def eval(input: InternalRow): Any = {
-    val evaluated = child.eval(input)
-    if (evaluated == null) null else cast(evaluated)
-  }
+  protected override def nullSafeEval(input: Any): Any = cast(input)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     // TODO: Add support for more data types.

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index cafbbaf..386feb9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -184,6 +184,27 @@ abstract class UnaryExpression extends Expression with 
trees.UnaryNode[Expressio
   override def nullable: Boolean = child.nullable
 
   /**
+   * Default behavior of evaluation according to the default nullability of 
UnaryExpression.
+   * If subclass of UnaryExpression override nullable, probably should also 
override this.
+   */
+  override def eval(input: InternalRow): Any = {
+    val value = child.eval(input)
+    if (value == null) {
+      null
+    } else {
+      nullSafeEval(value)
+    }
+  }
+
+  /**
+   * Called by default [[eval]] implementation.  If subclass of 
UnaryExpression keep the default
+   * nullability, they can override this method to save null-check code.  If 
we need full control
+   * of evaluation process, we should override [[eval]].
+   */
+  protected def nullSafeEval(input: Any): Any =
+    sys.error(s"UnaryExpressions must override either eval or nullSafeEval")
+
+  /**
    * Called by unary expressions to generate a code block that returns null if 
its parent returns
    * null, and if not not null, use `f` to generate the expression.
    *
@@ -198,21 +219,24 @@ abstract class UnaryExpression extends Expression with 
trees.UnaryNode[Expressio
       ctx: CodeGenContext,
       ev: GeneratedExpressionCode,
       f: String => String): String = {
-    nullSafeCodeGen(ctx, ev, (result, eval) => {
-      s"$result = ${f(eval)};"
+    nullSafeCodeGen(ctx, ev, eval => {
+      s"${ev.primitive} = ${f(eval)};"
     })
   }
 
   /**
    * Called by unary expressions to generate a code block that returns null if 
its parent returns
    * null, and if not not null, use `f` to generate the expression.
+   *
+   * @param f function that accepts the non-null evaluation result name of 
child and returns Java
+   *          code to compute the output.
    */
   protected def nullSafeCodeGen(
       ctx: CodeGenContext,
       ev: GeneratedExpressionCode,
-      f: (String, String) => String): String = {
+      f: String => String): String = {
     val eval = child.gen(ctx)
-    val resultCode = f(ev.primitive, eval.primitive)
+    val resultCode = f(eval.primitive)
     eval.code + s"""
       boolean ${ev.isNull} = ${eval.isNull};
       ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
@@ -236,6 +260,32 @@ abstract class BinaryExpression extends Expression with 
trees.BinaryNode[Express
   override def nullable: Boolean = left.nullable || right.nullable
 
   /**
+   * Default behavior of evaluation according to the default nullability of 
BinaryExpression.
+   * If subclass of BinaryExpression override nullable, probably should also 
override this.
+   */
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    if (value1 == null) {
+      null
+    } else {
+      val value2 = right.eval(input)
+      if (value2 == null) {
+        null
+      } else {
+        nullSafeEval(value1, value2)
+      }
+    }
+  }
+
+  /**
+   * Called by default [[eval]] implementation.  If subclass of 
BinaryExpression keep the default
+   * nullability, they can override this method to save null-check code.  If 
we need full control
+   * of evaluation process, we should override [[eval]].
+   */
+  protected def nullSafeEval(input1: Any, input2: Any): Any =
+    sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
+
+  /**
    * Short hand for generating binary evaluation code.
    * If either of the sub-expressions is null, the result of this computation
    * is assumed to be null.
@@ -246,8 +296,8 @@ abstract class BinaryExpression extends Expression with 
trees.BinaryNode[Express
     ctx: CodeGenContext,
     ev: GeneratedExpressionCode,
     f: (String, String) => String): String = {
-    nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
-      s"$result = ${f(eval1, eval2)};"
+    nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+      s"${ev.primitive} = ${f(eval1, eval2)};"
     })
   }
 
@@ -255,14 +305,17 @@ abstract class BinaryExpression extends Expression with 
trees.BinaryNode[Express
    * Short hand for generating binary evaluation code.
    * If either of the sub-expressions is null, the result of this computation
    * is assumed to be null.
+   *
+   * @param f function that accepts the 2 non-null evaluation result names of 
children
+   *          and returns Java code to compute the output.
    */
   protected def nullSafeCodeGen(
     ctx: CodeGenContext,
     ev: GeneratedExpressionCode,
-    f: (String, String, String) => String): String = {
+    f: (String, String) => String): String = {
     val eval1 = left.gen(ctx)
     val eval2 = right.gen(ctx)
-    val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
+    val resultCode = f(eval1.primitive, eval2.primitive)
     s"""
       ${eval1.code}
       boolean ${ev.isNull} = ${eval1.isNull};

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 3020e7f..e451c7f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -122,18 +122,16 @@ case class GetStructField(child: Expression, field: 
StructField, ordinal: Int)
   override def dataType: DataType = field.dataType
   override def nullable: Boolean = child.nullable || field.nullable
 
-  override def eval(input: InternalRow): Any = {
-    val baseValue = child.eval(input).asInstanceOf[InternalRow]
-    if (baseValue == null) null else baseValue(ordinal)
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    input.asInstanceOf[InternalRow](ordinal)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    nullSafeCodeGen(ctx, ev, (result, eval) => {
+    nullSafeCodeGen(ctx, ev, eval => {
       s"""
         if ($eval.isNullAt($ordinal)) {
           ${ev.isNull} = true;
         } else {
-          $result = ${ctx.getColumn(eval, dataType, ordinal)};
+          ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)};
         }
       """
     })
@@ -152,12 +150,9 @@ case class GetArrayStructFields(
   override def dataType: DataType = ArrayType(field.dataType, containsNull)
   override def nullable: Boolean = child.nullable || containsNull || 
field.nullable
 
-  override def eval(input: InternalRow): Any = {
-    val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
-    if (baseValue == null) null else {
-      baseValue.map { row =>
-        if (row == null) null else row(ordinal)
-      }
+  protected override def nullSafeEval(input: Any): Any = {
+    input.asInstanceOf[Seq[InternalRow]].map { row =>
+      if (row == null) null else row(ordinal)
     }
   }
 
@@ -165,7 +160,7 @@ case class GetArrayStructFields(
     val arraySeqClass = "scala.collection.mutable.ArraySeq"
     // TODO: consider using Array[_] for ArrayType child to avoid
     // boxing of primitives
-    nullSafeCodeGen(ctx, ev, (result, eval) => {
+    nullSafeCodeGen(ctx, ev, eval => {
       s"""
         final int n = $eval.size();
         final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n);
@@ -175,7 +170,7 @@ case class GetArrayStructFields(
             values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
           }
         }
-        $result = (${ctx.javaType(dataType)}) values;
+        ${ev.primitive} = (${ctx.javaType(dataType)}) values;
       """
     })
   }
@@ -193,22 +188,6 @@ abstract class ExtractValueWithOrdinal extends 
BinaryExpression with ExtractValu
   /** `Null` is returned for invalid ordinals. */
   override def nullable: Boolean = true
   override def toString: String = s"$child[$ordinal]"
-
-  override def eval(input: InternalRow): Any = {
-    val value = child.eval(input)
-    if (value == null) {
-      null
-    } else {
-      val o = ordinal.eval(input)
-      if (o == null) {
-        null
-      } else {
-        evalNotNull(value, o)
-      }
-    }
-  }
-
-  protected def evalNotNull(value: Any, ordinal: Any): Any
 }
 
 /**
@@ -219,7 +198,7 @@ case class GetArrayItem(child: Expression, ordinal: 
Expression)
 
   override def dataType: DataType = 
child.dataType.asInstanceOf[ArrayType].elementType
 
-  protected def evalNotNull(value: Any, ordinal: Any) = {
+  protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
     // TODO: consider using Array[_] for ArrayType child to avoid
     // boxing of primitives
     val baseValue = value.asInstanceOf[Seq[_]]
@@ -232,13 +211,13 @@ case class GetArrayItem(child: Expression, ordinal: 
Expression)
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+    nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
       s"""
         final int index = (int)$eval2;
         if (index >= $eval1.size() || index < 0) {
           ${ev.isNull} = true;
         } else {
-          $result = (${ctx.boxedType(dataType)})$eval1.apply(index);
+          ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index);
         }
       """
     })
@@ -253,16 +232,16 @@ case class GetMapValue(child: Expression, ordinal: 
Expression)
 
   override def dataType: DataType = 
child.dataType.asInstanceOf[MapType].valueType
 
-  protected def evalNotNull(value: Any, ordinal: Any) = {
+  protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
     val baseValue = value.asInstanceOf[Map[Any, _]]
     baseValue.get(ordinal).orNull
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+    nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
       s"""
         if ($eval1.contains($eval2)) {
-          $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
+          ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
         } else {
           ${ev.isNull} = true;
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 4fbf4c8..dca6642 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -26,18 +26,6 @@ abstract class UnaryArithmetic extends UnaryExpression {
   self: Product =>
 
   override def dataType: DataType = child.dataType
-
-  override def eval(input: InternalRow): Any = {
-    val evalE = child.eval(input)
-    if (evalE == null) {
-      null
-    } else {
-      evalInternal(evalE)
-    }
-  }
-
-  protected def evalInternal(evalE: Any): Any =
-    sys.error(s"UnaryArithmetics must override either eval or evalInternal")
 }
 
 case class UnaryMinus(child: Expression) extends UnaryArithmetic {
@@ -53,7 +41,7 @@ case class UnaryMinus(child: Expression) extends 
UnaryArithmetic {
     case dt: NumericType => defineCodeGen(ctx, ev, c => 
s"(${ctx.javaType(dt)})(-($c))")
   }
 
-  protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
+  protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
 }
 
 case class UnaryPositive(child: Expression) extends UnaryArithmetic {
@@ -62,7 +50,7 @@ case class UnaryPositive(child: Expression) extends 
UnaryArithmetic {
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String =
     defineCodeGen(ctx, ev, c => c)
 
-  protected override def evalInternal(evalE: Any) = evalE
+  protected override def nullSafeEval(input: Any): Any = input
 }
 
 /**
@@ -74,7 +62,7 @@ case class Abs(child: Expression) extends UnaryArithmetic {
 
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
-  protected override def evalInternal(evalE: Any) = numeric.abs(evalE)
+  protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
 }
 
 abstract class BinaryArithmetic extends BinaryOperator {
@@ -94,20 +82,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
 
   protected def checkTypesInternal(t: DataType): TypeCheckResult
 
-  override def eval(input: InternalRow): Any = {
-    val evalE1 = left.eval(input)
-    if(evalE1 == null) {
-      null
-    } else {
-      val evalE2 = right.eval(input)
-      if (evalE2 == null) {
-        null
-      } else {
-        evalInternal(evalE1, evalE2)
-      }
-    }
-  }
-
   /** Name of the function for this expression on a [[Decimal]] type. */
   def decimalMethod: String =
     sys.error("BinaryArithmetics must override either decimalMethod or 
genCode")
@@ -122,9 +96,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
     case _ =>
       defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
   }
-
-  protected def evalInternal(evalE1: Any, evalE2: Any): Any =
-    sys.error(s"BinaryArithmetics must override either eval or evalInternal")
 }
 
 private[sql] object BinaryArithmetic {
@@ -143,7 +114,7 @@ case class Add(left: Expression, right: Expression) extends 
BinaryArithmetic {
 
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = 
numeric.plus(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
numeric.plus(input1, input2)
 }
 
 case class Subtract(left: Expression, right: Expression) extends 
BinaryArithmetic {
@@ -158,7 +129,7 @@ case class Subtract(left: Expression, right: Expression) 
extends BinaryArithmeti
 
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = 
numeric.minus(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
numeric.minus(input1, input2)
 }
 
 case class Multiply(left: Expression, right: Expression) extends 
BinaryArithmetic {
@@ -173,7 +144,7 @@ case class Multiply(left: Expression, right: Expression) 
extends BinaryArithmeti
 
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = 
numeric.times(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
numeric.times(input1, input2)
 }
 
 case class Divide(left: Expression, right: Expression) extends 
BinaryArithmetic {
@@ -194,15 +165,15 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
   }
 
   override def eval(input: InternalRow): Any = {
-    val evalE2 = right.eval(input)
-    if (evalE2 == null || evalE2 == 0) {
+    val input2 = right.eval(input)
+    if (input2 == null || input2 == 0) {
       null
     } else {
-      val evalE1 = left.eval(input)
-      if (evalE1 == null) {
+      val input1 = left.eval(input)
+      if (input1 == null) {
         null
       } else {
-        div(evalE1, evalE2)
+        div(input1, input2)
       }
     }
   }
@@ -260,15 +231,15 @@ case class Remainder(left: Expression, right: Expression) 
extends BinaryArithmet
   }
 
   override def eval(input: InternalRow): Any = {
-    val evalE2 = right.eval(input)
-    if (evalE2 == null || evalE2 == 0) {
+    val input2 = right.eval(input)
+    if (input2 == null || input2 == 0) {
       null
     } else {
-      val evalE1 = left.eval(input)
-      if (evalE1 == null) {
+      val input1 = left.eval(input)
+      if (input1 == null) {
         null
       } else {
-        integral.rem(evalE1, evalE2)
+        integral.rem(input1, input2)
       }
     }
   }
@@ -317,17 +288,17 @@ case class MaxOf(left: Expression, right: Expression) 
extends BinaryArithmetic {
   private lazy val ordering = TypeUtils.getOrdering(dataType)
 
   override def eval(input: InternalRow): Any = {
-    val evalE1 = left.eval(input)
-    val evalE2 = right.eval(input)
-    if (evalE1 == null) {
-      evalE2
-    } else if (evalE2 == null) {
-      evalE1
+    val input1 = left.eval(input)
+    val input2 = right.eval(input)
+    if (input1 == null) {
+      input2
+    } else if (input2 == null) {
+      input1
     } else {
-      if (ordering.compare(evalE1, evalE2) < 0) {
-        evalE2
+      if (ordering.compare(input1, input2) < 0) {
+        input2
       } else {
-        evalE1
+        input1
       }
     }
   }
@@ -371,17 +342,17 @@ case class MinOf(left: Expression, right: Expression) 
extends BinaryArithmetic {
   private lazy val ordering = TypeUtils.getOrdering(dataType)
 
   override def eval(input: InternalRow): Any = {
-    val evalE1 = left.eval(input)
-    val evalE2 = right.eval(input)
-    if (evalE1 == null) {
-      evalE2
-    } else if (evalE2 == null) {
-      evalE1
+    val input1 = left.eval(input)
+    val input2 = right.eval(input)
+    if (input1 == null) {
+      input2
+    } else if (input2 == null) {
+      input1
     } else {
-      if (ordering.compare(evalE1, evalE2) < 0) {
-        evalE1
+      if (ordering.compare(input1, input2) < 0) {
+        input1
       } else {
-        evalE2
+        input2
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
index 9002dda..2d47124 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
@@ -45,7 +45,7 @@ case class BitwiseAnd(left: Expression, right: Expression) 
extends BinaryArithme
       ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, 
Any) => Any]
   }
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, 
evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
and(input1, input2)
 }
 
 /**
@@ -70,7 +70,7 @@ case class BitwiseOr(left: Expression, right: Expression) 
extends BinaryArithmet
       ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, 
Any) => Any]
   }
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, 
evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
or(input1, input2)
 }
 
 /**
@@ -95,7 +95,7 @@ case class BitwiseXor(left: Expression, right: Expression) 
extends BinaryArithme
       ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, 
Any) => Any]
   }
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any): Any = 
xor(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
xor(input1, input2)
 }
 
 /**
@@ -122,5 +122,5 @@ case class BitwiseNot(child: Expression) extends 
UnaryArithmetic {
     defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
   }
 
-  protected override def evalInternal(evalE: Any) = not(evalE)
+  protected override def nullSafeEval(input: Any): Any = not(input)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index f5c2dde..2fa74b4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -30,14 +30,8 @@ case class UnscaledValue(child: Expression) extends 
UnaryExpression {
   override def dataType: DataType = LongType
   override def toString: String = s"UnscaledValue($child)"
 
-  override def eval(input: InternalRow): Any = {
-    val childResult = child.eval(input)
-    if (childResult == null) {
-      null
-    } else {
-      childResult.asInstanceOf[Decimal].toUnscaledLong
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    input.asInstanceOf[Decimal].toUnscaledLong
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
@@ -54,26 +48,15 @@ case class MakeDecimal(child: Expression, precision: Int, 
scale: Int) extends Un
   override def dataType: DataType = DecimalType(precision, scale)
   override def toString: String = s"MakeDecimal($child,$precision,$scale)"
 
-  override def eval(input: InternalRow): Decimal = {
-    val childResult = child.eval(input)
-    if (childResult == null) {
-      null
-    } else {
-      new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    Decimal(input.asInstanceOf[Long], precision, scale)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    val eval = child.gen(ctx)
-    eval.code + s"""
-      boolean ${ev.isNull} = ${eval.isNull};
-      ${ctx.decimalType} ${ev.primitive} = null;
-
-      if (!${ev.isNull}) {
-        ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
-          ${eval.primitive}, $precision, $scale);
+    nullSafeCodeGen(ctx, ev, eval => {
+      s"""
+        ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, 
$precision, $scale);
         ${ev.isNull} = ${ev.primitive} == null;
-      }
       """
+    })
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 9250045..9dca851 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -61,21 +61,16 @@ abstract class UnaryMathExpression(f: Double => Double, 
name: String)
   override def nullable: Boolean = true
   override def toString: String = s"$name($child)"
 
-  override def eval(input: InternalRow): Any = {
-    val evalE = child.eval(input)
-    if (evalE == null) {
-      null
-    } else {
-      val result = f(evalE.asInstanceOf[Double])
-      if (result.isNaN) null else result
-    }
+  protected override def nullSafeEval(input: Any): Any = {
+    val result = f(input.asInstanceOf[Double])
+    if (result.isNaN) null else result
   }
 
   // name of function in java.lang.Math
   def funcName: String = name.toLowerCase
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    nullSafeCodeGen(ctx, ev, (result, eval) => {
+    nullSafeCodeGen(ctx, ev, eval => {
       s"""
         ${ev.primitive} = java.lang.Math.${funcName}($eval);
         if (Double.valueOf(${ev.primitive}).isNaN()) {
@@ -101,19 +96,9 @@ abstract class BinaryMathExpression(f: (Double, Double) => 
Double, name: String)
 
   override def dataType: DataType = DoubleType
 
-  override def eval(input: InternalRow): Any = {
-    val evalE1 = left.eval(input)
-    if (evalE1 == null) {
-      null
-    } else {
-      val evalE2 = right.eval(input)
-      if (evalE2 == null) {
-        null
-      } else {
-        val result = f(evalE1.asInstanceOf[Double], 
evalE2.asInstanceOf[Double])
-        if (result.isNaN) null else result
-      }
-    }
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
+    if (result.isNaN) null else result
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
@@ -194,39 +179,29 @@ case class Factorial(child: Expression) extends 
UnaryExpression with ExpectsInpu
 
   override def dataType: DataType = LongType
 
-  override def foldable: Boolean = child.foldable
-
   // If the value not in the range of [0, 20], it still will be null, so set 
it to be true here.
   override def nullable: Boolean = true
 
-  override def eval(input: InternalRow): Any = {
-    val evalE = child.eval(input)
-    if (evalE == null) {
+  protected override def nullSafeEval(input: Any): Any = {
+    val value = input.asInstanceOf[jl.Integer]
+    if (value > 20 || value < 0) {
       null
     } else {
-      val input = evalE.asInstanceOf[jl.Integer]
-      if (input > 20 || input < 0) {
-        null
-      } else {
-        Factorial.factorial(input)
-      }
+      Factorial.factorial(value)
     }
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    val eval = child.gen(ctx)
-    eval.code + s"""
-      boolean ${ev.isNull} = ${eval.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        if (${eval.primitive} > 20 || ${eval.primitive} < 0) {
+    nullSafeCodeGen(ctx, ev, eval => {
+      s"""
+        if ($eval > 20 || $eval < 0) {
           ${ev.isNull} = true;
         } else {
           ${ev.primitive} =
-            
org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive});
+            
org.apache.spark.sql.catalyst.expressions.Factorial.factorial($eval);
         }
-      }
-    """
+      """
+    })
   }
 }
 
@@ -235,17 +210,14 @@ case class Log(child: Expression) extends 
UnaryMathExpression(math.log, "LOG")
 case class Log2(child: Expression)
   extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), 
"LOG2") {
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    val eval = child.gen(ctx)
-    eval.code + s"""
-      boolean ${ev.isNull} = ${eval.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / 
java.lang.Math.log(2);
+    nullSafeCodeGen(ctx, ev, eval => {
+      s"""
+        ${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2);
         if (Double.valueOf(${ev.primitive}).isNaN()) {
           ${ev.isNull} = true;
         }
-      }
-    """
+      """
+    })
   }
 }
 
@@ -283,14 +255,8 @@ case class Bin(child: Expression)
   override def inputTypes: Seq[DataType] = Seq(LongType)
   override def dataType: DataType = StringType
 
-  override def eval(input: InternalRow): Any = {
-    val evalE = child.eval(input)
-    if (evalE == null) {
-      null
-    } else {
-      UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long]))
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long]))
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     defineCodeGen(ctx, ev, (c) =>
@@ -326,17 +292,10 @@ case class Hex(child: Expression) extends UnaryExpression 
with ExpectsInputTypes
 
   override def dataType: DataType = StringType
 
-  override def eval(input: InternalRow): Any = {
-    val num = child.eval(input)
-    if (num == null) {
-      null
-    } else {
-      child.dataType match {
-        case LongType => hex(num.asInstanceOf[Long])
-        case BinaryType => hex(num.asInstanceOf[Array[Byte]])
-        case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
-      }
-    }
+  protected override def nullSafeEval(num: Any): Any = child.dataType match {
+    case LongType => hex(num.asInstanceOf[Long])
+    case BinaryType => hex(num.asInstanceOf[Array[Byte]])
+    case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
   }
 
   private[this] def hex(bytes: Array[Byte]): UTF8String = {
@@ -377,14 +336,8 @@ case class Unhex(child: Expression) extends 
UnaryExpression with ExpectsInputTyp
   override def nullable: Boolean = true
   override def dataType: DataType = BinaryType
 
-  override def eval(input: InternalRow): Any = {
-    val num = child.eval(input)
-    if (num == null) {
-      null
-    } else {
-      unhex(num.asInstanceOf[UTF8String].getBytes)
-    }
-  }
+  protected override def nullSafeEval(num: Any): Any =
+    unhex(num.asInstanceOf[UTF8String].getBytes)
 
   private[this] def unhex(bytes: Array[Byte]): Array[Byte] = {
     val out = new Array[Byte]((bytes.length + 1) >> 1)
@@ -429,21 +382,10 @@ case class Unhex(child: Expression) extends 
UnaryExpression with ExpectsInputTyp
 case class Atan2(left: Expression, right: Expression)
   extends BinaryMathExpression(math.atan2, "ATAN2") {
 
-  override def eval(input: InternalRow): Any = {
-    val evalE1 = left.eval(input)
-    if (evalE1 == null) {
-      null
-    } else {
-      val evalE2 = right.eval(input)
-      if (evalE2 == null) {
-        null
-      } else {
-        // With codegen, the values returned by -0.0 and 0.0 are different. 
Handled with +0.0
-        val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
-          evalE2.asInstanceOf[Double] + 0.0)
-        if (result.isNaN) null else result
-      }
-    }
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    // With codegen, the values returned by -0.0 and 0.0 are different. 
Handled with +0.0
+    val result = math.atan2(input1.asInstanceOf[Double] + 0.0, 
input2.asInstanceOf[Double] + 0.0)
+    if (result.isNaN) null else result
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
@@ -480,25 +422,15 @@ case class ShiftLeft(left: Expression, right: Expression)
 
   override def dataType: DataType = left.dataType
 
-  override def eval(input: InternalRow): Any = {
-    val valueLeft = left.eval(input)
-    if (valueLeft != null) {
-      val valueRight = right.eval(input)
-      if (valueRight != null) {
-        valueLeft match {
-          case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer]
-          case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer]
-        }
-      } else {
-        null
-      }
-    } else {
-      null
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    input1 match {
+      case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
+      case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
     }
   }
 
   override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String = {
-    nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << 
$right;")
+    defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
   }
 }
 
@@ -516,25 +448,15 @@ case class ShiftRight(left: Expression, right: Expression)
 
   override def dataType: DataType = left.dataType
 
-  override def eval(input: InternalRow): Any = {
-    val valueLeft = left.eval(input)
-    if (valueLeft != null) {
-      val valueRight = right.eval(input)
-      if (valueRight != null) {
-        valueLeft match {
-          case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer]
-          case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer]
-        }
-      } else {
-        null
-      }
-    } else {
-      null
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    input1 match {
+      case l: jl.Long => l >> input2.asInstanceOf[jl.Integer]
+      case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer]
     }
   }
 
   override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String = {
-    nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> 
$right;")
+    defineCodeGen(ctx, ev, (left, right) => s"$left >> $right")
   }
 }
 
@@ -552,25 +474,15 @@ case class ShiftRightUnsigned(left: Expression, right: 
Expression)
 
   override def dataType: DataType = left.dataType
 
-  override def eval(input: InternalRow): Any = {
-    val valueLeft = left.eval(input)
-    if (valueLeft != null) {
-      val valueRight = right.eval(input)
-      if (valueRight != null) {
-        valueLeft match {
-          case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer]
-          case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer]
-        }
-      } else {
-        null
-      }
-    } else {
-      null
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    input1 match {
+      case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer]
+      case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer]
     }
   }
 
   override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String = {
-    nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> 
$right;")
+    defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right")
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e008af3..3b59cd4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -37,14 +37,8 @@ case class Md5(child: Expression) extends UnaryExpression 
with ExpectsInputTypes
 
   override def inputTypes: Seq[DataType] = Seq(BinaryType)
 
-  override def eval(input: InternalRow): Any = {
-    val value = child.eval(input)
-    if (value == null) {
-      null
-    } else {
-      
UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]]))
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]]))
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     defineCodeGen(ctx, ev, c =>
@@ -67,76 +61,56 @@ case class Sha2(left: Expression, right: Expression)
 
   override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
 
-  override def eval(input: InternalRow): Any = {
-    val evalE1 = left.eval(input)
-    if (evalE1 == null) {
-      null
-    } else {
-      val evalE2 = right.eval(input)
-      if (evalE2 == null) {
-        null
-      } else {
-        val bitLength = evalE2.asInstanceOf[Int]
-        val input = evalE1.asInstanceOf[Array[Byte]]
-        bitLength match {
-          case 224 =>
-            // DigestUtils doesn't support SHA-224 now
-            try {
-              val md = MessageDigest.getInstance("SHA-224")
-              md.update(input)
-              UTF8String.fromBytes(md.digest())
-            } catch {
-              // SHA-224 is not supported on the system, return null
-              case noa: NoSuchAlgorithmException => null
-            }
-          case 256 | 0 =>
-            UTF8String.fromString(DigestUtils.sha256Hex(input))
-          case 384 =>
-            UTF8String.fromString(DigestUtils.sha384Hex(input))
-          case 512 =>
-            UTF8String.fromString(DigestUtils.sha512Hex(input))
-          case _ => null
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val bitLength = input2.asInstanceOf[Int]
+    val input = input1.asInstanceOf[Array[Byte]]
+    bitLength match {
+      case 224 =>
+        // DigestUtils doesn't support SHA-224 now
+        try {
+          val md = MessageDigest.getInstance("SHA-224")
+          md.update(input)
+          UTF8String.fromBytes(md.digest())
+        } catch {
+          // SHA-224 is not supported on the system, return null
+          case noa: NoSuchAlgorithmException => null
         }
-      }
+      case 256 | 0 =>
+        UTF8String.fromString(DigestUtils.sha256Hex(input))
+      case 384 =>
+        UTF8String.fromString(DigestUtils.sha384Hex(input))
+      case 512 =>
+        UTF8String.fromString(DigestUtils.sha512Hex(input))
+      case _ => null
     }
   }
+
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    val eval1 = left.gen(ctx)
-    val eval2 = right.gen(ctx)
     val digestUtils = "org.apache.commons.codec.digest.DigestUtils"
-
-    s"""
-      ${eval1.code}
-      boolean ${ev.isNull} = ${eval1.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${eval2.code}
-        if (!${eval2.isNull}) {
-          if (${eval2.primitive} == 224) {
-            try {
-              java.security.MessageDigest md = 
java.security.MessageDigest.getInstance("SHA-224");
-              md.update(${eval1.primitive});
-              ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
-            } catch (java.security.NoSuchAlgorithmException e) {
-              ${ev.isNull} = true;
-            }
-          } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) {
-            ${ev.primitive} =
-              
${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive}));
-          } else if (${eval2.primitive} == 384) {
-            ${ev.primitive} =
-              
${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive}));
-          } else if (${eval2.primitive} == 512) {
-            ${ev.primitive} =
-              
${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive}));
-          } else {
+    nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+      s"""
+        if ($eval2 == 224) {
+          try {
+            java.security.MessageDigest md = 
java.security.MessageDigest.getInstance("SHA-224");
+            md.update($eval1);
+            ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
+          } catch (java.security.NoSuchAlgorithmException e) {
             ${ev.isNull} = true;
           }
+        } else if ($eval2 == 256 || $eval2 == 0) {
+          ${ev.primitive} =
+            ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1));
+        } else if ($eval2 == 384) {
+          ${ev.primitive} =
+            ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1));
+        } else if ($eval2 == 512) {
+          ${ev.primitive} =
+            ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1));
         } else {
           ${ev.isNull} = true;
         }
-      }
-    """
+      """
+    })
   }
 }
 
@@ -150,19 +124,12 @@ case class Sha1(child: Expression) extends 
UnaryExpression with ExpectsInputType
 
   override def inputTypes: Seq[DataType] = Seq(BinaryType)
 
-  override def eval(input: InternalRow): Any = {
-    val value = child.eval(input)
-    if (value == null) {
-      null
-    } else {
-      
UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]]))
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    UTF8String.fromString(DigestUtils.shaHex(input.asInstanceOf[Array[Byte]]))
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     defineCodeGen(ctx, ev, c =>
-      "org.apache.spark.unsafe.types.UTF8String.fromString" +
-        s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
+      
s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
     )
   }
 }
@@ -177,30 +144,20 @@ case class Crc32(child: Expression) extends 
UnaryExpression with ExpectsInputTyp
 
   override def inputTypes: Seq[DataType] = Seq(BinaryType)
 
-  override def eval(input: InternalRow): Any = {
-    val value = child.eval(input)
-    if (value == null) {
-      null
-    } else {
-      val checksum = new CRC32
-      checksum.update(value.asInstanceOf[Array[Byte]], 0, 
value.asInstanceOf[Array[Byte]].length)
-      checksum.getValue
-    }
+  protected override def nullSafeEval(input: Any): Any = {
+    val checksum = new CRC32
+    checksum.update(input.asInstanceOf[Array[Byte]], 0, 
input.asInstanceOf[Array[Byte]].length)
+    checksum.getValue
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    val value = child.gen(ctx)
     val CRC32 = "java.util.zip.CRC32"
-    s"""
-      ${value.code}
-      boolean ${ev.isNull} = ${value.isNull};
-      long ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${CRC32} checksum = new ${CRC32}();
-        checksum.update(${value.primitive}, 0, ${value.primitive}.length);
+    nullSafeCodeGen(ctx, ev, value => {
+      s"""
+        $CRC32 checksum = new $CRC32();
+        checksum.update($value, 0, $value.length);
         ${ev.primitive} = checksum.getValue();
-      }
-    """
+      """
+    })
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 0b479f4..402a0aa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -74,12 +74,7 @@ case class Not(child: Expression) extends UnaryExpression 
with Predicate with Ex
 
   override def inputTypes: Seq[DataType] = Seq(BooleanType)
 
-  override def eval(input: InternalRow): Any = {
-    child.eval(input) match {
-      case null => null
-      case b: Boolean => !b
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any = 
!input.asInstanceOf[Boolean]
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     defineCodeGen(ctx, ev, c => s"!($c)")
@@ -105,17 +100,14 @@ case class In(value: Expression, list: Seq[Expression]) 
extends Predicate {
  * Optimized version of In clause, when all filter values of In clause are
  * static.
  */
-case class InSet(value: Expression, hset: Set[Any])
-  extends Predicate {
-
-  override def children: Seq[Expression] = value :: Nil
+case class InSet(child: Expression, hset: Set[Any])
+  extends UnaryExpression with Predicate {
 
-  override def foldable: Boolean = value.foldable
   override def nullable: Boolean = true // TODO: Figure out correct 
nullability semantics of IN.
-  override def toString: String = s"$value INSET ${hset.mkString("(", ",", 
")")}"
+  override def toString: String = s"$child INSET ${hset.mkString("(", ",", 
")")}"
 
   override def eval(input: InternalRow): Any = {
-    hset.contains(value.eval(input))
+    hset.contains(child.eval(input))
   }
 }
 
@@ -127,15 +119,15 @@ case class And(left: Expression, right: Expression)
   override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
 
   override def eval(input: InternalRow): Any = {
-    val l = left.eval(input)
-    if (l == false) {
+    val input1 = left.eval(input)
+    if (input1 == false) {
        false
     } else {
-      val r = right.eval(input)
-      if (r == false) {
+      val input2 = right.eval(input)
+      if (input2 == false) {
         false
       } else {
-        if (l != null && r != null) {
+        if (input1 != null && input2 != null) {
           true
         } else {
           null
@@ -176,15 +168,15 @@ case class Or(left: Expression, right: Expression)
   override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
 
   override def eval(input: InternalRow): Any = {
-    val l = left.eval(input)
-    if (l == true) {
+    val input1 = left.eval(input)
+    if (input1 == true) {
       true
     } else {
-      val r = right.eval(input)
-      if (r == true) {
+      val input2 = right.eval(input)
+      if (input2 == true) {
         true
       } else {
-        if (l != null && r != null) {
+        if (input1 != null && input2 != null) {
           false
         } else {
           null
@@ -232,20 +224,6 @@ abstract class BinaryComparison extends BinaryOperator 
with Predicate {
 
   protected def checkTypesInternal(t: DataType): TypeCheckResult
 
-  override def eval(input: InternalRow): Any = {
-    val evalE1 = left.eval(input)
-    if (evalE1 == null) {
-      null
-    } else {
-      val evalE2 = right.eval(input)
-      if (evalE2 == null) {
-        null
-      } else {
-        evalInternal(evalE1, evalE2)
-      }
-    }
-  }
-
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     if (ctx.isPrimitiveType(left.dataType)) {
       // faster version
@@ -254,9 +232,6 @@ abstract class BinaryComparison extends BinaryOperator with 
Predicate {
       defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, 
c2)} $symbol 0")
     }
   }
-
-  protected def evalInternal(evalE1: Any, evalE2: Any): Any =
-    sys.error(s"BinaryComparisons must override either eval or evalInternal")
 }
 
 private[sql] object BinaryComparison {
@@ -277,9 +252,9 @@ case class EqualTo(left: Expression, right: Expression) 
extends BinaryComparison
 
   override protected def checkTypesInternal(t: DataType) = 
TypeCheckResult.TypeCheckSuccess
 
-  protected override def evalInternal(l: Any, r: Any) = {
-    if (left.dataType != BinaryType) l == r
-    else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], 
r.asInstanceOf[Array[Byte]])
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    if (left.dataType != BinaryType) input1 == input2
+    else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], 
input2.asInstanceOf[Array[Byte]])
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
@@ -295,15 +270,18 @@ case class EqualNullSafe(left: Expression, right: 
Expression) extends BinaryComp
   override protected def checkTypesInternal(t: DataType) = 
TypeCheckResult.TypeCheckSuccess
 
   override def eval(input: InternalRow): Any = {
-    val l = left.eval(input)
-    val r = right.eval(input)
-    if (l == null && r == null) {
+    val input1 = left.eval(input)
+    val input2 = right.eval(input)
+    if (input1 == null && input2 == null) {
       true
-    } else if (l == null || r == null) {
+    } else if (input1 == null || input2 == null) {
       false
     } else {
-      if (left.dataType != BinaryType) l == r
-      else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], 
r.asInstanceOf[Array[Byte]])
+      if (left.dataType != BinaryType) {
+        input1 == input2
+      } else {
+        java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], 
input2.asInstanceOf[Array[Byte]])
+      }
     }
   }
 
@@ -327,7 +305,7 @@ case class LessThan(left: Expression, right: Expression) 
extends BinaryCompariso
 
   private lazy val ordering = TypeUtils.getOrdering(left.dataType)
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = 
ordering.lt(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.lt(input1, input2)
 }
 
 case class LessThanOrEqual(left: Expression, right: Expression) extends 
BinaryComparison {
@@ -338,7 +316,7 @@ case class LessThanOrEqual(left: Expression, right: 
Expression) extends BinaryCo
 
   private lazy val ordering = TypeUtils.getOrdering(left.dataType)
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = 
ordering.lteq(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.lteq(input1, input2)
 }
 
 case class GreaterThan(left: Expression, right: Expression) extends 
BinaryComparison {
@@ -349,7 +327,7 @@ case class GreaterThan(left: Expression, right: Expression) 
extends BinaryCompar
 
   private lazy val ordering = TypeUtils.getOrdering(left.dataType)
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = 
ordering.gt(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.gt(input1, input2)
 }
 
 case class GreaterThanOrEqual(left: Expression, right: Expression) extends 
BinaryComparison {
@@ -360,5 +338,5 @@ case class GreaterThanOrEqual(left: Expression, right: 
Expression) extends Binar
 
   private lazy val ordering = TypeUtils.getOrdering(left.dataType)
 
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = 
ordering.gteq(evalE1, evalE2)
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.gteq(input1, input2)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 5d51a4c..9b44fb1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -135,6 +135,7 @@ case class AddItemToSet(item: Expression, set: Expression) 
extends Expression {
  */
 case class CombineSets(left: Expression, right: Expression) extends 
BinaryExpression {
 
+  override def nullable: Boolean = left.nullable
   override def dataType: DataType = left.dataType
 
   override def eval(input: InternalRow): Any = {
@@ -183,12 +184,8 @@ case class CountSet(child: Expression) extends 
UnaryExpression {
 
   override def dataType: DataType = LongType
 
-  override def eval(input: InternalRow): Any = {
-    val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]]
-    if (childEval != null) {
-      childEval.size.toLong
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    input.asInstanceOf[OpenHashSet[Any]].size.toLong
 
   override def toString: String = s"$child.count()"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c46aaf47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 1a14a7a..6e6a7fb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -31,7 +31,6 @@ trait StringRegexExpression extends ExpectsInputTypes {
   def escape(v: String): String
   def matches(regex: Pattern, str: String): Boolean
 
-  override def nullable: Boolean = left.nullable || right.nullable
   override def dataType: DataType = BooleanType
   override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
 
@@ -50,22 +49,12 @@ trait StringRegexExpression extends ExpectsInputTypes {
 
   protected def pattern(str: String) = if (cache == null) compile(str) else 
cache
 
-  override def eval(input: InternalRow): Any = {
-    val l = left.eval(input)
-    if (l == null) {
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val regex = pattern(input2.asInstanceOf[UTF8String].toString())
+    if(regex == null) {
       null
     } else {
-      val r = right.eval(input)
-      if(r == null) {
-        null
-      } else {
-        val regex = pattern(r.asInstanceOf[UTF8String].toString())
-        if(regex == null) {
-          null
-        } else {
-          matches(regex, l.asInstanceOf[UTF8String].toString())
-        }
-      }
+      matches(regex, input1.asInstanceOf[UTF8String].toString())
     }
   }
 }
@@ -120,14 +109,8 @@ trait CaseConversionExpression extends ExpectsInputTypes {
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val evaluated = child.eval(input)
-    if (evaluated == null) {
-      null
-    } else {
-      convert(evaluated.asInstanceOf[UTF8String])
-    }
-  }
+  protected override def nullSafeEval(input: Any): Any =
+    convert(input.asInstanceOf[UTF8String])
 }
 
 /**
@@ -160,20 +143,10 @@ trait StringComparison extends ExpectsInputTypes {
 
   def compare(l: UTF8String, r: UTF8String): Boolean
 
-  override def nullable: Boolean = left.nullable || right.nullable
-
   override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val leftEval = left.eval(input)
-    if(leftEval == null) {
-      null
-    } else {
-      val rightEval = right.eval(input)
-      if (rightEval == null) null
-      else compare(leftEval.asInstanceOf[UTF8String], 
rightEval.asInstanceOf[UTF8String])
-    }
-  }
+  protected override def nullSafeEval(input1: Any, input2: Any): Any =
+    compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String])
 
   override def toString: String = s"$nodeName($left, $right)"
 }
@@ -288,10 +261,8 @@ case class StringLength(child: Expression) extends 
UnaryExpression with ExpectsI
   override def dataType: DataType = IntegerType
   override def inputTypes: Seq[DataType] = Seq(StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val string = child.eval(input)
-    if (string == null) null else string.asInstanceOf[UTF8String].length
-  }
+  protected override def nullSafeEval(string: Any): Any =
+    string.asInstanceOf[UTF8String].length
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     defineCodeGen(ctx, ev, c => s"($c).length()")
@@ -310,24 +281,13 @@ case class Levenshtein(left: Expression, right: 
Expression) extends BinaryExpres
 
   override def dataType: DataType = IntegerType
 
-  override def eval(input: InternalRow): Any = {
-    val leftValue = left.eval(input)
-    if (leftValue == null) {
-      null
-    } else {
-      val rightValue = right.eval(input)
-      if(rightValue == null) {
-        null
-      } else {
-        StringUtils.getLevenshteinDistance(leftValue.toString, 
rightValue.toString)
-      }
-    }
-  }
+  protected override def nullSafeEval(input1: Any, input2: Any): Any =
+    StringUtils.getLevenshteinDistance(input1.toString, input2.toString)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
     val stringUtils = classOf[StringUtils].getName
-    nullSafeCodeGen(ctx, ev, (res, left, right) =>
-      s"$res = $stringUtils.getLevenshteinDistance($left.toString(), 
$right.toString());")
+    defineCodeGen(ctx, ev, (left, right) =>
+      s"$stringUtils.getLevenshteinDistance($left.toString(), 
$right.toString())")
   }
 }
 
@@ -338,17 +298,12 @@ case class Ascii(child: Expression) extends 
UnaryExpression with ExpectsInputTyp
   override def dataType: DataType = IntegerType
   override def inputTypes: Seq[DataType] = Seq(StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val string = child.eval(input)
-    if (string == null) {
-      null
+  protected override def nullSafeEval(string: Any): Any = {
+    val bytes = string.asInstanceOf[UTF8String].getBytes
+    if (bytes.length > 0) {
+      bytes(0).asInstanceOf[Int]
     } else {
-      val bytes = string.asInstanceOf[UTF8String].getBytes
-      if (bytes.length > 0) {
-        bytes(0).asInstanceOf[Int]
-      } else {
-        0
-      }
+      0
     }
   }
 }
@@ -360,15 +315,10 @@ case class Base64(child: Expression) extends 
UnaryExpression with ExpectsInputTy
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(BinaryType)
 
-  override def eval(input: InternalRow): Any = {
-    val bytes = child.eval(input)
-    if (bytes == null) {
-      null
-    } else {
-      UTF8String.fromBytes(
-        org.apache.commons.codec.binary.Base64.encodeBase64(
-          bytes.asInstanceOf[Array[Byte]]))
-    }
+  protected override def nullSafeEval(bytes: Any): Any = {
+    UTF8String.fromBytes(
+      org.apache.commons.codec.binary.Base64.encodeBase64(
+        bytes.asInstanceOf[Array[Byte]]))
   }
 }
 
@@ -379,14 +329,8 @@ case class UnBase64(child: Expression) extends 
UnaryExpression with ExpectsInput
   override def dataType: DataType = BinaryType
   override def inputTypes: Seq[DataType] = Seq(StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val string = child.eval(input)
-    if (string == null) {
-      null
-    } else {
-      
org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
-    }
-  }
+  protected override def nullSafeEval(string: Any): Any =
+    
org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
 }
 
 /**
@@ -402,19 +346,9 @@ case class Decode(bin: Expression, charset: Expression)
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val l = bin.eval(input)
-    if (l == null) {
-      null
-    } else {
-      val r = charset.eval(input)
-      if (r == null) {
-        null
-      } else {
-        val fromCharset = r.asInstanceOf[UTF8String].toString
-        UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], 
fromCharset))
-      }
-    }
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val fromCharset = input2.asInstanceOf[UTF8String].toString
+    UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], 
fromCharset))
   }
 }
 
@@ -431,19 +365,9 @@ case class Encode(value: Expression, charset: Expression)
   override def dataType: DataType = BinaryType
   override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val l = value.eval(input)
-    if (l == null) {
-      null
-    } else {
-      val r = charset.eval(input)
-      if (r == null) {
-        null
-      } else {
-        val toCharset = r.asInstanceOf[UTF8String].toString
-        l.asInstanceOf[UTF8String].toString.getBytes(toCharset)
-      }
-    }
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val toCharset = input2.asInstanceOf[UTF8String].toString
+    input1.asInstanceOf[UTF8String].toString.getBytes(toCharset)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to