This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 3fefeeb910 [spark] Merge into supports _ROW_ID shortcut (#6745)
3fefeeb910 is described below

commit 3fefeeb9103689be674058b7222243845c5c5300
Author: Weitai Li <[email protected]>
AuthorDate: Thu Dec 4 21:43:21 2025 +0800

    [spark] Merge into supports _ROW_ID shortcut (#6745)
---
 .../MergeIntoPaimonDataEvolutionTable.scala        | 62 ++++++++++++++++++----
 .../paimon/spark/sql/RowTrackingTestBase.scala     | 56 +++++++++++++++++++
 2 files changed, 107 insertions(+), 11 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
index 768e244a38..e3c2d3ead0 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
@@ -32,7 +32,7 @@ import org.apache.paimon.table.source.DataSplit
 import org.apache.spark.sql.{Dataset, Row, SparkSession}
 import org.apache.spark.sql.PaimonUtils._
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.resolver
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
Expression, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
EqualTo, Expression, Literal}
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftOuter}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -148,18 +148,21 @@ case class MergeIntoPaimonDataEvolutionTable(
   }
 
   private def targetRelatedSplits(sparkSession: SparkSession): Seq[DataSplit] 
= {
-    val targetDss = createDataset(
-      sparkSession,
-      targetRelation
-    )
     val sourceDss = createDataset(sparkSession, sourceRelation)
 
-    val firstRowIdsTouched = mutable.Set.empty[Long]
-
-    firstRowIdsTouched ++= findRelatedFirstRowIds(
-      targetDss.alias("_left").join(sourceDss, toColumn(matchedCondition), 
"inner"),
-      sparkSession,
-      "_left." + ROW_ID_NAME)
+    val firstRowIdsTouched = extractSourceRowIdMapping match {
+      case Some(sourceRowIdAttr) =>
+        // Shortcut: Directly get _FIRST_ROW_IDs from the source table.
+        findRelatedFirstRowIds(sourceDss, sparkSession, 
sourceRowIdAttr.name).toSet
+
+      case None =>
+        // Perform the full join to find related _FIRST_ROW_IDs.
+        val targetDss = createDataset(sparkSession, targetRelation)
+        findRelatedFirstRowIds(
+          targetDss.alias("_left").join(sourceDss, toColumn(matchedCondition), 
"inner"),
+          sparkSession,
+          "_left." + ROW_ID_NAME).toSet
+    }
 
     table
       .newSnapshotReader()
@@ -312,6 +315,43 @@ case class MergeIntoPaimonDataEvolutionTable(
     writer.write(toWrite)
   }
 
+  /**
+   * Attempts to identify a direct mapping from sourceTable's attribute to the 
target table's
+   * `_ROW_ID`.
+   *
+   * This is a shortcut optimization for `MERGE INTO` to avoid a full, 
expensive join when the merge
+   * condition is a simple equality on the target's `_ROW_ID`.
+   *
+   * @return
+   *   An `Option` containing the sourceTable's attribute if a pattern like
+   *   `target._ROW_ID = source.col` (or its reverse) is found, otherwise 
`None`.
+   */
+  private def extractSourceRowIdMapping: Option[AttributeReference] = {
+
+    // Helper to check if an attribute is the target's _ROW_ID
+    def isTargetRowId(attr: AttributeReference): Boolean = {
+      attr.name == ROW_ID_NAME && (targetRelation.output ++ 
targetRelation.metadataOutput)
+        .exists(_.exprId.equals(attr.exprId))
+    }
+
+    // Helper to check if an attribute belongs to the source table
+    def isSourceAttribute(attr: AttributeReference): Boolean = {
+      (sourceRelation.output ++ 
sourceRelation.metadataOutput).exists(_.exprId.equals(attr.exprId))
+    }
+
+    matchedCondition match {
+      // Case 1: target._ROW_ID = source.col
+      case EqualTo(left: AttributeReference, right: AttributeReference)
+          if isTargetRowId(left) && isSourceAttribute(right) =>
+        Some(right)
+      // Case 2: source.col = target._ROW_ID
+      case EqualTo(left: AttributeReference, right: AttributeReference)
+          if isSourceAttribute(left) && isTargetRowId(right) =>
+        Some(left)
+      case _ => None
+    }
+  }
+
   private def findRelatedFirstRowIds(
       dataset: Dataset[Row],
       sparkSession: SparkSession,
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
index 5fb3de0e07..aab69e9ea1 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
@@ -22,6 +22,13 @@ import org.apache.paimon.Snapshot.CommitKind
 import org.apache.paimon.spark.PaimonSparkTestBase
 
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.joins.BaseJoinExec
+import org.apache.spark.sql.util.QueryExecutionListener
+
+import scala.collection.mutable
 
 abstract class RowTrackingTestBase extends PaimonSparkTestBase {
 
@@ -360,6 +367,55 @@ abstract class RowTrackingTestBase extends 
PaimonSparkTestBase {
     }
   }
 
+  test("Data Evolution: merge into table with data-evolution with _ROW_ID 
shortcut") {
+    withTable("source", "target") {
+      sql("CREATE TABLE source (target_ROW_ID BIGINT, b INT, c STRING)")
+      sql(
+        "INSERT INTO source VALUES (0, 100, 'c11'), (2, 300, 'c33'), (4, 500, 
'c55'), (6, 700, 'c77'), (8, 900, 'c99')")
+
+      sql(
+        "CREATE TABLE target (a INT, b INT, c STRING) TBLPROPERTIES 
('row-tracking.enabled' = 'true', 'data-evolution.enabled' = 'true')")
+      sql(
+        "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 
'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
+
+      val capturedPlans: mutable.ListBuffer[LogicalPlan] = 
mutable.ListBuffer.empty
+      val listener = new QueryExecutionListener {
+        override def onSuccess(funcName: String, qe: QueryExecution, 
durationNs: Long): Unit = {
+          capturedPlans += qe.analyzed
+        }
+        override def onFailure(funcName: String, qe: QueryExecution, 
exception: Exception): Unit = {
+          capturedPlans += qe.analyzed
+        }
+      }
+      spark.listenerManager.register(listener)
+      sql(s"""
+             |MERGE INTO target
+             |USING source
+             |ON target._ROW_ID = source.target_ROW_ID
+             |WHEN MATCHED AND target.a = 5 THEN UPDATE SET b = source.b + 
target.b
+             |WHEN MATCHED AND source.c > 'c2' THEN UPDATE SET b = source.b, c 
= source.c
+             |WHEN NOT MATCHED AND c > 'c9' THEN INSERT (a, b, c) VALUES 
(target_ROW_ID, b * 1.1, c)
+             |WHEN NOT MATCHED THEN INSERT (a, b, c) VALUES (target_ROW_ID, b, 
c)
+             |""".stripMargin)
+      // Assert that no Join operator was used during
+      // 
`org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.targetRelatedSplits`
+      assert(capturedPlans.head.collect { case plan: Join => plan }.isEmpty)
+      spark.listenerManager.unregister(listener)
+
+      checkAnswer(
+        sql("SELECT *, _ROW_ID, _SEQUENCE_NUMBER FROM target ORDER BY a"),
+        Seq(
+          Row(1, 10, "c1", 0, 2),
+          Row(2, 20, "c2", 1, 2),
+          Row(3, 300, "c33", 2, 2),
+          Row(4, 40, "c4", 3, 2),
+          Row(5, 550, "c5", 4, 2),
+          Row(6, 700, "c77", 5, 2),
+          Row(8, 990, "c99", 6, 2))
+      )
+    }
+  }
+
   test("Data Evolution: update table throws exception") {
     withTable("t") {
       sql(

Reply via email to