This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bdeae870674 [SPARK-44341][SQL][PYTHON][FOLLOWUP] Move the base trait WindowEvaluatorFactoryBase to a single file bdeae870674 is described below commit bdeae87067452bb41f4776c4ab444a9d9645fdfc Author: Jiaan Geng <belie...@163.com> AuthorDate: Fri Jul 21 21:21:15 2023 +0800 [SPARK-44341][SQL][PYTHON][FOLLOWUP] Move the base trait WindowEvaluatorFactoryBase to a single file ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/41939 defined the computing logic through PartitionEvaluator API and use it in `WindowExec` and `WindowInPandasExec`. According to the comment https://github.com/apache/spark/pull/41939#discussion_r1270194752, this PR want move the base trait `WindowEvaluatorFactoryBase` to a single file. ### Why are the changes needed? Improve the code. ### Does this PR introduce _any_ user-facing change? 'No'. Just update inner implementation. ### How was this patch tested? N/A Closes #42106 from beliefer/SPARK-44341_followup. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../execution/window/WindowEvaluatorFactory.scala | 268 +-------------------- ...tory.scala => WindowEvaluatorFactoryBase.scala} | 130 +--------- 2 files changed, 2 insertions(+), 396 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala index 913f8762c79..fb4ea7f35c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala @@ -17,278 +17,12 @@ package org.apache.spark.sql.execution.window -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, JoinedRow, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecificInternalRow, SpecifiedWindowFrame, TimeAdd, TimestampAddYMIn [...] -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, NamedExpression, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{CalendarIntervalType, DateType, DayTimeIntervalType, DecimalType, IntegerType, TimestampNTZType, TimestampType, YearMonthIntervalType} -import org.apache.spark.util.collection.Utils - -trait WindowEvaluatorFactoryBase { - def windowExpression: Seq[NamedExpression] - def partitionSpec: Seq[Expression] - def orderSpec: Seq[SortOrder] - def childOutput: Seq[Attribute] - def spillSize: SQLMetric - - /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. - */ - protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map { case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(childOutput.size + i, e.dataType, e.nullable) - } - val unboundToRefMap = Utils.toMap(expressions, references) - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - childOutput ++ patchedWindowExpression, - childOutput) - } - - /** - * Create a bound ordering object for a given frame type and offset. A bound ordering object is - * used to determine which input row lies within the frame boundaries of an output row. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame to evaluate. This can either be a Row or Range frame. - * @param bound with respect to the row. - * @param timeZone the session local timezone for time related calculations. - * @return a bound ordering object. - */ - private def createBoundOrdering( - frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { - (frame, bound) match { - case (RowFrame, CurrentRow) => - RowBoundOrdering(0) - - case (RowFrame, IntegerLiteral(offset)) => - RowBoundOrdering(offset) - - case (RowFrame, _) => - throw new IllegalStateException(s"Unhandled bound in windows expressions: $bound") - - case (RangeFrame, CurrentRow) => - val ordering = RowOrdering.create(orderSpec, childOutput) - RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) - - case (RangeFrame, offset: Expression) if orderSpec.size == 1 => - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - - // Create the projection which returns the current 'value'. - val current = MutableProjection.create(expr :: Nil, childOutput) - - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => UnaryMinus(offset) - case Ascending => offset - } - - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = (expr.dataType, boundOffset.dataType) match { - case (DateType, IntegerType) => DateAdd(expr, boundOffset) - case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset) - case (TimestampType | TimestampNTZType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => - TimestampAddYMInterval(expr, boundOffset, Some(timeZone)) - case (TimestampType | TimestampNTZType, _: DayTimeIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (d: DecimalType, _: DecimalType) => DecimalAddNoOverflowCheck(expr, boundOffset, d) - case (a, b) if a == b => Add(expr, boundOffset) - } - val bound = MutableProjection.create(boundExpr :: Nil, childOutput) - - // Construct the ordering. This is used to compare the result of current value projection - // to the result of bound value projection. This is done manually because we want to use - // Code Generation (if it is enabled). - val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil - val ordering = RowOrdering.create(boundSortExprs, Nil) - RangeBoundOrdering(ordering, current, bound) - - case (RangeFrame, _) => - throw new IllegalStateException("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") - } - } - - /** - * Collection containing an entry for each window frame to process. Each entry contains a frame's - * [[WindowExpression]]s and factory function for the [[WindowFunctionFrame]]. - */ - protected lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression]) - type ExpressionBuffer = mutable.Buffer[Expression] - val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] - - // Add a function and its function to the map for a given frame. - def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = fn match { - // This branch is used for Lead/Lag to support ignoring null and optimize the performance - // for NthValue ignoring null. - // All window frames move in rows. If there are multiple Leads, Lags or NthValues acting on - // a row and operating on different input expressions, they should not be moved uniformly - // by row. Therefore, we put these functions in different window frames. - case f: OffsetWindowFunction if f.ignoreNulls => - (tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized)) - case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil) - } - val (es, fns) = framedFunctions.getOrElseUpdate( - key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) - es += e - fns += fn - } - - // Collect all valid window functions and group them by their frame. - windowExpression.foreach { x => - x.foreach { - case e@WindowExpression(function, spec) => - val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - function match { - case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", frame, e, f) - case f: FrameLessOffsetWindowFunction => - collect("FRAME_LESS_OFFSET", f.fakeFrame, e, f) - case f: OffsetWindowFunction if frame.frameType == RowFrame && - frame.lower == UnboundedPreceding => - frame.upper match { - case UnboundedFollowing => collect("UNBOUNDED_OFFSET", f.fakeFrame, e, f) - case CurrentRow => collect("UNBOUNDED_PRECEDING_OFFSET", f.fakeFrame, e, f) - case _ => collect("AGGREGATE", frame, e, f) - } - case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) - case f => throw new IllegalStateException(s"Unsupported window function: $f") - } - case _ => - } - } - - // Map the groups to a (unbound) expression and frame factory pair. - var numExpressions = 0 - val timeZone = SQLConf.get.sessionLocalTimeZone - framedFunctions.toSeq.map { - case (key, (expressions, functionSeq)) => - val ordinal = numExpressions - val functions = functionSeq.toArray - - // Construct an aggregate processor if we need one. - // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions - // in a single Window physical node. Therefore, we can assume no SQL aggregation - // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL - // aggregation function in a single physical node. - def processor = if (functions.exists(_.isInstanceOf[PythonFuncExpression])) { - null - } else { - AggregateProcessor( - functions, - ordinal, - childOutput, - (expressions, schema) => - MutableProjection.create(expressions, schema)) - } - - // Create the factory to produce WindowFunctionFrame. - val factory = key match { - // Frameless offset Frame - case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => - new FrameLessOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - childOutput, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { - new UnboundedOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - childOutput, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - } - case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { - new UnboundedPrecedingOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - childOutput, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - } - - // Entire Partition Frame. - case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) - } - - // Growing Frame. - case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) => - target: InternalRow => { - new UnboundedPrecedingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, upper, timeZone)) - } - - // Shrinking Frame. - case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone)) - } - - // Moving Frame. - case ("AGGREGATE", frameType, lower, upper, _) => - target: InternalRow => { - new SlidingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone), - createBoundOrdering(frameType, upper, timeZone)) - } - - case _ => - throw new IllegalStateException(s"Unsupported factory: $key") - } - - // Keep track of the number of expressions. This is a side-effect in a map... - numExpressions += expressions.size - - // Create the Window Expression - Frame Factory pair. - (expressions, factory) - } - } - -} class WindowEvaluatorFactory( val windowExpression: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala similarity index 73% copy from sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala copy to sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala index 913f8762c79..4491861dd9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala @@ -20,11 +20,9 @@ package org.apache.spark.sql.execution.window import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, JoinedRow, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecificInternalRow, SpecifiedWindowFrame, TimeAdd, TimestampAddYMIn [...] +import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecifiedWindowFrame, TimeAdd, TimestampAddYMInterval, UnaryMinus, UnboundedFol [...] import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{CalendarIntervalType, DateType, DayTimeIntervalType, DecimalType, IntegerType, TimestampNTZType, TimestampType, YearMonthIntervalType} @@ -289,129 +287,3 @@ trait WindowEvaluatorFactoryBase { } } - -class WindowEvaluatorFactory( - val windowExpression: Seq[NamedExpression], - val partitionSpec: Seq[Expression], - val orderSpec: Seq[SortOrder], - val childOutput: Seq[Attribute], - val spillSize: SQLMetric) - extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { - - override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { - new WindowPartitionEvaluator() - } - - class WindowPartitionEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { - private val conf: SQLConf = SQLConf.get - - // Unwrap the window expressions and window frame factories from the map. - private val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) - private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold - private val spillThreshold = conf.windowExecBufferSpillThreshold - - override def eval( - partitionIndex: Int, - inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { - val stream = inputs.head - new Iterator[InternalRow] { - - // Get all relevant projections. - val result = createResultProjection(expressions) - val grouping = UnsafeProjection.create(partitionSpec, childOutput) - - // Manage the stream and the grouping. - var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null - var nextRowAvailable: Boolean = false - private[this] def fetchNextRow(): Unit = { - nextRowAvailable = stream.hasNext - if (nextRowAvailable) { - nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) - } else { - nextRow = null - nextGroup = null - } - } - fetchNextRow() - - // Manage the current partition. - val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) - - var bufferIterator: Iterator[UnsafeRow] = _ - - val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) - val frames = factories.map(_(windowFunctionResult)) - val numFrames = frames.length - private[this] def fetchNextPartition(): Unit = { - // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() - - // clear last partition - buffer.clear() - - while (nextRowAvailable && nextGroup == currentGroup) { - buffer.add(nextRow) - fetchNextRow() - } - - // Setup the frames. - var i = 0 - while (i < numFrames) { - frames(i).prepare(buffer) - i += 1 - } - - // Setup iteration - rowIndex = 0 - bufferIterator = buffer.generateIterator() - } - - // Iteration - var rowIndex = 0 - - override final def hasNext: Boolean = { - val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable - if (!found) { - // clear final partition - buffer.clear() - spillSize += buffer.spillSize - } - found - } - - val join = new JoinedRow - override final def next(): InternalRow = { - // Load the next partition if we need to. - if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { - fetchNextPartition() - } - - if (bufferIterator.hasNext) { - val current = bufferIterator.next() - - // Get the results for the window frames. - var i = 0 - while (i < numFrames) { - frames(i).write(rowIndex, current) - i += 1 - } - - // 'Merge' the input row with the window function result - join(current, windowFunctionResult) - rowIndex += 1 - - // Return the projection. - result(join) - } else { - throw new NoSuchElementException - } - } - } - } - } -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org