This is an automated email from the ASF dual-hosted git repository.

biyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 50a1c88051 [spark] Extract postCommit for both v1 and V2 write (#5705)
50a1c88051 is described below

commit 50a1c88051bf07171d692087b3bf933a9eb70db1
Author: Zouxxyy <[email protected]>
AuthorDate: Fri Jun 6 21:36:17 2025 +0800

    [spark] Extract postCommit for both v1 and V2 write (#5705)
---
 .../catalyst/analysis/ReplacePaimonFunctions.scala |   2 +-
 .../paimon/spark/commands/PaimonSparkWriter.scala  |  56 +---------
 .../spark/commands/WriteIntoPaimonTable.scala      |  38 +------
 .../apache/paimon/spark/write/PaimonV2Write.scala  |  64 +++++++-----
 .../apache/paimon/spark/write/WriteHelper.scala    | 114 +++++++++++++++++++++
 .../spark/sql/InsertOverwriteTableTestBase.scala   |   2 +-
 .../paimon/spark/sql/PaimonTagDdlTestBase.scala    |  44 ++++----
 7 files changed, 183 insertions(+), 137 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/ReplacePaimonFunctions.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/ReplacePaimonFunctions.scala
index d3650d27f8..3437966953 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/ReplacePaimonFunctions.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/ReplacePaimonFunctions.scala
@@ -30,7 +30,7 @@ import 
org.apache.spark.sql.connector.catalog.PaimonCatalogImplicits._
 import org.apache.spark.sql.types.StringType
 import org.apache.spark.unsafe.types.UTF8String
 
-import scala.jdk.CollectionConverters._
+import scala.collection.JavaConverters._
 
 /** A rule to replace Paimon functions with literal values. */
 case class ReplacePaimonFunctions(spark: SparkSession) extends 
Rule[LogicalPlan] {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
index 33332f6e0a..bd6ff63c2e 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
@@ -18,7 +18,6 @@
 
 package org.apache.paimon.spark.commands
 
-import org.apache.paimon.CoreOptions
 import org.apache.paimon.CoreOptions.WRITE_ONLY
 import org.apache.paimon.codegen.CodeGenUtils
 import org.apache.paimon.crosspartition.{IndexBootstrap, KeyPartOrRow}
@@ -33,24 +32,24 @@ import 
org.apache.paimon.spark.catalog.functions.BucketFunction
 import org.apache.paimon.spark.schema.SparkSystemColumns.{BUCKET_COL, 
ROW_KIND_COL}
 import org.apache.paimon.spark.util.OptionUtils.paimonExtensionEnabled
 import org.apache.paimon.spark.util.SparkRowUtils
+import org.apache.paimon.spark.write.WriteHelper
 import org.apache.paimon.table.BucketMode._
 import org.apache.paimon.table.FileStoreTable
 import org.apache.paimon.table.sink._
 import org.apache.paimon.types.{RowKind, RowType}
-import org.apache.paimon.utils.{InternalRowPartitionComputer, 
PartitionPathUtils, PartitionStatisticsReporter, SerializationUtils}
+import org.apache.paimon.utils.SerializationUtils
 
 import org.apache.spark.{Partitioner, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions._
-import org.slf4j.LoggerFactory
 
 import java.io.IOException
 import java.util.Collections.singletonMap
 
 import scala.collection.JavaConverters._
 
-case class PaimonSparkWriter(table: FileStoreTable) {
+case class PaimonSparkWriter(table: FileStoreTable) extends WriteHelper {
 
   private lazy val tableSchema = table.schema
 
@@ -58,18 +57,6 @@ case class PaimonSparkWriter(table: FileStoreTable) {
 
   private lazy val bucketMode = table.bucketMode
 
-  private lazy val coreOptions = table.coreOptions()
-
-  private lazy val disableReportStats = {
-    val config = coreOptions.toConfiguration
-    config.get(CoreOptions.PARTITION_IDLE_TIME_TO_REPORT_STATISTIC).toMillis 
<= 0 ||
-    table.partitionKeys.isEmpty ||
-    !coreOptions.partitionedTableInMetastore ||
-    table.catalogEnvironment.partitionHandler() == null
-  }
-
-  private lazy val log = LoggerFactory.getLogger(classOf[PaimonSparkWriter])
-
   @transient private lazy val serializer = new CommitMessageSerializer
 
   val writeBuilder: BatchWriteBuilder = table.newBatchWriteBuilder()
@@ -336,40 +323,6 @@ case class PaimonSparkWriter(table: FileStoreTable) {
       .map(deserializeCommitMessage(serializer, _))
   }
 
-  private def reportToHms(messages: Seq[CommitMessage]): Unit = {
-    if (disableReportStats) {
-      return
-    }
-
-    val partitionComputer = new InternalRowPartitionComputer(
-      coreOptions.partitionDefaultName,
-      table.schema.logicalPartitionType,
-      table.partitionKeys.toArray(new Array[String](0)),
-      coreOptions.legacyPartitionName()
-    )
-    val hmsReporter = new PartitionStatisticsReporter(
-      table,
-      table.catalogEnvironment.partitionHandler()
-    )
-
-    val partitions = messages.map(_.partition()).distinct
-    val currentTime = System.currentTimeMillis()
-    try {
-      partitions.foreach {
-        partition =>
-          val partitionPath = PartitionPathUtils.generatePartitionPath(
-            partitionComputer.generatePartValues(partition))
-          hmsReporter.report(partitionPath, currentTime)
-      }
-    } catch {
-      case e: Throwable =>
-        log.warn("Failed to report to hms", e)
-
-    } finally {
-      hmsReporter.close()
-    }
-  }
-
   def commit(commitMessages: Seq[CommitMessage]): Unit = {
     val tableCommit = writeBuilder.newCommit()
     try {
@@ -379,8 +332,7 @@ case class PaimonSparkWriter(table: FileStoreTable) {
     } finally {
       tableCommit.close()
     }
-
-    reportToHms(commitMessages)
+    postCommit(commitMessages)
   }
 
   /** Boostrap and repartition for cross partition mode. */
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
index d526ac2c3b..c56000ff99 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
@@ -18,16 +18,11 @@
 
 package org.apache.paimon.spark.commands
 
-import org.apache.paimon.CoreOptions
-import org.apache.paimon.CoreOptions.{DYNAMIC_PARTITION_OVERWRITE, 
TagCreationMode}
+import org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE
 import org.apache.paimon.options.Options
-import org.apache.paimon.partition.actions.PartitionMarkDoneAction
 import org.apache.paimon.spark._
 import org.apache.paimon.spark.schema.SparkSystemColumns
 import org.apache.paimon.table.FileStoreTable
-import org.apache.paimon.table.sink.CommitMessage
-import org.apache.paimon.tag.TagBatchCreation
-import org.apache.paimon.utils.{InternalRowPartitionComputer, 
PartitionPathUtils, TypeUtils}
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, PaimonUtils, Row, SparkSession}
@@ -84,39 +79,9 @@ case class WriteIntoPaimonTable(
     val commitMessages = writer.write(data)
     writer.commit(commitMessages)
 
-    preFinish(commitMessages)
     Seq.empty
   }
 
-  private def preFinish(commitMessages: Seq[CommitMessage]): Unit = {
-    if (table.coreOptions().tagCreationMode() == TagCreationMode.BATCH) {
-      val tagCreation = new TagBatchCreation(table)
-      tagCreation.createTag()
-    }
-    markDoneIfNeeded(commitMessages)
-  }
-
-  private def markDoneIfNeeded(commitMessages: Seq[CommitMessage]): Unit = {
-    val coreOptions = table.coreOptions()
-    if 
(coreOptions.toConfiguration.get(CoreOptions.PARTITION_MARK_DONE_WHEN_END_INPUT))
 {
-      val actions =
-        PartitionMarkDoneAction.createActions(getClass.getClassLoader, table, 
table.coreOptions())
-      val partitionComputer = new InternalRowPartitionComputer(
-        coreOptions.partitionDefaultName,
-        TypeUtils.project(table.rowType(), table.partitionKeys()),
-        table.partitionKeys().asScala.toArray,
-        coreOptions.legacyPartitionName()
-      )
-      val partitions = commitMessages
-        .map(c => c.partition())
-        .distinct
-        .map(p => 
PartitionPathUtils.generatePartitionPath(partitionComputer.generatePartValues(p)))
-      for (partition <- partitions) {
-        actions.forEach(a => a.markDone(partition))
-      }
-    }
-  }
-
   private def parseSaveMode(): (Boolean, Map[String, String]) = {
     var dynamicPartitionOverwriteMode = false
     val overwritePartition = saveMode match {
@@ -140,5 +105,4 @@ case class WriteIntoPaimonTable(
 
   override def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): 
LogicalPlan =
     this.asInstanceOf[WriteIntoPaimonTable]
-
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
index a97cccfbbc..5c325c26fd 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.StructType
 
 import java.io.{IOException, UncheckedIOException}
 
-import scala.jdk.CollectionConverters._
+import scala.collection.JavaConverters._
 import scala.util.{Failure, Success, Try}
 
 class PaimonV2Write(
@@ -52,12 +52,6 @@ class PaimonV2Write(
     storeTable.copy(
       Map(CoreOptions.DYNAMIC_PARTITION_OVERWRITE.key -> 
overwriteDynamic.toString).asJava)
 
-  private val batchWriteBuilder = {
-    val builder = table.newBatchWriteBuilder()
-    overwritePartitions.foreach(partitions => 
builder.withOverwrite(partitions.asJava))
-    builder
-  }
-
   private val writeRequirement = PaimonWriteRequirement(table)
 
   override def requiredDistribution(): Distribution = {
@@ -72,7 +66,7 @@ class PaimonV2Write(
     ordering
   }
 
-  override def toBatch: BatchWrite = new PaimonBatchWrite
+  override def toBatch: BatchWrite = PaimonBatchWrite(table, writeSchema, 
overwritePartitions)
 
   override def toString: String = {
     val overwriteDynamicStr = if (overwriteDynamic) {
@@ -87,35 +81,51 @@ class PaimonV2Write(
     }
     
s"PaimonWrite(table=${table.fullName()}$overwriteDynamicStr$overwritePartitionsStr)"
   }
+}
 
-  private class PaimonBatchWrite extends BatchWrite {
-    override def createBatchWriterFactory(info: PhysicalWriteInfo): 
DataWriterFactory =
-      WriterFactory(writeSchema, batchWriteBuilder)
+private case class PaimonBatchWrite(
+    table: FileStoreTable,
+    writeSchema: StructType,
+    overwritePartitions: Option[Map[String, String]])
+  extends BatchWrite
+  with WriteHelper {
 
-    override def useCommitCoordinator(): Boolean = false
+  private val batchWriteBuilder = {
+    val builder = table.newBatchWriteBuilder()
+    overwritePartitions.foreach(partitions => 
builder.withOverwrite(partitions.asJava))
+    builder
+  }
+
+  override def createBatchWriterFactory(info: PhysicalWriteInfo): 
DataWriterFactory =
+    WriterFactory(writeSchema, batchWriteBuilder)
 
-    override def commit(messages: Array[WriterCommitMessage]): Unit = {
-      logInfo(s"Committing to table ${table.name()}")
-      val batchTableCommit = batchWriteBuilder.newCommit()
+  override def useCommitCoordinator(): Boolean = false
 
-      val commitMessages = messages.collect {
+  override def commit(messages: Array[WriterCommitMessage]): Unit = {
+    logInfo(s"Committing to table ${table.name()}")
+    val batchTableCommit = batchWriteBuilder.newCommit()
+
+    val commitMessages = messages
+      .collect {
         case taskCommit: TaskCommit => taskCommit.commitMessages()
         case other =>
           throw new IllegalArgumentException(s"${other.getClass.getName} is 
not supported")
-      }.flatten
-
-      try {
-        val start = System.currentTimeMillis()
-        batchTableCommit.commit(commitMessages.toList.asJava)
-        logInfo(s"Committed in ${System.currentTimeMillis() - start} ms")
-      } finally {
-        batchTableCommit.close()
       }
-    }
+      .flatten
+      .toSeq
 
-    override def abort(messages: Array[WriterCommitMessage]): Unit = {
-      // TODO clean uncommitted files
+    try {
+      val start = System.currentTimeMillis()
+      batchTableCommit.commit(commitMessages.asJava)
+      logInfo(s"Committed in ${System.currentTimeMillis() - start} ms")
+    } finally {
+      batchTableCommit.close()
     }
+    postCommit(commitMessages)
+  }
+
+  override def abort(messages: Array[WriterCommitMessage]): Unit = {
+    // TODO clean uncommitted files
   }
 }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/WriteHelper.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/WriteHelper.scala
new file mode 100644
index 0000000000..bc07a310c0
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/WriteHelper.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.write
+
+import org.apache.paimon.CoreOptions
+import org.apache.paimon.CoreOptions.TagCreationMode
+import org.apache.paimon.partition.actions.PartitionMarkDoneAction
+import org.apache.paimon.table.FileStoreTable
+import org.apache.paimon.table.sink.CommitMessage
+import org.apache.paimon.tag.TagBatchCreation
+import org.apache.paimon.utils.{InternalRowPartitionComputer, 
PartitionPathUtils, PartitionStatisticsReporter, TypeUtils}
+
+import org.apache.spark.internal.Logging
+
+import scala.collection.JavaConverters._
+
+trait WriteHelper extends Logging {
+
+  val table: FileStoreTable
+
+  lazy val coreOptions: CoreOptions = table.coreOptions()
+
+  def postCommit(messages: Seq[CommitMessage]): Unit = {
+    if (messages.isEmpty) {
+      return
+    }
+
+    reportToHms(messages)
+    batchCreateTag()
+    markDoneIfNeeded(messages)
+  }
+
+  private def reportToHms(messages: Seq[CommitMessage]): Unit = {
+    val config = coreOptions.toConfiguration
+    if (
+      config.get(CoreOptions.PARTITION_IDLE_TIME_TO_REPORT_STATISTIC).toMillis 
<= 0 ||
+      table.partitionKeys.isEmpty ||
+      !coreOptions.partitionedTableInMetastore ||
+      table.catalogEnvironment.partitionHandler() == null
+    ) {
+      return
+    }
+
+    val partitionComputer = new InternalRowPartitionComputer(
+      coreOptions.partitionDefaultName,
+      table.schema.logicalPartitionType,
+      table.partitionKeys.toArray(new Array[String](0)),
+      coreOptions.legacyPartitionName()
+    )
+    val hmsReporter = new PartitionStatisticsReporter(
+      table,
+      table.catalogEnvironment.partitionHandler()
+    )
+
+    val partitions = messages.map(_.partition()).distinct
+    val currentTime = System.currentTimeMillis()
+    try {
+      partitions.foreach {
+        partition =>
+          val partitionPath = PartitionPathUtils.generatePartitionPath(
+            partitionComputer.generatePartValues(partition))
+          hmsReporter.report(partitionPath, currentTime)
+      }
+    } catch {
+      case e: Throwable =>
+        logWarning("Failed to report to hms", e)
+    } finally {
+      hmsReporter.close()
+    }
+  }
+
+  private def batchCreateTag(): Unit = {
+    if (coreOptions.tagCreationMode() == TagCreationMode.BATCH) {
+      val tagCreation = new TagBatchCreation(table)
+      tagCreation.createTag()
+    }
+  }
+
+  private def markDoneIfNeeded(commitMessages: Seq[CommitMessage]): Unit = {
+    if 
(coreOptions.toConfiguration.get(CoreOptions.PARTITION_MARK_DONE_WHEN_END_INPUT))
 {
+      val actions =
+        PartitionMarkDoneAction.createActions(getClass.getClassLoader, table, 
coreOptions)
+      val partitionComputer = new InternalRowPartitionComputer(
+        coreOptions.partitionDefaultName,
+        TypeUtils.project(table.rowType(), table.partitionKeys()),
+        table.partitionKeys().asScala.toArray,
+        coreOptions.legacyPartitionName()
+      )
+      val partitions = commitMessages
+        .map(c => c.partition())
+        .distinct
+        .map(p => 
PartitionPathUtils.generatePartitionPath(partitionComputer.generatePartValues(p)))
+      for (partition <- partitions) {
+        actions.forEach(a => a.markDone(partition))
+      }
+    }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
index e6e17d4848..82fb1aed0a 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
@@ -610,7 +610,7 @@ abstract class InsertOverwriteTableTestBase extends 
PaimonSparkTestBase {
   }
 
   test("Paimon Insert: dynamic insert into table with partition columns 
contain primary key") {
-    withSQLConf("spark.sql.shuffle.partitions" -> "10") {
+    withSparkSQLConf("spark.sql.shuffle.partitions" -> "10") {
       withTable("pk_pt") {
         sql("""
               |create table pk_pt (c1 int) partitioned by(p1 string, p2 string)
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonTagDdlTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonTagDdlTestBase.scala
index 7bbc7ace04..cd260df973 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonTagDdlTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonTagDdlTestBase.scala
@@ -188,25 +188,31 @@ abstract class PaimonTagDdlTestBase extends 
PaimonSparkTestBase {
   }
 
   test("Tag expiration: batch write expire tag") {
-    spark.sql("""CREATE TABLE T (id INT, name STRING)
-                |USING PAIMON
-                |TBLPROPERTIES (
-                |'file.format' = 'avro',
-                |'tag.automatic-creation'='batch',
-                |'tag.num-retained-max'='1')""".stripMargin)
-
-    val table = loadTable("T")
-
-    withSparkSQLConf("spark.paimon.tag.batch.customized-name" -> 
"batch-tag-1") {
-      spark.sql("insert into T values(1, 'a')")
-      assertResult(1)(table.tagManager().tagObjects().size())
-      
assertResult("batch-tag-1")(loadTable("T").tagManager().tagObjects().get(0).getRight)
-    }
-
-    withSparkSQLConf("spark.paimon.tag.batch.customized-name" -> 
"batch-tag-2") {
-      spark.sql("insert into T values(2, 'b')")
-      assertResult(1)(table.tagManager().tagObjects().size())
-      
assertResult("batch-tag-2")(loadTable("T").tagManager().tagObjects().get(0).getRight)
+    for (useV2Write <- Seq("true", "false")) {
+      withSparkSQLConf("spark.paimon.write.use-v2-write" -> useV2Write) {
+        withTable("T") {
+          spark.sql("""CREATE TABLE T (id INT, name STRING)
+                      |USING PAIMON
+                      |TBLPROPERTIES (
+                      |'file.format' = 'avro',
+                      |'tag.automatic-creation'='batch',
+                      |'tag.num-retained-max'='1')""".stripMargin)
+
+          val table = loadTable("T")
+
+          withSparkSQLConf("spark.paimon.tag.batch.customized-name" -> 
"batch-tag-1") {
+            spark.sql("insert into T values(1, 'a')")
+            assertResult(1)(table.tagManager().tagObjects().size())
+            
assertResult("batch-tag-1")(loadTable("T").tagManager().tagObjects().get(0).getRight)
+          }
+
+          withSparkSQLConf("spark.paimon.tag.batch.customized-name" -> 
"batch-tag-2") {
+            spark.sql("insert into T values(2, 'b')")
+            assertResult(1)(table.tagManager().tagObjects().size())
+            
assertResult("batch-tag-2")(loadTable("T").tagManager().tagObjects().get(0).getRight)
+          }
+        }
+      }
     }
   }
 }

Reply via email to