Github user yhuai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9819#discussion_r47466167
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
 ---
    @@ -246,85 +260,238 @@ object SpecifiedWindowFrame {
       }
     }
     
    +case class UnresolvedWindowExpression(
    +    child: Expression,
    +    windowSpec: WindowSpecReference) extends UnaryExpression with 
Unevaluable {
    +
    +  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
    +  override def foldable: Boolean = throw new UnresolvedException(this, 
"foldable")
    +  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
    +  override lazy val resolved = false
    +}
    +
    +case class WindowExpression(
    +    windowFunction: Expression,
    +    windowSpec: WindowSpecDefinition) extends Expression with Unevaluable {
    +
    +  override def children: Seq[Expression] = windowFunction :: windowSpec :: 
Nil
    +
    +  override def dataType: DataType = windowFunction.dataType
    +  override def foldable: Boolean = windowFunction.foldable
    +  override def nullable: Boolean = windowFunction.nullable
    +
    +  override def toString: String = s"$windowFunction $windowSpec"
    +}
    +
     /**
    - * Every window function needs to maintain a output buffer for its output.
    - * It should expect that for a n-row window frame, it will be called n 
times
    - * to retrieve value corresponding with these n rows.
    + * A window function is a function that can only be evaluated in the 
context of a window operator.
      */
     trait WindowFunction extends Expression {
    -  def init(): Unit
    +  /** Frame in which the window operator must be executed. */
    +  def frame: WindowFrame = UnspecifiedFrame
    +}
     
    -  def reset(): Unit
    +/**
    + * An offset window function is a window function that returns the value 
of the input column offset
    + * by a number of rows within the partition. For instance: an 
OffsetWindowfunction for value x with
    + * offset -2, will get the value of x 2 rows back in the partition.
    + */
    +abstract class OffsetWindowFunction
    +  extends Expression with WindowFunction with Unevaluable with 
