yaooqinn commented on code in PR #55422: URL: https://github.com/apache/spark/pull/55422#discussion_r3115436054
########## sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala: ########## @@ -0,0 +1,610 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import java.util.{LinkedHashMap => JLinkedHashMap, Map => JMap} + +import scala.collection.mutable + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.ArrayImplicits._ + +/** + * Block-chunked segment tree for range aggregate queries over window partitions. + * + * See `the class documentation` + * for the full design (API contract Section 2, block-chunked memory layout Section 3, + * DeclarativeAggregate binding Section 4, error handling Section 5, test hooks Section 6). + * + * initial implementation scope: correctness only. The data layer uses + * `ExternalAppendOnlyUnsafeRowArray` to hold input rows (spillable). Each + * block materializes its own small segment tree (levels 0..h). Internal + * nodes are cached in an LRU keyed by block index; block root aggregates + * (block pre-aggregates) stay resident for all blocks. + * + * Note: the design doc Section 3.3 specifies leaves are NOT materialized and + * recomputed from the spillable array on demand. For initial implementation simplicity + * we materialize leaves inside the per-block internal node arrays. + * // TODO(SPARK-XXXXX) re-assess after Frame integration. + * + * @note Instances are not thread-safe. + */ +private[window] class WindowSegmentTree( + functions: Array[DeclarativeAggregate], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + fanout: Int = WindowSegmentTree.DefaultFanout, + blockSize: Int = WindowSegmentTree.DefaultBlockSize, + maxCachedBlocks: Option[Int] = None, + spillThreshold: Int = Int.MaxValue, + inMemoryThreshold: Int = Int.MaxValue, + taskMemoryManager: TaskMemoryManager = null) + extends AutoCloseable { + + require(fanout >= 2, s"fanout must be >= 2, got $fanout") + require(blockSize >= 1, s"blockSize must be >= 1, got $blockSize") + require(functions.nonEmpty, "WindowSegmentTree requires at least one aggregate function") + maxCachedBlocks.foreach { n => + require(n >= 1, s"maxCachedBlocks must be >= 1 when specified, got $n") + } + require(taskMemoryManager != null, + "WindowSegmentTree requires a non-null TaskMemoryManager; " + + "in tests use `new TaskMemoryManager(new TestMemoryManager(conf), 0)`") + + // ---------- Schemas & projections ---------- + + private val bufferAttrs: Seq[AttributeReference] = + functions.flatMap(_.aggBufferAttributes).toImmutableArraySeq + private val rightAttrs: Seq[AttributeReference] = + functions.flatMap(_.inputAggBufferAttributes).toImmutableArraySeq + private val bufferDataTypes: IndexedSeq[DataType] = + bufferAttrs.map(_.dataType).toIndexedSeq + + private val initialValues: Seq[Expression] = functions.flatMap(_.initialValues).toIndexedSeq + private val updateExpressions: Seq[Expression] = + functions.flatMap(_.updateExpressions).toIndexedSeq + private val mergeExpressions: Seq[Expression] = + functions.flatMap(_.mergeExpressions).toIndexedSeq + + private[this] val initProj: MutableProjection = newMutableProjection(initialValues, Nil) + private[this] val updateProj: MutableProjection = + newMutableProjection(updateExpressions, bufferAttrs ++ inputSchema) + private[this] val mergeProj: MutableProjection = + newMutableProjection(mergeExpressions, bufferAttrs ++ rightAttrs) + + private val inputUnsafeProj: UnsafeProjection = + UnsafeProjection.create(inputSchema.map(_.dataType).toArray) + + private[this] val joinedRow = new JoinedRow() + + // ---------- State ---------- + + private var numRows: Int = 0 + private var numBlocks: Int = 0 + private var rowArray: ExternalAppendOnlyUnsafeRowArray = _ + private var closed: Boolean = false + + /** Always-resident per-block root aggregates. `blockAggregates(i)` = + * merged buffer over all rows in block i. */ + private var blockAggregates: Array[InternalRow] = Array.empty + + /** Rough byte width of one aggregate buffer row. Chosen at 16 B/field as a + * conservative heap-overhead-aware lower bound for a + * `SpecificInternalRow` slot: primitive `MutableValue` is 8 B, boxed + * references and object headers push the effective footprint higher. + * Tighter per-type sizing (real boxing cost, variable-length fields) is + * intentionally out of scope here; TaskMemoryManager remains the hard + * backstop via spill / OOM. + * TODO(SPARK-XXXXX): per-type width estimator keyed on + * `bufferDataTypes` (primitive 16 B, String/Binary/Decimal wider). */ + private val bufferWidthBytes: Long = { + val bytesPerField = 16L + math.max(1L, bufferDataTypes.size.toLong * bytesPerField) + } + + /** Number of aggregate-buffer slots cached per block (contract I5). + * + * Invariant: equals `sum over levels L of levels(L).length` for any + * block materialized by [[buildBlockLevels]]. Level 0 holds `blockSize` + * leaf buffers and each subsequent level holds `ceil(prev / fanout)` + * parent buffers until a single root remains. The iterative ceiling + * matches the allocation in [[buildBlockLevels]] for every + * `(blockSize, fanout)` pair, including non-power-of-`fanout` cases. + * For `blockSize == 1` the block is a single leaf with no parent + * levels, so this returns 1. + * TODO(SPARK-XXXXX): drop the leaf term when [[buildBlockLevels]] + * switches to on-demand leaf recomputation. */ + private val cachedSlotsPerBlock: Long = { + var n = blockSize.toLong + var sum = n + while (n > 1L) { + n = (n + fanout - 1) / fanout + sum += n + } + sum + } + + /** Bytes accounted per cached block (contract I5). Conservative: assumes + * every block is full; tail block (`numRows % blockSize != 0`) will + * hold fewer leaves, giving a small headroom. */ + private[this] val blockBytes: Long = + math.max(1L, cachedSlotsPerBlock * bufferWidthBytes) + + /** `spans(L)` = number of leaves covered by a single node at level L. Depends + * only on fanout + blockSize, so precomputed once. */ + private val spans: Array[Int] = { + val maxLevel = { + var lvl = 0 + var span = 1L + while (span < blockSize) { span *= fanout; lvl += 1 } + lvl + } + val arr = new Array[Int](maxLevel + 1) + var s = 1L + var i = 0 + while (i <= maxLevel) { + arr(i) = if (s > Int.MaxValue) Int.MaxValue else s.toInt + s *= fanout + i += 1 + } + arr + } + + /** LRU cache of per-block internal node arrays. Key = blockIdx. + * Value = `Array[Array[InternalRow]]` with levels(0..h). Auto-eviction via + * `removeEldestEntry` is disabled (contract I3) -- eviction is driven + * explicitly from [[ensureBlockLevels]] (capacity overflow) or + * [[SegTreeSpiller.spill]] (TMM pressure). Each cache entry maps 1:1 to + * one [[acquireBlockMemory]] accounting. Callers (e.g. + * `SegmentTreeWindowFunctionFrame`) should pass a W-aware + * value like `ceil(W / blockSize) + 2`. */ + private val blockLevelsCache: JLinkedHashMap[Integer, Array[Array[InternalRow]]] = + new JLinkedHashMap[Integer, Array[Array[InternalRow]]](16, 0.75f, true) { + override def removeEldestEntry( + eldest: JMap.Entry[Integer, Array[Array[InternalRow]]]): Boolean = false + } + + // ---------- Memory consumer (contract Section 2.2) ---------- + + /** + * Private MemoryConsumer tracking cached block levels under TMM. + * + * Heap accounting only (no Tungsten pages): uses + * [[MemoryConsumer.acquireMemory]] / [[MemoryConsumer.freeMemory]]. The + * [[MemoryConsumer]] base class records `used` via an `AtomicLong` when + * we call these -- so TMM's consumer-priority sort in + * `acquireExecutionMemory` sees our pressure accurately. + * + * @note `spill()` MUST NOT call `acquireMemory` (contract I1). + */ + private final class SegTreeSpiller extends MemoryConsumer( + taskMemoryManager, + taskMemoryManager.pageSizeBytes(), + taskMemoryManager.getTungstenMemoryMode()) { + override def spill(size: Long, trigger: MemoryConsumer): Long = { + // I2: self-trigger short-circuit. Prevents re-entrant eviction when + // our own acquireMemory() triggers TMM to poll us. + if (trigger eq this) return 0L + // I8: rowArray spilled-to-disk detection. If the rowArray has already + // spilled, evicting our cache is counter-productive (rebuild would + // O(blockStart)-scan the spill file). Return 0L to let TMM fall + // through to the next consumer. + // + // TODO(SPARK-XXXXX) #segtree-spill-priority (contract Section 7 O4): current + // heuristic uses `spillSize > 0` as the "has spilled" signal. A more + // precise check would consult `UnsafeExternalSorter.getSpillWriters` + // state, but that API is not public. Re-evaluate after benchmark. + // FIXME(kentyao): upstream a public "hasSpilled" hook on the array. Review Comment: Addressed in `79e42da60ad` (R3-cleanup-oss). Thank you for the advice! ########## sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala: ########## @@ -191,13 +200,17 @@ trait WindowEvaluatorFactoryBase { // in a single Window physical node. Therefore, we can assume no SQL aggregation // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL // aggregation function in a single physical node. - def processor = if (functions.exists(_.isInstanceOf[PythonFuncExpression])) { + val aggFilters: Array[Option[Expression]] = expressions.map { + case WindowExpression(ae: AggregateExpression, _) => ae.filter + case _ => None + }.toArray + // Shared per (key) across the factory closure's invocations; each + // Frame calls `processor.initialize(...)` in `prepare`, so cross- Review Comment: Addressed in `79e42da60ad` (R3-cleanup-oss). Thank you for the advice! ########## sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala: ########## @@ -295,4 +344,63 @@ trait WindowEvaluatorFactoryBase { } } + /** + * Segment-tree path eligibility. The tree relies on `DeclarativeAggregate.mergeExpressions`, + * which [[AggregateWindowFunction]]s (e.g. NthValue, NTile, Rank, RowNumber, NullIndex) refuse + * with `mergeUnsupportedByWindowFunctionError`. They extend DeclarativeAggregate but are + * NOT merge-capable, so we must exclude them here even though the base-class check passes. + * Normal aggregate window expressions reach this code as the inner DeclarativeAggregate + * unwrapped from [[AggregateExpression]] (see `windowFrameExpressionFactoryPairs.collect`). + */ + private def eligibleForSegTree( + functions: Array[Expression], + filters: Array[Option[Expression]], + frameType: FrameType): Boolean = { + // RANGE is accepted only for single-column order specs. Multi-column + // RANGE with non-zero offset is already rejected by `createBoundOrdering` + // (see the `RangeFrame, _` internalError branch above), so gating here + // on `orderSpec.size == 1` matches the Sliding-path invariant exactly. + val frameTypeOk = frameType match { + case RowFrame => true + case RangeFrame => orderSpec.size == 1 + } + SQLConf.get.windowSegmentTreeEnabled && + frameTypeOk && + filters.forall(_.isEmpty) && + functions.forall { f => + f.isInstanceOf[DeclarativeAggregate] && !f.isInstanceOf[AggregateWindowFunction] Review Comment: Addressed in `80fe0a75993` (F9 allowlist). Thank you for the advice! ########## sql/core/src/test/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionSuite.scala: ########## @@ -0,0 +1,862 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import org.apache.spark.sql.{DataFrame, Encoder, Encoders, QueryTest, Row} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, Window} +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, LongType, StructType} + +/** + * End-to-end tests for the block-chunked segment-tree moving window frame. + * + * Coverage by section: + * - Coverage: various cases (basic aggregates), various cases (frame boundaries), + * min-partition-rows fallback, AggregateWindowFunction + * regression. + * - Coverage: various cases (NULL, NaN/Infinity), various cases + * (numeric / string / date-timestamp types), various cases + * (unsupported-merge / DISTINCT / feature-flag fallback). + * + */ +class SegmentTreeWindowFunctionSuite extends QueryTest with SharedSparkSession { + + import testImplicits._ + + // Common config: force the segment-tree path regardless of partition size + // (we exercise the fallback explicitly below). + private val enableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1") + + private val disableSegTree: Map[String, String] = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false") + + /** Build `f(conf)` twice (enabled / disabled) and assert equal results. */ + private def checkEquivalence(build: () => DataFrame): Unit = { + val baseline: Array[Row] = withSQLConf(disableSegTree.toSeq: _*) { + build().collect().sortBy(_.toString) + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = build().collect().sortBy(_.toString) + assert(actual.toSeq === baseline.toSeq, + s"segment-tree output differs from baseline.\nExpected: ${baseline.toSeq}\n" + + s"Actual: ${actual.toSeq}") + } + } + + /** Standard fixture: 3 partitions, sizes 40/40/40, values = row index. */ + private def baseDF: DataFrame = { + spark.range(0, 120).selectExpr( + "id", + "(id % 3) AS pk", + "CAST(id AS INT) AS v") + } + + private def winSpec(lo: Int, hi: Int) = + Window.partitionBy($"pk").orderBy($"id").rowsBetween(lo, hi) + + // ---------------- A1: basic aggregate equivalence ---------------- + + test("MIN over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", min($"v").over(winSpec(-3, 3)).as("agg"))) + } + + test("MAX over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", max($"v").over(winSpec(-3, 3)).as("agg"))) + } + + test("SUM over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(winSpec(-3, 3)).as("agg"))) + } + + test("COUNT over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", count($"v").over(winSpec(-3, 3)).as("agg"))) + } + + test("AVG over ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING") { + checkEquivalence(() => + baseDF.select($"id", $"pk", avg($"v").over(winSpec(-3, 3)).as("agg"))) + } + + test("MIN + MAX + SUM share a single window frame") { + checkEquivalence(() => + baseDF.select( + $"id", + $"pk", + min($"v").over(winSpec(-3, 3)).as("mn"), + max($"v").over(winSpec(-3, 3)).as("mx"), + sum($"v").over(winSpec(-3, 3)).as("sm"))) + } + + // ---------------- A2: frame-size boundaries ---------------- + + test("frame size = 1 (CURRENT ROW only)") { + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(winSpec(0, 0)).as("agg"))) + } + + test("frame spans full partition") { + // 40 rows per partition; use a wide symmetric window covering it. + checkEquivalence(() => + baseDF.select($"id", $"pk", sum($"v").over(winSpec(-100, 100)).as("agg"))) + } + + test("frame extends past both partition edges") { + checkEquivalence(() => + baseDF.select($"id", $"pk", + sum($"v").over(winSpec(-50, 50)).as("agg"), + min($"v").over(winSpec(-50, 50)).as("mn"), + max($"v").over(winSpec(-50, 50)).as("mx"))) + } + + + test("partition below minPartitionRows falls back to SlidingWindowFunctionFrame") { + // 5-row partition, min threshold = 10 -> must fall back. + val df = spark.range(0, 5).selectExpr( + "id", "0 AS pk", "CAST(id AS INT) AS v") + val enabledConf = Map( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "true", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "10") + + // 1. Result correctness against baseline. + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + df.select($"id", sum($"v").over(winSpec(-1, 1)).as("s")) + .collect().sortBy(_.toString) + } + val actual = withSQLConf(enabledConf.toSeq: _*) { + df.select($"id", sum($"v").over(winSpec(-1, 1)).as("s")) + .collect().sortBy(_.toString) + } + assert(actual.toSeq === baseline.toSeq) + + // 2. Directly exercise the frame to confirm the fallback flag flips. + withSQLConf(enabledConf.toSeq: _*) { + SegmentTreeWindowTestHelpers.withSmallPartitionFrame( + SQLConf.get, rows = 5) { frame => + assert(frame.fallbackUsed, + "expected fallbackUsed=true for partition smaller than minPartitionRows") + } + } + } + + + test("NTH_VALUE over ROWS frame falls back cleanly (no mergeExpressions crash)") { + // NthValue extends DeclarativeAggregate but its mergeExpressions throws + // mergeUnsupportedByWindowFunctionError. eligibleForSegTree must exclude it. + val df = baseDF + val withSegTree = withSQLConf(enableSegTree.toSeq: _*) { + df.selectExpr( + "id", "pk", + "nth_value(v, 3) OVER (PARTITION BY pk ORDER BY id " + + "ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING) AS n3") + .collect().sortBy(_.toString) + } + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + df.selectExpr( + "id", "pk", + "nth_value(v, 3) OVER (PARTITION BY pk ORDER BY id " + + "ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING) AS n3") + .collect().sortBy(_.toString) + } + assert(withSegTree.toSeq === baseline.toSeq) + } + + test("ROW_NUMBER over ROWS frame falls back cleanly (no mergeExpressions crash)") { + val df = baseDF + val withSegTree = withSQLConf(enableSegTree.toSeq: _*) { + df.selectExpr( + "id", "pk", + "row_number() OVER (PARTITION BY pk ORDER BY id) AS rn") + .collect().sortBy(_.toString) + } + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + df.selectExpr( + "id", "pk", + "row_number() OVER (PARTITION BY pk ORDER BY id) AS rn") + .collect().sortBy(_.toString) + } + assert(withSegTree.toSeq === baseline.toSeq) + } + + // A3.*: NULL / NaN / Infinity handling + // A4.*: numeric / string / date-timestamp types + // various cases: unsupported-merge / DISTINCT / feature-flag fallback + // + // All tests use the same oracle strategy as the frame integration: run with + // `segmentTree.enabled=true` (forced via min-rows=1) and with `=false`, + // then assert bit-for-bit equal Row sequences. That gives us: + // - Correctness: seg-tree output matches SlidingWindowFunctionFrame + // (the community-validated baseline). + // - Fallback paths: various cases exercise the `eligibleForSegTree` filter, + // which must decline to drive the seg-tree path and hand off to the + // sliding frame; equal rows prove the hand-off preserves semantics. + + // ---------------- A3: NULL / special values ---------------- + + test("all-NULL column: MIN/MAX/SUM/AVG/COUNT") { + val df = spark.range(0, 30).selectExpr( + "id", "(id % 3) AS pk", "CAST(NULL AS INT) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(winSpec(-3, 3)).as("mn"), + max($"v").over(winSpec(-3, 3)).as("mx"), + sum($"v").over(winSpec(-3, 3)).as("sm"), + avg($"v").over(winSpec(-3, 3)).as("av"), + count($"v").over(winSpec(-3, 3)).as("cn"))) + } + + test("mixed NULL and non-NULL: NULLs must not leak into MIN/MAX") { + // Every 3rd value is NULL. Aggregates must skip them (NULL-agnostic merge). + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CASE WHEN id % 3 = 0 THEN NULL ELSE CAST(id AS INT) END AS v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(winSpec(-4, 4)).as("mn"), + max($"v").over(winSpec(-4, 4)).as("mx"), + sum($"v").over(winSpec(-4, 4)).as("sm"), + count($"v").over(winSpec(-4, 4)).as("cn"))) + } + + test("Double NaN and +/-Infinity propagate correctly through MIN/MAX/SUM") { + // Spark's NaN ordering: NaN is treated as greater than +Inf for MIN/MAX. + // +Inf + -Inf = NaN for SUM. The seg-tree path uses DeclarativeAggregate's + // own merge, so behavior must match the baseline exactly. + val df = spark.range(0, 30).selectExpr( + "id", + "(id % 2) AS pk", + """CASE + WHEN id % 7 = 0 THEN double('NaN') + WHEN id % 7 = 1 THEN double('Infinity') + WHEN id % 7 = 2 THEN double('-Infinity') + ELSE CAST(id AS DOUBLE) + END AS v""") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(winSpec(-3, 3)).as("mn"), + max($"v").over(winSpec(-3, 3)).as("mx"), + sum($"v").over(winSpec(-3, 3)).as("sm"))) + } + + // ---------------- A4: data types ---------------- + + test("numeric types: Int / Long / Double / Decimal") { + val df = spark.range(0, 60).selectExpr( + "id", + "(id % 3) AS pk", + "CAST(id AS INT) AS vi", + "CAST(id * 1000000L AS LONG) AS vl", + "CAST(id AS DOUBLE) + 0.25 AS vd", + "CAST(id AS DECIMAL(20,4)) AS vdec") + checkEquivalence(() => + df.select($"id", $"pk", + sum($"vi").over(winSpec(-2, 2)).as("si"), + min($"vl").over(winSpec(-2, 2)).as("ml"), + max($"vd").over(winSpec(-2, 2)).as("xd"), + sum($"vdec").over(winSpec(-2, 2)).as("sdec"), + avg($"vdec").over(winSpec(-2, 2)).as("adec"))) + } + + test("String lexicographic MIN/MAX") { + // Deliberately non-monotone string values so that MIN/MAX actually + // exercise segment-tree merge rather than trivially matching the edge. + val df = spark.range(0, 40).selectExpr( + "id", + "(id % 2) AS pk", + "CONCAT('s', LPAD(CAST((id * 37) % 97 AS STRING), 3, '0')) AS v") + checkEquivalence(() => + df.select($"id", $"pk", + min($"v").over(winSpec(-3, 3)).as("mn"), + max($"v").over(winSpec(-3, 3)).as("mx"))) + } + + test("Date / Timestamp MIN/MAX") { + val df = spark.range(0, 40).selectExpr( + "id", + "(id % 2) AS pk", + "date_add(DATE'2020-01-01', CAST((id * 13) % 365 AS INT)) AS vd", + "CAST(TIMESTAMP'2020-01-01 00:00:00' + " + + "make_interval(0, 0, 0, 0, 0, 0, CAST(id AS DECIMAL(18,6))) AS TIMESTAMP) AS vt") + checkEquivalence(() => + df.select($"id", $"pk", + min($"vd").over(winSpec(-3, 3)).as("mnd"), + max($"vd").over(winSpec(-3, 3)).as("mxd"), + min($"vt").over(winSpec(-3, 3)).as("mnt"), + max($"vt").over(winSpec(-3, 3)).as("mxt"))) + } + + + test("collect_list falls back cleanly (non-DeclarativeAggregate)") { + // collect_list is a Collect(TypedImperativeAggregate) -- not a + // DeclarativeAggregate, so eligibleForSegTree must decline and the + // sliding frame must take over. + checkEquivalence(() => + baseDF.select($"id", $"pk", + collect_list($"v").over(winSpec(-2, 2)).as("lst"))) + } + + test("DISTINCT window aggregate is rejected by analyzer regardless of seg-tree flag") { + // Spark does not support DISTINCT window aggregates at all -- the analyzer + // throws DISTINCT_WINDOW_FUNCTION_UNSUPPORTED before we ever reach frame + // construction. The seg-tree feature flag must not alter this behavior. + def run(): Unit = { + baseDF.select($"id", $"pk", + count_distinct($"v").over(winSpec(-3, 3)).as("cd")).collect() + } + withSQLConf(disableSegTree.toSeq: _*) { + val e = intercept[org.apache.spark.sql.AnalysisException](run()) + assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED")) + } + withSQLConf(enableSegTree.toSeq: _*) { + val e = intercept[org.apache.spark.sql.AnalysisException](run()) + assert(e.getMessage.contains("DISTINCT_WINDOW_FUNCTION_UNSUPPORTED")) + } + } + + test("feature flag off: segmentTree.enabled=false yields baseline semantics") { + // Sanity check: disabling the flag on a workload the seg-tree path would + // otherwise handle (MIN, wide frame, partitions above the min-rows + // threshold) still produces the SlidingWindowFunctionFrame answer. + val df = baseDF + val expected = withSQLConf(disableSegTree.toSeq: _*) { + df.select($"id", $"pk", min($"v").over(winSpec(-3, 3)).as("mn")) + .collect().sortBy(_.toString).toSeq + } + // Explicit disable with the full-size partition config (no min-rows + // override). This exercises the flag-off branch of eligibleForSegTree. + withSQLConf( + SQLConf.WINDOW_SEGMENT_TREE_ENABLED.key -> "false", + SQLConf.WINDOW_SEGMENT_TREE_MIN_PARTITION_ROWS.key -> "1024") { + val actual = df.select($"id", $"pk", min($"v").over(winSpec(-3, 3)).as("mn")) + .collect().sortBy(_.toString).toSeq + assert(actual === expected) + } + } + + // A5.*: RANGE frame equivalence between seg-tree path and sliding baseline. + // + // The factory (`eligibleForSegTree`) now admits RangeFrame when orderSpec + // has exactly one ordering expression. These tests exercise the + // frameType-aware admit/drop loops in SegmentTreeWindowFunctionFrame. + // + // All tests follow the same oracle pattern as NULL/NaN/collation coverage: run the same + // query twice (seg-tree on / off) and assert equal Row sequences. + // Aggregates are MIN/MAX (non-invertible) to guarantee the seg-tree + // code path is actually exercised rather than short-circuited. + + /** Run `sql` twice (flag off / on) and checkAnswer equality. */ + private def checkRangeEquivalence(df: DataFrame, query: String): Unit = { + df.createOrReplaceTempView("t") + try { + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + spark.sql(query).collect().sortBy(_.toString) + } + withSQLConf(enableSegTree.toSeq: _*) { + val actual = spark.sql(query).collect().sortBy(_.toString) + assert(actual.toSeq === baseline.toSeq, + s"segment-tree output differs from baseline.\nExpected: ${baseline.toSeq}\n" + + s"Actual: ${actual.toSeq}") + } + } finally { + spark.catalog.dropTempView("t") + } + } + + test("-- RANGE INT offset basic (non-uniform gaps, MIN/MAX)") { + // Non-uniform k gaps (1, 3, 4, 4, 7, 10, 15, ...) so that the frame + // edges shift by variable amounts and admit/drop loops must consult + // the order-key comparator rather than just row count. + val df = spark.range(0, 40).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST(CASE CAST(id AS INT) % 7 " + + "WHEN 0 THEN 1 WHEN 1 THEN 3 WHEN 2 THEN 4 WHEN 3 THEN 4 " + + "WHEN 4 THEN 7 WHEN 5 THEN 10 ELSE 15 END + (CAST(id AS INT) / 7) * 20 AS INT) AS k", + "CAST((id * 31) % 97 AS INT) AS v") + checkRangeEquivalence(df, + """SELECT id, pk, + | MIN(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS mn, + | MAX(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("-- RANGE Timestamp with INTERVAL offset (MAX)") { + // Irregular gaps: 30min / 90min / 2h / 30min / ... so frame edges + // crossing the 1-hour bound must rely on the timestamp comparator. + val df = spark.range(0, 30).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) % 2) AS pk", + "CAST(TIMESTAMP'2024-01-01 10:00:00' + " + + "make_interval(0, 0, 0, 0, 0, 30 * CAST(id AS INT) * " + + "(CASE CAST(id AS INT) % 3 WHEN 0 THEN 1 WHEN 1 THEN 3 ELSE 4 END), 0) " + + "AS TIMESTAMP) AS ts", + "CAST((id * 17) % 53 AS INT) AS v") + checkRangeEquivalence(df, + """SELECT id, pk, + | MAX(v) OVER (PARTITION BY pk ORDER BY ts + | RANGE BETWEEN INTERVAL '1' HOUR PRECEDING + | AND INTERVAL '1' HOUR FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("-- RANGE with tie (duplicate order keys) inclusion at boundary") { + // k = [1, 2, 2, 2, 3, 4, 5] repeated across partitions. A frame of + // `0 PRECEDING AND 0 FOLLOWING` must include the FULL tie group at + // the current row's k, not just the current row itself. If the + // seg-tree path confuses RANGE with ROWS the tie group of k=2 would + // return a per-row MIN/MAX rather than a group-level one. + val rows = (0 until 40).map { i => + val k = Seq(1, 2, 2, 2, 3, 4, 5)(i % 7) + (i, i % 2, k, (i * 13) % 41) + } + val df = rows.toDF("id", "pk", "k", "v") + checkRangeEquivalence(df, + """SELECT id, pk, k, + | MIN(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 0 PRECEDING AND 0 FOLLOWING) AS mn, + | MAX(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 0 PRECEDING AND 0 FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("-- RANGE frame wider than partition (C4: admit/drop loops no-op)") { + // Partition size 5 rows, frame covers everything. Once the first + // batch is admitted, the admit/drop loops in the seg-tree frame must + // detect the effective frame is unchanged and skip work. + val df = spark.range(0, 25).selectExpr( + "CAST(id AS INT) AS id", + "(CAST(id AS INT) / 5) AS pk", + "CAST((id * 7) % 23 AS INT) AS k", + "CAST((id * 19) % 101 AS INT) AS v") + checkRangeEquivalence(df, + """SELECT id, pk, + | MIN(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 100 PRECEDING AND 100 FOLLOWING) AS mn, + | MAX(v) OVER (PARTITION BY pk ORDER BY k + | RANGE BETWEEN 100 PRECEDING AND 100 FOLLOWING) AS mx + |FROM t""".stripMargin) + } + + test("-- RANGE with NULL order key (NULLS FIRST / NULLS LAST)") { + // k = [NULL, NULL, 1, 2, 3, NULL] repeated. Spark groups all NULLs + // into a single equivalence class at head (NULLS FIRST) or tail + // (NULLS LAST). The seg-tree admit/drop loops must treat NULL as a + // tie group identically to the sliding baseline. + val rows = (0 until 36).map { i => + val kOpt: Option[Int] = (i % 6) match { + case 0 | 1 | 5 => None + case 2 => Some(1) + case 3 => Some(2) + case _ => Some(3) + } + (i, i % 2, kOpt, (i * 11) % 37) + } + val df = rows.toDF("id", "pk", "k", "v") + checkRangeEquivalence(df, + """SELECT id, pk, + | MIN(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS FIRST + | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mn_nf, + | MAX(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS FIRST + | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mx_nf, + | MIN(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS LAST + | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mn_nl, + | MAX(v) OVER (PARTITION BY pk ORDER BY k ASC NULLS LAST + | RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS mx_nl + |FROM t""".stripMargin) + } + + // Decimal overflow across seg-tree block merge + // BinaryType MIN/MAX across seg-tree block merge + // UDAF (ScalaUDAF / ScalaAggregator) fallback + // + // All Decimal/Binary tests force blockSize=16 (the minimum the config + // allows) + window size > blockSize so the seg-tree path crosses at + // least one block boundary (exercises `mergeExpressions` rather than + // only `update`). Design doc asked for blockSize=4 but the SQLConf + // validator rejects anything below 16; scaling data/frame up preserves + // the merge-path coverage the doc intended. + + private val segTreeBlock: String = "16" + private val segTreeFramePrec: Int = 17 + private val segTreeRows: Int = 20 + + private def withSegTreeBlock(conf: (String, String)*)(body: => Unit): Unit = { + val extra = Seq(SQLConf.WINDOW_SEGMENT_TREE_BLOCK_SIZE.key -> segTreeBlock) ++ conf + withSQLConf(extra: _*)(body) + } + + + /** + * 20 rows in a single partition, Decimal(38, 0) values near the type's + * upper bound. Frame of `segTreeFramePrec` PRECEDING..CURRENT ROW means + * any window with >= 2 such rows overflows the widened Sum buffer + * (Decimal(38,0) widened to Decimal(38,0); any addition above 1e38 + * overflows). Block size 16 + frame 17 forces cross-block merge. + */ + private def decimalOverflowDF: DataFrame = { + // 9e37 -- below Decimal(38,0) MAX (~9.99e37), but 2x overflows. + val big = "90000000000000000000000000000000000000" // 38 digits + spark.range(0, segTreeRows.toLong).selectExpr( + "CAST(id AS INT) AS id", + "0 AS pk", + s"CAST('$big' AS DECIMAL(38, 0)) AS v") + } + + private val decimalOverflowSql: String = + s"""SELECT id, pk, + | SUM(v) OVER (PARTITION BY pk ORDER BY id + | ROWS BETWEEN $segTreeFramePrec PRECEDING AND CURRENT ROW) AS s + |FROM t""".stripMargin + + test("a -- Decimal overflow ANSI on, seg-tree matches sliding (both throw)") { + val df = decimalOverflowDF + df.createOrReplaceTempView("t") + try { + withSegTreeBlock(SQLConf.ANSI_ENABLED.key -> "true") { + withSQLConf(disableSegTree.toSeq: _*) { + val e = intercept[Exception] { + spark.sql(decimalOverflowSql).collect() + } + assert(rootArithmeticCause(e).isDefined, + s"expected ArithmeticException root cause, got: ${e.getMessage}") + } + withSQLConf(enableSegTree.toSeq: _*) { + val e = intercept[Exception] { + spark.sql(decimalOverflowSql).collect() + } + assert(rootArithmeticCause(e).isDefined, + s"expected ArithmeticException root cause, got: ${e.getMessage}") + } + } + } finally { + spark.catalog.dropTempView("t") + } + } + + test("b -- Decimal overflow ANSI off, seg-tree matches sliding (NULL on overflow)") { + val df = decimalOverflowDF + df.createOrReplaceTempView("t") + try { + withSegTreeBlock(SQLConf.ANSI_ENABLED.key -> "false") { + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + spark.sql(decimalOverflowSql).collect().sortBy(_.toString) + } + // At least one row must be NULL so we know overflow actually fired. + assert(baseline.exists(_.isNullAt(2)), + "baseline should contain NULL overflow rows; test data may be too small") + withSQLConf(enableSegTree.toSeq: _*) { + val actual = spark.sql(decimalOverflowSql).collect().sortBy(_.toString) + assert(actual.toSeq === baseline.toSeq) + } + } + } finally { + spark.catalog.dropTempView("t") + } + } + + test("c -- mid-window Decimal overflow slides past (seg-tree == sliding)") { + // 24 rows, blockSize=16. Frame ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + // (size 4). Near-MAX values at ids 14,15,16,17 so any 4-row window + // containing >=2 of them overflows. Windows at: + // id<14 or id>20 -> safe + // id in [15..20] -> overlaps >=2 big values -> NULL + // The big-value band (14..17) straddles the block boundary at id=16, + // guaranteeing cross-block merge paths see overflowing buffers. + val big = "90000000000000000000000000000000000000" + val df = spark.range(0, 24).selectExpr( + "CAST(id AS INT) AS id", + "0 AS pk", + s"""CASE WHEN id IN (14, 15, 16, 17) + THEN CAST('$big' AS DECIMAL(38, 0)) + ELSE CAST(id AS DECIMAL(38, 0)) + END AS v""") + df.createOrReplaceTempView("t") + try { + val sqlStr = + """SELECT id, pk, + | SUM(v) OVER (PARTITION BY pk ORDER BY id + | ROWS BETWEEN 3 PRECEDING AND CURRENT ROW) AS s + |FROM t""".stripMargin + withSegTreeBlock(SQLConf.ANSI_ENABLED.key -> "false") { + val baseline = withSQLConf(disableSegTree.toSeq: _*) { + spark.sql(sqlStr).collect().sortBy(_.toString) + } + // Sanity: overflow fired on some rows AND window slides past to + // recover non-NULL on the tail rows. + assert(baseline.exists(_.isNullAt(2)), + "baseline should contain NULL overflow rows") + assert(baseline.exists(r => r.getInt(0) >= 21 && !r.isNullAt(2)), + "rows with id>=21 should be non-NULL (window slid past big values)") + withSQLConf(enableSegTree.toSeq: _*) { + val actual = spark.sql(sqlStr).collect().sortBy(_.toString) + assert(actual.toSeq === baseline.toSeq) + } + } + } finally { + spark.catalog.dropTempView("t") + } + } + + /** Walk a SparkException chain for an ArithmeticException (ANSI overflow). */ + private def rootArithmeticCause(t: Throwable): Option[Throwable] = { + var cur: Throwable = t + while (cur != null) { + if (cur.isInstanceOf[ArithmeticException]) return Some(cur) + cur = cur.getCause + } + None + } Review Comment: Addressed in `6c433a68d78` (R3-T5-exceptionutils). Thank you for the advice! ########## sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala: ########## @@ -181,7 +181,12 @@ private[window] final class AggregateProcessor( } /** Evaluate buffer. */ - def evaluate(target: InternalRow): Unit = { - evaluateProjection.target(target)(buffer) + def evaluate(target: InternalRow): Unit = evaluate(buffer, target) + + /** Evaluate using an arbitrary `source` buffer (e.g. a segment-tree query + * result) instead of the internal one. See + * `docs/frame-integration-contract.md` Section 3. */ + private[window] def evaluate(source: InternalRow, target: InternalRow): Unit = { Review Comment: Addressed in `639e04e7101` (R3-scaladoc-contract). Thank you for the advice! ########## sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowSegmentTree.scala: ########## @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import java.util.{LinkedHashMap => JLinkedHashMap, Map => JMap} + +import scala.collection.mutable + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.ArrayImplicits._ + +/** Review Comment: Addressed in `639e04e7101` (R3-scaladoc-contract). Thank you for the advice! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
