Github user yhuai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9819#discussion_r47971208
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
 ---
    @@ -246,85 +260,281 @@ object SpecifiedWindowFrame {
       }
     }
     
    +case class UnresolvedWindowExpression(
    +    child: Expression,
    +    windowSpec: WindowSpecReference) extends UnaryExpression with 
Unevaluable {
    +
    +  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
    +  override def foldable: Boolean = throw new UnresolvedException(this, 
"foldable")
    +  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
    +  override lazy val resolved = false
    +}
    +
    +case class WindowExpression(
    +    windowFunction: Expression,
    +    windowSpec: WindowSpecDefinition) extends Expression with Unevaluable {
    +
    +  override def children: Seq[Expression] = windowFunction :: windowSpec :: 
Nil
    +
    +  override def dataType: DataType = windowFunction.dataType
    +  override def foldable: Boolean = windowFunction.foldable
    +  override def nullable: Boolean = windowFunction.nullable
    +
    +  override def toString: String = s"$windowFunction $windowSpec"
    +}
    +
     /**
    - * Every window function needs to maintain a output buffer for its output.
    - * It should expect that for a n-row window frame, it will be called n 
times
    - * to retrieve value corresponding with these n rows.
    + * A window function is a function that can only be evaluated in the 
context of a window operator.
      */
     trait WindowFunction extends Expression {
    -  def init(): Unit
    +  /** Frame in which the window operator must be executed. */
    +  def frame: WindowFrame = UnspecifiedFrame
    +}
    +
    +/**
    + * An offset window function is a window function that returns the value 
of the input column offset
    + * by a number of rows within the partition. For instance: an 
OffsetWindowfunction for value x with
    + * offset -2, will get the value of x 2 rows back in the partition.
    + */
    +abstract class OffsetWindowFunction
    +  extends Expression with WindowFunction with Unevaluable with 