ImplicitCastInputTypes {
    +  val input: Expression
    +  val default: Expression
    +  val offset: Expression
    +  val offsetSign: Int
    +
    +  override def children: Seq[Expression] = Seq(input, offset, default)
     
    -  def prepareInputParameters(input: InternalRow): AnyRef
    +  override def foldable: Boolean = input.foldable && (default == null || 
default.foldable)
     
    -  def update(input: AnyRef): Unit
    +  override def nullable: Boolean = input.nullable && (default == null || 
default.nullable)
     
    -  def batchUpdate(inputs: Array[AnyRef]): Unit
    +  override lazy val frame = {
    +    // This will be triggered by the Analyzer.
    +    val offsetValue = offset.eval() match {
    +      case o: Int => o
    +      case x => throw new AnalysisException(
    +        s"Offset expression must be a foldable integer expression: $x")
    +    }
    +    val boundary = ValueFollowing(offsetSign * offsetValue)
    +    SpecifiedWindowFrame(RowFrame, boundary, boundary)
    +  }
     
    -  def evaluate(): Unit
    +  override def dataType: DataType = input.dataType
     
    -  def get(index: Int): Any
    +  override def inputTypes: Seq[AbstractDataType] =
    +    Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType))
     
    -  def newInstance(): WindowFunction
    +  override def toString: String = s"$prettyName($input, $offset, $default)"
     }
     
    -case class UnresolvedWindowFunction(
    -    name: String,
    -    children: Seq[Expression])
    -  extends Expression with WindowFunction with Unevaluable {
    +case class Lead(input: Expression, offset: Expression, default: Expression)
    +    extends OffsetWindowFunction {
     
    -  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
    -  override def foldable: Boolean = throw new UnresolvedException(this, 
"foldable")
    -  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
    -  override lazy val resolved = false
    +  def this(input: Expression, offset: Expression) = this(input, offset, 
Literal(null))
     
    -  override def init(): Unit = throw new UnresolvedException(this, "init")
    -  override def reset(): Unit = throw new UnresolvedException(this, "reset")
    -  override def prepareInputParameters(input: InternalRow): AnyRef =
    -    throw new UnresolvedException(this, "prepareInputParameters")
    -  override def update(input: AnyRef): Unit = throw new 
UnresolvedException(this, "update")
    -  override def batchUpdate(inputs: Array[AnyRef]): Unit =
    -    throw new UnresolvedException(this, "batchUpdate")
    -  override def evaluate(): Unit = throw new UnresolvedException(this, 
"evaluate")
    -  override def get(index: Int): Any = throw new UnresolvedException(this, 
"get")
    +  def this(input: Expression) = this(input, Literal(1))
     
    -  override def toString: String = s"'$name(${children.mkString(",")})"
    +  def this() = this(Literal(null))
     
    -  override def newInstance(): WindowFunction = throw new 
UnresolvedException(this, "newInstance")
    +  val offsetSign = 1
     }
     
    -case class UnresolvedWindowExpression(
    -    child: UnresolvedWindowFunction,
    -    windowSpec: WindowSpecReference) extends UnaryExpression with 
Unevaluable {
    +case class Lag(input: Expression, offset: Expression, default: Expression)
    +    extends OffsetWindowFunction {
     
    -  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
    -  override def foldable: Boolean = throw new UnresolvedException(this, 
"foldable")
    -  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
    -  override lazy val resolved = false
    -}
    +  def this(input: Expression, offset: Expression) = this(input, offset, 
Literal(null))
     
    -case class WindowExpression(
    -    windowFunction: WindowFunction,
    -    windowSpec: WindowSpecDefinition) extends Expression with Unevaluable {
    +  def this(input: Expression) = this(input, Literal(1))
     
    -  override def children: Seq[Expression] = windowFunction :: windowSpec :: 
Nil
    +  def this() = this(Literal(null))
     
    -  override def dataType: DataType = windowFunction.dataType
    -  override def foldable: Boolean = windowFunction.foldable
    -  override def nullable: Boolean = windowFunction.nullable
    +  val offsetSign = -1
    +}
     
    -  override def toString: String = s"$windowFunction $windowSpec"
    +abstract class AggregateWindowFunction extends DeclarativeAggregate with 
WindowFunction {
    +  self: Product =>
    +  override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, 
CurrentRow)
    +  override def dataType: DataType = IntegerType
    +  override def nullable: Boolean = false
    +  override val mergeExpressions = Nil // TODO how to deal with this?
    +}
    +
    +abstract class RowNumberLike extends AggregateWindowFunction {
    +  override def children: Seq[Expression] = Nil
    +  override def inputTypes: Seq[AbstractDataType] = Nil
    +  protected val zero = Literal(0)
    +  protected val one = Literal(1)
    +  protected val rowNumber = AttributeReference("rowNumber", IntegerType, 
nullable = false)()
    +  override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: 
Nil
    +  override val initialValues: Seq[Expression] = zero :: Nil
    +  override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: 
Nil
     }
     
     /**
    - * Extractor for making working with frame boundaries easier.
    + * A [[SizeBasedWindowFunction]] needs the size of the current window for 
its calculation.
      */
    -object FrameBoundaryExtractor {
    -  def unapply(boundary: FrameBoundary): Option[Int] = boundary match {
    -    case CurrentRow => Some(0)
    -    case ValuePreceding(offset) => Some(-offset)
    -    case ValueFollowing(offset) => Some(offset)
    -    case _ => None
    +trait SizeBasedWindowFunction extends AggregateWindowFunction {
    +  protected def n: AttributeReference = SizeBasedWindowFunction.n
    +}
    +
    +object SizeBasedWindowFunction {
    +  val n = AttributeReference("window__partition__size", IntegerType, 
nullable = false)()
    +}
    +
    +case class RowNumber() extends RowNumberLike {
    +  override val evaluateExpression = Cast(rowNumber, IntegerType)
    +}
    +
    +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
    +  override def dataType: DataType = DoubleType
    +  override val frame = SpecifiedWindowFrame(RangeFrame, 
UnboundedPreceding, CurrentRow)
    +  override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), 
Cast(n, DoubleType))
    +}
    +
    +case class NTile(buckets: Expression) extends RowNumberLike with 
SizeBasedWindowFunction {
    +  def this() = this(Literal(1))
    +
    +  // Validate buckets.
    +  buckets.eval() match {
    --- End diff --
    
    oh, I somehow missed `case x => throw new AnalysisException(`... Sorry.
    
    It makes sense. Let's keep it as is.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to