cloud-fan commented on code in PR #55912: URL: https://github.com/apache/spark/pull/55912#discussion_r3285347188
########## sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala: ########## @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. For each equi-key group, all right rows with that + * key are also buffered in memory; skewed equi-key groups can OOM. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryExecNode { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, + "number of output rows")) + + override def output: Seq[Attribute] = joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _: InnerLike => + left.output ++ right.output + case other => + throw SparkException.internalError( + s"$nodeName does not support join type: $other") + } + + override def outputOrdering: Seq[SortOrder] = { + // Output preserves left-side ordering (equi-keys + as-of key) + left.outputOrdering + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (leftKeys.isEmpty) { + AllTuples :: AllTuples :: Nil + } else { + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + val leftOrdering = leftKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(leftAsOfExpr, Ascending) + val rightOrdering = rightKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(rightAsOfExpr, Ascending) + leftOrdering :: rightOrdering :: Nil + } + + override def outputPartitioning: Partitioning = left.outputPartitioning + + // Determine scan direction based on the order expression (distance metric). + // This is a performance heuristic only -- if it misclassifies, the scan + // still produces the correct result; only the early-termination shortcut + // is lost. + // + // orderExpression is direction-unique by construction: + // Backward: Subtract(leftAsOf, rightAsOf) -> right-to-left + // Forward: Subtract(rightAsOf, leftAsOf) -> left-to-right + // Nearest: If(...) -> left-to-right + private val scanRightToLeft: Boolean = orderExpression match { + case Subtract(l, _, _) if l.semanticEquals(leftAsOfExpr) => true + case _ => false + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val scanFromRight = scanRightToLeft + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val scanner = new SortMergeAsOfJoinScanner( + leftIter, + rightIter, + left.output, + right.output, + leftKeys, + rightKeys, + asOfCondition, + orderExpression, + joinType, + condition, + numOutputRows, + scanFromRight + ) + // Register cleanup to release the right-side buffer on task completion + TaskContext.get().addTaskCompletionListener[Unit](_ => scanner.close()) + scanner.iterator + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SortMergeAsOfJoinExec = { + copy(left = newLeft, right = newRight) + } +} + +/** + * Performs the sort-merge AS-OF join scan. + * + * Both inputs are sorted by (equi-keys, as-of key) ascending. For each + * left row within an equi-key group, we find the right row that satisfies + * the as-of condition and minimizes the order expression (distance). + * + * Since the right side is sorted by as-of key within each group, for + * backward joins we scan right-to-left and stop at the first match + * (exploiting sort order for early termination). + */ +private[joins] class SortMergeAsOfJoinScanner( + leftIter: Iterator[InternalRow], + rightIter: Iterator[InternalRow], + leftOutput: Seq[Attribute], + rightOutput: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + residualCondition: Option[Expression], + numOutputRows: SQLMetric, + scanRightToLeft: Boolean) { + + private val joinedOutput = leftOutput ++ rightOutput + private val joinedRow = new JoinedRow() + private val resultProjection = + UnsafeProjection.create(joinedOutput, joinedOutput) + + // Bound expressions for evaluating conditions on joined rows + private val boundAsOfCond = bindReference(asOfCondition, joinedOutput) + private val boundOrderExpr = bindReference(orderExpression, joinedOutput) + private val boundResidualCond = + residualCondition.map(bindReference(_, joinedOutput)) + + // Key ordering for equi-join keys + private val equiKeyOrdering: Option[BaseOrdering] = + if (leftKeys.nonEmpty) { + val keyAttributes = leftKeys.zipWithIndex.map { case (key, i) => + AttributeReference(s"key_$i", key.dataType, key.nullable)() + } + Some(GenerateOrdering.generate( + keyAttributes.map(SortOrder(_, Ascending)), keyAttributes)) + } else { + None + } + + // Projections to extract equi-keys for comparison + private val leftKeyProj = UnsafeProjection.create(leftKeys, leftOutput) + private val rightKeyProj = UnsafeProjection.create(rightKeys, rightOutput) + + // Ordering for the distance metric + private val distanceOrdering = + TypeUtils.getInterpretedOrdering(orderExpression.dataType) + + // Null row for LeftOuter when no match is found + private val nullRightRow = new GenericInternalRow(rightOutput.length) + + // Right-side buffer: holds right rows for the current equi-key group. + // Rows are sorted by as-of key ascending (guaranteed by requiredChildOrdering). + private val rightGroupBuffer = new ArrayBuffer[InternalRow]() + private var rightGroupKey: UnsafeRow = _ + private var rightPeek: InternalRow = _ + private var rightDone: Boolean = !rightIter.hasNext + + // Initialize: read first right row + if (!rightDone) { + rightPeek = rightIter.next().copy() + } + + /** Release resources held by this scanner. */ + def close(): Unit = { + rightGroupBuffer.clear() + rightGroupBuffer.trimToSize() + } + + def iterator: Iterator[InternalRow] = new Iterator[InternalRow] { + private var nextRow: InternalRow = _ + private val leftIterBuffered = leftIter.buffered + + override def hasNext: Boolean = { + if (nextRow != null) return true + nextRow = findNext() + nextRow != null + } + + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException + val result = nextRow + nextRow = null + result + } + + private def findNext(): InternalRow = { + while (leftIterBuffered.hasNext) { + val leftRow = leftIterBuffered.next() + val leftKey = leftKeyProj(leftRow).copy() + + // Advance right side to the matching equi-key group + advanceRightTo(leftKey) + + // Search for best match exploiting sort order + val bestMatch = findBestInGroup(leftRow) + + if (bestMatch != null) { + numOutputRows += 1 + joinedRow.withLeft(leftRow).withRight(bestMatch) + return resultProjection(joinedRow).copy() + } else if (joinType == LeftOuter) { + numOutputRows += 1 + joinedRow.withLeft(leftRow).withRight(nullRightRow) + return resultProjection(joinedRow).copy() + } + // Inner join: no match, skip + } + null + } + } + + /** + * Advance the right side so that rightGroupBuffer contains all right + * rows whose equi-key matches `leftKey`. + */ + private def advanceRightTo(leftKey: UnsafeRow): Unit = { + equiKeyOrdering match { + case None => + // No equi-keys: buffer all right rows once. + // WARNING: This loads the entire right partition into memory. + if (rightGroupBuffer.isEmpty && !rightDone) { + bufferAllRight() + } + case Some(ordering) => + // Check if current buffer already matches + if (rightGroupKey != null && + ordering.compare(leftKey, rightGroupKey) == 0) { + return + } + + // Skip right rows with keys < leftKey + while (!rightDone && rightPeek != null) { + val rightKey = rightKeyProj(rightPeek) + val cmp = ordering.compare(leftKey, rightKey) + if (cmp > 0) { + rightPeek = if (rightIter.hasNext) { + rightIter.next().copy() + } else { + rightDone = true; null + } + } else if (cmp == 0) { + bufferRightGroup(leftKey, ordering) + return + } else { + rightGroupBuffer.clear() + rightGroupKey = null + return + } + } + rightGroupBuffer.clear() + rightGroupKey = null + } + } + + /** Buffer all right rows with the same equi-key as leftKey. */ + private def bufferRightGroup( + leftKey: UnsafeRow, ordering: BaseOrdering): Unit = { + rightGroupBuffer.clear() + rightGroupKey = leftKey.copy() + + while (!rightDone && rightPeek != null) { + val rightKey = rightKeyProj(rightPeek) + if (ordering.compare(leftKey, rightKey) == 0) { + rightGroupBuffer += rightPeek + rightPeek = if (rightIter.hasNext) { + rightIter.next().copy() + } else { + rightDone = true; null + } + } else { + return + } + } + } + + /** Buffer all remaining right rows (no equi-keys case). */ + private def bufferAllRight(): Unit = { + rightGroupBuffer.clear() + if (rightPeek != null) { + rightGroupBuffer += rightPeek + rightPeek = null + } + while (rightIter.hasNext) { + rightGroupBuffer += rightIter.next().copy() + } + rightDone = true + } + + /** + * Find the best matching right row for the given left row within the + * current group buffer. + * + * The buffer is sorted by as-of key ascending. The scan direction is + * chosen based on where the best match is expected: + * - Backward (left >= right): best match near the end -> right-to-left + * - Forward (left <= right): best match near the start -> left-to-right + * - Nearest: full scan needed (left-to-right, stop when distance + * increases after finding a match) + */ + private def findBestInGroup(leftRow: InternalRow): InternalRow = { + if (scanRightToLeft) { + findBestRightToLeft(leftRow) + } else { + findBestLeftToRight(leftRow) + } + } + + /** Scan from end to start (optimal for Backward joins). */ + private def findBestRightToLeft(leftRow: InternalRow): InternalRow = { + var bestMatch: InternalRow = null + var bestDistance: Any = null + + var i = rightGroupBuffer.size - 1 + while (i >= 0) { + val rightRow = rightGroupBuffer(i) + joinedRow.withLeft(leftRow).withRight(rightRow) + + val asOfSatisfied = boundAsOfCond.eval(joinedRow) + if (asOfSatisfied != null && asOfSatisfied.asInstanceOf[Boolean]) { + val residualSatisfied = boundResidualCond.forall { cond => + val result = cond.eval(joinedRow) + result != null && result.asInstanceOf[Boolean] + } + if (residualSatisfied) { + val distance = boundOrderExpr.eval(joinedRow) + if (distance != null) { + if (bestMatch == null) { + bestMatch = rightRow + bestDistance = distance + } else if (distanceOrdering.lt(distance, bestDistance)) { + bestMatch = rightRow + bestDistance = distance + } else { + return bestMatch + } + } + } + } else if (bestMatch != null) { + return bestMatch + } + i -= 1 + } + bestMatch + } + + /** Scan from start to end (optimal for Forward/Nearest joins). */ + private def findBestLeftToRight(leftRow: InternalRow): InternalRow = { + var bestMatch: InternalRow = null + var bestDistance: Any = null + + var i = 0 + while (i < rightGroupBuffer.size) { + val rightRow = rightGroupBuffer(i) + joinedRow.withLeft(leftRow).withRight(rightRow) + + val asOfSatisfied = boundAsOfCond.eval(joinedRow) + if (asOfSatisfied != null && asOfSatisfied.asInstanceOf[Boolean]) { + val residualSatisfied = boundResidualCond.forall { cond => + val result = cond.eval(joinedRow) + result != null && result.asInstanceOf[Boolean] + } + if (residualSatisfied) { + val distance = boundOrderExpr.eval(joinedRow) + if (distance != null) { + if (bestMatch == null) { + bestMatch = rightRow + bestDistance = distance + } else if (distanceOrdering.lt(distance, bestDistance)) { + bestMatch = rightRow + bestDistance = distance + } else { + return bestMatch + } + } + } + } else if (bestMatch != null) { Review Comment: **Correctness bug for `direction=nearest, allowExactMatches=false`.** The early-termination at `else if (bestMatch != null) return bestMatch` assumes the as-of-false zone is at the trailing end of the scan. That's true for Backward (right-to-left scan) and Forward (left-to-right scan, with or without tolerance) -- the false zone is always at a boundary. But Nearest + `!allowExactMatches` has the as-of condition `Not(EqualTo(leftAsOf, rightAsOf))`, which is false only at the single interior point `right == left`, with valid matches on both sides. **Counterexample:** - `left.ts = [10]` - `right.ts = [1, 10, 11]` - `direction = nearest, allowExactMatches = false` Scan left-to-right: - `r=1`: asOf true (1 != 10), distance=9 → `bestMatch=1` - `r=10`: asOf false (10 == 10) → returns `bestMatch=1` - `r=11`: never visited; distance would have been 1 (correct answer) The existing `nearest join - allowExactMatches = false` test at `SortMergeAsOfJoinSuite.scala:517` passes only because its right-side data doesn't include a value past the equal point that's closer than the best-so-far from below. **Fix:** pass direction (or a flag) into the scanner and skip this branch for Nearest. The distance-based termination (`distanceOrdering.lt` else return) is sufficient on its own because the `|left - right|` distance is V-shaped -- once it starts increasing past the minimum, no later row can beat it. Regression test to add: a Nearest + `!allowExactMatches` case where right values bracket the left key, e.g. `left=10, right=[1, 10, 11]`, expecting the row at `11`. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala: ########## @@ -178,6 +178,149 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. */ + /** + * Plans AS-OF joins using a dedicated sort-merge operator when the + * conf is enabled. + */ + object AsOfJoinSelection extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case j @ AsOfJoin(left, right, asOfCondition, condition, joinType, + orderExpression, _) if conf.sortMergeAsOfJoinEnabled => + val (leftKeys, rightKeys, residual) = condition match { + case Some(cond) => extractEquiJoinKeys(cond, left, right) + case None => (Seq.empty[Expression], Seq.empty[Expression], None) + } + val (leftAsOf, rightAsOf) = extractAsOfExprs( + asOfCondition, orderExpression, left, right) + + joins.SortMergeAsOfJoinExec( + leftKeys, rightKeys, leftAsOf, rightAsOf, + asOfCondition, orderExpression, joinType, residual, + planLater(left), planLater(right)) :: Nil + case _ => Nil + } + + /** + * Extract equi-join key pairs and residual (non-equi) condition + * from a conjunction. Only EqualTo is treated as equi-key; + * EqualNullSafe is excluded because the Scanner does not implement + * null-safe comparison semantics. + */ + private def extractEquiJoinKeys( + condition: Expression, + left: LogicalPlan, + right: LogicalPlan): (Seq[Expression], Seq[Expression], Option[Expression]) = { + val leftKeys = + new scala.collection.mutable.ArrayBuffer[Expression]() + val rightKeys = + new scala.collection.mutable.ArrayBuffer[Expression]() + val residuals = + new scala.collection.mutable.ArrayBuffer[Expression]() + + flattenAnd(condition).foreach { + case EqualTo(l, r) + if l.references.subsetOf(left.outputSet) && + r.references.subsetOf(right.outputSet) => + leftKeys += l; rightKeys += r + case EqualTo(l, r) + if r.references.subsetOf(left.outputSet) && + l.references.subsetOf(right.outputSet) => + leftKeys += r; rightKeys += l + case other => + residuals += other + } + val residual = residuals.reduceOption(And) + (leftKeys.toSeq, rightKeys.toSeq, residual) + } + + private def flattenAnd(expr: Expression): Seq[Expression] = expr match { + case And(l, r) => flattenAnd(l) ++ flattenAnd(r) + case other => Seq(other) + } + + private def extractAsOfExprs( + asOfCondition: Expression, + orderExpression: Expression, + left: LogicalPlan, + right: LogicalPlan): (Expression, Expression) = { + val leftAttrs = left.outputSet + val rightAttrs = right.outputSet + + def find(expr: Expression): Option[(Expression, Expression)] = expr match { + case GreaterThanOrEqual(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case GreaterThan(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case LessThanOrEqual(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case LessThan(l, r) + if l.references.subsetOf(leftAttrs) && r.references.subsetOf(rightAttrs) => + Some((l, r)) + case GreaterThanOrEqual(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case GreaterThan(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case LessThanOrEqual(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case LessThan(l, r) + if l.references.subsetOf(rightAttrs) && r.references.subsetOf(leftAttrs) => + Some((r, l)) + case And(l, r) => find(l).orElse(find(r)) + case _ => None + } + + find(asOfCondition).orElse { Review Comment: **Wrong `leftAsOfExpr` extracted for Nearest + tolerance.** For `Nearest + allowExactMatches=true + tolerance`, the analyzer emits (`basicLogicalOperators.scala:2417-2419`): ``` And(GreaterThanOrEqual(rightAsOf, Subtract(leftAsOf, tolerance)), LessThanOrEqual(rightAsOf, Add(leftAsOf, tolerance))) ``` `find(asOfCondition)` walks into the first conjunct `GTE(rightAsOf, Subtract(leftAsOf, tolerance))` and matches the 5th case (`GTE` with `l ⊆ rightAttrs && r ⊆ leftAttrs`), returning `(r, l) = (Subtract(leftAsOf, tolerance), rightAsOf)`. So `leftAsOfExpr` becomes `Subtract(leftAsOf, tolerance)` instead of `leftAsOf`. The `Nearest + !allowExactMatches + tolerance` case has the same misextraction via the `GT` 6th case. **Effect:** `requiredChildOrdering` requires the left child sorted by `Subtract(leftAsOf, tolerance)` ascending. The row order is the same (tolerance is constant), so results are correct, but `EnsureRequirements` can't recognize that an already-`leftAsOf`-sorted child satisfies the requirement and inserts a redundant sort. **Suggested simplification:** always extract via `findFromOrder(orderExpression)` since `orderExpression` is direction-unique by construction (`basicLogicalOperators.scala:2433-2438`) -- `Subtract(leftAsOf, rightAsOf)` for Backward, `Subtract(rightAsOf, leftAsOf)` for Forward, `If(_, Subtract, _)` for Nearest. `findFromOrder` already returns `(leftAsOf, rightAsOf)` cleanly for all six direction × tolerance × `allowExactMatches` combinations. That single change lets you delete the `find(asOfCondition)` walk *and* the `Last resort` fallback -- the planner no longer needs to interpret the analyzer's as-of-condition shape at all. **Even cleaner long-term:** add `leftAsOf: Expression, rightAsOf: Expression` as fields on the `AsOfJoin` logical node (the analyzer has them at construction time, see `Dataset.scala:846-887`). Then the planner doesn't reverse-engineer anything. ########## sql/core/src/test/scala/org/apache/spark/sql/SortMergeAsOfJoinSuite.scala: ########## @@ -0,0 +1,542 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.joins.SortMergeAsOfJoinExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class SortMergeAsOfJoinSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key, "true") + } + + override def afterAll(): Unit = { + spark.conf.unset(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key) + super.afterAll() + } + + def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = { + val schema1 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil) + val rowSeq1: List[Row] = List( + Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c")) + val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) + + val schema2 = StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil) + val rowSeq2: List[Row] = List( + Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3), + Row(6, "y", 6), Row(7, "z", 7)) + val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) + + (df1, df2) + } + + test("uses SortMergeAsOfJoinExec physical operator") { + val (df1, df2) = prepareForAsOfJoin() + val result = df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward") + val plan = result.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case _: SortMergeAsOfJoinExec => true + }.nonEmpty, s"Expected SortMergeAsOfJoinExec in plan:\n$plan") + } + + test("backward join - simple") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - usingColumns") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - left outer") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("forward join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) // inner join: no match for 10 + ) + } + + test("nearest join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "nearest"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - tolerance = 1") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", + tolerance = functions.lit(1), + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) + ) + } + + test("backward join - allowExactMatches = false") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = false, direction = "backward"), + Seq( + // left.a=1: no right row with a < 1 → no match + // left.a=5: right.a=3 (3 < 5) → match + Row(5, "y", "b", 3, "x", 3), + // left.a=10: right.a=7 (7 < 10) → match + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("empty left side") { + val (_, df2) = prepareForAsOfJoin() + val emptyDf = spark.createDataFrame( + java.util.Collections.emptyList[Row](), + StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil)) + checkAnswer( + emptyDf.joinAsOf( + df2, emptyDf.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq.empty + ) + } + + test("empty right side") { + val (df1, _) = prepareForAsOfJoin() + val emptyDf = spark.createDataFrame( + java.util.Collections.emptyList[Row](), + StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil)) + // Inner join: no matches possible + checkAnswer( + df1.joinAsOf( + emptyDf, df1.col("a"), emptyDf.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq.empty + ) + // Left outer: all left rows with null right + checkAnswer( + df1.joinAsOf( + emptyDf, df1.col("a"), emptyDf.col("a"), usingColumns = Seq.empty, + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", null, null, null) + ) + ) + } + + test("null as-of keys") { + val schema1 = StructType( + StructField("a", IntegerType, true) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("a", IntegerType, true) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(null, "x"), Row(3, "y"), Row(7, "z")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(1, "a"), Row(null, "b"), Row(5, "c")).asJava, schema2) + // Null as-of keys should not match anything (as-of condition + // evaluates to null for null inputs) + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(null, "x", null, null), + Row(3, "y", 1, "a"), + Row(7, "z", 5, "c") + ) + ) + } + + test("multiple rows with same equi-key") { + val schema1 = StructType( + StructField("grp", StringType) :: + StructField("ts", IntegerType) :: Nil) + val schema2 = StructType( + StructField("grp", StringType) :: + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List( + Row("A", 5), Row("A", 10), Row("A", 15), + Row("B", 3), Row("B", 8) + ).asJava, schema1) + val df2 = spark.createDataFrame( + List( + Row("A", 2, "a1"), Row("A", 7, "a2"), Row("A", 12, "a3"), + Row("B", 1, "b1"), Row("B", 6, "b2"), Row("B", 10, "b3") + ).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq("grp"), + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row("A", 5, "A", 2, "a1"), + Row("A", 10, "A", 7, "a2"), + Row("A", 15, "A", 12, "a3"), + Row("B", 3, "B", 1, "b1"), + Row("B", 8, "B", 6, "b2") + ) + ) + } + + test("long type as-of key") { + val schema1 = StructType( + StructField("ts", LongType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", LongType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(100L, "a"), Row(200L, "b"), Row(300L, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(50L, "x"), Row(150L, "y"), Row(250L, "z")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(100L, "a", 50L, "x"), + Row(200L, "b", 150L, "y"), + Row(300L, "c", 250L, "z") + ) + ) + } + + test("double type as-of key") { + val schema1 = StructType( + StructField("ts", DoubleType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", DoubleType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(1.5, "a"), Row(3.0, "b"), Row(5.5, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(1.0, "x"), Row(2.5, "y"), Row(4.0, "z")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1.5, "a", 1.0, "x"), + Row(3.0, "b", 2.5, "y"), + Row(5.5, "c", 4.0, "z") + ) + ) + } + + test("conf disabled falls back to correlated subquery rewrite") { + val (df1, df2) = prepareForAsOfJoin() + withSQLConf(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "false") { + val result = df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward") + val plan = result.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case _: SortMergeAsOfJoinExec => true + }.isEmpty, "Should NOT use SortMergeAsOfJoinExec when conf is disabled") + // Results should still be correct + checkAnswer(result, Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + )) + } + } + + test("self join") { + val schema = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df = spark.createDataFrame( + List(Row(1, "a"), Row(3, "b"), Row(5, "c")).asJava, schema) + checkAnswer( + df.joinAsOf( + df, df.col("ts"), df.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "a", 1, "a"), + Row(3, "b", 3, "b"), + Row(5, "c", 5, "c") + ) + ) + } + + test("no equi-key - all rows in single partition") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(2, "a"), Row(5, "b"), Row(9, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(1, "x"), Row(4, "y"), Row(7, "z")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(2, "a", 1, "x"), + Row(5, "b", 4, "y"), + Row(9, "c", 7, "z") + ) + ) + } + + test("forward join - left outer with no match") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(1, "a"), Row(5, "b"), Row(10, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(3, "x"), Row(7, "y")).asJava, schema2) + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "a", 3, "x"), + Row(5, "b", 7, "y"), + Row(10, "c", null, null) // no right row >= 10 + ) + ) + } + + test("forward join - tolerance") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(1, "a"), Row(5, "b"), Row(10, "c")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(2, "x"), Row(7, "y"), Row(15, "z")).asJava, schema2) + // tolerance = 3: only match if right.ts <= left.ts + 3 + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", + tolerance = functions.lit(3), + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "a", 2, "x"), // 2 <= 1+3=4, match + Row(5, "b", 7, "y") // 7 <= 5+3=8, match + // 10: right.ts=15, 15 > 10+3=13, no match + ) + ) + } + + test("nearest join - tolerance") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(10, "a"), Row(20, "b")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(5, "x"), Row(12, "y"), Row(25, "z")).asJava, schema2) + // tolerance = 3: only match if |left.ts - right.ts| <= 3 + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "leftouter", + tolerance = functions.lit(3), + allowExactMatches = true, direction = "nearest"), + Seq( + Row(10, "a", 12, "y"), // |10-12|=2 <= 3, match + Row(20, "b", null, null) // |20-25|=5 > 3, no match + ) + ) + } + + test("nearest join - equidistant right rows") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(10, "a")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(8, "x"), Row(12, "y")).asJava, schema2) + // Both are distance 2 from left.ts=10. The scan is left-to-right + // (Nearest direction), so the first match (ts=8) wins when distances + // are equal (distanceOrdering.lt is strict). + val result = df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", + tolerance = functions.lit(5), + allowExactMatches = true, direction = "nearest") + // Verify we get exactly one row (tie-breaking is deterministic) + assert(result.count() == 1) + val row = result.collect().head + assert(row.getInt(0) == 10) + // The tie-breaker picks the first encountered in scan order + assert(row.getInt(2) == 8 || row.getInt(2) == 12) + } + + test("forward join - allowExactMatches = false") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = false, direction = "forward"), + Seq( + // left.a=1: right.a must be > 1 -> right.a=2 + Row(1, "x", "a", 2, "w", 2), + // left.a=5: right.a must be > 5 -> right.a=6 + Row(5, "y", "b", 6, "y", 6), + // left.a=10: no right.a > 10 + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) + ) + } + + test("nearest join - allowExactMatches = false") { + val schema1 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val schema2 = StructType( + StructField("ts", IntegerType) :: + StructField("val", StringType) :: Nil) + val df1 = spark.createDataFrame( + List(Row(5, "a"), Row(10, "b")).asJava, schema1) + val df2 = spark.createDataFrame( + List(Row(5, "x"), Row(8, "y"), Row(10, "z")).asJava, schema2) + // allowExactMatches=false: exact matches excluded + checkAnswer( + df1.joinAsOf( + df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = false, direction = "nearest"), + Seq( + // left.ts=5: exclude right.ts=5, nearest is 8 (distance 3) + Row(5, "a", 8, "y"), + // left.ts=10: exclude right.ts=10, nearest is 8 (distance 2) + Row(10, "b", 8, "y") + ) + ) + } Review Comment: Suggest adding a regression test for the Nearest + `!allowExactMatches` bug noted on `SortMergeAsOfJoinExec.scala:425`. Something like: ```scala test("nearest join - allowExactMatches = false with right rows on both sides") { val schema = StructType( StructField("ts", IntegerType) :: StructField("val", StringType) :: Nil) val df1 = spark.createDataFrame(List(Row(10, "a")).asJava, schema) val df2 = spark.createDataFrame( List(Row(1, "x"), Row(10, "y"), Row(11, "z")).asJava, schema) checkAnswer( df1.joinAsOf(df2, df1.col("ts"), df2.col("ts"), usingColumns = Seq.empty, joinType = "inner", tolerance = null, allowExactMatches = false, direction = "nearest"), Seq(Row(10, "a", 11, "z")) // distance 1, not Row(10, "a", 1, "x") with distance 9 ) } ``` This currently fails -- the operator returns `Row(10, "a", 1, "x")`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
