Copilot commented on code in PR #55422:
URL: https://github.com/apache/spark/pull/55422#discussion_r3122401335
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala:
##########
@@ -295,4 +342,58 @@ trait WindowEvaluatorFactoryBase {
}
}
+ /**
+ * Segment-tree path eligibility. The tree relies on
+ * `DeclarativeAggregate.mergeExpressions`, which
[[AggregateWindowFunction]]s
+ * (NthValue, NTile, Rank, RowNumber, NullIndex) refuse via
+ * `mergeUnsupportedByWindowFunctionError`: they extend DeclarativeAggregate
+ * but are NOT merge-capable. Normal aggregate window expressions reach this
+ * code as the inner DeclarativeAggregate unwrapped from
+ * [[AggregateExpression]] (see `windowFrameExpressionFactoryPairs.collect`).
+ */
+ private def eligibleForSegTree(
+ functions: Array[Expression],
+ filters: Array[Option[Expression]],
+ frameType: FrameType): Boolean = {
+ // RANGE accepted only for single-column order specs. Multi-column RANGE
+ // with non-zero offset is already rejected by `createBoundOrdering`, so
+ // gating here on `orderSpec.size == 1` matches the Sliding-path invariant.
+ val frameTypeOk = frameType match {
+ case RowFrame => true
+ case RangeFrame => orderSpec.size == 1
+ }
+ SQLConf.get.windowSegmentTreeEnabled &&
+ frameTypeOk &&
+ filters.forall(_.isEmpty) &&
+ functions.forall(WindowSegmentTree.isEligible) &&
+ !functions.exists {
+ case ae: AggregateExpression => ae.isDistinct
+ case _ => false
+ }
+ }
Review Comment:
`eligibleForSegTree` tries to reject DISTINCT via `functions.exists { case
ae: AggregateExpression => ae.isDistinct }`, but `functions` here are the
*inner* expressions (e.g. `Min`, `Sum`) extracted from `AggregateExpression`,
so this check will never match and is effectively dead/misleading. Consider
either removing it (since DISTINCT window aggregates are already
analyzer-rejected) or plumb a distinct flag from `WindowExpression(ae:
AggregateExpression, _)` (similar to how `aggFilters` is derived) and gate on
that instead.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala:
##########
@@ -0,0 +1,892 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.window
+
+import org.apache.spark.sql.{DataFrame, Encoder, Encoders, QueryTest, Row}
+import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer,
Window}
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.util.SparkErrorUtils
+
+/**
+ * End-to-end tests for the block-chunked segment-tree moving window frame.
+ * Covers basic aggregates, frame boundaries, min-rows fallback, NULL/NaN,
+ * numeric/string/date-timestamp types, RANGE, Decimal/Binary merge, UDAF
+ * fallback, and frame lifecycle.
+ */
+class SegmentTreeWindowFunctionSuite extends QueryTest with SharedSparkSession
{
+
+ import testImplicits._
+
+ // Force seg-tree path regardless of partition size (fallback exercised
explicitly).
+ private val enableSegTree: Map[String, String] = Map(
+ SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true",
+ SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1")
+
+ private val disableSegTree: Map[String, String] = Map(
+ SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false")
+
+ /** Build `f(conf)` twice (enabled / disabled) and assert equal results. */
+ private def checkEquivalence(build: () => DataFrame): Unit = {
+ val baseline: Array[Row] = withSQLConf(disableSegTree.toSeq: _*) {
+ build().collect().sortBy(_.toString)
+ }
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val actual = build().collect().sortBy(_.toString)
+ assert(actual.toSeq === baseline.toSeq,
Review Comment:
`checkEquivalence` sorts results by `Row.toString`, but `Row.toString` can
be unstable for some value types (notably `Array[Byte]` where the element
`toString` is address-based). That can reorder baseline vs actual differently
and cause false failures. Since this suite already extends `QueryTest`,
consider using `prepareAnswer(..., isSorted = false)` / `sameRows` or sorting
by stable keys (e.g. `id`, `pk`) instead of `toString`.
```suggestion
build().collect()
}
withSQLConf(enableSegTree.toSeq: _*) {
val actual = build().collect()
assert(sameRows(actual.toSeq, baseline.toSeq),
```
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala:
##########
@@ -0,0 +1,592 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.window
+
+import java.util.{LinkedHashMap => JLinkedHashMap, Map => JMap}
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkException
+import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count,
DeclarativeAggregate, Max, Min, StddevPop, StddevSamp, Sum, VariancePop,
VarianceSamp}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.util.ArrayImplicits._
+
+/**
+ * Block-chunked segment tree for moving-frame window aggregates. Partitions
are
+ * split into blocks of `blockSize` rows; each block has its own small segtree
+ * (fanout `F`, height `h`). Block roots stay resident; internal nodes are
cached
+ * in an LRU keyed by block index. Queries cost O(log W).
+ *
+ * Memory accounting invariants:
+ * - I1: `SegTreeSpiller.spill()` MUST NOT call `acquireMemory` on its own
+ * consumer (would deadlock TMM's consumer-priority sort). All acquires
+ * happen on the hot path ([[ensureBlockLevels]]).
+ * - I2: `spill(_, trigger)` returns 0 when `trigger eq this` (self-trigger
+ * short-circuit) to prevent re-entrant eviction.
+ * - I3: LRU `removeEldestEntry` is disabled; eviction is driven explicitly
+ * from [[ensureBlockLevels]] or [[SegTreeSpiller.spill]].
+ * - I4: Every successful [[acquireBlockMemory]] is paired with exactly one
+ * [[releaseBlockMemory]]. [[close]] is idempotent.
+ * - I5: Per-block bytes are a conservative upper bound (full block, 16
B/field).
+ * - I8: If `rowArray` already spilled to disk, `spill` returns 0 (rebuild
+ * would O(blockStart)-scan the spill file).
+ *
+ * @note Instances are not thread-safe.
+ */
+private[window] class WindowSegmentTree(
+ functions: Array[DeclarativeAggregate],
+ inputSchema: Seq[Attribute],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) =>
MutableProjection,
+ fanout: Int = WindowSegmentTree.DefaultFanout,
+ blockSize: Int = WindowSegmentTree.DefaultBlockSize,
+ maxCachedBlocks: Option[Int] = None,
+ taskMemoryManager: TaskMemoryManager = null)
+ extends AutoCloseable {
+
+ require(fanout >= 2, s"fanout must be >= 2, got $fanout")
+ require(blockSize >= 1, s"blockSize must be >= 1, got $blockSize")
+ require(functions.nonEmpty, "WindowSegmentTree requires at least one
aggregate function")
+ maxCachedBlocks.foreach { n =>
+ require(n >= 1, s"maxCachedBlocks must be >= 1 when specified, got $n")
+ }
+ require(taskMemoryManager != null,
+ "WindowSegmentTree requires a non-null TaskMemoryManager; " +
+ "in tests use `new TaskMemoryManager(new TestMemoryManager(conf), 0)`")
+
+ // ---------- Schemas & projections ----------
+
+ private val bufferAttrs: Seq[AttributeReference] =
+ functions.flatMap(_.aggBufferAttributes).toImmutableArraySeq
+ private val rightAttrs: Seq[AttributeReference] =
+ functions.flatMap(_.inputAggBufferAttributes).toImmutableArraySeq
+ private val bufferDataTypes: IndexedSeq[DataType] =
+ bufferAttrs.map(_.dataType).toIndexedSeq
+
+ private val initialValues: Seq[Expression] =
functions.flatMap(_.initialValues).toIndexedSeq
+ private val updateExpressions: Seq[Expression] =
+ functions.flatMap(_.updateExpressions).toIndexedSeq
+ private val mergeExpressions: Seq[Expression] =
+ functions.flatMap(_.mergeExpressions).toIndexedSeq
+
+ private[this] val initProj: MutableProjection =
newMutableProjection(initialValues, Nil)
+ private[this] val updateProj: MutableProjection =
+ newMutableProjection(updateExpressions, bufferAttrs ++ inputSchema)
+ private[this] val mergeProj: MutableProjection =
+ newMutableProjection(mergeExpressions, bufferAttrs ++ rightAttrs)
+
+ private[this] val joinedRow = new JoinedRow()
+
+ // ---------- State ----------
+
+ private var numRows: Int = 0
+ private var numBlocks: Int = 0
+ private var rowArray: ExternalAppendOnlyUnsafeRowArray = _
+ private var closed: Boolean = false
+
+ /** Always-resident per-block root aggregates: `blockAggregates(i)` =
+ * merged buffer over all rows in block i. */
+ private var blockAggregates: Array[InternalRow] = Array.empty
+
+ /** Conservative byte width of one aggregate buffer row at 16 B/field:
+ * primitive `MutableValue` is 8 B, boxed references and object headers
+ * push the effective footprint higher. Tighter per-type sizing is out
+ * of scope; TaskMemoryManager remains the hard backstop via spill / OOM. */
+ private val bufferWidthBytes: Long = {
+ val bytesPerField = 16L
+ math.max(1L, bufferDataTypes.size.toLong * bytesPerField)
+ }
+
+ /** Number of aggregate-buffer slots cached per block (see I5).
+ *
+ * Invariant: equals `sum over levels L of levels(L).length` for any block
+ * built by [[buildBlockLevels]]: level 0 holds `blockSize` leaves and each
+ * next level holds `ceil(prev / fanout)` parents until a single root
+ * remains. For `blockSize == 1` this is 1 (single leaf, no parents). */
+ private val cachedSlotsPerBlock: Long = {
+ var n = blockSize.toLong
+ var sum = n
+ while (n > 1L) {
+ n = (n + fanout - 1) / fanout
+ sum += n
+ }
+ sum
+ }
+
+ /** Bytes accounted per cached block (see I5). Conservative: assumes every
+ * block is full; tail block leaves a small headroom. */
+ private[this] val blockBytes: Long =
+ math.max(1L, cachedSlotsPerBlock * bufferWidthBytes)
+
+ /** `spans(L)` = number of leaves covered by a single node at level L.
Depends
+ * only on fanout + blockSize, so precomputed once. */
+ private val spans: Array[Int] = {
+ val maxLevel = {
+ var lvl = 0
+ var span = 1L
+ while (span < blockSize) { span *= fanout; lvl += 1 }
+ lvl
+ }
+ val arr = new Array[Int](maxLevel + 1)
+ var s = 1L
+ var i = 0
+ while (i <= maxLevel) {
+ arr(i) = if (s > Int.MaxValue) Int.MaxValue else s.toInt
+ s *= fanout
+ i += 1
+ }
+ arr
+ }
+
+ /** LRU cache of per-block internal node arrays. Key = blockIdx;
+ * value = `Array[Array[InternalRow]]` with levels(0..h). Auto-eviction
+ * via `removeEldestEntry` is disabled (I3) -- driven explicitly from
+ * [[ensureBlockLevels]] or [[SegTreeSpiller.spill]]. Each entry maps 1:1
+ * to one [[acquireBlockMemory]] accounting. Callers should pass a W-aware
+ * `maxCachedBlocks` like `ceil(W / blockSize) + 2`. */
+ private val blockLevelsCache: JLinkedHashMap[Integer,
Array[Array[InternalRow]]] =
+ new JLinkedHashMap[Integer, Array[Array[InternalRow]]](16, 0.75f, true) {
+ override def removeEldestEntry(
+ eldest: JMap.Entry[Integer, Array[Array[InternalRow]]]): Boolean =
false
+ }
+
+ // ---------- Memory consumer ----------
+
+ /**
+ * Private MemoryConsumer tracking cached block levels under TMM. Heap-only
+ * (no Tungsten pages): uses [[MemoryConsumer.acquireMemory]] /
+ * [[MemoryConsumer.freeMemory]], which update the base class `used`
+ * AtomicLong so TMM's consumer-priority sort sees our pressure accurately.
+ *
+ * Hardcoded [[MemoryMode.ON_HEAP]] (not `tmm.getTungstenMemoryMode`): the
+ * cache holds plain JVM objects (`SpecificInternalRow` /
+ * `Array[Array[InternalRow]]`), never Tungsten pages. Under
+ * `spark.memory.offHeap.enabled=true`, enrolling as OFF_HEAP would let
+ * TMM pick us as a spill candidate for off-heap pressure; our `spill()`
+ * would then phantom-credit the off-heap pool without releasing any
+ * off-heap bytes, violating [[TaskMemoryManager#acquireExecutionMemory]]'s
+ * same-pool spill contract. Mirrors
+ * [[org.apache.spark.util.collection.Spillable]], which also hardcodes
+ * ON_HEAP for the same reason. Consequence under off-heap Tungsten: I8
+ * (below) degrades to a no-op because segtree and `rowArray` live in
+ * different pools -- a loss of optimization, not a correctness hazard.
+ *
+ * @note `spill()` MUST NOT call `acquireMemory` (see I1).
+ */
+ private final class SegTreeSpiller extends MemoryConsumer(
+ taskMemoryManager,
+ taskMemoryManager.pageSizeBytes(),
+ MemoryMode.ON_HEAP) {
+ override def spill(size: Long, trigger: MemoryConsumer): Long = {
+ // I2: self-trigger short-circuit (prevent re-entrant eviction).
+ if (trigger eq this) return 0L
+ // I8: rowArray already spilled -- evicting our cache is
counter-productive
+ // (rebuild would O(blockStart)-scan the spill file). `spillSize > 0` is
+ // the available "has spilled" signal (UnsafeExternalSorter state is not
+ // public).
+ if (rowArray != null && rowArray.spillSize > 0) return 0L
+ evictUntil(size)
+ }
+ }
+
+ private[this] val spiller: SegTreeSpiller = new SegTreeSpiller
+
+ // ---------- Public API ----------
+
+ def size: Int = numRows
+
+ /**
+ * Build the tree against a caller-owned row array.
+ *
+ * Ownership: the tree holds a reference to `rows` for its lifetime but does
+ * NOT own it -- the caller (typically `WindowPartitionEvaluator.buffer`)
+ * manages `clear()` / lifetime at partition boundaries. `close()` drops
+ * the reference without mutating the array.
+ *
+ * Exception-safe: if aggregation throws, previously built state is
preserved.
+ */
+ def build(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
+ // rows.length is Int by design; check guards against future widening.
+ val n = rows.length
+ if (n < 0) {
+ throw SparkException.internalError(
+ s"WindowSegmentTree cannot hold more than Int.MaxValue rows, got $n")
+ }
+ val nBlocks = if (n == 0) 0 else (n + blockSize - 1) / blockSize
+ val newBlockAggs = computeBlockAggregates(rows, n, nBlocks)
+
+ // Commit.
+ rowArray = rows
+ numRows = n
+ numBlocks = nBlocks
+ blockAggregates = newBlockAggs
+ // Rebuild invalidates cached block levels; release accounting first (I4).
+ releaseAllCachedBlocks()
+ }
+
+ /**
+ * Query [lo, hi) and directly evaluate the result via `processor.evaluate`
+ * into `target`. Uses an internal pre-allocated buffer so no per-call
+ * allocation is needed.
+ */
+ private[window] def queryInto(
+ lo: Int, hi: Int, processor: AggregateProcessor, target: InternalRow):
Unit = {
+ query(lo, hi, internalQueryBuffer)
+ processor.evaluate(internalQueryBuffer, target)
+ }
+
+ private[this] val internalQueryBuffer: InternalRow = newBuffer()
+
+ def query(lo: Int, hi: Int, outBuffer: InternalRow): Unit = {
+ if (lo < 0 || hi > numRows || lo > hi) {
+ throw SparkException.internalError(
+ s"Invalid range [lo=$lo, hi=$hi) for size=$numRows")
+ }
+ // Reset outBuffer to identity only after bounds validation.
+ initProj.target(outBuffer)(InternalRow.empty)
+ if (lo == hi) return
+
+ val blo = lo / blockSize
+ val bhi = (hi - 1) / blockSize
+
+ if (blo == bhi) {
+ val blockStart = blo * blockSize
+ mergeBlockRange(blo, lo - blockStart, hi - blockStart, outBuffer)
+ } else {
+ // left partial
+ val loStart = blo * blockSize
+ val loBlockRows = math.min(blockSize, numRows - loStart)
+ mergeBlockRange(blo, lo - loStart, loBlockRows, outBuffer)
+ // full blocks
+ var b = blo + 1
+ while (b < bhi) {
+ mergeInto(outBuffer, blockAggregates(b))
+ b += 1
+ }
+ // right partial
+ val hiStart = bhi * blockSize
+ mergeBlockRange(bhi, 0, hi - hiStart, outBuffer)
+ }
+ }
+
+ /** Terminal: releases all state. Idempotent (I4). */
+ override def close(): Unit = {
+ if (closed) return
+ // Free all cached-block accounting before dropping references.
+ releaseAllCachedBlocks()
+ closeRowArray()
+ blockAggregates = Array.empty
+ numRows = 0
+ numBlocks = 0
+ closed = true
+ }
+
+ // ---------- Test hooks (package-private) ----------
+
+ private[window] def peekBlockCount: Int = numBlocks
+
+ private[window] def testOnlySpiller(): MemoryConsumer = spiller
+
+ /** Test-only accessor for the per-block memory accounting value. */
+ private[window] def peekBlockBytes: Long = blockBytes
+
+ /** NOTE: test-only; promotes block to MRU in the LRU cache as a side
effect. */
+ private[window] def peekLevelSize(blockIdx: Int, level: Int): Int = {
+ val levels = ensureBlockLevels(blockIdx)
+ levels(level).length
+ }
+
+ /** NOTE: test-only; promotes block to MRU in the LRU cache as a side
effect. */
+ private[window] def peekLevelCount(blockIdx: Int): Int = {
+ val levels = ensureBlockLevels(blockIdx)
+ levels.length
+ }
+
+ // ---------- Internals ----------
+
+ private def computeBlockAggregates(
+ array: ExternalAppendOnlyUnsafeRowArray,
+ n: Int,
+ nBlocks: Int): Array[InternalRow] = {
+ if (n == 0) return Array.empty
+ val result = new Array[InternalRow](nBlocks)
+ val iter = array.generateIterator()
+ var b = 0
+ while (b < nBlocks) {
+ val buf = newBuffer()
+ initProj.target(buf)(InternalRow.empty)
+ val start = b * blockSize
+ val end = math.min(start + blockSize, n)
+ var i = start
+ while (i < end) {
+ if (!iter.hasNext) {
+ throw SparkException.internalError("rowArray iterator exhausted
unexpectedly")
+ }
+ val row = iter.next()
+ updateProj.target(buf)(joinedRow(buf, row))
+ i += 1
+ }
+ result(b) = buf
+ b += 1
+ }
+ result
+ }
+
+ /** Merge `src` buffer into `dst` buffer using mergeProj. */
+ private def mergeInto(dst: InternalRow, src: InternalRow): Unit = {
+ mergeProj.target(dst)(joinedRow(dst, src))
+ }
+
+ private def newBuffer(): InternalRow =
+ new SpecificInternalRow(bufferDataTypes)
+
+ /** Merge the given leaf range [lo, hi) inside `blockIdx` into `out`. */
+ private def mergeBlockRange(
+ blockIdx: Int, lo: Int, hi: Int, out: InternalRow): Unit = {
+ if (lo >= hi) return
+ val levels = ensureBlockLevels(blockIdx)
+ val blockRows = levels(0).length
+ val topLevel = levels.length - 1
+ queryDescend(levels, blockRows, topLevel, 0, lo, hi, out)
+ }
+
+ /** Descend the (per-block) segment tree merging any node fully contained
+ * in [queryLo, queryHi) into `out`. A node at (level L, index idx) covers
+ * leaves `[idx * span, min((idx+1)*span, blockRows))` where span = F^L. */
+ private def queryDescend(
+ levels: Array[Array[InternalRow]],
+ blockRows: Int,
+ level: Int,
+ idx: Int,
+ queryLo: Int,
+ queryHi: Int,
+ out: InternalRow): Unit = {
+ val span = spans(level)
+ val nodeLo = idx * span
+ val nodeHi = math.min(nodeLo + span, blockRows)
+ if (queryLo >= nodeHi || queryHi <= nodeLo) return
+ if (queryLo <= nodeLo && nodeHi <= queryHi) {
+ mergeInto(out, levels(level)(idx))
+ return
+ }
+ val childLevel = level - 1
+ val childLevelSize = levels(childLevel).length
+ var c = 0
+ while (c < fanout) {
+ val childIdx = idx * fanout + c
+ if (childIdx < childLevelSize) {
+ queryDescend(levels, blockRows, childLevel, childIdx, queryLo,
queryHi, out)
+ }
+ c += 1
+ }
+ }
+
+ /** Build (or fetch from LRU) the full per-block levels array.
+ * Protocol: acquire memory -> build -> cache. Eviction on capacity
+ * overflow or on TMM spill request. */
+ private def ensureBlockLevels(blockIdx: Int): Array[Array[InternalRow]] = {
+ val cached = blockLevelsCache.get(Integer.valueOf(blockIdx))
+ if (cached != null) return cached
+
+ // Enforce LRU capacity before building a new entry (I3).
+ val cap = maxCachedBlocks.getOrElse(Int.MaxValue)
+ while (blockLevelsCache.size() >= cap) {
+ if (!evictEldest()) return throwCacheEvictFailed(blockIdx)
+ }
+
+ // Acquire accounting; on partial grant, try one manual evict-and-retry.
+ if (!acquireBlockMemory(blockIdx)) {
+ if (!evictEldest() || !acquireBlockMemory(blockIdx)) {
+ // scalastyle:off throwerror
+ throw QueryExecutionErrors.cannotAcquireMemoryForWindowAggregateError(
+ blockBytes, 0L)
+ // scalastyle:on throwerror
Review Comment:
The OOM error is always constructed with `receivedBytes = 0L`, even though
`acquireMemory(blockBytes)` may return a partial grant that is then rolled back
in `acquireBlockMemory`. This makes the diagnostic parameters misleading.
Consider returning the actual granted bytes from `acquireBlockMemory` (or
capturing it) and passing that value into
`cannotAcquireMemoryForWindowAggregateError`.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala:
##########
@@ -0,0 +1,892 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.window
+
+import org.apache.spark.sql.{DataFrame, Encoder, Encoders, QueryTest, Row}
+import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer,
Window}
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.util.SparkErrorUtils
+
+/**
+ * End-to-end tests for the block-chunked segment-tree moving window frame.
+ * Covers basic aggregates, frame boundaries, min-rows fallback, NULL/NaN,
+ * numeric/string/date-timestamp types, RANGE, Decimal/Binary merge, UDAF
+ * fallback, and frame lifecycle.
+ */
+class SegmentTreeWindowFunctionSuite extends QueryTest with SharedSparkSession
{
+
+ import testImplicits._
+
+ // Force seg-tree path regardless of partition size (fallback exercised
explicitly).
+ private val enableSegTree: Map[String, String] = Map(
+ SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true",
+ SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1")
+
+ private val disableSegTree: Map[String, String] = Map(
+ SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false")
+
+ /** Build `f(conf)` twice (enabled / disabled) and assert equal results. */
+ private def checkEquivalence(build: () => DataFrame): Unit = {
+ val baseline: Array[Row] = withSQLConf(disableSegTree.toSeq: _*) {
+ build().collect().sortBy(_.toString)
+ }
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val actual = build().collect().sortBy(_.toString)
+ assert(actual.toSeq === baseline.toSeq,
+ s"segment-tree output differs from baseline.\nExpected:
${baseline.toSeq}\n" +
+ s"Actual: ${actual.toSeq}")
+ }
+ }
+
+ /** Standard fixture: 3 partitions, sizes 40/40/40, values = row index. */
+ private def baseDF: DataFrame = {
+ spark.range(0, 120).selectExpr(
+ "id",
+ "(id % 3) AS pk",
+ "CAST(id AS INT) AS v")
+ }
+
+ private def winSpec(lo: Int, hi: Int) =
+ Window.partitionBy($"pk").orderBy($"id").rowsBetween(lo, hi)
+
+ // A1: basic aggregate equivalence
+
+ test("MIN over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk", min($"v").over(winSpec(-3, 3)).as("agg")))
+ }
+
+ test("MAX over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk", max($"v").over(winSpec(-3, 3)).as("agg")))
+ }
+
+ test("SUM over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk", sum($"v").over(winSpec(-3, 3)).as("agg")))
+ }
+
+ test("COUNT over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk", count($"v").over(winSpec(-3, 3)).as("agg")))
+ }
+
+ test("AVG over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk", avg($"v").over(winSpec(-3, 3)).as("agg")))
+ }
+
+ test("MIN + MAX + SUM share a single window frame") {
+ checkEquivalence(() =>
+ baseDF.select(
+ $"id",
+ $"pk",
+ min($"v").over(winSpec(-3, 3)).as("mn"),
+ max($"v").over(winSpec(-3, 3)).as("mx"),
+ sum($"v").over(winSpec(-3, 3)).as("sm")))
+ }
+
+ // A2: frame-size boundaries
+
+ test("frame size = 1 (CURRENT ROW only)") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk", sum($"v").over(winSpec(0, 0)).as("agg")))
+ }
+
+ test("frame spans full partition") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk", sum($"v").over(winSpec(-100,
100)).as("agg")))
+ }
+
+ test("frame extends past both partition edges") {
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk",
+ sum($"v").over(winSpec(-50, 50)).as("agg"),
+ min($"v").over(winSpec(-50, 50)).as("mn"),
+ max($"v").over(winSpec(-50, 50)).as("mx")))
+ }
+
+
+ test("partition below minPartitionRows falls back to
SlidingWindowFunctionFrame") {
+ // 5-row partition, min threshold = 10 -> must fall back.
+ val df = spark.range(0, 5).selectExpr(
+ "id", "0 AS pk", "CAST(id AS INT) AS v")
+ val enabledConf = Map(
+ SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true",
+ SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "10")
+
+ val baseline = withSQLConf(disableSegTree.toSeq: _*) {
+ df.select($"id", sum($"v").over(winSpec(-1, 1)).as("s"))
+ .collect().sortBy(_.toString)
+ }
+ val actual = withSQLConf(enabledConf.toSeq: _*) {
+ df.select($"id", sum($"v").over(winSpec(-1, 1)).as("s"))
+ .collect().sortBy(_.toString)
+ }
+ assert(actual.toSeq === baseline.toSeq)
+
+ // Confirm the fallback flag actually flipped.
+ withSQLConf(enabledConf.toSeq: _*) {
+ SegmentTreeWindowTestHelpers.withSmallPartitionFrame(
+ SQLConf.get, rows = 5) { frame =>
+ assert(frame.fallbackUsed,
+ "expected fallbackUsed=true for partition smaller than
minPartitionRows")
+ }
+ }
+ }
+
+
+ test("NTH_VALUE over ROWS frame falls back cleanly (no mergeExpressions
crash)") {
+ // NthValue extends DeclarativeAggregate but its mergeExpressions throws
+ // mergeUnsupportedByWindowFunctionError. eligibleForSegTree must exclude
it.
+ val df = baseDF
+ val withSegTree = withSQLConf(enableSegTree.toSeq: _*) {
+ df.selectExpr(
+ "id", "pk",
+ "nth_value(v, 3) OVER (PARTITION BY pk ORDER BY id " +
+ "ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING) AS n3")
+ .collect().sortBy(_.toString)
+ }
+ val baseline = withSQLConf(disableSegTree.toSeq: _*) {
+ df.selectExpr(
+ "id", "pk",
+ "nth_value(v, 3) OVER (PARTITION BY pk ORDER BY id " +
+ "ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING) AS n3")
+ .collect().sortBy(_.toString)
+ }
+ assert(withSegTree.toSeq === baseline.toSeq)
+ }
+
+ test("ROW_NUMBER over ROWS frame falls back cleanly (no mergeExpressions
crash)") {
+ val df = baseDF
+ val withSegTree = withSQLConf(enableSegTree.toSeq: _*) {
+ df.selectExpr(
+ "id", "pk",
+ "row_number() OVER (PARTITION BY pk ORDER BY id) AS rn")
+ .collect().sortBy(_.toString)
+ }
+ val baseline = withSQLConf(disableSegTree.toSeq: _*) {
+ df.selectExpr(
+ "id", "pk",
+ "row_number() OVER (PARTITION BY pk ORDER BY id) AS rn")
+ .collect().sortBy(_.toString)
+ }
+ assert(withSegTree.toSeq === baseline.toSeq)
+ }
+
+ // A3: NULL/NaN/Infinity; A4: numeric/string/date-timestamp types.
+ // Unsupported-merge / DISTINCT / feature-flag fallback.
+ // Oracle: run with seg-tree enabled and disabled, assert equal Row
sequences.
+
+ // A3: NULL / special values
+
+ test("all-NULL column: MIN/MAX/SUM/AVG/COUNT") {
+ val df = spark.range(0, 30).selectExpr(
+ "id", "(id % 3) AS pk", "CAST(NULL AS INT) AS v")
+ checkEquivalence(() =>
+ df.select($"id", $"pk",
+ min($"v").over(winSpec(-3, 3)).as("mn"),
+ max($"v").over(winSpec(-3, 3)).as("mx"),
+ sum($"v").over(winSpec(-3, 3)).as("sm"),
+ avg($"v").over(winSpec(-3, 3)).as("av"),
+ count($"v").over(winSpec(-3, 3)).as("cn")))
+ }
+
+ test("mixed NULL and non-NULL: NULLs must not leak into MIN/MAX") {
+ // Every 3rd value is NULL. Aggregates must skip them (NULL-agnostic
merge).
+ val df = spark.range(0, 60).selectExpr(
+ "id",
+ "(id % 3) AS pk",
+ "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v")
+ checkEquivalence(() =>
+ df.select($"id", $"pk",
+ min($"v").over(winSpec(-4, 4)).as("mn"),
+ max($"v").over(winSpec(-4, 4)).as("mx"),
+ sum($"v").over(winSpec(-4, 4)).as("sm"),
+ count($"v").over(winSpec(-4, 4)).as("cn")))
+ }
+
+ test("Double NaN and +/-Infinity propagate correctly through MIN/MAX/SUM") {
+ // Trap: NaN > +Inf in Spark's MIN/MAX ordering; +Inf + -Inf = NaN in SUM.
+ // Seg-tree uses DeclarativeAggregate.merge; behavior must match baseline.
+ val df = spark.range(0, 30).selectExpr(
+ "id",
+ "(id % 2) AS pk",
+ """CASE
+ WHEN id % 7 = 0 THEN double('NaN')
+ WHEN id % 7 = 1 THEN double('Infinity')
+ WHEN id % 7 = 2 THEN double('-Infinity')
+ ELSE CAST(id AS DOUBLE)
+ END AS v""")
+ checkEquivalence(() =>
+ df.select($"id", $"pk",
+ min($"v").over(winSpec(-3, 3)).as("mn"),
+ max($"v").over(winSpec(-3, 3)).as("mx"),
+ sum($"v").over(winSpec(-3, 3)).as("sm")))
+ }
+
+ // A4: data types
+
+ test("numeric types: Int / Long / Double / Decimal") {
+ val df = spark.range(0, 60).selectExpr(
+ "id",
+ "(id % 3) AS pk",
+ "CAST(id AS INT) AS vi",
+ "CAST(id * 1000000L AS LONG) AS vl",
+ "CAST(id AS DOUBLE) + 0.25 AS vd",
+ "CAST(id AS DECIMAL(20,4)) AS vdec")
+ checkEquivalence(() =>
+ df.select($"id", $"pk",
+ sum($"vi").over(winSpec(-2, 2)).as("si"),
+ min($"vl").over(winSpec(-2, 2)).as("ml"),
+ max($"vd").over(winSpec(-2, 2)).as("xd"),
+ sum($"vdec").over(winSpec(-2, 2)).as("sdec"),
+ avg($"vdec").over(winSpec(-2, 2)).as("adec")))
+ }
+
+ test("String lexicographic MIN/MAX") {
+ // Non-monotone values so MIN/MAX exercise the seg-tree merge (not edge).
+ val df = spark.range(0, 40).selectExpr(
+ "id",
+ "(id % 2) AS pk",
+ "CONCAT('s', LPAD(CAST((id * 37) % 97 AS STRING), 3, '0')) AS v")
+ checkEquivalence(() =>
+ df.select($"id", $"pk",
+ min($"v").over(winSpec(-3, 3)).as("mn"),
+ max($"v").over(winSpec(-3, 3)).as("mx")))
+ }
+
+ test("Date / Timestamp MIN/MAX") {
+ val df = spark.range(0, 40).selectExpr(
+ "id",
+ "(id % 2) AS pk",
+ "date_add(DATE'2020-01-01', CAST((id * 13) % 365 AS INT)) AS vd",
+ "CAST(TIMESTAMP'2020-01-01 00:00:00' + " +
+ "make_interval(0, 0, 0, 0, 0, 0, CAST(id AS DECIMAL(18,6))) AS
TIMESTAMP) AS vt")
+ checkEquivalence(() =>
+ df.select($"id", $"pk",
+ min($"vd").over(winSpec(-3, 3)).as("mnd"),
+ max($"vd").over(winSpec(-3, 3)).as("mxd"),
+ min($"vt").over(winSpec(-3, 3)).as("mnt"),
+ max($"vt").over(winSpec(-3, 3)).as("mxt")))
+ }
+
+
+ test("collect_list falls back cleanly (non-DeclarativeAggregate)") {
+ // collect_list is TypedImperativeAggregate; eligibleForSegTree must
decline.
+ checkEquivalence(() =>
+ baseDF.select($"id", $"pk",
+ collect_list($"v").over(winSpec(-2, 2)).as("lst")))
+ }
+
+ test("DISTINCT window aggregate is rejected by analyzer regardless of
seg-tree flag") {
+ // Analyzer throws DISTINCT_WINDOW_FUNCTION_UNSUPPORTED before frame
+ // construction; seg-tree flag must not alter this behavior.
+ def run(): Unit = {
+ baseDF.select($"id", $"pk",
+ count_distinct($"v").over(winSpec(-3, 3)).as("cd")).collect()
+ }
+ withSQLConf(disableSegTree.toSeq: _*) {
+ val e = intercept[org.apache.spark.sql.AnalysisException](run())
+ assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED"))
+ }
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val e = intercept[org.apache.spark.sql.AnalysisException](run())
+ assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED"))
+ }
+ }
+
+ test("feature flag off: segmentTree.enabled=false yields baseline
semantics") {
+ // Sanity: disabling the flag on a seg-tree-eligible workload still
+ // produces the SlidingWindowFunctionFrame answer.
+ val df = baseDF
+ val expected = withSQLConf(disableSegTree.toSeq: _*) {
+ df.select($"id", $"pk", min($"v").over(winSpec(-3, 3)).as("mn"))
+ .collect().sortBy(_.toString).toSeq
+ }
+ withSQLConf(
+ SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false",
+ SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1024") {
+ val actual = df.select($"id", $"pk", min($"v").over(winSpec(-3,
3)).as("mn"))
+ .collect().sortBy(_.toString).toSeq
+ assert(actual === expected)
+ }
+ }
+
+ // A5: RANGE frame equivalence (single-order-expr admission).
+ // MIN/MAX non-invertible, guaranteeing seg-tree path is exercised.
+
+ /** Run `sql` twice (flag off / on) and checkAnswer equality. */
+ private def checkRangeEquivalence(df: DataFrame, query: String): Unit = {
+ df.createOrReplaceTempView("t")
+ try {
+ val baseline = withSQLConf(disableSegTree.toSeq: _*) {
+ spark.sql(query).collect().sortBy(_.toString)
+ }
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val actual = spark.sql(query).collect().sortBy(_.toString)
+ assert(actual.toSeq === baseline.toSeq,
+ s"segment-tree output differs from baseline.\nExpected:
${baseline.toSeq}\n" +
+ s"Actual: ${actual.toSeq}")
+ }
+ } finally {
+ spark.catalog.dropTempView("t")
+ }
+ }
+
+ test("-- RANGE INT offset basic (non-uniform gaps, MIN/MAX)") {
+ // Non-uniform gaps so admit/drop loops must consult the order-key
comparator.
+ val df = spark.range(0, 40).selectExpr(
+ "CAST(id AS INT) AS id",
+ "(CAST(id AS INT) % 2) AS pk",
+ "CAST(CASE CAST(id AS INT) % 7 " +
+ "WHEN 0 THEN 1 WHEN 1 THEN 3 WHEN 2 THEN 4 WHEN 3 THEN 4 " +
+ "WHEN 4 THEN 7 WHEN 5 THEN 10 ELSE 15 END + (CAST(id AS INT) / 7) * 20
AS INT) AS k",
+ "CAST((id * 31) % 97 AS INT) AS v")
+ checkRangeEquivalence(df,
+ """SELECT id, pk,
+ | MIN(v) OVER (PARTITION BY pk ORDER BY k
+ | RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS mn,
+ | MAX(v) OVER (PARTITION BY pk ORDER BY k
+ | RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS mx
+ |FROM t""".stripMargin)
+ }
+
+ test("-- RANGE Timestamp with INTERVAL offset (MAX)") {
+ // Irregular gaps force the timestamp comparator at the 1-hour boundary.
+ val df = spark.range(0, 30).selectExpr(
+ "CAST(id AS INT) AS id",
+ "(CAST(id AS INT) % 2) AS pk",
+ "CAST(TIMESTAMP'2024-01-01 10:00:00' + " +
+ "make_interval(0, 0, 0, 0, 0, 30 * CAST(id AS INT) * " +
+ "(CASE CAST(id AS INT) % 3 WHEN 0 THEN 1 WHEN 1 THEN 3 ELSE 4 END), 0)
" +
+ "AS TIMESTAMP) AS ts",
+ "CAST((id * 17) % 53 AS INT) AS v")
+ checkRangeEquivalence(df,
+ """SELECT id, pk,
+ | MAX(v) OVER (PARTITION BY pk ORDER BY ts
+ | RANGE BETWEEN INTERVAL '1' HOUR PRECEDING
+ | AND INTERVAL '1' HOUR FOLLOWING) AS mx
+ |FROM t""".stripMargin)
+ }
+
+ test("-- RANGE with tie (duplicate order keys) inclusion at boundary") {
+ // Trap: RANGE `0 PRECEDING AND 0 FOLLOWING` must include the FULL tie
+ // group at the current row's key, not just the current row. A ROWS-vs-
+ // RANGE confusion would return per-row MIN/MAX instead of group-level.
+ val rows = (0 until 40).map { i =>
+ val k = Seq(1, 2, 2, 2, 3, 4, 5)(i % 7)
+ (i, i % 2, k, (i * 13) % 41)
+ }
+ val df = rows.toDF("id", "pk", "k", "v")
+ checkRangeEquivalence(df,
+ """SELECT id, pk, k,
+ | MIN(v) OVER (PARTITION BY pk ORDER BY k
+ | RANGE BETWEEN 0 PRECEDING AND 0 FOLLOWING) AS mn,
+ | MAX(v) OVER (PARTITION BY pk ORDER BY k
+ | RANGE BETWEEN 0 PRECEDING AND 0 FOLLOWING) AS mx
+ |FROM t""".stripMargin)
+ }
+
+ test("-- RANGE frame wider than partition (C4: admit/drop loops no-op)") {
+ // Once the first batch is admitted, admit/drop must detect no change
+ // and skip work.
+ val df = spark.range(0, 25).selectExpr(
+ "CAST(id AS INT) AS id",
+ "(CAST(id AS INT) / 5) AS pk",
+ "CAST((id * 7) % 23 AS INT) AS k",
+ "CAST((id * 19) % 101 AS INT) AS v")
+ checkRangeEquivalence(df,
+ """SELECT id, pk,
+ | MIN(v) OVER (PARTITION BY pk ORDER BY k
+ | RANGE BETWEEN 100 PRECEDING AND 100 FOLLOWING) AS mn,
+ | MAX(v) OVER (PARTITION BY pk ORDER BY k
+ | RANGE BETWEEN 100 PRECEDING AND 100 FOLLOWING) AS mx
+ |FROM t""".stripMargin)
+ }
+
+ test("-- RANGE with NULL order key (NULLS FIRST / NULLS LAST)") {
+ // Trap: Spark groups all NULLs into a single equivalence class at head
+ // (NULLS FIRST) or tail (NULLS LAST); seg-tree must treat NULL as a
+ // tie group identical to the sliding baseline.
+ val rows = (0 until 36).map { i =>
+ val kOpt: Option[Int] = (i % 6) match {
+ case 0 | 1 | 5 => None
+ case 2 => Some(1)
+ case 3 => Some(2)
+ case _ => Some(3)
+ }
+ (i, i % 2, kOpt, (i * 11) % 37)
+ }
+ val df = rows.toDF("id", "pk", "k", "v")
+ checkRangeEquivalence(df,
+ """SELECT id, pk,
+ | MIN(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS FIRST
+ | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mn_nf,
+ | MAX(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS FIRST
+ | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mx_nf,
+ | MIN(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS LAST
+ | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mn_nl,
+ | MAX(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS LAST
+ | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mx_nl
+ |FROM t""".stripMargin)
+ }
+
+ // Decimal overflow / BinaryType MIN/MAX across block merge; UDAF fallback.
+ // Trap: blockSize=16 is SQLConf minimum; frame > blockSize ensures the
+ // seg-tree merge path is actually crossed.
+
+ private val segTreeBlock: String = "16"
+ private val segTreeFramePrec: Int = 17
+ private val segTreeRows: Int = 20
+
+ private def withSegTreeBlock(conf: (String, String)*)(body: => Unit): Unit =
{
+ val extra = Seq(SQLConf.WINDOW_SEGMENT_TREE_BLOCK_SIZE.key ->
segTreeBlock) ++ conf
+ withSQLConf(extra: _*)(body)
+ }
+
+
+ /**
+ * 20 rows in one partition, Decimal(38, 0) values near the type's upper
+ * bound; frame of `segTreeFramePrec` PRECEDING..CURRENT ROW makes any
+ * >=2-row window overflow Sum. Block 16 + frame 17 forces cross-block merge.
+ */
+ private def decimalOverflowDF: DataFrame = {
+ // 9e37 -- below Decimal(38,0) MAX (~9.99e37), but 2x overflows.
+ val big = "90000000000000000000000000000000000000" // 38 digits
+ spark.range(0, segTreeRows.toLong).selectExpr(
+ "CAST(id AS INT) AS id",
+ "0 AS pk",
+ s"CAST('$big' AS DECIMAL(38, 0)) AS v")
+ }
+
+ private val decimalOverflowSql: String =
+ s"""SELECT id, pk,
+ | SUM(v) OVER (PARTITION BY pk ORDER BY id
+ | ROWS BETWEEN $segTreeFramePrec PRECEDING AND CURRENT ROW) AS s
+ |FROM t""".stripMargin
+
+ test("a -- Decimal overflow ANSI on, seg-tree matches sliding (both throw)")
{
+ val df = decimalOverflowDF
+ df.createOrReplaceTempView("t")
+ try {
+ withSegTreeBlock(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(disableSegTree.toSeq: _*) {
+ val e = intercept[Exception] {
+ spark.sql(decimalOverflowSql).collect()
+ }
+ assert(hasArithmeticCause(e),
+ s"expected ArithmeticException root cause, got: ${e.getMessage}")
+ }
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val e = intercept[Exception] {
+ spark.sql(decimalOverflowSql).collect()
+ }
+ assert(hasArithmeticCause(e),
+ s"expected ArithmeticException root cause, got: ${e.getMessage}")
+ }
+ }
+ } finally {
+ spark.catalog.dropTempView("t")
+ }
+ }
+
+ test("b -- Decimal overflow ANSI off, seg-tree matches sliding (NULL on
overflow)") {
+ val df = decimalOverflowDF
+ df.createOrReplaceTempView("t")
+ try {
+ withSegTreeBlock(SQLConf.ANSI_ENABLED.key -> "false") {
+ val baseline = withSQLConf(disableSegTree.toSeq: _*) {
+ spark.sql(decimalOverflowSql).collect().sortBy(_.toString)
+ }
+ // At least one row must be NULL so we know overflow actually fired.
+ assert(baseline.exists(_.isNullAt(2)),
+ "baseline should contain NULL overflow rows; test data may be too
small")
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val actual =
spark.sql(decimalOverflowSql).collect().sortBy(_.toString)
+ assert(actual.toSeq === baseline.toSeq)
+ }
+ }
+ } finally {
+ spark.catalog.dropTempView("t")
+ }
+ }
+
+ test("c -- mid-window Decimal overflow slides past (seg-tree == sliding)") {
+ // Big values at ids 14..17 straddle block boundary at id=16, so any
+ // 4-row window overlapping >=2 of them overflows (-> NULL when ANSI off)
+ // and cross-block merge sees overflowing buffers. Rows past id=20 slide
+ // clear and recover non-NULL.
+ val big = "90000000000000000000000000000000000000"
+ val df = spark.range(0, 24).selectExpr(
+ "CAST(id AS INT) AS id",
+ "0 AS pk",
+ s"""CASE WHEN id IN (14, 15, 16, 17)
+ THEN CAST('$big' AS DECIMAL(38, 0))
+ ELSE CAST(id AS DECIMAL(38, 0))
+ END AS v""")
+ df.createOrReplaceTempView("t")
+ try {
+ val sqlStr =
+ """SELECT id, pk,
+ | SUM(v) OVER (PARTITION BY pk ORDER BY id
+ | ROWS BETWEEN 3 PRECEDING AND CURRENT ROW) AS s
+ |FROM t""".stripMargin
+ withSegTreeBlock(SQLConf.ANSI_ENABLED.key -> "false") {
+ val baseline = withSQLConf(disableSegTree.toSeq: _*) {
+ spark.sql(sqlStr).collect().sortBy(_.toString)
+ }
+ // Sanity: overflow fired AND later rows recover non-NULL.
+ assert(baseline.exists(_.isNullAt(2)),
+ "baseline should contain NULL overflow rows")
+ assert(baseline.exists(r => r.getInt(0) >= 21 && !r.isNullAt(2)),
+ "rows with id>=21 should be non-NULL (window slid past big values)")
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val actual = spark.sql(sqlStr).collect().sortBy(_.toString)
+ assert(actual.toSeq === baseline.toSeq)
+ }
+ }
+ } finally {
+ spark.catalog.dropTempView("t")
+ }
+ }
+
+ /** True iff the root cause of `t` is an [[ArithmeticException]] (ANSI
overflow). */
+ private def hasArithmeticCause(t: Throwable): Boolean =
+
Option(SparkErrorUtils.getRootCause(t)).exists(_.isInstanceOf[ArithmeticException])
+
+
+ /** Pattern of 20 Array[Byte] values used across a. */
+ private def binaryVariedRows: Seq[(Int, Array[Byte])] = {
+ (0 until 20).map { i =>
+ val arr: Array[Byte] = (i % 8) match {
+ case 0 => Array[Byte](0x01, 0x02)
+ case 1 => Array[Byte](0x00)
+ case 2 => Array[Byte](0x7f)
+ case 3 => Array[Byte](0x7f, 0x00)
+ case 4 => Array[Byte](0x10, 0x20, 0x30)
+ case 5 => Array[Byte](0x10, 0x20)
+ case 6 => Array[Byte](0x10)
+ case _ => Array[Byte](0x05, 0x05, 0x05, 0x05)
+ }
+ (i, arr)
+ }
+ }
+
+ test("a -- BinaryType MIN/MAX cross-block merge") {
+ // Varied lengths/content; frame > blockSize guarantees merge path hit.
+ val df = binaryVariedRows.toDF("id", "v").selectExpr("id", "0 AS pk", "v")
+ df.createOrReplaceTempView("t")
+ try {
+ withSegTreeBlock() {
+ val sqlStr =
+ s"""SELECT id, pk,
+ | MIN(v) OVER (PARTITION BY pk ORDER BY id
+ | ROWS BETWEEN $segTreeFramePrec PRECEDING AND CURRENT ROW) AS
mn,
+ | MAX(v) OVER (PARTITION BY pk ORDER BY id
+ | ROWS BETWEEN $segTreeFramePrec PRECEDING AND CURRENT ROW) AS
mx
+ |FROM t""".stripMargin
+ val baseline = withSQLConf(disableSegTree.toSeq: _*) {
+ spark.sql(sqlStr).collect().sortBy(_.toString)
+ }
+ withSQLConf(enableSegTree.toSeq: _*) {
+ val actual = spark.sql(sqlStr).collect().sortBy(_.toString)
Review Comment:
These BinaryType tests sort collected Rows by `_.toString`, but
`Row.toString` will include JVM address-based `Array[Byte].toString`, so
baseline vs actual can sort into different orders even when results are
identical. Prefer sorting by stable columns (e.g. `id`) or using
`prepareAnswer(..., isSorted = false)` / `sameRows` from `QueryTest`, which
normalizes binary values before sorting.
```suggestion
spark.sql(sqlStr).collect().sortBy(_.getInt(0))
}
withSQLConf(enableSegTree.toSeq: _*) {
val actual = spark.sql(sqlStr).collect().sortBy(_.getInt(0))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]