This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 95dc829c92a [SPARK-44341][SQL][PYTHON] Define the computing logic through PartitionEvaluator API and use it in WindowExec and WindowInPandasExec 95dc829c92a is described below commit 95dc829c92ae7cac51662895e2c1e79ada7ff4e6 Author: Jiaan Geng <belie...@163.com> AuthorDate: Tue Jul 18 11:37:14 2023 +0800 [SPARK-44341][SQL][PYTHON] Define the computing logic through PartitionEvaluator API and use it in WindowExec and WindowInPandasExec ### What changes were proposed in this pull request? `WindowExec` and `WindowInPandasExec` are updated to use the `PartitionEvaluator` API to do execution. ### Why are the changes needed? To define the computing logic and requires the caller side to explicitly list what needs to be serialized and sent to executors ### Does this PR introduce _any_ user-facing change? 'No'. Just update the inner implementation. ### How was this patch tested? Add new test cases. Closes #41939 from beliefer/SPARK-44341. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- ....scala => WindowInPandasEvaluatorFactory.scala} | 176 +++++------ .../sql/execution/python/WindowInPandasExec.scala | 332 +-------------------- ...ExecBase.scala => WindowEvaluatorFactory.scala} | 189 +++++++++--- .../spark/sql/execution/window/WindowExec.scala | 118 +------- .../sql/execution/window/WindowExecBase.scala | 257 ---------------- 5 files changed, 248 insertions(+), 824 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala similarity index 67% copy from sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala copy to sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index 3d43c417dcb..364e94ab158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -22,74 +22,28 @@ import java.io.File import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} +import org.apache.spark.{JobArtifactSet, PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.window._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils -/** - * This class calculates and outputs windowed aggregates over the rows in a single partition. - * - * This is similar to [[WindowExec]]. The main difference is that this node does not compute - * any window aggregation values. Instead, it computes the lower and upper bound for each window - * (i.e. window bounds) and pass the data and indices to Python worker to do the actual window - * aggregation. - * - * It currently materializes all data associated with the same partition key and passes them to - * Python worker. This is not strictly necessary for sliding windows and can be improved (by - * possibly slicing data into overlapping chunks and stitching them together). - * - * This class groups window expressions by their window boundaries so that window expressions - * with the same window boundaries can share the same window bounds. The window bounds are - * prepended to the data passed to the python worker. - * - * For example, if we have: - * avg(v) over specifiedwindowframe(RowFrame, -5, 5), - * avg(v) over specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing), - * avg(v) over specifiedwindowframe(RowFrame, -3, 3), - * max(v) over specifiedwindowframe(RowFrame, -3, 3) - * - * The python input will look like: - * (lower_bound_w1, upper_bound_w1, lower_bound_w3, upper_bound_w3, v) - * - * where w1 is specifiedwindowframe(RowFrame, -5, 5) - * w2 is specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing) - * w3 is specifiedwindowframe(RowFrame, -3, 3) - * - * Note that w2 doesn't have bound indices in the python input because it's unbounded window - * so it's bound indices will always be the same. - * - * Bounded window and Unbounded window are evaluated differently in Python worker: - * (1) Bounded window takes the window bound indices in addition to the input columns. - * Unbounded window takes only input columns. - * (2) Bounded window evaluates the udf once per input row. - * Unbounded window evaluates the udf once per window partition. - * This is controlled by Python runner conf "pandas_window_bound_types" - * - * The logic to compute window bounds is delegated to [[WindowFunctionFrame]] and shared with - * [[WindowExec]] - * - * Note this doesn't support partial aggregation and all aggregation is computed from the entire - * window. - */ -case class WindowInPandasExec( - windowExpression: Seq[NamedExpression], - partitionSpec: Seq[Expression], - orderSpec: Seq[SortOrder], - child: SparkPlan) - extends WindowExecBase with PythonSQLMetrics { - override lazy val metrics: Map[String, SQLMetric] = pythonMetrics ++ Map( - "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size") - ) +class WindowInPandasEvaluatorFactory( + val windowExpression: Seq[NamedExpression], + val partitionSpec: Seq[Expression], + val orderSpec: Seq[SortOrder], + val childOutput: Seq[Attribute], + val spillSize: SQLMetric, + pythonMetrics: Map[String, SQLMetric]) + extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { /** * Helper functions and data structures for window bounds @@ -106,7 +60,9 @@ case class WindowInPandasExec( * Enum for window bound types. Used only inside this class. */ private sealed case class WindowBoundType(value: String) + private object UnboundedWindow extends WindowBoundType("unbounded") + private object BoundedWindow extends WindowBoundType("bounded") private val windowBoundTypeConf = "pandas_window_bound_types" @@ -124,19 +80,19 @@ case class WindowInPandasExec( } } + // Helper functions /** * See [[WindowBoundHelpers]] for details. */ private def computeWindowBoundHelpers( - factories: Seq[InternalRow => WindowFunctionFrame] - ): WindowBoundHelpers = { - val functionFrames = factories.map(_(EmptyRow)) + factories: Seq[InternalRow => WindowFunctionFrame]): WindowBoundHelpers = { + val functionFrames = factories.map(_ (EmptyRow)) val windowBoundTypes = functionFrames.map { case _: UnboundedWindowFunctionFrame => UnboundedWindow case _: UnboundedFollowingWindowFunctionFrame | - _: SlidingWindowFunctionFrame | - _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow + _: SlidingWindowFunctionFrame | + _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow // It should be impossible to get other types of window function frame here case frame => throw QueryExecutionErrors.unexpectedWindowFunctionFrameError(frame.toString) } @@ -149,66 +105,72 @@ case class WindowInPandasExec( val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => - if (num == 0) { - // Sentinel values for unbounded window - (-1, -1) - } else { - (upperBoundIndex - 2, upperBoundIndex - 1) - } + if (num == 0) { + // Sentinel values for unbounded window + (-1, -1) + } else { + (upperBoundIndex - 2, upperBoundIndex - 1) + } } def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 + def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) } - protected override def doExecute(): RDD[InternalRow] = { + override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { + new WindowInPandasPartitionEvaluator() + } + + class WindowInPandasPartitionEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { + private val conf: SQLConf = SQLConf.get + // Unwrap the expressions and factories from the map. - val expressionsWithFrameIndex = + private val expressionsWithFrameIndex = windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) } - val expressions = expressionsWithFrameIndex.map(_._1) - val expressionIndexToFrameIndex = + private val expressions = expressionsWithFrameIndex.map(_._1) + private val expressionIndexToFrameIndex = expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap - val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - // Helper functions - val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = + private val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = computeWindowBoundHelpers(factories) - val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } - val numFrames = factories.length + private val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } + private val numFrames = factories.length - val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold - val spillThreshold = conf.windowExecBufferSpillThreshold - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val largeVarTypes = conf.arrowUseLargeVarTypes + private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + private val spillThreshold = conf.windowExecBufferSpillThreshold + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes // Extract window expressions and window functions - val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) - val udfExpressions = windowExpressions.map { e => + private val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) + private val udfExpressions = windowExpressions.map { e => e.windowFunction.asInstanceOf[AggregateExpression].aggregateFunction.asInstanceOf[PythonUDAF] } // We shouldn't be chaining anything here. // All chained python functions should only contain one function. - val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + private val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip require(pyFuncs.length == expressions.length) - val udfWindowBoundTypes = pyFuncs.indices.map(i => + private val udfWindowBoundTypes = pyFuncs.indices.map(i => frameWindowBoundTypes(expressionIndexToFrameIndex(i))) - val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) // Filter child output attributes down to only those that are UDF inputs. // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node // handles UDF inputs. - val dataInputs = new ArrayBuffer[Expression] - val dataInputTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => + private val dataInputs = new ArrayBuffer[Expression] + private val dataInputTypes = new ArrayBuffer[DataType] + private val argOffsets = inputs.map { input => input.map { e => if (dataInputs.exists(_.semanticEquals(e))) { dataInputs.indexWhere(_.semanticEquals(e)) @@ -227,7 +189,7 @@ case class WindowInPandasExec( // Setting window bounds for each window frames. Each window frame has different bounds so // each has its own window bound columns. - val windowBoundsInput = factories.indices.flatMap { frameIndex => + private val windowBoundsInput = factories.indices.flatMap { frameIndex => if (isBounded(frameIndex)) { Seq( BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), @@ -252,13 +214,14 @@ case class WindowInPandasExec( } } - val allInputs = windowBoundsInput ++ dataInputs - val allInputTypes = allInputs.map(_.dataType) - val spillSize = longMetric("spillSize") - val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + private val allInputs = windowBoundsInput ++ dataInputs + private val allInputTypes = allInputs.map(_.dataType) + private val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - // Start processing. - child.execute().mapPartitions { iter => + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { + val iter = inputs.head val context = TaskContext.get() // Get all relevant projections. @@ -266,19 +229,19 @@ case class WindowInPandasExec( val pythonInputProj = UnsafeProjection.create( allInputs, windowBoundsInput.map(ref => - AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output + AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ childOutput ) val pythonInputSchema = StructType( allInputTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) } ) - val grouping = UnsafeProjection.create(partitionSpec, child.output) + val grouping = UnsafeProjection.create(partitionSpec, childOutput) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + new File(Utils.getLocalDir(SparkEnv.get.conf)), childOutput.length) context.addTaskCompletionListener[Unit] { _ => queue.close() } @@ -294,6 +257,7 @@ case class WindowInPandasExec( var nextRow: UnsafeRow = null var nextGroup: UnsafeRow = null var nextRowAvailable: Boolean = false + private[this] def fetchNextRow(): Unit = { nextRowAvailable = stream.hasNext if (nextRowAvailable) { @@ -304,6 +268,7 @@ case class WindowInPandasExec( nextGroup = null } } + fetchNextRow() // Manage the current partition. @@ -313,7 +278,7 @@ case class WindowInPandasExec( val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) - val frames = factories.map(_(indexRow)) + val frames = factories.map(_ (indexRow)) private[this] def fetchNextPartition(): Unit = { // Collect all the rows in the current partition. @@ -401,7 +366,4 @@ case class WindowInPandasExec( } } } - - override protected def withNewChildInternal(newChild: SparkPlan): WindowInPandasExec = - copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 3d43c417dcb..ba1f2c132ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -17,24 +17,12 @@ package org.apache.spark.sql.execution.python -import java.io.File - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.window._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.util.Utils /** * This class calculates and outputs windowed aggregates over the rows in a single partition. @@ -91,313 +79,25 @@ case class WindowInPandasExec( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size") ) - /** - * Helper functions and data structures for window bounds - * - * It contains: - * (1) Total number of window bound indices in the python input row - * (2) Function from frame index to its lower bound column index in the python input row - * (3) Function from frame index to its upper bound column index in the python input row - * (4) Seq from frame index to its window bound type - */ - private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType]) - - /** - * Enum for window bound types. Used only inside this class. - */ - private sealed case class WindowBoundType(value: String) - private object UnboundedWindow extends WindowBoundType("unbounded") - private object BoundedWindow extends WindowBoundType("bounded") - - private val windowBoundTypeConf = "pandas_window_bound_types" - - private def collectFunctions( - udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = { - udf.children match { - case Seq(u: PythonFuncExpression) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) - case children => - // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression]))) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) - } - } - - /** - * See [[WindowBoundHelpers]] for details. - */ - private def computeWindowBoundHelpers( - factories: Seq[InternalRow => WindowFunctionFrame] - ): WindowBoundHelpers = { - val functionFrames = factories.map(_(EmptyRow)) - - val windowBoundTypes = functionFrames.map { - case _: UnboundedWindowFunctionFrame => UnboundedWindow - case _: UnboundedFollowingWindowFunctionFrame | - _: SlidingWindowFunctionFrame | - _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow - // It should be impossible to get other types of window function frame here - case frame => throw QueryExecutionErrors.unexpectedWindowFunctionFrameError(frame.toString) - } - - val requiredIndices = functionFrames.map { - case _: UnboundedWindowFunctionFrame => 0 - case _ => 2 - } - - val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail - - val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => - if (num == 0) { - // Sentinel values for unbounded window - (-1, -1) - } else { - (upperBoundIndex - 2, upperBoundIndex - 1) - } - } - - def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 - def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 - - (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) - } - protected override def doExecute(): RDD[InternalRow] = { - // Unwrap the expressions and factories from the map. - val expressionsWithFrameIndex = - windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { - case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) - } - - val expressions = expressionsWithFrameIndex.map(_._1) - val expressionIndexToFrameIndex = - expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap - - val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - - // Helper functions - val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = - computeWindowBoundHelpers(factories) - val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } - val numFrames = factories.length - - val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold - val spillThreshold = conf.windowExecBufferSpillThreshold - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val largeVarTypes = conf.arrowUseLargeVarTypes - - // Extract window expressions and window functions - val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) - val udfExpressions = windowExpressions.map { e => - e.windowFunction.asInstanceOf[AggregateExpression].aggregateFunction.asInstanceOf[PythonUDAF] - } - - // We shouldn't be chaining anything here. - // All chained python functions should only contain one function. - val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip - require(pyFuncs.length == expressions.length) - - val udfWindowBoundTypes = pyFuncs.indices.map(i => - frameWindowBoundTypes(expressionIndexToFrameIndex(i))) - val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) - + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) - - // Filter child output attributes down to only those that are UDF inputs. - // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node - // handles UDF inputs. - val dataInputs = new ArrayBuffer[Expression] - val dataInputTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => - input.map { e => - if (dataInputs.exists(_.semanticEquals(e))) { - dataInputs.indexWhere(_.semanticEquals(e)) - } else { - dataInputs += e - dataInputTypes += e.dataType - dataInputs.length - 1 - } - }.toArray - }.toArray - - // In addition to UDF inputs, we will prepend window bounds for each UDFs. - // For bounded windows, we prepend lower bound and upper bound. For unbounded windows, - // we no not add window bounds. (strictly speaking, we only need to lower or upper bound - // if the window is bounded only on one side, this can be improved in the future) - - // Setting window bounds for each window frames. Each window frame has different bounds so - // each has its own window bound columns. - val windowBoundsInput = factories.indices.flatMap { frameIndex => - if (isBounded(frameIndex)) { - Seq( - BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), - BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false) - ) - } else { - Seq.empty - } - } - - // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset - // for the UDF is (lowerBoundOffset, upperBoundOffset, inputOffset1, inputOffset2, ...) - // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...) - pyFuncs.indices.foreach { exprIndex => - val frameIndex = expressionIndexToFrameIndex(exprIndex) - if (isBounded(frameIndex)) { - argOffsets(exprIndex) = - Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ - argOffsets(exprIndex).map(_ + windowBoundsInput.length) - } else { - argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) - } - } - - val allInputs = windowBoundsInput ++ dataInputs - val allInputTypes = allInputs.map(_.dataType) val spillSize = longMetric("spillSize") - val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - - // Start processing. - child.execute().mapPartitions { iter => - val context = TaskContext.get() - - // Get all relevant projections. - val resultProj = createResultProjection(expressions) - val pythonInputProj = UnsafeProjection.create( - allInputs, - windowBoundsInput.map(ref => - AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output - ) - val pythonInputSchema = StructType( - allInputTypes.zipWithIndex.map { case (dt, i) => - StructField(s"_$i", dt) - } - ) - val grouping = UnsafeProjection.create(partitionSpec, child.output) - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener[Unit] { _ => - queue.close() - } - - val stream = iter.map { row => - queue.add(row.asInstanceOf[UnsafeRow]) - row - } - - val pythonInput = new Iterator[Iterator[UnsafeRow]] { - - // Manage the stream and the grouping. - var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null - var nextRowAvailable: Boolean = false - private[this] def fetchNextRow(): Unit = { - nextRowAvailable = stream.hasNext - if (nextRowAvailable) { - nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) - } else { - nextRow = null - nextGroup = null - } - } - fetchNextRow() - - // Manage the current partition. - val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) - var bufferIterator: Iterator[UnsafeRow] = _ - - val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) - val frames = factories.map(_(indexRow)) + val evaluatorFactory = + new WindowInPandasEvaluatorFactory( + windowExpression, + partitionSpec, + orderSpec, + child.output, + spillSize, + pythonMetrics) - private[this] def fetchNextPartition(): Unit = { - // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() - - // clear last partition - buffer.clear() - - while (nextRowAvailable && nextGroup == currentGroup) { - buffer.add(nextRow) - fetchNextRow() - } - - // Setup the frames. - var i = 0 - while (i < numFrames) { - frames(i).prepare(buffer) - i += 1 - } - - // Setup iteration - rowIndex = 0 - bufferIterator = buffer.generateIterator() - } - - // Iteration - var rowIndex = 0 - - override final def hasNext: Boolean = { - val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable - if (!found) { - // clear final partition - buffer.clear() - spillSize += buffer.spillSize - } - found - } - - override final def next(): Iterator[UnsafeRow] = { - // Load the next partition if we need to. - if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { - fetchNextPartition() - } - - val join = new JoinedRow - - bufferIterator.zipWithIndex.map { - case (current, index) => - var frameIndex = 0 - while (frameIndex < numFrames) { - frames(frameIndex).write(index, current) - // If the window is unbounded we don't need to write out window bounds. - if (isBounded(frameIndex)) { - indexRow.setInt( - lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound()) - indexRow.setInt( - upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound()) - } - frameIndex += 1 - } - - pythonInputProj(join(indexRow, current)) - } - } - } - - val windowFunctionResult = new ArrowPythonRunner( - pyFuncs, - PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, - argOffsets, - pythonInputSchema, - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID).compute(pythonInput, context.partitionId(), context) - - val joined = new JoinedRow - - windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => - val leftRow = queue.remove() - val joinedRow = joined(leftRow, windowOutput) - resultProj(joinedRow) + // Start processing. + if (conf.usePartitionEvaluator) { + child.execute().mapPartitionsWithEvaluator(evaluatorFactory) + } else { + child.execute().mapPartitions { iter => + val evaluator = evaluatorFactory.createEvaluator() + evaluator.eval(0, iter) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala similarity index 67% copy from sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala copy to sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala index 82fc308e409..913f8762c79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala @@ -20,42 +20,22 @@ package org.apache.spark.sql.execution.window import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, JoinedRow, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecificInternalRow, SpecifiedWindowFrame, TimeAdd, TimestampAddYMIn [...] import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} -import org.apache.spark.sql.execution.UnaryExecNode -import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{CalendarIntervalType, DateType, DayTimeIntervalType, DecimalType, IntegerType, TimestampNTZType, TimestampType, YearMonthIntervalType} import org.apache.spark.util.collection.Utils -/** - * Holds common logic for window operators - */ -trait WindowExecBase extends UnaryExecNode { +trait WindowEvaluatorFactoryBase { def windowExpression: Seq[NamedExpression] def partitionSpec: Seq[Expression] def orderSpec: Seq[SortOrder] - - override def output: Seq[Attribute] = - child.output ++ windowExpression.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MiB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else { - ClusteredDistribution(partitionSpec) :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning + def childOutput: Seq[Attribute] + def spillSize: SQLMetric /** * Create the resulting projection. @@ -68,13 +48,13 @@ trait WindowExecBase extends UnaryExecNode { protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { val references = expressions.zipWithIndex.map { case (e, i) => // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) + BoundReference(childOutput.size + i, e.dataType, e.nullable) } val unboundToRefMap = Utils.toMap(expressions, references) val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) + childOutput ++ patchedWindowExpression, + childOutput) } /** @@ -101,7 +81,7 @@ trait WindowExecBase extends UnaryExecNode { throw new IllegalStateException(s"Unhandled bound in windows expressions: $bound") case (RangeFrame, CurrentRow) => - val ordering = RowOrdering.create(orderSpec, child.output) + val ordering = RowOrdering.create(orderSpec, childOutput) RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) case (RangeFrame, offset: Expression) if orderSpec.size == 1 => @@ -110,7 +90,7 @@ trait WindowExecBase extends UnaryExecNode { val expr = sortExpr.child // Create the projection which returns the current 'value'. - val current = MutableProjection.create(expr :: Nil, child.output) + val current = MutableProjection.create(expr :: Nil, childOutput) // Flip the sign of the offset when processing the order is descending val boundOffset = sortExpr.direction match { @@ -131,7 +111,7 @@ trait WindowExecBase extends UnaryExecNode { case (d: DecimalType, _: DecimalType) => DecimalAddNoOverflowCheck(expr, boundOffset, d) case (a, b) if a == b => Add(expr, boundOffset) } - val bound = MutableProjection.create(boundExpr :: Nil, child.output) + val bound = MutableProjection.create(boundExpr :: Nil, childOutput) // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use @@ -176,7 +156,7 @@ trait WindowExecBase extends UnaryExecNode { // Collect all valid window functions and group them by their frame. windowExpression.foreach { x => x.foreach { - case e @ WindowExpression(function, spec) => + case e@WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", frame, e, f) @@ -198,7 +178,7 @@ trait WindowExecBase extends UnaryExecNode { // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 - val timeZone = conf.sessionLocalTimeZone + val timeZone = SQLConf.get.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -215,7 +195,7 @@ trait WindowExecBase extends UnaryExecNode { AggregateProcessor( functions, ordinal, - child.output, + childOutput, (expressions, schema) => MutableProjection.create(expressions, schema)) } @@ -230,7 +210,7 @@ trait WindowExecBase extends UnaryExecNode { ordinal, // OFFSET frame functions are guaranteed be OffsetWindowFunction. functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, + childOutput, (expressions, schema) => MutableProjection.create(expressions, schema), offset, @@ -242,7 +222,7 @@ trait WindowExecBase extends UnaryExecNode { ordinal, // OFFSET frame functions are guaranteed be OffsetWindowFunction. functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, + childOutput, (expressions, schema) => MutableProjection.create(expressions, schema), offset, @@ -255,7 +235,7 @@ trait WindowExecBase extends UnaryExecNode { ordinal, // OFFSET frame functions are guaranteed be OffsetWindowFunction. functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, + childOutput, (expressions, schema) => MutableProjection.create(expressions, schema), offset, @@ -307,4 +287,131 @@ trait WindowExecBase extends UnaryExecNode { (expressions, factory) } } + +} + +class WindowEvaluatorFactory( + val windowExpression: Seq[NamedExpression], + val partitionSpec: Seq[Expression], + val orderSpec: Seq[SortOrder], + val childOutput: Seq[Attribute], + val spillSize: SQLMetric) + extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { + + override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = { + new WindowPartitionEvaluator() + } + + class WindowPartitionEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { + private val conf: SQLConf = SQLConf.get + + // Unwrap the window expressions and window frame factories from the map. + private val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + private val spillThreshold = conf.windowExecBufferSpillThreshold + + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { + val stream = inputs.head + new Iterator[InternalRow] { + + // Get all relevant projections. + val result = createResultProjection(expressions) + val grouping = UnsafeProjection.create(partitionSpec, childOutput) + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow(): Unit = { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + fetchNextRow() + + // Manage the current partition. + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + + var bufferIterator: Iterator[UnsafeRow] = _ + + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) + val numFrames = frames.length + private[this] def fetchNextPartition(): Unit = { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + buffer.clear() + + while (nextRowAvailable && nextGroup == currentGroup) { + buffer.add(nextRow) + fetchNextRow() + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(buffer) + i += 1 + } + + // Setup iteration + rowIndex = 0 + bufferIterator = buffer.generateIterator() + } + + // Iteration + var rowIndex = 0 + + override final def hasNext: Boolean = { + val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable + if (!found) { + // clear final partition + buffer.clear() + spillSize += buffer.spillSize + } + found + } + + val join = new JoinedRow + override final def next(): InternalRow = { + // Load the next partition if we need to. + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { + fetchNextPartition() + } + + if (bufferIterator.hasNext) { + val current = bufferIterator.next() + + // Get the results for the window frames. + var i = 0 + while (i < numFrames) { + frames(i).write(rowIndex, current) + i += 1 + } + + // 'Merge' the input row with the window function result + join(current, windowFunctionResult) + rowIndex += 1 + + // Return the projection. + result(join) + } else { + throw new NoSuchElementException + } + } + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index dda5da6c9e9..35e59aef94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.window import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} /** @@ -95,111 +95,23 @@ case class WindowExec( ) protected override def doExecute(): RDD[InternalRow] = { - // Unwrap the window expressions and window frame factories from the map. - val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) - val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold - val spillThreshold = conf.windowExecBufferSpillThreshold val spillSize = longMetric("spillSize") - // Start processing. - child.execute().mapPartitions { stream => - new Iterator[InternalRow] { - - // Get all relevant projections. - val result = createResultProjection(expressions) - val grouping = UnsafeProjection.create(partitionSpec, child.output) - - // Manage the stream and the grouping. - var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null - var nextRowAvailable: Boolean = false - private[this] def fetchNextRow(): Unit = { - nextRowAvailable = stream.hasNext - if (nextRowAvailable) { - nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) - } else { - nextRow = null - nextGroup = null - } - } - fetchNextRow() - - // Manage the current partition. - val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) - - var bufferIterator: Iterator[UnsafeRow] = _ - - val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) - val frames = factories.map(_(windowFunctionResult)) - val numFrames = frames.length - private[this] def fetchNextPartition(): Unit = { - // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() - - // clear last partition - buffer.clear() - - while (nextRowAvailable && nextGroup == currentGroup) { - buffer.add(nextRow) - fetchNextRow() - } - - // Setup the frames. - var i = 0 - while (i < numFrames) { - frames(i).prepare(buffer) - i += 1 - } + val evaluatorFactory = + new WindowEvaluatorFactory( + windowExpression, + partitionSpec, + orderSpec, + child.output, + spillSize) - // Setup iteration - rowIndex = 0 - bufferIterator = buffer.generateIterator() - } - - // Iteration - var rowIndex = 0 - - override final def hasNext: Boolean = { - val found = (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable - if (!found) { - // clear final partition - buffer.clear() - spillSize += buffer.spillSize - } - found - } - - val join = new JoinedRow - override final def next(): InternalRow = { - // Load the next partition if we need to. - if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { - fetchNextPartition() - } - - if (bufferIterator.hasNext) { - val current = bufferIterator.next() - - // Get the results for the window frames. - var i = 0 - while (i < numFrames) { - frames(i).write(rowIndex, current) - i += 1 - } - - // 'Merge' the input row with the window function result - join(current, windowFunctionResult) - rowIndex += 1 - - // Return the projection. - result(join) - } else { - throw new NoSuchElementException - } - } + // Start processing. + if (conf.usePartitionEvaluator) { + child.execute().mapPartitionsWithEvaluator(evaluatorFactory) + } else { + child.execute().mapPartitions { iter => + val evaluator = evaluatorFactory.createEvaluator() + evaluator.eval(0, iter) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index 82fc308e409..29f2256efc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -17,16 +17,9 @@ package org.apache.spark.sql.execution.window -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.UnaryExecNode -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.Utils /** * Holds common logic for window operators @@ -57,254 +50,4 @@ trait WindowExecBase extends UnaryExecNode { override def outputPartitioning: Partitioning = child.outputPartitioning - /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. - */ - protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map { case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) - } - val unboundToRefMap = Utils.toMap(expressions, references) - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) - } - - /** - * Create a bound ordering object for a given frame type and offset. A bound ordering object is - * used to determine which input row lies within the frame boundaries of an output row. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame to evaluate. This can either be a Row or Range frame. - * @param bound with respect to the row. - * @param timeZone the session local timezone for time related calculations. - * @return a bound ordering object. - */ - private def createBoundOrdering( - frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { - (frame, bound) match { - case (RowFrame, CurrentRow) => - RowBoundOrdering(0) - - case (RowFrame, IntegerLiteral(offset)) => - RowBoundOrdering(offset) - - case (RowFrame, _) => - throw new IllegalStateException(s"Unhandled bound in windows expressions: $bound") - - case (RangeFrame, CurrentRow) => - val ordering = RowOrdering.create(orderSpec, child.output) - RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) - - case (RangeFrame, offset: Expression) if orderSpec.size == 1 => - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - - // Create the projection which returns the current 'value'. - val current = MutableProjection.create(expr :: Nil, child.output) - - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => UnaryMinus(offset) - case Ascending => offset - } - - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = (expr.dataType, boundOffset.dataType) match { - case (DateType, IntegerType) => DateAdd(expr, boundOffset) - case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset) - case (TimestampType | TimestampNTZType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => - TimestampAddYMInterval(expr, boundOffset, Some(timeZone)) - case (TimestampType | TimestampNTZType, _: DayTimeIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (d: DecimalType, _: DecimalType) => DecimalAddNoOverflowCheck(expr, boundOffset, d) - case (a, b) if a == b => Add(expr, boundOffset) - } - val bound = MutableProjection.create(boundExpr :: Nil, child.output) - - // Construct the ordering. This is used to compare the result of current value projection - // to the result of bound value projection. This is done manually because we want to use - // Code Generation (if it is enabled). - val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil - val ordering = RowOrdering.create(boundSortExprs, Nil) - RangeBoundOrdering(ordering, current, bound) - - case (RangeFrame, _) => - throw new IllegalStateException("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") - } - } - - /** - * Collection containing an entry for each window frame to process. Each entry contains a frame's - * [[WindowExpression]]s and factory function for the [[WindowFunctionFrame]]. - */ - protected lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression]) - type ExpressionBuffer = mutable.Buffer[Expression] - val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] - - // Add a function and its function to the map for a given frame. - def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = fn match { - // This branch is used for Lead/Lag to support ignoring null and optimize the performance - // for NthValue ignoring null. - // All window frames move in rows. If there are multiple Leads, Lags or NthValues acting on - // a row and operating on different input expressions, they should not be moved uniformly - // by row. Therefore, we put these functions in different window frames. - case f: OffsetWindowFunction if f.ignoreNulls => - (tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized)) - case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil) - } - val (es, fns) = framedFunctions.getOrElseUpdate( - key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) - es += e - fns += fn - } - - // Collect all valid window functions and group them by their frame. - windowExpression.foreach { x => - x.foreach { - case e @ WindowExpression(function, spec) => - val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - function match { - case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", frame, e, f) - case f: FrameLessOffsetWindowFunction => - collect("FRAME_LESS_OFFSET", f.fakeFrame, e, f) - case f: OffsetWindowFunction if frame.frameType == RowFrame && - frame.lower == UnboundedPreceding => - frame.upper match { - case UnboundedFollowing => collect("UNBOUNDED_OFFSET", f.fakeFrame, e, f) - case CurrentRow => collect("UNBOUNDED_PRECEDING_OFFSET", f.fakeFrame, e, f) - case _ => collect("AGGREGATE", frame, e, f) - } - case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) - case f => throw new IllegalStateException(s"Unsupported window function: $f") - } - case _ => - } - } - - // Map the groups to a (unbound) expression and frame factory pair. - var numExpressions = 0 - val timeZone = conf.sessionLocalTimeZone - framedFunctions.toSeq.map { - case (key, (expressions, functionSeq)) => - val ordinal = numExpressions - val functions = functionSeq.toArray - - // Construct an aggregate processor if we need one. - // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions - // in a single Window physical node. Therefore, we can assume no SQL aggregation - // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL - // aggregation function in a single physical node. - def processor = if (functions.exists(_.isInstanceOf[PythonFuncExpression])) { - null - } else { - AggregateProcessor( - functions, - ordinal, - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema)) - } - - // Create the factory to produce WindowFunctionFrame. - val factory = key match { - // Frameless offset Frame - case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => - new FrameLessOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { - new UnboundedOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - } - case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { - new UnboundedPrecedingOffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunction. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - MutableProjection.create(expressions, schema), - offset, - expr.nonEmpty) - } - - // Entire Partition Frame. - case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) - } - - // Growing Frame. - case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) => - target: InternalRow => { - new UnboundedPrecedingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, upper, timeZone)) - } - - // Shrinking Frame. - case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => - target: InternalRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone)) - } - - // Moving Frame. - case ("AGGREGATE", frameType, lower, upper, _) => - target: InternalRow => { - new SlidingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone), - createBoundOrdering(frameType, upper, timeZone)) - } - - case _ => - throw new IllegalStateException(s"Unsupported factory: $key") - } - - // Keep track of the number of expressions. This is a side-effect in a map... - numExpressions += expressions.size - - // Create the Window Expression - Frame Factory pair. - (expressions, factory) - } - } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org