ImplicitCastInputTypes {
    +  /**
    +   * Input expression to evaluate against a row which a number of rows 
below or above (depending on
    +   * the value and sign of the offset) the current row.
    +   */
    +  val input: Expression
    +
    +  /**
    +   * Default result value for the function when the input expression 
returns NULL. The default will
    +   * evaluated against the current row instead of the offset row.
    +   */
    +  val default: Expression
     
    -  def reset(): Unit
    +  /**
    +   * (Foldable) expression that contains the number of rows between the 
current row and the row
    +   * where the input expression is evaluated.
    +   */
    +  val offset: Expression
     
    -  def prepareInputParameters(input: InternalRow): AnyRef
    +  /**
    +   * Direction (above = 1/below = -1) of the number of rows between the 
current row and the row
    +   * where the input expression is evaluated.
    +   */
    +  val direction: SortDirection
     
    -  def update(input: AnyRef): Unit
    +  override def children: Seq[Expression] = Seq(input, offset, default)
     
    -  def batchUpdate(inputs: Array[AnyRef]): Unit
    +  /*
    +   * The result of an OffsetWindowFunction is dependent on the frame in 
which the
    +   * OffsetWindowFunction is executed, the input expression and the 
default expression. Even when
    +   * both the input and the default expression are foldable, the result is 
still not foldable due to
    +   * the frame.
    +   */
    +  override def foldable: Boolean = input.foldable && (default == null || 
default.foldable)
     
    -  def evaluate(): Unit
    +  override def nullable: Boolean = input.nullable && (default == null || 
default.nullable)
     
    -  def get(index: Int): Any
    +  override lazy val frame = {
    +    // This will be triggered by the Analyzer.
    +    val offsetValue = offset.eval() match {
    +      case o: Int => o
    +      case x => throw new AnalysisException(
    +        s"Offset expression must be a foldable integer expression: $x")
    +    }
    +    val boundary = direction match {
    +      case Ascending => ValueFollowing(offsetValue)
    +      case Descending => ValuePreceding(offsetValue)
    +    }
    +    SpecifiedWindowFrame(RowFrame, boundary, boundary)
    +  }
    +
    +  override def dataType: DataType = input.dataType
     
    -  def newInstance(): WindowFunction
    +  override def inputTypes: Seq[AbstractDataType] =
    +    Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType))
    +
    +  override def toString: String = s"$prettyName($input, $offset, $default)"
     }
     
    -case class UnresolvedWindowFunction(
    -    name: String,
    -    children: Seq[Expression])
    -  extends Expression with WindowFunction with Unevaluable {
    +case class Lead(input: Expression, offset: Expression, default: Expression)
    +    extends OffsetWindowFunction {
     
    -  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
    -  override def foldable: Boolean = throw new UnresolvedException(this, 
"foldable")
    -  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
    -  override lazy val resolved = false
    +  def this(input: Expression, offset: Expression) = this(input, offset, 
Literal(null))
     
    -  override def init(): Unit = throw new UnresolvedException(this, "init")
    -  override def reset(): Unit = throw new UnresolvedException(this, "reset")
    -  override def prepareInputParameters(input: InternalRow): AnyRef =
    -    throw new UnresolvedException(this, "prepareInputParameters")
    -  override def update(input: AnyRef): Unit = throw new 
UnresolvedException(this, "update")
    -  override def batchUpdate(inputs: Array[AnyRef]): Unit =
    -    throw new UnresolvedException(this, "batchUpdate")
    -  override def evaluate(): Unit = throw new UnresolvedException(this, 
"evaluate")
    -  override def get(index: Int): Any = throw new UnresolvedException(this, 
"get")
    +  def this(input: Expression) = this(input, Literal(1))
     
    -  override def toString: String = s"'$name(${children.mkString(",")})"
    +  def this() = this(Literal(null))
     
    -  override def newInstance(): WindowFunction = throw new 
UnresolvedException(this, "newInstance")
    +  override val direction = Ascending
     }
     
    -case class UnresolvedWindowExpression(
    -    child: UnresolvedWindowFunction,
    -    windowSpec: WindowSpecReference) extends UnaryExpression with 
Unevaluable {
    +case class Lag(input: Expression, offset: Expression, default: Expression)
    +    extends OffsetWindowFunction {
     
    -  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
    -  override def foldable: Boolean = throw new UnresolvedException(this, 
"foldable")
    -  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
    -  override lazy val resolved = false
    +  def this(input: Expression, offset: Expression) = this(input, offset, 
Literal(null))
    +
    +  def this(input: Expression) = this(input, Literal(1))
    +
    +  def this() = this(Literal(null))
    +
    +  override val direction = Descending
     }
     
    -case class WindowExpression(
    -    windowFunction: WindowFunction,
    -    windowSpec: WindowSpecDefinition) extends Expression with Unevaluable {
    +abstract class AggregateWindowFunction extends DeclarativeAggregate with 
WindowFunction {
    +  self: Product =>
    +  override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, 
CurrentRow)
    +  override def dataType: DataType = IntegerType
    +  override def nullable: Boolean = false
    +  override def supportsPartial: Boolean = false
    +  override lazy val mergeExpressions =
    +    throw new UnsupportedOperationException("Window Functions do not 
support merging.")
    +}
     
    -  override def children: Seq[Expression] = windowFunction :: windowSpec :: 
Nil
    +abstract class RowNumberLike extends AggregateWindowFunction {
    +  override def children: Seq[Expression] = Nil
    +  override def inputTypes: Seq[AbstractDataType] = Nil
    +  protected val zero = Literal(0)
    +  protected val one = Literal(1)
    +  protected val rowNumber = AttributeReference("rowNumber", IntegerType, 
nullable = false)()
    +  override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: 
Nil
    +  override val initialValues: Seq[Expression] = zero :: Nil
    +  override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: 
Nil
    +}
     
    -  override def dataType: DataType = windowFunction.dataType
    -  override def foldable: Boolean = windowFunction.foldable
    -  override def nullable: Boolean = windowFunction.nullable
    +/**
    + * A [[SizeBasedWindowFunction]] needs the size of the current window for 
its calculation.
    + */
    +trait SizeBasedWindowFunction extends AggregateWindowFunction {
    +  protected def n: AttributeReference = SizeBasedWindowFunction.n
    +}
     
    -  override def toString: String = s"$windowFunction $windowSpec"
    +object SizeBasedWindowFunction {
    +  val n = AttributeReference("window__partition__size", IntegerType, 
nullable = false)()
    +}
    +
    +case class RowNumber() extends RowNumberLike {
    +  override val evaluateExpression = rowNumber
    +}
    +
    +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
    +  override def dataType: DataType = DoubleType
    +  // The frame for CUME_DIST is Range based instead of Row based, because 
CUME_DIST must
    +  // return the same value for equal values in the partition.
    +  override val frame = SpecifiedWindowFrame(RangeFrame, 
UnboundedPreceding, CurrentRow)
    +  override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), 
Cast(n, DoubleType))
    +}
    +
    +case class NTile(buckets: Expression) extends RowNumberLike with 
SizeBasedWindowFunction {
    --- End diff --
    
    Seems we can add some comments to explain how it works in a follow-up PR?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to