This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 95dc829c92a [SPARK-44341][SQL][PYTHON] Define the computing logic 
through PartitionEvaluator API and use it in WindowExec and WindowInPandasExec
95dc829c92a is described below

commit 95dc829c92ae7cac51662895e2c1e79ada7ff4e6
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Tue Jul 18 11:37:14 2023 +0800

    [SPARK-44341][SQL][PYTHON] Define the computing logic through 
PartitionEvaluator API and use it in WindowExec and WindowInPandasExec
    
    ### What changes were proposed in this pull request?
    `WindowExec` and `WindowInPandasExec` are updated to use the 
`PartitionEvaluator` API to do execution.
    
    ### Why are the changes needed?
    To define the computing logic and requires the caller side to explicitly 
list what needs to be serialized and sent to executors
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Just update the inner implementation.
    
    ### How was this patch tested?
    Add new test cases.
    
    Closes #41939 from beliefer/SPARK-44341.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 ....scala => WindowInPandasEvaluatorFactory.scala} | 176 +++++------
 .../sql/execution/python/WindowInPandasExec.scala  | 332 +--------------------
 ...ExecBase.scala => WindowEvaluatorFactory.scala} | 189 +++++++++---
 .../spark/sql/execution/window/WindowExec.scala    | 118 +-------
 .../sql/execution/window/WindowExecBase.scala      | 257 ----------------
 5 files changed, 248 insertions(+), 824 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
similarity index 67%
copy from 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
copy to 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
index 3d43c417dcb..364e94ab158 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
@@ -22,74 +22,28 @@ import java.io.File
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
+import org.apache.spark.{JobArtifactSet, PartitionEvaluator, 
PartitionEvaluatorFactory, SparkEnv, TaskContext}
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, 
NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, 
SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, 
SparkPlan}
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-import org.apache.spark.sql.execution.window._
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, 
UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, 
UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, IntegerType, StructField, 
StructType}
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.util.Utils
 
-/**
- * This class calculates and outputs windowed aggregates over the rows in a 
single partition.
- *
- * This is similar to [[WindowExec]]. The main difference is that this node 
does not compute
- * any window aggregation values. Instead, it computes the lower and upper 
bound for each window
- * (i.e. window bounds) and pass the data and indices to Python worker to do 
the actual window
- * aggregation.
- *
- * It currently materializes all data associated with the same partition key 
and passes them to
- * Python worker. This is not strictly necessary for sliding windows and can 
be improved (by
- * possibly slicing data into overlapping chunks and stitching them together).
- *
- * This class groups window expressions by their window boundaries so that 
window expressions
- * with the same window boundaries can share the same window bounds. The 
window bounds are
- * prepended to the data passed to the python worker.
- *
- * For example, if we have:
- *     avg(v) over specifiedwindowframe(RowFrame, -5, 5),
- *     avg(v) over specifiedwindowframe(RowFrame, UnboundedPreceding, 
UnboundedFollowing),
- *     avg(v) over specifiedwindowframe(RowFrame, -3, 3),
- *     max(v) over specifiedwindowframe(RowFrame, -3, 3)
- *
- * The python input will look like:
- * (lower_bound_w1, upper_bound_w1, lower_bound_w3, upper_bound_w3, v)
- *
- * where w1 is specifiedwindowframe(RowFrame, -5, 5)
- *       w2 is specifiedwindowframe(RowFrame, UnboundedPreceding, 
UnboundedFollowing)
- *       w3 is specifiedwindowframe(RowFrame, -3, 3)
- *
- * Note that w2 doesn't have bound indices in the python input because it's 
unbounded window
- * so it's bound indices will always be the same.
- *
- * Bounded window and Unbounded window are evaluated differently in Python 
worker:
- * (1) Bounded window takes the window bound indices in addition to the input 
columns.
- *     Unbounded window takes only input columns.
- * (2) Bounded window evaluates the udf once per input row.
- *     Unbounded window evaluates the udf once per window partition.
- * This is controlled by Python runner conf "pandas_window_bound_types"
- *
- * The logic to compute window bounds is delegated to [[WindowFunctionFrame]] 
and shared with
- * [[WindowExec]]
- *
- * Note this doesn't support partial aggregation and all aggregation is 
computed from the entire
- * window.
- */
-case class WindowInPandasExec(
-    windowExpression: Seq[NamedExpression],
-    partitionSpec: Seq[Expression],
-    orderSpec: Seq[SortOrder],
-    child: SparkPlan)
-  extends WindowExecBase with PythonSQLMetrics {
-  override lazy val metrics: Map[String, SQLMetric] = pythonMetrics ++ Map(
-    "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")
-  )
+class WindowInPandasEvaluatorFactory(
+    val windowExpression: Seq[NamedExpression],
+    val partitionSpec: Seq[Expression],
+    val orderSpec: Seq[SortOrder],
+    val childOutput: Seq[Attribute],
+    val spillSize: SQLMetric,
+    pythonMetrics: Map[String, SQLMetric])
+  extends PartitionEvaluatorFactory[InternalRow, InternalRow] with 
WindowEvaluatorFactoryBase {
 
   /**
    * Helper functions and data structures for window bounds
@@ -106,7 +60,9 @@ case class WindowInPandasExec(
    * Enum for window bound types. Used only inside this class.
    */
   private sealed case class WindowBoundType(value: String)
+
   private object UnboundedWindow extends WindowBoundType("unbounded")
+
   private object BoundedWindow extends WindowBoundType("bounded")
 
   private val windowBoundTypeConf = "pandas_window_bound_types"
@@ -124,19 +80,19 @@ case class WindowInPandasExec(
     }
   }
 
