icexelloss closed pull request #23279: [SPARK-26328][SQL] Use GenerateOrdering for group key comparision in WindowExec URL: https://github.com/apache/spark/pull/23279
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index fede0f3e92d67..5cd5f50c44451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -24,8 +24,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType} @@ -304,20 +304,18 @@ case class WindowExec( // Get all relevant projections. val result = createResultProjection(expressions) - val grouping = UnsafeProjection.create(partitionSpec, child.output) + val groupOrdering = GenerateOrdering.generate( + partitionSpec.map(SortOrder(_, Ascending)), child.output) // Manage the stream and the grouping. var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null var nextRowAvailable: Boolean = false private[this] def fetchNextRow() { nextRowAvailable = stream.hasNext if (nextRowAvailable) { nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) } else { nextRow = null - nextGroup = null } } fetchNextRow() @@ -333,13 +331,13 @@ case class WindowExec( val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() + // Before we start to fetch new input rows, make a copy of nextRow. + val currentRow = nextRow.copy() // clear last partition buffer.clear() - while (nextRowAvailable && nextGroup == currentGroup) { + while (nextRowAvailable && groupOrdering.compare(currentRow, nextRow) == 0) { buffer.add(nextRow) fetchNextRow() } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org