Repository: spark Updated Branches: refs/heads/master 408a3ff2c -> 0f3fa2f28
[SPARK-24996][SQL] Use DSL in DeclarativeAggregate ## What changes were proposed in this pull request? The PR refactors the aggregate expressions which were not using DSL in order to simplify them. ## How was this patch tested? NA Author: Marco Gaido <marcogaid...@gmail.com> Closes #21970 from mgaido91/SPARK-24996. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f3fa2f2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f3fa2f2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f3fa2f2 Branch: refs/heads/master Commit: 0f3fa2f289f53a8ceea3b0a52fa6dc319001b10b Parents: 408a3ff Author: Marco Gaido <marcogaid...@gmail.com> Authored: Mon Aug 6 19:46:51 2018 -0400 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Mon Aug 6 19:46:51 2018 -0400 ---------------------------------------------------------------------- .../apache/spark/sql/catalyst/dsl/package.scala | 2 + .../expressions/aggregate/Average.scala | 2 +- .../aggregate/CentralMomentAgg.scala | 40 +++++++++----------- .../catalyst/expressions/aggregate/Corr.scala | 13 +++---- .../expressions/aggregate/Covariance.scala | 16 ++++---- .../catalyst/expressions/aggregate/First.scala | 7 ++-- .../catalyst/expressions/aggregate/Last.scala | 7 ++-- .../catalyst/expressions/aggregate/Max.scala | 5 ++- .../catalyst/expressions/aggregate/Min.scala | 5 ++- .../catalyst/expressions/aggregate/Sum.scala | 7 ++-- .../expressions/windowExpressions.scala | 30 +++++++-------- 11 files changed, 65 insertions(+), 69 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 75387fa..2b582b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -167,6 +167,8 @@ package object dsl { def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def coalesce(args: Expression*): Expression = Coalesce(args) + def greatest(args: Expression*): Expression = Greatest(args) + def least(args: Expression*): Expression = Least(args) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) def star(names: String*): Expression = names match { http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index f1fad77..5ecb77b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -68,7 +68,7 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { Add( sum, coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))), - /* count = */ If(IsNull(child), count, count + 1L) + /* count = */ If(child.isNull, count, count + 1L) ) override lazy val updateExpressions = updateExpressionsDef http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 6bbb083..e2ff0ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -75,7 +75,7 @@ abstract class CentralMomentAgg(child: Expression) val n2 = n.right val newN = n1 + n2 val delta = avg.right - avg.left - val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val deltaN = If(newN === 0.0, 0.0, delta / newN) val newAvg = avg.left + deltaN * n2 // higher order moments computed according to: @@ -102,7 +102,7 @@ abstract class CentralMomentAgg(child: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val delta = child - avg val deltaN = delta / newN val newAvg = avg + deltaN @@ -123,11 +123,11 @@ abstract class CentralMomentAgg(child: Expression) } trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) + If(child.isNull, n, newN), + If(child.isNull, avg, newAvg), + If(child.isNull, m2, newM2), + If(child.isNull, m3, newM3), + If(child.isNull, m4, newM4) )) } } @@ -142,8 +142,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - Sqrt(m2 / n)) + If(n === 0.0, Literal.create(null, DoubleType), sqrt(m2 / n)) } override def prettyName: String = "stddev_pop" @@ -159,9 +158,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - Sqrt(m2 / (n - Literal(1.0))))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } override def prettyName: String = "stddev_samp" @@ -175,8 +173,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - m2 / n) + If(n === 0.0, Literal.create(null, DoubleType), m2 / n) } override def prettyName: String = "var_pop" @@ -190,9 +187,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - m2 / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } override def prettyName: String = "var_samp" @@ -207,9 +203,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 3 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) } } @@ -220,9 +215,8 @@ case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 4 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - n * m4 / (m2 * m2) - Literal(3.0))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, n * m4 / (m2 * m2) - 3.0)) } override def prettyName: String = "kurtosis" http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 3cdef72..e14cc71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -54,9 +54,9 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) val n2 = n.right val newN = n1 + n2 val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dxN = If(newN === 0.0, 0.0, dx / newN) val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val dyN = If(newN === 0.0, 0.0, dy / newN) val newXAvg = xAvg.left + dxN * n2 val newYAvg = yAvg.left + dyN * n2 val newCk = ck.left + ck.right + dx * dyN * n1 * n2 @@ -67,7 +67,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val dx = x - xAvg val dxN = dx / newN val dy = y - yAvg @@ -78,7 +78,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) val newXMk = xMk + dx * (x - newXAvg) val newYMk = yMk + dy * (y - newYAvg) - val isNull = IsNull(x) || IsNull(y) + val isNull = x.isNull || y.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), @@ -99,9 +99,8 @@ case class Corr(x: Expression, y: Expression) extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / Sqrt(xMk * yMk))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / sqrt(xMk * yMk))) } override def prettyName: String = "corr" http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 72a7c62..ee28eb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -50,9 +50,9 @@ abstract class Covariance(x: Expression, y: Expression) val n2 = n.right val newN = n1 + n2 val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dxN = If(newN === 0.0, 0.0, dx / newN) val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val dyN = If(newN === 0.0, 0.0, dy / newN) val newXAvg = xAvg.left + dxN * n2 val newYAvg = yAvg.left + dyN * n2 val newCk = ck.left + ck.right + dx * dyN * n1 * n2 @@ -61,7 +61,7 @@ abstract class Covariance(x: Expression, y: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val dx = x - xAvg val dy = y - yAvg val dyN = dy / newN @@ -69,7 +69,7 @@ abstract class Covariance(x: Expression, y: Expression) val newYAvg = yAvg + dyN val newCk = ck + dx * (y - newYAvg) - val isNull = IsNull(x) || IsNull(y) + val isNull = x.isNull || y.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), @@ -83,8 +83,7 @@ abstract class Covariance(x: Expression, y: Expression) usage = "_FUNC_(expr1, expr2) - Returns the population covariance of a set of number pairs.") case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - ck / n) + If(n === 0.0, Literal.create(null, DoubleType), ck / n) } override def prettyName: String = "covar_pop" } @@ -94,9 +93,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance usage = "_FUNC_(expr1, expr2) - Returns the sample covariance of a set of number pairs.") case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / (n - 1.0))) } override def prettyName: String = "covar_samp" } http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 4e671e1..f51bfd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* first = */ If(Or(valueSet, IsNull(child)), first, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* first = */ If(valueSet || child.isNull, first, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -97,7 +98,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) // false, we are safe to do so because first.right will be null in this case). Seq( /* first = */ If(valueSet.left, first.left, first.right), - /* valueSet = */ Or(valueSet.left, valueSet.right) + /* valueSet = */ valueSet.left || valueSet.right ) } http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 0ccabb9..2650d7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* last = */ If(child.isNull, last, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -95,7 +96,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) // Prefer the right hand expression if it has been set. Seq( /* last = */ If(valueSet.right, last.right, last.left), - /* valueSet = */ Or(valueSet.right, valueSet.left) + /* valueSet = */ valueSet.right || valueSet.left ) } http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 58fd1d8..71099eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ Greatest(Seq(max, child)) + /* max = */ greatest(max, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* max = */ Greatest(Seq(max.left, max.right)) + /* max = */ greatest(max.left, max.right) ) } http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index b2724ee..8c4ba93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ Least(Seq(min, child)) + /* min = */ least(min, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* min = */ Least(Seq(min.left, min.right)) + /* min = */ least(min.left, min.right) ) } http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 86e40a9..761dba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -61,12 +62,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast if (child.nullable) { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) ) } else { Seq( /* sum = */ - Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + coalesce(sum, zero) + child.cast(sumDataType) ) } } @@ -74,7 +75,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) + coalesce(coalesce(sum.left, zero) + sum.right, sum.left) ) } http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 53c6f01..707f312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ @@ -476,7 +477,7 @@ abstract class RowNumberLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil override val initialValues: Seq[Expression] = zero :: Nil - override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil + override val updateExpressions: Seq[Expression] = rowNumber + one :: Nil } /** @@ -527,7 +528,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must // return the same value for equal values in the partition. override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) + override val evaluateExpression = rowNumber.cast(DoubleType) / n.cast(DoubleType) override def prettyName: String = "cume_dist" } @@ -587,8 +588,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() private val bucketsWithPadding = AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() - private def bucketOverflow(e: Expression) = - If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + private def bucketOverflow(e: Expression) = If(rowNumber >= bucketThreshold, e, zero) override val aggBufferAttributes = Seq( rowNumber, @@ -602,15 +602,14 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow zero, zero, zero, - Cast(Divide(n, buckets), IntegerType), - Cast(Remainder(n, buckets), IntegerType) + (n / buckets).cast(IntegerType), + (n % buckets).cast(IntegerType) ) override val updateExpressions = Seq( - Add(rowNumber, one), - Add(bucket, bucketOverflow(one)), - Add(bucketThreshold, bucketOverflow( - Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + rowNumber + one, + bucket + bucketOverflow(one), + bucketThreshold + bucketOverflow(bucketSize + If(bucket < bucketsWithPadding, one, zero)), NoOp, NoOp ) @@ -644,7 +643,7 @@ abstract class RankLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() protected val zero = Literal(0) protected val one = Literal(1) - protected val increaseRowNumber = Add(rowNumber, one) + protected val increaseRowNumber = rowNumber + one /** * Different RankLike implementations use different source expressions to update their rank value. @@ -653,7 +652,7 @@ abstract class RankLike extends AggregateWindowFunction { protected def rankSource: Expression = rowNumber /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ - protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + protected val increaseRank = If(orderEquals && rank =!= zero, rank, rankSource) override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs override val initialValues = zero +: one +: orderInit @@ -707,7 +706,7 @@ case class Rank(children: Seq[Expression]) extends RankLike { case class DenseRank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) - override protected def rankSource = Add(rank, one) + override protected def rankSource = rank + one override val updateExpressions = increaseRank +: children override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit @@ -736,8 +735,7 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase def this() = this(Nil) override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) override def dataType: DataType = DoubleType - override val evaluateExpression = If(GreaterThan(n, one), - Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), - Literal(0.0d)) + override val evaluateExpression = + If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 0.0d) override def prettyName: String = "percent_rank" } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org