Repository: spark Updated Branches: refs/heads/master 84d31aa5d -> 952e4d1c8
[SPARK-24321][SQL] Extract common code from Divide/Remainder to a base trait ## What changes were proposed in this pull request? Extract common code from `Divide`/`Remainder` to a new base trait, `DivModLike`. Further refactoring to make `Pmod` work with `DivModLike` is to be done as a separate task. ## How was this patch tested? Existing tests in `ArithmeticExpressionSuite` covers the functionality. Author: Kris Mok <kris....@databricks.com> Closes #21367 from rednaxelafx/catalyst-divmod. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/952e4d1c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/952e4d1c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/952e4d1c Branch: refs/heads/master Commit: 952e4d1c830c4eb3dfd522be3d292dd02d8c9065 Parents: 84d31aa Author: Kris Mok <kris....@databricks.com> Authored: Tue May 22 19:12:30 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue May 22 19:12:30 2018 +0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/arithmetic.scala | 145 +++++++------------ 1 file changed, 51 insertions(+), 94 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/952e4d1c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d..efd4e99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -220,30 +220,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", - examples = """ - Examples: - > SELECT 3 _FUNC_ 2; - 1.5 - > SELECT 2L _FUNC_ 2L; - 1.0 - """) -// scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) +// Common base trait for Divide and Remainder, since these two classes are almost identical +trait DivModLike extends BinaryArithmetic { - override def symbol: String = "/" - override def decimalMethod: String = "$div" override def nullable: Boolean = true - private lazy val div: (Any, Any) => Any = dataType match { - case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - } - - override def eval(input: InternalRow): Any = { + final override def eval(input: InternalRow): Any = { val input2 = right.eval(input) if (input2 == null || input2 == 0) { null @@ -252,13 +234,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (input1 == null) { null } else { - div(input1, input2) + evalOperation(input1, input2) } } } + def evalOperation(left: Any, right: Any): Any + /** - * Special case handling due to division by 0 => null. + * Special case handling due to division/remainder by 0 => null. */ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) @@ -269,7 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val divide = if (dataType.isInstanceOf[DecimalType]) { + val operation = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" @@ -283,7 +267,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.isNull} = true; } else { ${eval1.code} - ${ev.value} = $divide; + ${ev.value} = $operation; }""") } else { ev.copy(code = s""" @@ -297,13 +281,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.value} = $divide; + ${ev.value} = $operation; } }""") } } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1.5 + > SELECT 2L _FUNC_ 2L; + 1.0 + """) +// scalastyle:on line.size.limit +case class Divide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" + override def decimalMethod: String = "$div" + + private lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -313,82 +322,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic > SELECT MOD(2, 1.8); 0.2 """) -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true - private lazy val integral = dataType match { - case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] - case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + private lazy val mod: (Any, Any) => Any = dataType match { + // special cases to make float/double primitive types faster + case DoubleType => + (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double] + case FloatType => + (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float] + + // catch-all cases + case i: IntegralType => + val integral = i.integral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) + case i: FractionalType => // should only be DecimalType for now + val integral = i.asIntegral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) } - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { - null - } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case d: Double => d % input2.asInstanceOf[java.lang.Double] - case f: Float => f % input2.asInstanceOf[java.lang.Float] - case _ => integral.rem(input1, input2) - } - } - } - } - - /** - * Special case handling for x % 0 ==> null. - */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) - val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" - } else { - s"${eval2.value} == 0" - } - val javaType = CodeGenerator.javaType(dataType) - val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" - } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" - } - if (!left.nullable && !right.nullable) { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - ${ev.value} = $remainder; - }""") - } else { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { - ${ev.isNull} = true; - } else { - ${ev.value} = $remainder; - } - }""") - } - } + override def evalOperation(left: Any, right: Any): Any = mod(left, right) } @ExpressionDescription( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org