Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20211#discussion_r160599447
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 ---
    @@ -80,27 +84,77 @@ case class FlatMapGroupsInPandasExec(
         val sessionLocalTimeZone = conf.sessionLocalTimeZone
         val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
     
    -    inputRDD.mapPartitionsInternal { iter =>
    -      val grouped = if (groupingAttributes.isEmpty) {
    -        Iterator(iter)
    -      } else {
    +    if (additionalGroupingAttributes.isEmpty) {
    +      // Fast path if additional grouping attributes is empty
    +
    +      inputRDD.mapPartitionsInternal { iter =>
    +        val grouped = if (groupingAttributes.isEmpty) {
    +          Iterator(iter)
    +        } else {
    +          val groupedIter = GroupedIterator(iter, groupingAttributes, 
child.output)
    +          val dropGrouping =
    +            
UnsafeProjection.create(child.output.drop(groupingAttributes.length), 
child.output)
    +          groupedIter.map {
    +            case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
    +          }
    +        }
    +
    +        val context = TaskContext.get()
    +
    +        val columnarBatchIter = new ArrowPythonRunner(
    +          chainedFunc, bufferSize, reuseWorker,
    +          PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema,
    +          sessionLocalTimeZone, pandasRespectSessionTimeZone)
    +          .compute(grouped, context.partitionId(), context)
    +
    +        columnarBatchIter
    +          .flatMap(_.rowIterator.asScala)
    +          .map(UnsafeProjection.create(output, output))
    +      }
    +    } else {
    +      // If additionGroupingAttributes is not empty, join the grouping 
attributes with
    +      // the udf output to get the final result
    +
    +      inputRDD.mapPartitionsInternal { iter =>
    +        assert(groupingAttributes.nonEmpty)
    +
             val groupedIter = GroupedIterator(iter, groupingAttributes, 
child.output)
    +
    +        val context = TaskContext.get()
    +
    +        val queue = HybridRowQueue(context.taskMemoryManager(),
    +          new File(Utils.getLocalDir(SparkEnv.get.conf)), 
additionalGroupingAttributes.length)
    +        context.addTaskCompletionListener { _ =>
    +          queue.close()
    +        }
    +        val additionalGroupingProj = UnsafeProjection.create(
    +          additionalGroupingAttributes, groupingAttributes)
             val dropGrouping =
               
UnsafeProjection.create(child.output.drop(groupingAttributes.length), 
child.output)
    -        groupedIter.map {
    -          case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
    +        val grouped = groupedIter.map {
    +          case (k, groupedRowIter) =>
    +            val additionalGrouping = additionalGroupingProj(k)
    +            queue.add(additionalGrouping)
    +            (additionalGrouping, groupedRowIter.map(dropGrouping))
    --- End diff --
    
    We can return only `groupedRowIter.map(dropGrouping)`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to