+  // Helper functions
   /**
    * See [[WindowBoundHelpers]] for details.
    */
   private def computeWindowBoundHelpers(
-      factories: Seq[InternalRow => WindowFunctionFrame]
-  ): WindowBoundHelpers = {
-    val functionFrames = factories.map(_(EmptyRow))
+      factories: Seq[InternalRow => WindowFunctionFrame]): WindowBoundHelpers 
= {
+    val functionFrames = factories.map(_ (EmptyRow))
 
     val windowBoundTypes = functionFrames.map {
       case _: UnboundedWindowFunctionFrame => UnboundedWindow
       case _: UnboundedFollowingWindowFunctionFrame |
-        _: SlidingWindowFunctionFrame |
-        _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow
+           _: SlidingWindowFunctionFrame |
+           _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow
       // It should be impossible to get other types of window function frame 
here
       case frame => throw 
QueryExecutionErrors.unexpectedWindowFunctionFrameError(frame.toString)
     }
@@ -149,66 +105,72 @@ case class WindowInPandasExec(
     val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail
 
     val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, 
upperBoundIndex) =>
-        if (num == 0) {
-          // Sentinel values for unbounded window
-          (-1, -1)
-        } else {
-          (upperBoundIndex - 2, upperBoundIndex - 1)
-        }
+      if (num == 0) {
+        // Sentinel values for unbounded window
+        (-1, -1)
+      } else {
+        (upperBoundIndex - 2, upperBoundIndex - 1)
+      }
     }
 
     def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1
+
     def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2
 
     (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes)
   }
 
-  protected override def doExecute(): RDD[InternalRow] = {
+  override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] 
= {
+    new WindowInPandasPartitionEvaluator()
+  }
+
+  class WindowInPandasPartitionEvaluator extends 
PartitionEvaluator[InternalRow, InternalRow] {
+    private val conf: SQLConf = SQLConf.get
+
     // Unwrap the expressions and factories from the map.
-    val expressionsWithFrameIndex =
+    private val expressionsWithFrameIndex =
       windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap {
         case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex))
       }
 
-    val expressions = expressionsWithFrameIndex.map(_._1)
-    val expressionIndexToFrameIndex =
+    private val expressions = expressionsWithFrameIndex.map(_._1)
+    private val expressionIndexToFrameIndex =
       expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap
 
-    val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+    private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
 
-    // Helper functions
-    val (numBoundIndices, lowerBoundIndex, upperBoundIndex, 
frameWindowBoundTypes) =
+    private val (numBoundIndices, lowerBoundIndex, upperBoundIndex, 
frameWindowBoundTypes) =
       computeWindowBoundHelpers(factories)
-    val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 }
-    val numFrames = factories.length
+    private val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) 
>= 0 }
+    private val numFrames = factories.length
 
-    val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
-    val spillThreshold = conf.windowExecBufferSpillThreshold
-    val sessionLocalTimeZone = conf.sessionLocalTimeZone
-    val largeVarTypes = conf.arrowUseLargeVarTypes
+    private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
+    private val spillThreshold = conf.windowExecBufferSpillThreshold
+    private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+    private val largeVarTypes = conf.arrowUseLargeVarTypes
 
     // Extract window expressions and window functions
