Repository: spark
Updated Branches:
  refs/heads/master 84d31aa5d -> 952e4d1c8


[SPARK-24321][SQL] Extract common code from Divide/Remainder to a base trait

## What changes were proposed in this pull request?

Extract common code from `Divide`/`Remainder` to a new base trait, `DivModLike`.

Further refactoring to make `Pmod` work with `DivModLike` is to be done as a 
separate task.

## How was this patch tested?

Existing tests in `ArithmeticExpressionSuite` covers the functionality.

Author: Kris Mok <kris....@databricks.com>

Closes #21367 from rednaxelafx/catalyst-divmod.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/952e4d1c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/952e4d1c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/952e4d1c

Branch: refs/heads/master
Commit: 952e4d1c830c4eb3dfd522be3d292dd02d8c9065
Parents: 84d31aa
Author: Kris Mok <kris....@databricks.com>
Authored: Tue May 22 19:12:30 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue May 22 19:12:30 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 145 +++++++------------
 1 file changed, 51 insertions(+), 94 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/952e4d1c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d4e322d..efd4e99 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -220,30 +220,12 @@ case class Multiply(left: Expression, right: Expression) 
extends BinaryArithmeti
   protected override def nullSafeEval(input1: Any, input2: Any): Any = 
numeric.times(input1, input2)
 }
 
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs 
floating point division.",
-  examples = """
-    Examples:
-      > SELECT 3 _FUNC_ 2;
-       1.5
-      > SELECT 2L _FUNC_ 2L;
-       1.0
-  """)
-// scalastyle:on line.size.limit
-case class Divide(left: Expression, right: Expression) extends 
BinaryArithmetic {
-
-  override def inputType: AbstractDataType = TypeCollection(DoubleType, 
DecimalType)
+// Common base trait for Divide and Remainder, since these two classes are 
almost identical
+trait DivModLike extends BinaryArithmetic {
 
-  override def symbol: String = "/"
-  override def decimalMethod: String = "$div"
   override def nullable: Boolean = true
 
-  private lazy val div: (Any, Any) => Any = dataType match {
-    case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
-  }
-
-  override def eval(input: InternalRow): Any = {
+  final override def eval(input: InternalRow): Any = {
     val input2 = right.eval(input)
     if (input2 == null || input2 == 0) {
       null
@@ -252,13 +234,15 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
       if (input1 == null) {
         null
       } else {
-        div(input1, input2)
+        evalOperation(input1, input2)
       }
     }
   }
 
+  def evalOperation(left: Any, right: Any): Any
+
   /**
-   * Special case handling due to division by 0 => null.
+   * Special case handling due to division/remainder by 0 => null.
    */
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val eval1 = left.genCode(ctx)
@@ -269,7 +253,7 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
       s"${eval2.value} == 0"
     }
     val javaType = CodeGenerator.javaType(dataType)
-    val divide = if (dataType.isInstanceOf[DecimalType]) {
+    val operation = if (dataType.isInstanceOf[DecimalType]) {
       s"${eval1.value}.$decimalMethod(${eval2.value})"
     } else {
       s"($javaType)(${eval1.value} $symbol ${eval2.value})"
@@ -283,7 +267,7 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
           ${ev.isNull} = true;
         } else {
           ${eval1.code}
-          ${ev.value} = $divide;
+          ${ev.value} = $operation;
         }""")
     } else {
       ev.copy(code = s"""
@@ -297,13 +281,38 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
           if (${eval1.isNull}) {
             ${ev.isNull} = true;
           } else {
-            ${ev.value} = $divide;
+            ${ev.value} = $operation;
           }
         }""")
     }
   }
 }
 
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs 
floating point division.",
+  examples = """
+    Examples:
+      > SELECT 3 _FUNC_ 2;
+       1.5
+      > SELECT 2L _FUNC_ 2L;
+       1.0
+  """)
+// scalastyle:on line.size.limit
+case class Divide(left: Expression, right: Expression) extends DivModLike {
+
+  override def inputType: AbstractDataType = TypeCollection(DoubleType, 
DecimalType)
+
+  override def symbol: String = "/"
+  override def decimalMethod: String = "$div"
+
+  private lazy val div: (Any, Any) => Any = dataType match {
+    case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
+  }
+
+  override def evalOperation(left: Any, right: Any): Any = div(left, right)
+}
+
 @ExpressionDescription(
   usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.",
   examples = """
@@ -313,82 +322,30 @@ case class Divide(left: Expression, right: Expression) 
extends BinaryArithmetic
       > SELECT MOD(2, 1.8);
        0.2
   """)
-case class Remainder(left: Expression, right: Expression) extends 
BinaryArithmetic {
+case class Remainder(left: Expression, right: Expression) extends DivModLike {
 
   override def inputType: AbstractDataType = NumericType
 
   override def symbol: String = "%"
   override def decimalMethod: String = "remainder"
-  override def nullable: Boolean = true
 
-  private lazy val integral = dataType match {
-    case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
-    case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
+  private lazy val mod: (Any, Any) => Any = dataType match {
+    // special cases to make float/double primitive types faster
+    case DoubleType =>
+      (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double]
+    case FloatType =>
+      (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float]
+
+    // catch-all cases
+    case i: IntegralType =>
+      val integral = i.integral.asInstanceOf[Integral[Any]]
+      (left, right) => integral.rem(left, right)
+    case i: FractionalType => // should only be DecimalType for now
+      val integral = i.asIntegral.asInstanceOf[Integral[Any]]
+      (left, right) => integral.rem(left, right)
   }
 
-  override def eval(input: InternalRow): Any = {
-    val input2 = right.eval(input)
-    if (input2 == null || input2 == 0) {
-      null
-    } else {
-      val input1 = left.eval(input)
-      if (input1 == null) {
-        null
-      } else {
-        input1 match {
-          case d: Double => d % input2.asInstanceOf[java.lang.Double]
-          case f: Float => f % input2.asInstanceOf[java.lang.Float]
-          case _ => integral.rem(input1, input2)
-        }
-      }
-    }
-  }
-
-  /**
-   * Special case handling for x % 0 ==> null.
-   */
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val eval1 = left.genCode(ctx)
-    val eval2 = right.genCode(ctx)
-    val isZero = if (dataType.isInstanceOf[DecimalType]) {
-      s"${eval2.value}.isZero()"
-    } else {
-      s"${eval2.value} == 0"
-    }
-    val javaType = CodeGenerator.javaType(dataType)
-    val remainder = if (dataType.isInstanceOf[DecimalType]) {
-      s"${eval1.value}.$decimalMethod(${eval2.value})"
-    } else {
-      s"($javaType)(${eval1.value} $symbol ${eval2.value})"
-    }
-    if (!left.nullable && !right.nullable) {
-      ev.copy(code = s"""
-        ${eval2.code}
-        boolean ${ev.isNull} = false;
-        $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
-        if ($isZero) {
-          ${ev.isNull} = true;
-        } else {
-          ${eval1.code}
-          ${ev.value} = $remainder;
-        }""")
-    } else {
-      ev.copy(code = s"""
-        ${eval2.code}
-        boolean ${ev.isNull} = false;
-        $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
-        if (${eval2.isNull} || $isZero) {
-          ${ev.isNull} = true;
-        } else {
-          ${eval1.code}
-          if (${eval1.isNull}) {
-            ${ev.isNull} = true;
-          } else {
-            ${ev.value} = $remainder;
-          }
-        }""")
-    }
-  }
+  override def evalOperation(left: Any, right: Any): Any = mod(left, right)
 }
 
 @ExpressionDescription(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to