Github user yhuai commented on a diff in the pull request: https://github.com/apache/spark/pull/9819#discussion_r47971208 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala --- @@ -246,85 +260,281 @@ object SpecifiedWindowFrame { } } +case class UnresolvedWindowExpression( + child: Expression, + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} + +case class WindowExpression( + windowFunction: Expression, + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { + + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil + + override def dataType: DataType = windowFunction.dataType + override def foldable: Boolean = windowFunction.foldable + override def nullable: Boolean = windowFunction.nullable + + override def toString: String = s"$windowFunction $windowSpec" +} + /** - * Every window function needs to maintain a output buffer for its output. - * It should expect that for a n-row window frame, it will be called n times - * to retrieve value corresponding with these n rows. + * A window function is a function that can only be evaluated in the context of a window operator. */ trait WindowFunction extends Expression { - def init(): Unit + /** Frame in which the window operator must be executed. */ + def frame: WindowFrame = UnspecifiedFrame +} + +/** + * An offset window function is a window function that returns the value of the input column offset + * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with + * offset -2, will get the value of x 2 rows back in the partition. + */ +abstract class OffsetWindowFunction + extends Expression with WindowFunction with Unevaluable with ImplicitCastInputTypes { + /** + * Input expression to evaluate against a row which a number of rows below or above (depending on + * the value and sign of the offset) the current row. + */ + val input: Expression + + /** + * Default result value for the function when the input expression returns NULL. The default will + * evaluated against the current row instead of the offset row. + */ + val default: Expression - def reset(): Unit + /** + * (Foldable) expression that contains the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val offset: Expression - def prepareInputParameters(input: InternalRow): AnyRef + /** + * Direction (above = 1/below = -1) of the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val direction: SortDirection - def update(input: AnyRef): Unit + override def children: Seq[Expression] = Seq(input, offset, default) - def batchUpdate(inputs: Array[AnyRef]): Unit + /* + * The result of an OffsetWindowFunction is dependent on the frame in which the + * OffsetWindowFunction is executed, the input expression and the default expression. Even when + * both the input and the default expression are foldable, the result is still not foldable due to + * the frame. + */ + override def foldable: Boolean = input.foldable && (default == null || default.foldable) - def evaluate(): Unit + override def nullable: Boolean = input.nullable && (default == null || default.nullable) - def get(index: Int): Any + override lazy val frame = { + // This will be triggered by the Analyzer. + val offsetValue = offset.eval() match { + case o: Int => o + case x => throw new AnalysisException( + s"Offset expression must be a foldable integer expression: $x") + } + val boundary = direction match { + case Ascending => ValueFollowing(offsetValue) + case Descending => ValuePreceding(offsetValue) + } + SpecifiedWindowFrame(RowFrame, boundary, boundary) + } + + override def dataType: DataType = input.dataType - def newInstance(): WindowFunction + override def inputTypes: Seq[AbstractDataType] = + Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) + + override def toString: String = s"$prettyName($input, $offset, $default)" } -case class UnresolvedWindowFunction( - name: String, - children: Seq[Expression]) - extends Expression with WindowFunction with Unevaluable { +case class Lead(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) - override def init(): Unit = throw new UnresolvedException(this, "init") - override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: InternalRow): AnyRef = - throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") - override def batchUpdate(inputs: Array[AnyRef]): Unit = - throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = throw new UnresolvedException(this, "get") + def this(input: Expression) = this(input, Literal(1)) - override def toString: String = s"'$name(${children.mkString(",")})" + def this() = this(Literal(null)) - override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") + override val direction = Ascending } -case class UnresolvedWindowExpression( - child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { +case class Lag(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) + + def this(input: Expression) = this(input, Literal(1)) + + def this() = this(Literal(null)) + + override val direction = Descending } -case class WindowExpression( - windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { +abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { + self: Product => + override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) + override def dataType: DataType = IntegerType + override def nullable: Boolean = false + override def supportsPartial: Boolean = false + override lazy val mergeExpressions = + throw new UnsupportedOperationException("Window Functions do not support merging.") +} - override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil +abstract class RowNumberLike extends AggregateWindowFunction { + override def children: Seq[Expression] = Nil + override def inputTypes: Seq[AbstractDataType] = Nil + protected val zero = Literal(0) + protected val one = Literal(1) + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil + override val initialValues: Seq[Expression] = zero :: Nil + override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil +} - override def dataType: DataType = windowFunction.dataType - override def foldable: Boolean = windowFunction.foldable - override def nullable: Boolean = windowFunction.nullable +/** + * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. + */ +trait SizeBasedWindowFunction extends AggregateWindowFunction { + protected def n: AttributeReference = SizeBasedWindowFunction.n +} - override def toString: String = s"$windowFunction $windowSpec" +object SizeBasedWindowFunction { + val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() +} + +case class RowNumber() extends RowNumberLike { + override val evaluateExpression = rowNumber +} + +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { + override def dataType: DataType = DoubleType + // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must + // return the same value for equal values in the partition. + override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) +} + +case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { --- End diff -- Seems we can add some comments to explain how it works in a follow-up PR?
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org