-    val windowExpressions = expressions.flatMap(_.collect { case e: 
WindowExpression => e })
-    val udfExpressions = windowExpressions.map { e =>
+    private val windowExpressions = expressions.flatMap(_.collect { case e: 
WindowExpression => e })
+    private val udfExpressions = windowExpressions.map { e =>
       
e.windowFunction.asInstanceOf[AggregateExpression].aggregateFunction.asInstanceOf[PythonUDAF]
     }
 
     // We shouldn't be chaining anything here.
     // All chained python functions should only contain one function.
-    val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
+    private val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
     require(pyFuncs.length == expressions.length)
 
-    val udfWindowBoundTypes = pyFuncs.indices.map(i =>
+    private val udfWindowBoundTypes = pyFuncs.indices.map(i =>
       frameWindowBoundTypes(expressionIndexToFrameIndex(i)))
-    val pythonRunnerConf: Map[String, String] = 
(ArrowUtils.getPythonRunnerConfMap(conf)
+    private val pythonRunnerConf: Map[String, String] = 
(ArrowUtils.getPythonRunnerConfMap(conf)
       + (windowBoundTypeConf -> 
udfWindowBoundTypes.map(_.value).mkString(",")))
 
     // Filter child output attributes down to only those that are UDF inputs.
     // Also eliminate duplicate UDF inputs. This is similar to how other 
Python UDF node
     // handles UDF inputs.
-    val dataInputs = new ArrayBuffer[Expression]
-    val dataInputTypes = new ArrayBuffer[DataType]
-    val argOffsets = inputs.map { input =>
+    private val dataInputs = new ArrayBuffer[Expression]
+    private val dataInputTypes = new ArrayBuffer[DataType]
+    private val argOffsets = inputs.map { input =>
       input.map { e =>
         if (dataInputs.exists(_.semanticEquals(e))) {
           dataInputs.indexWhere(_.semanticEquals(e))
@@ -227,7 +189,7 @@ case class WindowInPandasExec(
 
     // Setting window bounds for each window frames. Each window frame has 
different bounds so
     // each has its own window bound columns.
-    val windowBoundsInput = factories.indices.flatMap { frameIndex =>
+    private val windowBoundsInput = factories.indices.flatMap { frameIndex =>
       if (isBounded(frameIndex)) {
         Seq(
           BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = 
false),
@@ -252,13 +214,14 @@ case class WindowInPandasExec(
       }
     }
 
-    val allInputs = windowBoundsInput ++ dataInputs
-    val allInputTypes = allInputs.map(_.dataType)
-    val spillSize = longMetric("spillSize")
-    val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+    private val allInputs = windowBoundsInput ++ dataInputs
+    private val allInputTypes = allInputs.map(_.dataType)
+    private val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
 
-    // Start processing.
-    child.execute().mapPartitions { iter =>
+    override def eval(
+        partitionIndex: Int,
+        inputs: Iterator[InternalRow]*): Iterator[InternalRow] = {
+      val iter = inputs.head
       val context = TaskContext.get()
 
       // Get all relevant projections.
@@ -266,19 +229,19 @@ case class WindowInPandasExec(
       val pythonInputProj = UnsafeProjection.create(
         allInputs,
         windowBoundsInput.map(ref =>
-          AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ 
child.output
+          AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ 
childOutput
       )
       val pythonInputSchema = StructType(
         allInputTypes.zipWithIndex.map { case (dt, i) =>
           StructField(s"_$i", dt)
         }
       )
-      val grouping = UnsafeProjection.create(partitionSpec, child.output)
+      val grouping = UnsafeProjection.create(partitionSpec, childOutput)
 
       // The queue used to buffer input rows so we can drain it to
       // combine input with output from Python.
       val queue = HybridRowQueue(context.taskMemoryManager(),
-        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+        new File(Utils.getLocalDir(SparkEnv.get.conf)), childOutput.length)
       context.addTaskCompletionListener[Unit] { _ =>
         queue.close()
       }
@@ -294,6 +257,7 @@ case class WindowInPandasExec(
         var nextRow: UnsafeRow = null
         var nextGroup: UnsafeRow = null
         var nextRowAvailable: Boolean = false
+
         private[this] def fetchNextRow(): Unit = {
           nextRowAvailable = stream.hasNext
           if (nextRowAvailable) {
@@ -304,6 +268,7 @@ case class WindowInPandasExec(
             nextGroup = null
           }
         }
+
         fetchNextRow()
 
         // Manage the current partition.
@@ -313,7 +278,7 @@ case class WindowInPandasExec(
 
         val indexRow = new 
SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType))
 
-        val frames = factories.map(_(indexRow))
+        val frames = factories.map(_ (indexRow))
 
         private[this] def fetchNextPartition(): Unit = {
           // Collect all the rows in the current partition.
@@ -401,7 +366,4 @@ case class WindowInPandasExec(
       }
     }
   }
-
-  override protected def withNewChildInternal(newChild: SparkPlan): 
WindowInPandasExec =
-    copy(child = newChild)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index 3d43c417dcb..ba1f2c132ff 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -17,24 +17,12 @@
 
 package org.apache.spark.sql.execution.python
 
-import java.io.File
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, 
SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.execution.window._
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.util.Utils
 
 /**
  * This class calculates and outputs windowed aggregates over the rows in a 
single partition.
@@ -91,313 +79,25 @@ case class WindowInPandasExec(
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")
   )
 
-  /**
-   * Helper functions and data structures for window bounds
-   *
-   * It contains:
-   * (1) Total number of window bound indices in the python input row
-   * (2) Function from frame index to its lower bound column index in the 
python input row
-   * (3) Function from frame index to its upper bound column index in the 
python input row
-   * (4) Seq from frame index to its window bound type
-   */
-  private type WindowBoundHelpers = (Int, Int => Int, Int => Int, 
Seq[WindowBoundType])
-
-  /**
-   * Enum for window bound types. Used only inside this class.
-   */
-  private sealed case class WindowBoundType(value: String)
-  private object UnboundedWindow extends WindowBoundType("unbounded")
-  private object BoundedWindow extends WindowBoundType("bounded")
-
-  private val windowBoundTypeConf = "pandas_window_bound_types"
-
-  private def collectFunctions(
-      udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = {
-    udf.children match {
-      case Seq(u: PythonFuncExpression) =>
-        val (chained, children) = collectFunctions(u)
-        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
-      case children =>
-        // There should not be any other UDFs, or the children can't be 
evaluated directly.
-        
assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression])))
-        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
-    }
-  }
-
-  /**
-   * See [[WindowBoundHelpers]] for details.
-   */
-  private def computeWindowBoundHelpers(
-      factories: Seq[InternalRow => WindowFunctionFrame]
-  ): WindowBoundHelpers = {
-    val functionFrames = factories.map(_(EmptyRow))
-
-    val windowBoundTypes = functionFrames.map {
-      case _: UnboundedWindowFunctionFrame => UnboundedWindow
-      case _: UnboundedFollowingWindowFunctionFrame |
-        _: SlidingWindowFunctionFrame |
-        _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow
-      // It should be impossible to get other types of window function frame 
here
-      case frame => throw 
QueryExecutionErrors.unexpectedWindowFunctionFrameError(frame.toString)
-    }
-
-    val requiredIndices = functionFrames.map {
-      case _: UnboundedWindowFunctionFrame => 0
-      case _ => 2
-    }
-
-    val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail
-
-    val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, 
upperBoundIndex) =>
-        if (num == 0) {
-          // Sentinel values for unbounded window
-          (-1, -1)
-        } else {
-          (upperBoundIndex - 2, upperBoundIndex - 1)
-        }
-    }
-
-    def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1
-    def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2
-
-    (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes)
-  }
-
   protected override def doExecute(): RDD[InternalRow] = {
-    // Unwrap the expressions and factories from the map.
-    val expressionsWithFrameIndex =
-      windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap {
-        case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex))
-      }
-
-    val expressions = expressionsWithFrameIndex.map(_._1)
-    val expressionIndexToFrameIndex =
-      expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap
-
-    val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
-
-    // Helper functions
-    val (numBoundIndices, lowerBoundIndex, upperBoundIndex, 
frameWindowBoundTypes) =
-      computeWindowBoundHelpers(factories)
-    val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 }
-    val numFrames = factories.length
-
-    val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
-    val spillThreshold = conf.windowExecBufferSpillThreshold
-    val sessionLocalTimeZone = conf.sessionLocalTimeZone
-    val largeVarTypes = conf.arrowUseLargeVarTypes
-
-    // Extract window expressions and window functions
-    val windowExpressions = expressions.flatMap(_.collect { case e: 
WindowExpression => e })
-    val udfExpressions = windowExpressions.map { e =>
-      
e.windowFunction.asInstanceOf[AggregateExpression].aggregateFunction.asInstanceOf[PythonUDAF]
-    }
-
-    // We shouldn't be chaining anything here.
-    // All chained python functions should only contain one function.
-    val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
-    require(pyFuncs.length == expressions.length)
-
-    val udfWindowBoundTypes = pyFuncs.indices.map(i =>
-      frameWindowBoundTypes(expressionIndexToFrameIndex(i)))
-    val pythonRunnerConf: Map[String, String] = 
(ArrowUtils.getPythonRunnerConfMap(conf)
-      + (windowBoundTypeConf -> 
udfWindowBoundTypes.map(_.value).mkString(",")))
-
-    // Filter child output attributes down to only those that are UDF inputs.
-    // Also eliminate duplicate UDF inputs. This is similar to how other 
Python UDF node
-    // handles UDF inputs.
-    val dataInputs = new ArrayBuffer[Expression]
-    val dataInputTypes = new ArrayBuffer[DataType]
-    val argOffsets = inputs.map { input =>
-      input.map { e =>
-        if (dataInputs.exists(_.semanticEquals(e))) {
-          dataInputs.indexWhere(_.semanticEquals(e))
-        } else {
-          dataInputs += e
-          dataInputTypes += e.dataType
-          dataInputs.length - 1
-        }
-      }.toArray
-    }.toArray
-
-    // In addition to UDF inputs, we will prepend window bounds for each UDFs.
-    // For bounded windows, we prepend lower bound and upper bound. For 
unbounded windows,
-    // we no not add window bounds. (strictly speaking, we only need to lower 
or upper bound
-    // if the window is bounded only on one side, this can be improved in the 
future)
-
-    // Setting window bounds for each window frames. Each window frame has 
different bounds so
-    // each has its own window bound columns.
-    val windowBoundsInput = factories.indices.flatMap { frameIndex =>
-      if (isBounded(frameIndex)) {
-        Seq(
-          BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = 
false),
-          BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = 
false)
-        )
-      } else {
-        Seq.empty
-      }
-    }
-
-    // Setting the window bounds argOffset for each UDF. For UDFs with bounded 
window, argOffset
-    // for the UDF is (lowerBoundOffset, upperBoundOffset, inputOffset1, 
inputOffset2, ...)
-    // For UDFs with unbounded window, argOffset is (inputOffset1, 
inputOffset2, ...)
-    pyFuncs.indices.foreach { exprIndex =>
-      val frameIndex = expressionIndexToFrameIndex(exprIndex)
-      if (isBounded(frameIndex)) {
-        argOffsets(exprIndex) =
-          Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++
-            argOffsets(exprIndex).map(_ + windowBoundsInput.length)
-      } else {
-        argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + 
windowBoundsInput.length)
-      }
-    }
-
-    val allInputs = windowBoundsInput ++ dataInputs
-    val allInputTypes = allInputs.map(_.dataType)
     val spillSize = longMetric("spillSize")
-    val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
-
-    // Start processing.
-    child.execute().mapPartitions { iter =>
-      val context = TaskContext.get()
-
-      // Get all relevant projections.
-      val resultProj = createResultProjection(expressions)
-      val pythonInputProj = UnsafeProjection.create(
-        allInputs,
-        windowBoundsInput.map(ref =>
-          AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ 
child.output
-      )
-      val pythonInputSchema = StructType(
-        allInputTypes.zipWithIndex.map { case (dt, i) =>
-          StructField(s"_$i", dt)
-        }
-      )
-      val grouping = UnsafeProjection.create(partitionSpec, child.output)
-
-      // The queue used to buffer input rows so we can drain it to
-      // combine input with output from Python.
-      val queue = HybridRowQueue(context.taskMemoryManager(),
-        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
-      context.addTaskCompletionListener[Unit] { _ =>
-        queue.close()
-      }
-
-      val stream = iter.map { row =>
-        queue.add(row.asInstanceOf[UnsafeRow])
-        row
-      }
-
-      val pythonInput = new Iterator[Iterator[UnsafeRow]] {
-
-        // Manage the stream and the grouping.
-        var nextRow: UnsafeRow = null
-        var nextGroup: UnsafeRow = null
-        var nextRowAvailable: Boolean = false
-        private[this] def fetchNextRow(): Unit = {
-          nextRowAvailable = stream.hasNext
-          if (nextRowAvailable) {
-            nextRow = stream.next().asInstanceOf[UnsafeRow]
-            nextGroup = grouping(nextRow)
-          } else {
-            nextRow = null
-            nextGroup = null
-          }
-        }
-        fetchNextRow()
-
-        // Manage the current partition.
-        val buffer: ExternalAppendOnlyUnsafeRowArray =
-          new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, 
spillThreshold)
-        var bufferIterator: Iterator[UnsafeRow] = _
-
-        val indexRow = new 
SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType))
 
-        val frames = factories.map(_(indexRow))
+    val evaluatorFactory =
+      new WindowInPandasEvaluatorFactory(
+        windowExpression,
+        partitionSpec,
+        orderSpec,
+        child.output,
+        spillSize,
+        pythonMetrics)
 
-        private[this] def fetchNextPartition(): Unit = {
-          // Collect all the rows in the current partition.
-          // Before we start to fetch new input rows, make a copy of nextGroup.
-          val currentGroup = nextGroup.copy()
-
-          // clear last partition
-          buffer.clear()
-
-          while (nextRowAvailable && nextGroup == currentGroup) {
-            buffer.add(nextRow)
-            fetchNextRow()
-          }
-
-          // Setup the frames.
-          var i = 0
-          while (i < numFrames) {
-            frames(i).prepare(buffer)
-            i += 1
-          }
-
-          // Setup iteration
-          rowIndex = 0
-          bufferIterator = buffer.generateIterator()
-        }
-
-        // Iteration
-        var rowIndex = 0
-
-        override final def hasNext: Boolean = {
-          val found = (bufferIterator != null && bufferIterator.hasNext) || 
nextRowAvailable
-          if (!found) {
-            // clear final partition
-            buffer.clear()
-            spillSize += buffer.spillSize
-          }
-          found
-        }
-
-        override final def next(): Iterator[UnsafeRow] = {
-          // Load the next partition if we need to.
-          if ((bufferIterator == null || !bufferIterator.hasNext) && 
nextRowAvailable) {
-            fetchNextPartition()
-          }
-
-          val join = new JoinedRow
-
-          bufferIterator.zipWithIndex.map {
-            case (current, index) =>
-              var frameIndex = 0
-              while (frameIndex < numFrames) {
-                frames(frameIndex).write(index, current)
-                // If the window is unbounded we don't need to write out 
window bounds.
-                if (isBounded(frameIndex)) {
-                  indexRow.setInt(
-                    lowerBoundIndex(frameIndex), 
frames(frameIndex).currentLowerBound())
-                  indexRow.setInt(
-                    upperBoundIndex(frameIndex), 
frames(frameIndex).currentUpperBound())
-                }
-                frameIndex += 1
-              }
-
-              pythonInputProj(join(indexRow, current))
-          }
-        }
-      }
-
-      val windowFunctionResult = new ArrowPythonRunner(
-        pyFuncs,
-        PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
-        argOffsets,
-        pythonInputSchema,
-        sessionLocalTimeZone,
-        largeVarTypes,
-        pythonRunnerConf,
-        pythonMetrics,
-        jobArtifactUUID).compute(pythonInput, context.partitionId(), context)
-
-      val joined = new JoinedRow
-
-      windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
-        val leftRow = queue.remove()
-        val joinedRow = joined(leftRow, windowOutput)
-        resultProj(joinedRow)
+    // Start processing.
+    if (conf.usePartitionEvaluator) {
+      child.execute().mapPartitionsWithEvaluator(evaluatorFactory)
+    } else {
+      child.execute().mapPartitions { iter =>
+        val evaluator = evaluatorFactory.createEvaluator()
+        evaluator.eval(0, iter)
       }
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala
similarity index 67%
copy from 
sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
copy to 
sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala
index 82fc308e409..913f8762c79 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala
@@ -20,42 +20,22 @@ package org.apache.spark.sql.execution.window
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Add, 
AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, 
DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, 
FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, 
JoinedRow, MutableProjection, NamedExpression, OffsetWindowFunction, 
PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, 
SpecificInternalRow, SpecifiedWindowFrame, TimeAdd, TimestampAddYMIn [...]
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
-import org.apache.spark.sql.execution.UnaryExecNode
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{CalendarIntervalType, DateType, 
DayTimeIntervalType, DecimalType, IntegerType, TimestampNTZType, TimestampType, 
YearMonthIntervalType}
 import org.apache.spark.util.collection.Utils
 
-/**
- * Holds common logic for window operators
- */
-trait WindowExecBase extends UnaryExecNode {
+trait WindowEvaluatorFactoryBase {
   def windowExpression: Seq[NamedExpression]
   def partitionSpec: Seq[Expression]
   def orderSpec: Seq[SortOrder]
-
-  override def output: Seq[Attribute] =
-    child.output ++ windowExpression.map(_.toAttribute)
-
-  override def requiredChildDistribution: Seq[Distribution] = {
-    if (partitionSpec.isEmpty) {
-      // Only show warning when the number of bytes is larger than 100 MiB?
-      logWarning("No Partition Defined for Window operation! Moving all data 
to a single "
-        + "partition, this can cause serious performance degradation.")
-      AllTuples :: Nil
-    } else {
-      ClusteredDistribution(partitionSpec) :: Nil
-    }
-  }
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
-
-  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
-
-  override def outputPartitioning: Partitioning = child.outputPartitioning
+  def childOutput: Seq[Attribute]
+  def spillSize: SQLMetric
 
   /**
    * Create the resulting projection.
@@ -68,13 +48,13 @@ trait WindowExecBase extends UnaryExecNode {
   protected def createResultProjection(expressions: Seq[Expression]): 
UnsafeProjection = {
     val references = expressions.zipWithIndex.map { case (e, i) =>
       // Results of window expressions will be on the right side of child's 
output
-      BoundReference(child.output.size + i, e.dataType, e.nullable)
+      BoundReference(childOutput.size + i, e.dataType, e.nullable)
     }
     val unboundToRefMap = Utils.toMap(expressions, references)
     val patchedWindowExpression = 
windowExpression.map(_.transform(unboundToRefMap))
     UnsafeProjection.create(
-      child.output ++ patchedWindowExpression,
-      child.output)
+      childOutput ++ patchedWindowExpression,
+      childOutput)
   }
 
   /**
@@ -101,7 +81,7 @@ trait WindowExecBase extends UnaryExecNode {
         throw new IllegalStateException(s"Unhandled bound in windows 
expressions: $bound")
 
       case (RangeFrame, CurrentRow) =>
-        val ordering = RowOrdering.create(orderSpec, child.output)
+        val ordering = RowOrdering.create(orderSpec, childOutput)
         RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
 
       case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
@@ -110,7 +90,7 @@ trait WindowExecBase extends UnaryExecNode {
         val expr = sortExpr.child
 
         // Create the projection which returns the current 'value'.
-        val current = MutableProjection.create(expr :: Nil, child.output)
+        val current = MutableProjection.create(expr :: Nil, childOutput)
 
         // Flip the sign of the offset when processing the order is descending
         val boundOffset = sortExpr.direction match {
@@ -131,7 +111,7 @@ trait WindowExecBase extends UnaryExecNode {
           case (d: DecimalType, _: DecimalType) => 
DecimalAddNoOverflowCheck(expr, boundOffset, d)
           case (a, b) if a == b => Add(expr, boundOffset)
         }
-        val bound = MutableProjection.create(boundExpr :: Nil, child.output)
+        val bound = MutableProjection.create(boundExpr :: Nil, childOutput)
 
         // Construct the ordering. This is used to compare the result of 
current value projection
         // to the result of bound value projection. This is done manually 
because we want to use
@@ -176,7 +156,7 @@ trait WindowExecBase extends UnaryExecNode {
     // Collect all valid window functions and group them by their frame.
     windowExpression.foreach { x =>
       x.foreach {
-        case e @ WindowExpression(function, spec) =>
+        case e@WindowExpression(function, spec) =>
           val frame = 
spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
           function match {
             case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", 
frame, e, f)
@@ -198,7 +178,7 @@ trait WindowExecBase extends UnaryExecNode {
 
     // Map the groups to a (unbound) expression and frame factory pair.
     var numExpressions = 0
-    val timeZone = conf.sessionLocalTimeZone
+    val timeZone = SQLConf.get.sessionLocalTimeZone
     framedFunctions.toSeq.map {
       case (key, (expressions, functionSeq)) =>
         val ordinal = numExpressions
@@ -215,7 +195,7 @@ trait WindowExecBase extends UnaryExecNode {
           AggregateProcessor(
             functions,
             ordinal,
-            child.output,
+            childOutput,
             (expressions, schema) =>
               MutableProjection.create(expressions, schema))
         }
@@ -230,7 +210,7 @@ trait WindowExecBase extends UnaryExecNode {
                 ordinal,
                 // OFFSET frame functions are guaranteed be 
OffsetWindowFunction.
                 functions.map(_.asInstanceOf[OffsetWindowFunction]),
-                child.output,
+                childOutput,
                 (expressions, schema) =>
                   MutableProjection.create(expressions, schema),
                 offset,
@@ -242,7 +222,7 @@ trait WindowExecBase extends UnaryExecNode {
                 ordinal,
                 // OFFSET frame functions are guaranteed be 
OffsetWindowFunction.
                 functions.map(_.asInstanceOf[OffsetWindowFunction]),
-                child.output,
+                childOutput,
                 (expressions, schema) =>
                   MutableProjection.create(expressions, schema),
                 offset,
@@ -255,7 +235,7 @@ trait WindowExecBase extends UnaryExecNode {
                 ordinal,
                 // OFFSET frame functions are guaranteed be 
OffsetWindowFunction.
                 functions.map(_.asInstanceOf[OffsetWindowFunction]),
-                child.output,
+                childOutput,
                 (expressions, schema) =>
                   MutableProjection.create(expressions, schema),
                 offset,
@@ -307,4 +287,131 @@ trait WindowExecBase extends UnaryExecNode {
         (expressions, factory)
     }
   }
+
+}
+
+class WindowEvaluatorFactory(
+    val windowExpression: Seq[NamedExpression],
+    val partitionSpec: Seq[Expression],
+    val orderSpec: Seq[SortOrder],
+    val childOutput: Seq[Attribute],
+    val spillSize: SQLMetric)
+  extends PartitionEvaluatorFactory[InternalRow, InternalRow] with 
WindowEvaluatorFactoryBase {
+
+  override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] 
= {
+    new WindowPartitionEvaluator()
+  }
+
+  class WindowPartitionEvaluator extends PartitionEvaluator[InternalRow, 
InternalRow] {
+    private val conf: SQLConf = SQLConf.get
+
+    // Unwrap the window expressions and window frame factories from the map.
+    private val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
+    private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+    private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
+    private val spillThreshold = conf.windowExecBufferSpillThreshold
+
+    override def eval(
+        partitionIndex: Int,
+        inputs: Iterator[InternalRow]*): Iterator[InternalRow] = {
+      val stream = inputs.head
+      new Iterator[InternalRow] {
+
+        // Get all relevant projections.
+        val result = createResultProjection(expressions)
+        val grouping = UnsafeProjection.create(partitionSpec, childOutput)
+
+        // Manage the stream and the grouping.
+        var nextRow: UnsafeRow = null
+        var nextGroup: UnsafeRow = null
+        var nextRowAvailable: Boolean = false
+        private[this] def fetchNextRow(): Unit = {
+          nextRowAvailable = stream.hasNext
+          if (nextRowAvailable) {
+            nextRow = stream.next().asInstanceOf[UnsafeRow]
+            nextGroup = grouping(nextRow)
+          } else {
+            nextRow = null
+            nextGroup = null
+          }
+        }
+        fetchNextRow()
+
+        // Manage the current partition.
+        val buffer: ExternalAppendOnlyUnsafeRowArray =
+          new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, 
spillThreshold)
+
+        var bufferIterator: Iterator[UnsafeRow] = _
+
+        val windowFunctionResult = new 
SpecificInternalRow(expressions.map(_.dataType))
+        val frames = factories.map(_(windowFunctionResult))
+        val numFrames = frames.length
+        private[this] def fetchNextPartition(): Unit = {
+          // Collect all the rows in the current partition.
+          // Before we start to fetch new input rows, make a copy of nextGroup.
+          val currentGroup = nextGroup.copy()
+
+          // clear last partition
+          buffer.clear()
+
+          while (nextRowAvailable && nextGroup == currentGroup) {
+            buffer.add(nextRow)
+            fetchNextRow()
+          }
+
+          // Setup the frames.
+          var i = 0
+          while (i < numFrames) {
+            frames(i).prepare(buffer)
+            i += 1
+          }
+
+          // Setup iteration
+          rowIndex = 0
+          bufferIterator = buffer.generateIterator()
+        }
+
+        // Iteration
+        var rowIndex = 0
+
+        override final def hasNext: Boolean = {
+          val found = (bufferIterator != null && bufferIterator.hasNext) || 
nextRowAvailable
+          if (!found) {
+            // clear final partition
+            buffer.clear()
+            spillSize += buffer.spillSize
+          }
+          found
+        }
+
+        val join = new JoinedRow
+        override final def next(): InternalRow = {
+          // Load the next partition if we need to.
+          if ((bufferIterator == null || !bufferIterator.hasNext) && 
nextRowAvailable) {
+            fetchNextPartition()
+          }
+
+          if (bufferIterator.hasNext) {
+            val current = bufferIterator.next()
+
+            // Get the results for the window frames.
+            var i = 0
+            while (i < numFrames) {
+              frames(i).write(rowIndex, current)
+              i += 1
+            }
+
+            // 'Merge' the input row with the window function result
+            join(current, windowFunctionResult)
+            rowIndex += 1
+
+            // Return the projection.
+            result(join)
+          } else {
+            throw new NoSuchElementException
+          }
+        }
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index dda5da6c9e9..35e59aef94f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.window
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, 
SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 
 /**
@@ -95,111 +95,23 @@ case class WindowExec(
   )
 
   protected override def doExecute(): RDD[InternalRow] = {
-    // Unwrap the window expressions and window frame factories from the map.
-    val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
-    val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
-    val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
-    val spillThreshold = conf.windowExecBufferSpillThreshold
     val spillSize = longMetric("spillSize")
 
-    // Start processing.
-    child.execute().mapPartitions { stream =>
-      new Iterator[InternalRow] {
-
-        // Get all relevant projections.
-        val result = createResultProjection(expressions)
-        val grouping = UnsafeProjection.create(partitionSpec, child.output)
-
-        // Manage the stream and the grouping.
-        var nextRow: UnsafeRow = null
-        var nextGroup: UnsafeRow = null
-        var nextRowAvailable: Boolean = false
-        private[this] def fetchNextRow(): Unit = {
-          nextRowAvailable = stream.hasNext
-          if (nextRowAvailable) {
-            nextRow = stream.next().asInstanceOf[UnsafeRow]
-            nextGroup = grouping(nextRow)
-          } else {
-            nextRow = null
-            nextGroup = null
-          }
-        }
-        fetchNextRow()
-
-        // Manage the current partition.
-        val buffer: ExternalAppendOnlyUnsafeRowArray =
-          new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, 
spillThreshold)
-
-        var bufferIterator: Iterator[UnsafeRow] = _
-
-        val windowFunctionResult = new 
SpecificInternalRow(expressions.map(_.dataType))
-        val frames = factories.map(_(windowFunctionResult))
-        val numFrames = frames.length
-        private[this] def fetchNextPartition(): Unit = {
-          // Collect all the rows in the current partition.
-          // Before we start to fetch new input rows, make a copy of nextGroup.
-          val currentGroup = nextGroup.copy()
-
-          // clear last partition
-          buffer.clear()
-
-          while (nextRowAvailable && nextGroup == currentGroup) {
-            buffer.add(nextRow)
-            fetchNextRow()
-          }
-
-          // Setup the frames.
-          var i = 0
-          while (i < numFrames) {
-            frames(i).prepare(buffer)
-            i += 1
-          }
+    val evaluatorFactory =
+      new WindowEvaluatorFactory(
+        windowExpression,
+        partitionSpec,
+        orderSpec,
+        child.output,
+        spillSize)
 
-          // Setup iteration
-          rowIndex = 0
-          bufferIterator = buffer.generateIterator()
-        }
-
-        // Iteration
-        var rowIndex = 0
-
-        override final def hasNext: Boolean = {
-          val found = (bufferIterator != null && bufferIterator.hasNext) || 
nextRowAvailable
-          if (!found) {
-            // clear final partition
-            buffer.clear()
-            spillSize += buffer.spillSize
-          }
-          found
-        }
-
-        val join = new JoinedRow
-        override final def next(): InternalRow = {
-          // Load the next partition if we need to.
-          if ((bufferIterator == null || !bufferIterator.hasNext) && 
nextRowAvailable) {
-            fetchNextPartition()
-          }
-
-          if (bufferIterator.hasNext) {
-            val current = bufferIterator.next()
-
-            // Get the results for the window frames.
-            var i = 0
-            while (i < numFrames) {
-              frames(i).write(rowIndex, current)
-              i += 1
-            }
-
-            // 'Merge' the input row with the window function result
-            join(current, windowFunctionResult)
-            rowIndex += 1
-
-            // Return the projection.
-            result(join)
-          } else {
-            throw new NoSuchElementException
-          }
-        }
+    // Start processing.
+    if (conf.usePartitionEvaluator) {
+      child.execute().mapPartitionsWithEvaluator(evaluatorFactory)
+    } else {
+      child.execute().mapPartitions { iter =>
+        val evaluator = evaluatorFactory.createEvaluator()
+        evaluator.eval(0, iter)
       }
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
index 82fc308e409..29f2256efc1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
@@ -17,16 +17,9 @@
 
 package org.apache.spark.sql.execution.window
 
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.execution.UnaryExecNode
-import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.Utils
 
 /**
  * Holds common logic for window operators
@@ -57,254 +50,4 @@ trait WindowExecBase extends UnaryExecNode {
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
-  /**
-   * Create the resulting projection.
-   *
-   * This method uses Code Generation. It can only be used on the executor 
side.
-   *
-   * @param expressions unbound ordered function expressions.
-   * @return the final resulting projection.
-   */
-  protected def createResultProjection(expressions: Seq[Expression]): 
UnsafeProjection = {
-    val references = expressions.zipWithIndex.map { case (e, i) =>
-      // Results of window expressions will be on the right side of child's 
output
-      BoundReference(child.output.size + i, e.dataType, e.nullable)
-    }
-    val unboundToRefMap = Utils.toMap(expressions, references)
-    val patchedWindowExpression = 
windowExpression.map(_.transform(unboundToRefMap))
-    UnsafeProjection.create(
-      child.output ++ patchedWindowExpression,
-      child.output)
-  }
-
-  /**
-   * Create a bound ordering object for a given frame type and offset. A bound 
ordering object is
-   * used to determine which input row lies within the frame boundaries of an 
output row.
-   *
-   * This method uses Code Generation. It can only be used on the executor 
side.
-   *
-   * @param frame to evaluate. This can either be a Row or Range frame.
-   * @param bound with respect to the row.
-   * @param timeZone the session local timezone for time related calculations.
-   * @return a bound ordering object.
-   */
-  private def createBoundOrdering(
-      frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
-    (frame, bound) match {
-      case (RowFrame, CurrentRow) =>
-        RowBoundOrdering(0)
-
-      case (RowFrame, IntegerLiteral(offset)) =>
-        RowBoundOrdering(offset)
-
-      case (RowFrame, _) =>
-        throw new IllegalStateException(s"Unhandled bound in windows 
expressions: $bound")
-
-      case (RangeFrame, CurrentRow) =>
-        val ordering = RowOrdering.create(orderSpec, child.output)
-        RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
-
-      case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
-        // Use only the first order expression when the offset is non-null.
-        val sortExpr = orderSpec.head
-        val expr = sortExpr.child
-
-        // Create the projection which returns the current 'value'.
-        val current = MutableProjection.create(expr :: Nil, child.output)
-
-        // Flip the sign of the offset when processing the order is descending
-        val boundOffset = sortExpr.direction match {
-          case Descending => UnaryMinus(offset)
-          case Ascending => offset
-        }
-
-        // Create the projection which returns the current 'value' modified by 
adding the offset.
-        val boundExpr = (expr.dataType, boundOffset.dataType) match {
-          case (DateType, IntegerType) => DateAdd(expr, boundOffset)
-          case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(expr, 
boundOffset)
-          case (TimestampType | TimestampNTZType, CalendarIntervalType) =>
-            TimeAdd(expr, boundOffset, Some(timeZone))
-          case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
-            TimestampAddYMInterval(expr, boundOffset, Some(timeZone))
-          case (TimestampType | TimestampNTZType, _: DayTimeIntervalType) =>
-            TimeAdd(expr, boundOffset, Some(timeZone))
-          case (d: DecimalType, _: DecimalType) => 
DecimalAddNoOverflowCheck(expr, boundOffset, d)
-          case (a, b) if a == b => Add(expr, boundOffset)
-        }
-        val bound = MutableProjection.create(boundExpr :: Nil, child.output)
-
-        // Construct the ordering. This is used to compare the result of 
current value projection
-        // to the result of bound value projection. This is done manually 
because we want to use
-        // Code Generation (if it is enabled).
-        val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, 
expr.nullable)) :: Nil
-        val ordering = RowOrdering.create(boundSortExprs, Nil)
-        RangeBoundOrdering(ordering, current, bound)
-
-      case (RangeFrame, _) =>
-        throw new IllegalStateException("Non-Zero range offsets are not 
supported for windows " +
-          "with multiple order expressions.")
-    }
-  }
-
-  /**
-   * Collection containing an entry for each window frame to process. Each 
entry contains a frame's
-   * [[WindowExpression]]s and factory function for the 
[[WindowFunctionFrame]].
-   */
-  protected lazy val windowFrameExpressionFactoryPairs = {
-    type FrameKey = (String, FrameType, Expression, Expression, 
Seq[Expression])
-    type ExpressionBuffer = mutable.Buffer[Expression]
-    val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, 
ExpressionBuffer)]
-
-    // Add a function and its function to the map for a given frame.
-    def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: 
Expression): Unit = {
-      val key = fn match {
-        // This branch is used for Lead/Lag to support ignoring null and 
optimize the performance
-        // for NthValue ignoring null.
-        // All window frames move in rows. If there are multiple Leads, Lags 
or NthValues acting on
-        // a row and operating on different input expressions, they should not 
be moved uniformly
-        // by row. Therefore, we put these functions in different window 
frames.
-        case f: OffsetWindowFunction if f.ignoreNulls =>
-          (tpe, fr.frameType, fr.lower, fr.upper, 
f.children.map(_.canonicalized))
-        case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil)
-      }
-      val (es, fns) = framedFunctions.getOrElseUpdate(
-        key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
-      es += e
-      fns += fn
-    }
-
-    // Collect all valid window functions and group them by their frame.
-    windowExpression.foreach { x =>
-      x.foreach {
-        case e @ WindowExpression(function, spec) =>
-          val frame = 
spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
-          function match {
-            case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", 
frame, e, f)
-            case f: FrameLessOffsetWindowFunction =>
-              collect("FRAME_LESS_OFFSET", f.fakeFrame, e, f)
-            case f: OffsetWindowFunction if frame.frameType == RowFrame &&
-              frame.lower == UnboundedPreceding =>
-              frame.upper match {
-                case UnboundedFollowing => collect("UNBOUNDED_OFFSET", 
f.fakeFrame, e, f)
-                case CurrentRow => collect("UNBOUNDED_PRECEDING_OFFSET", 
f.fakeFrame, e, f)
-                case _ => collect("AGGREGATE", frame, e, f)
-              }
-            case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, 
f)
-            case f => throw new IllegalStateException(s"Unsupported window 
function: $f")
-          }
-        case _ =>
-      }
-    }
-
-    // Map the groups to a (unbound) expression and frame factory pair.
-    var numExpressions = 0
-    val timeZone = conf.sessionLocalTimeZone
-    framedFunctions.toSeq.map {
-      case (key, (expressions, functionSeq)) =>
-        val ordinal = numExpressions
-        val functions = functionSeq.toArray
-
-        // Construct an aggregate processor if we need one.
-        // Currently we don't allow mixing of Pandas UDF and SQL aggregation 
functions
-        // in a single Window physical node. Therefore, we can assume no SQL 
aggregation
-        // functions if Pandas UDF exists. In the future, we might mix Pandas 
UDF and SQL
-        // aggregation function in a single physical node.
-        def processor = if 
(functions.exists(_.isInstanceOf[PythonFuncExpression])) {
-          null
-        } else {
-          AggregateProcessor(
-            functions,
-            ordinal,
-            child.output,
-            (expressions, schema) =>
-              MutableProjection.create(expressions, schema))
-        }
-
-        // Create the factory to produce WindowFunctionFrame.
-        val factory = key match {
-          // Frameless offset Frame
-          case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) =>
-            target: InternalRow =>
-              new FrameLessOffsetWindowFunctionFrame(
-                target,
-                ordinal,
-                // OFFSET frame functions are guaranteed be 
OffsetWindowFunction.
-                functions.map(_.asInstanceOf[OffsetWindowFunction]),
-                child.output,
-                (expressions, schema) =>
-                  MutableProjection.create(expressions, schema),
-                offset,
-                expr.nonEmpty)
-          case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, expr) =>
-            target: InternalRow => {
-              new UnboundedOffsetWindowFunctionFrame(
-                target,
-                ordinal,
-                // OFFSET frame functions are guaranteed be 
OffsetWindowFunction.
-                functions.map(_.asInstanceOf[OffsetWindowFunction]),
-                child.output,
-                (expressions, schema) =>
-                  MutableProjection.create(expressions, schema),
-                offset,
-                expr.nonEmpty)
-            }
-          case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, 
expr) =>
-            target: InternalRow => {
-              new UnboundedPrecedingOffsetWindowFunctionFrame(
-                target,
-                ordinal,
-                // OFFSET frame functions are guaranteed be 
OffsetWindowFunction.
-                functions.map(_.asInstanceOf[OffsetWindowFunction]),
-                child.output,
-                (expressions, schema) =>
-                  MutableProjection.create(expressions, schema),
-                offset,
-                expr.nonEmpty)
-            }
-
-          // Entire Partition Frame.
-          case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) =>
-            target: InternalRow => {
-              new UnboundedWindowFunctionFrame(target, processor)
-            }
-
-          // Growing Frame.
-          case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) =>
-            target: InternalRow => {
-              new UnboundedPrecedingWindowFunctionFrame(
-                target,
-                processor,
-                createBoundOrdering(frameType, upper, timeZone))
-            }
-
-          // Shrinking Frame.
-          case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) =>
-            target: InternalRow => {
-              new UnboundedFollowingWindowFunctionFrame(
-                target,
-                processor,
-                createBoundOrdering(frameType, lower, timeZone))
-            }
-
-          // Moving Frame.
-          case ("AGGREGATE", frameType, lower, upper, _) =>
-            target: InternalRow => {
-              new SlidingWindowFunctionFrame(
-                target,
-                processor,
-                createBoundOrdering(frameType, lower, timeZone),
-                createBoundOrdering(frameType, upper, timeZone))
-            }
-
-          case _ =>
-            throw new IllegalStateException(s"Unsupported factory: $key")
-        }
-
-        // Keep track of the number of expressions. This is a side-effect in a 
map...
-        numExpressions += expressions.size
-
-        // Create the Window Expression - Frame Factory pair.
-        (expressions, factory)
-    }
-  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to