c21 commented on a change in pull request #32242: URL: https://github.com/apache/spark/pull/32242#discussion_r616468793
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ########## @@ -128,6 +128,16 @@ case class HashAggregateExec( // all the mode of aggregate expressions private val modes = aggregateExpressions.map(_.mode).distinct + // This is for testing final aggregate with number-of-rows-based fall back as specified in + // `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and + // regular hash map. So the aggregation buffers from both maps need to be merged together + // to avoid correctness issue. + // + // This scenario only happens in unit test with number-of-rows-based fall back. + // There should not be same keys in both maps with size-based fall back in production. + private val isTestFinalAggregateWithFallback: Boolean = testFallbackStartsAt.isDefined && Review comment: @cloud-fan - sure. This is how number-of-rows-based fallback works. With an internal config `spark.sql.TungstenAggregate.testFallbackStartsAt`, we can set (1). when to fallback from first level hash map to second level hash map, and (2). when to fallback from second level hash map to sort. Suppose `spark.sql.TungstenAggregate.testFallbackStartsAt` = "2, 3". Then the generated code per input row (aggregate the row into hash map) looks like: ``` UnsafeRow agg_buffer = null; if (counter < 2) { // 1st level hash map agg_buffer = fastHashMap.findOrInsert(key); } if (agg_buffer == null) { // generated. code for key in unsafe row format ... if (counter < 3) { // 2nd level hash map agg_buffer = regularHashMap.getAggregationBufferFromUnsafeRow(key_in_unsafe_row, ...); } if (agg_buffer == null) { // sort-based fallback regularHashMap.destructAndCreateExternalSorter(); ... counter = 0; } } counter += 1; ``` Example generated code is Line 187-232 in https://gist.github.com/c21/d0f704c0a33c24ec05387ff4df438bff . I tried to add a method `fastHashMap.find(key): boolean`, and change code like this: ``` ... if (fastHashMap.find(key) || counter < 2) { // 1st level hash map agg_buffer = fastHashMap.findOrInsert(key); } ... ``` But I later found the case as I mentioned above: 1. key(a) is inserted into second level hash map (when counter exceeds 1st threshold) 2. sort-based fallback happens, and counter is reset to 0 (when counter exceeds 2nd threshold) 3. key(a) is not in first level hash map, and counter does not exceed 1st threshold, the key(a) is inserted into first level hash map as well by mistake. We can further add code like this: ``` if ((fastHashMap.find(key) && !regularHashMap.find(key_in_unsafe_row)) || counter < 2) { // 1st level hash map agg_buffer = fastHashMap.findOrInsert(key); } ``` But it introduces more ad-hoc change and looks pretty ugly with a lot of code needs to be moved. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org