Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20211#discussion_r160599447 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala --- @@ -80,27 +84,77 @@ case class FlatMapGroupsInPandasExec( val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone - inputRDD.mapPartitionsInternal { iter => - val grouped = if (groupingAttributes.isEmpty) { - Iterator(iter) - } else { + if (additionalGroupingAttributes.isEmpty) { + // Fast path if additional grouping attributes is empty + + inputRDD.mapPartitionsInternal { iter => + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + } + } + + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(grouped, context.partitionId(), context) + + columnarBatchIter + .flatMap(_.rowIterator.asScala) + .map(UnsafeProjection.create(output, output)) + } + } else { + // If additionGroupingAttributes is not empty, join the grouping attributes with + // the udf output to get the final result + + inputRDD.mapPartitionsInternal { iter => + assert(groupingAttributes.nonEmpty) + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + + val context = TaskContext.get() + + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), additionalGroupingAttributes.length) + context.addTaskCompletionListener { _ => + queue.close() + } + val additionalGroupingProj = UnsafeProjection.create( + additionalGroupingAttributes, groupingAttributes) val dropGrouping = UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) - groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + val grouped = groupedIter.map { + case (k, groupedRowIter) => + val additionalGrouping = additionalGroupingProj(k) + queue.add(additionalGrouping) + (additionalGrouping, groupedRowIter.map(dropGrouping)) --- End diff -- We can return only `groupedRowIter.map(dropGrouping)`.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org