Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22305#discussion_r231429605
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
 ---
    @@ -73,68 +118,147 @@ case class WindowInPandasExec(
       }
     
       /**
    -   * Create the resulting projection.
    -   *
    -   * This method uses Code Generation. It can only be used on the executor 
side.
    +   * Helper function to get all relevant helper functions and data 
structures for window bounds
        *
    -   * @param expressions unbound ordered function expressions.
    -   * @return the final resulting projection.
    +   * This function returns:
    +   * (1) Total number of window bound indices in the python input row
    +   * (2) Function from frame index to its lower bound column index in the 
python input row
    +   * (3) Function from frame index to its upper bound column index in the 
python input row
    +   * (4) Function that returns a frame requires window bound indices in 
the python input row
    +   *     (unbounded window doesn't need it)
    +   * (5) Function from frame index to its eval type
        */
    -  private[this] def createResultProjection(expressions: Seq[Expression]): 
UnsafeProjection = {
    -    val references = expressions.zipWithIndex.map { case (e, i) =>
    -      // Results of window expressions will be on the right side of 
child's output
    -      BoundReference(child.output.size + i, e.dataType, e.nullable)
    +  private def computeWindowBoundHelpers(
    +      factories: Seq[InternalRow => WindowFunctionFrame]
    +  ): (Int, Int => Int, Int => Int, Int => Boolean, Int => Int) = {
    +    val dummyRow = new SpecificInternalRow()
    +    val functionFrames = factories.map(_(dummyRow))
    +
    +    val evalTypes = functionFrames.map {
    +      case _: UnboundedWindowFunctionFrame => 
PythonEvalType.SQL_UNBOUNDED_WINDOW_AGG_PANDAS_UDF
    +      case _ => PythonEvalType.SQL_BOUNDED_WINDOW_AGG_PANDAS_UDF
    +    }
    +
    +    val requiredIndices = functionFrames.map {
    +      case _: UnboundedWindowFunctionFrame => 0
    +      case _ => 2
         }
    -    val unboundToRefMap = expressions.zip(references).toMap
    -    val patchedWindowExpression = 
windowExpression.map(_.transform(unboundToRefMap))
    -    UnsafeProjection.create(
    -      child.output ++ patchedWindowExpression,
    -      child.output)
    +
    +    val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail
    +
    +    val boundIndices = (requiredIndices zip upperBoundIndices).map {case 
(num, upperBoundIndex) =>
    +        if (num == 0) {
    +          // Sentinel values for unbounded window
    +          (-1, -1)
    +        } else {
    +          (upperBoundIndex - 2, upperBoundIndex - 1)
    +        }
    +    }
    +
    +    def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1
    +    def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2
    +    def frameEvalType(frameIndex: Int) = evalTypes(frameIndex)
    +    def frameRequireIndex(frameIndex: Int) =
    +      evalTypes(frameIndex) == 
PythonEvalType.SQL_BOUNDED_WINDOW_AGG_PANDAS_UDF
    +
    +    (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, 
frameRequireIndex, frameEvalType)
       }
     
       protected override def doExecute(): RDD[InternalRow] = {
    -    val inputRDD = child.execute()
    +    // Unwrap the expressions and factories from the map.
    +    val expressionsWithFrameIndex =
    +      windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap {
    +        case (buffer, frameIndex) => buffer.map( expr => (expr, 
frameIndex))
    +      }
    +
    +    val expressions = expressionsWithFrameIndex.map(_._1)
    +    val expressionIndexToFrameIndex =
    +      expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap
    +
    +    val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
    +
    +    val (numBoundIndices, lowerBoundIndex, upperBoundIndex, 
frameRequireIndex, frameEvalType) =
    +      computeWindowBoundHelpers(factories)
    +
    +    val funcEvalTypes = expressions.indices.map(
    +      i => frameEvalType(expressionIndexToFrameIndex(i)))
    +
    +    val numFrames = factories.length
    +
    +    val inMemoryThreshold = 
sqlContext.conf.windowExecBufferInMemoryThreshold
    +    val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
    --- End diff --
    
    `conf` should work instead of `sqlContext.conf`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to