This is an automated email from the ASF dual-hosted git repository.

junhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bacfdbf71 [python] Fix tryCommit failed. (#6969)
2bacfdbf71 is described below

commit 2bacfdbf712827c1630b399db7b5f470211aed2d
Author: umi <[email protected]>
AuthorDate: Thu Jan 8 12:10:49 2026 +0800

    [python] Fix tryCommit failed. (#6969)
---
 paimon-python/pypaimon/common/file_io.py           |  18 +-
 .../pypaimon/common/options/core_options.py        |  48 +++
 .../pypaimon/common/options/options_utils.py       |  13 +
 paimon-python/pypaimon/common/time_utils.py        |  81 +++++
 .../pypaimon/snapshot/snapshot_manager.py          |  47 ++-
 paimon-python/pypaimon/tests/blob_table_test.py    | 135 ++++++++
 .../pypaimon/tests/reader_append_only_test.py      | 106 +++++++
 .../pypaimon/tests/reader_primary_key_test.py      | 104 +++++++
 .../pypaimon/tests/schema_evolution_read_test.py   |   1 +
 paimon-python/pypaimon/write/file_store_commit.py  | 343 ++++++++++++++++-----
 10 files changed, 810 insertions(+), 86 deletions(-)

diff --git a/paimon-python/pypaimon/common/file_io.py 
b/paimon-python/pypaimon/common/file_io.py
index 497d711a49..2ec1909306 100644
--- a/paimon-python/pypaimon/common/file_io.py
+++ b/paimon-python/pypaimon/common/file_io.py
@@ -18,13 +18,15 @@
 import logging
 import os
 import subprocess
+import threading
+import uuid
 from pathlib import Path
 from typing import Any, Dict, List, Optional
 from urllib.parse import splitport, urlparse
 
 import pyarrow
 from packaging.version import parse
-from pyarrow._fs import FileSystem
+from pyarrow._fs import FileSystem, LocalFileSystem
 
 from pypaimon.common.options import Options
 from pypaimon.common.options.config import OssOptions, S3Options
@@ -37,6 +39,8 @@ from pypaimon.write.blob_format_writer import BlobFormatWriter
 
 
 class FileIO:
+    rename_lock = threading.Lock()
+
     def __init__(self, path: str, catalog_options: Options):
         self.properties = catalog_options
         self.logger = logging.getLogger(__name__)
@@ -251,7 +255,15 @@ class FileIO:
                 self.mkdirs(str(dst_parent))
 
             src_str = self.to_filesystem_path(src)
-            self.filesystem.move(src_str, dst_str)
+            if isinstance(self.filesystem, LocalFileSystem):
+                if self.exists(dst):
+                    return False
+                with FileIO.rename_lock:
+                    if self.exists(dst):
+                        return False
+                    self.filesystem.move(src_str, dst_str)
+            else:
+                self.filesystem.move(src_str, dst_str)
             return True
         except Exception as e:
             self.logger.warning(f"Failed to rename {src} to {dst}: {e}")
@@ -303,7 +315,7 @@ class FileIO:
             return input_stream.read().decode('utf-8')
 
     def try_to_write_atomic(self, path: str, content: str) -> bool:
-        temp_path = path + ".tmp"
+        temp_path = path + str(uuid.uuid4()) + ".tmp"
         success = False
         try:
             self.write_file(temp_path, content, False)
diff --git a/paimon-python/pypaimon/common/options/core_options.py 
b/paimon-python/pypaimon/common/options/core_options.py
index 4ab5a253d7..49230240a7 100644
--- a/paimon-python/pypaimon/common/options/core_options.py
+++ b/paimon-python/pypaimon/common/options/core_options.py
@@ -15,9 +15,12 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
################################################################################
+import sys
 from enum import Enum
 from typing import Dict
 
+from datetime import timedelta
+
 from pypaimon.common.memory_size import MemorySize
 from pypaimon.common.options import Options
 from pypaimon.common.options.config_options import ConfigOptions
@@ -239,6 +242,34 @@ class CoreOptions:
         .with_description("The prefix for commit user.")
     )
 
+    COMMIT_MAX_RETRIES: ConfigOption[int] = (
+        ConfigOptions.key("commit.max-retries")
+        .int_type()
+        .default_value(10)
+        .with_description("Maximum number of retries for commit operations.")
+    )
+
+    COMMIT_TIMEOUT: ConfigOption[timedelta] = (
+        ConfigOptions.key("commit.timeout")
+        .duration_type()
+        .no_default_value()
+        .with_description("Timeout for commit operations (e.g., '10s', '5m'). 
If not set, effectively unlimited.")
+    )
+
+    COMMIT_MIN_RETRY_WAIT: ConfigOption[timedelta] = (
+        ConfigOptions.key("commit.min-retry-wait")
+        .duration_type()
+        .default_value(timedelta(milliseconds=10))
+        .with_description("Minimum wait time between commit retries (e.g., 
'10ms', '100ms').")
+    )
+
+    COMMIT_MAX_RETRY_WAIT: ConfigOption[timedelta] = (
+        ConfigOptions.key("commit.max-retry-wait")
+        .duration_type()
+        .default_value(timedelta(seconds=10))
+        .with_description("Maximum wait time between commit retries (e.g., 
'1s', '10s').")
+    )
+
     ROW_TRACKING_ENABLED: ConfigOption[bool] = (
         ConfigOptions.key("row-tracking.enabled")
         .boolean_type()
@@ -390,3 +421,20 @@ class CoreOptions:
 
     def data_file_external_paths_specific_fs(self, default=None):
         return 
self.options.get(CoreOptions.DATA_FILE_EXTERNAL_PATHS_SPECIFIC_FS, default)
+
+    def commit_max_retries(self) -> int:
+        return self.options.get(CoreOptions.COMMIT_MAX_RETRIES)
+
+    def commit_timeout(self) -> int:
+        timeout = self.options.get(CoreOptions.COMMIT_TIMEOUT)
+        if timeout is None:
+            return sys.maxsize
+        return int(timeout.total_seconds() * 1000)
+
+    def commit_min_retry_wait(self) -> int:
+        wait = self.options.get(CoreOptions.COMMIT_MIN_RETRY_WAIT)
+        return int(wait.total_seconds() * 1000)
+
+    def commit_max_retry_wait(self) -> int:
+        wait = self.options.get(CoreOptions.COMMIT_MAX_RETRY_WAIT)
+        return int(wait.total_seconds() * 1000)
diff --git a/paimon-python/pypaimon/common/options/options_utils.py 
b/paimon-python/pypaimon/common/options/options_utils.py
index f48f549df4..9938e87e74 100644
--- a/paimon-python/pypaimon/common/options/options_utils.py
+++ b/paimon-python/pypaimon/common/options/options_utils.py
@@ -16,10 +16,12 @@ See the License for the specific language governing 
permissions and
 limitations under the License.
 """
 
+from datetime import timedelta
 from enum import Enum
 from typing import Any, Type
 
 from pypaimon.common.memory_size import MemorySize
+from pypaimon.common.time_utils import parse_duration
 
 
 class OptionsUtils:
@@ -63,6 +65,8 @@ class OptionsUtils:
             return OptionsUtils.convert_to_double(value)
         elif target_type == MemorySize:
             return OptionsUtils.convert_to_memory_size(value)
+        elif target_type == timedelta:
+            return OptionsUtils.convert_to_duration(value)
         else:
             raise ValueError(f"Unsupported type: {target_type}")
 
@@ -125,6 +129,15 @@ class OptionsUtils:
             return MemorySize.parse(value)
         raise ValueError(f"Cannot convert {type(value)} to MemorySize")
 
+    @staticmethod
+    def convert_to_duration(value: Any) -> timedelta:
+        if isinstance(value, timedelta):
+            return value
+        if isinstance(value, str):
+            milliseconds = parse_duration(value)
+            return timedelta(milliseconds=milliseconds)
+        raise ValueError(f"Cannot convert {type(value)} to timedelta")
+
     @staticmethod
     def convert_to_enum(value: Any, enum_class: Type[Enum]) -> Enum:
 
diff --git a/paimon-python/pypaimon/common/time_utils.py 
b/paimon-python/pypaimon/common/time_utils.py
new file mode 100644
index 0000000000..1a02bd4398
--- /dev/null
+++ b/paimon-python/pypaimon/common/time_utils.py
@@ -0,0 +1,81 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+################################################################################
+
+
+def parse_duration(text: str) -> int:
+    if text is None:
+        raise ValueError("text cannot be None")
+
+    trimmed = text.strip().lower()
+    if not trimmed:
+        raise ValueError("argument is an empty- or whitespace-only string")
+
+    pos = 0
+    while pos < len(trimmed) and trimmed[pos].isdigit():
+        pos += 1
+
+    number_str = trimmed[:pos]
+    unit_str = trimmed[pos:].strip()
+
+    if not number_str:
+        raise ValueError("text does not start with a number")
+
+    try:
+        value = int(number_str)
+    except ValueError:
+        raise ValueError(
+            f"The value '{number_str}' cannot be re represented as 64bit 
number (numeric overflow)."
+        )
+
+    if not unit_str:
+        result_ms = value
+    elif unit_str in ('ns', 'nano', 'nanosecond', 'nanoseconds'):
+        result_ms = value / 1_000_000
+    elif unit_str in ('µs', 'micro', 'microsecond', 'microseconds'):
+        result_ms = value / 1_000
+    elif unit_str in ('ms', 'milli', 'millisecond', 'milliseconds'):
+        result_ms = value
+    elif unit_str in ('s', 'sec', 'second', 'seconds'):
+        result_ms = value * 1000
+    elif unit_str in ('m', 'min', 'minute', 'minutes'):
+        result_ms = value * 60 * 1000
+    elif unit_str in ('h', 'hour', 'hours'):
+        result_ms = value * 60 * 60 * 1000
+    elif unit_str in ('d', 'day', 'days'):
+        result_ms = value * 24 * 60 * 60 * 1000
+    else:
+        supported_units = (
+            'DAYS: (d | day | days), '
+            'HOURS: (h | hour | hours), '
+            'MINUTES: (m | min | minute | minutes), '
+            'SECONDS: (s | sec | second | seconds), '
+            'MILLISECONDS: (ms | milli | millisecond | milliseconds), '
+            'MICROSECONDS: (µs | micro | microsecond | microseconds), '
+            'NANOSECONDS: (ns | nano | nanosecond | nanoseconds)'
+        )
+        raise ValueError(
+            f"Time interval unit label '{unit_str}' does not match any of the 
recognized units: "
+            f"{supported_units}"
+        )
+
+    result_ms_int = int(round(result_ms))
+
+    if result_ms_int < 0:
+        raise ValueError(f"Duration cannot be negative: {text}")
+
+    return result_ms_int
diff --git a/paimon-python/pypaimon/snapshot/snapshot_manager.py 
b/paimon-python/pypaimon/snapshot/snapshot_manager.py
index 0d96563057..8291d9cf2c 100644
--- a/paimon-python/pypaimon/snapshot/snapshot_manager.py
+++ b/paimon-python/pypaimon/snapshot/snapshot_manager.py
@@ -38,7 +38,7 @@ class SnapshotManager:
         if not self.file_io.exists(self.latest_file):
             return None
 
-        latest_content = self.file_io.read_file_utf8(self.latest_file)
+        latest_content = self.read_latest_file()
         latest_snapshot_id = int(latest_content.strip())
 
         snapshot_file = f"{self.snapshot_dir}/snapshot-{latest_snapshot_id}"
@@ -48,6 +48,51 @@ class SnapshotManager:
         snapshot_content = self.file_io.read_file_utf8(snapshot_file)
         return JSON.from_json(snapshot_content, Snapshot)
 
+    def read_latest_file(self, max_retries: int = 5):
+        """
+        Read the latest snapshot ID from LATEST file with retry mechanism.
+        If file doesn't exist or is empty after retries, scan snapshot 
directory for max ID.
+        """
+        import re
+        import time
+
+        # Try to read LATEST file with retries
+        for retry_count in range(max_retries):
+            try:
+                if self.file_io.exists(self.latest_file):
+                    content = self.file_io.read_file_utf8(self.latest_file)
+                    if content and content.strip():
+                        return content.strip()
+
+                # File doesn't exist or is empty, wait a bit before retry
+                if retry_count < max_retries - 1:
+                    time.sleep(0.001)
+
+            except Exception:
+                # On exception, wait and retry
+                if retry_count < max_retries - 1:
+                    time.sleep(0.001)
+
+        # List all files in snapshot directory
+        file_infos = self.file_io.list_status(self.snapshot_dir)
+
+        max_snapshot_id = None
+        snapshot_pattern = re.compile(r'^snapshot-(\d+)$')
+
+        for file_info in file_infos:
+            # Get filename from path
+            filename = file_info.path.split('/')[-1]
+            match = snapshot_pattern.match(filename)
+            if match:
+                snapshot_id = int(match.group(1))
+                if max_snapshot_id is None or snapshot_id > max_snapshot_id:
+                    max_snapshot_id = snapshot_id
+
+        if not max_snapshot_id:
+            raise RuntimeError(f"No snapshot content found in 
{self.snapshot_dir}")
+
+        return str(max_snapshot_id)
+
     def get_snapshot_path(self, snapshot_id: int) -> str:
         """
         Get the path for a snapshot file.
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py 
b/paimon-python/pypaimon/tests/blob_table_test.py
index de59d0398e..f87f73ded7 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -2567,6 +2567,141 @@ class DataBlobWriterTest(unittest.TestCase):
 
         self.assertEqual(actual, expected)
 
+    def test_concurrent_blob_writes_with_retry(self):
+        """Test concurrent blob writes to verify retry mechanism works 
correctly."""
+        import threading
+        from pypaimon import Schema
+        from pypaimon.snapshot.snapshot_manager import SnapshotManager
+
+        # Run the test 10 times to verify stability
+        iter_num = 2
+        for test_iteration in range(iter_num):
+            # Create a unique table for each iteration
+            table_name = f'test_db.blob_concurrent_writes_{test_iteration}'
+
+            # Create schema with blob column
+            pa_schema = pa.schema([
+                ('id', pa.int32()),
+                ('thread_id', pa.int32()),
+                ('metadata', pa.string()),
+                ('blob_data', pa.large_binary()),
+            ])
+
+            schema = Schema.from_pyarrow_schema(
+                pa_schema,
+                options={
+                    'row-tracking.enabled': 'true',
+                    'data-evolution.enabled': 'true'
+                }
+            )
+            self.catalog.create_table(table_name, schema, False)
+            table = self.catalog.get_table(table_name)
+
+            write_results = []
+            write_errors = []
+
+            # Create blob pattern for testing
+            blob_size = 5 * 1024  # 5KB
+            blob_pattern = b'BLOB_PATTERN_' + b'X' * 1024
+            pattern_size = len(blob_pattern)
+            repetitions = blob_size // pattern_size
+            base_blob_data = blob_pattern * repetitions
+
+            def write_blob_data(thread_id, start_id):
+                """Write blob data in a separate thread."""
+                try:
+                    threading.current_thread().name = 
f"Iter{test_iteration}-Thread-{thread_id}"
+                    write_builder = table.new_batch_write_builder()
+                    table_write = write_builder.new_write()
+                    table_commit = write_builder.new_commit()
+
+                    # Create unique blob data for this thread
+                    data = {
+                        'id': list(range(start_id, start_id + 5)),
+                        'thread_id': [thread_id] * 5,
+                        'metadata': [f'thread{thread_id}_blob_{i}' for i in 
range(5)],
+                        'blob_data': [i.to_bytes(2, byteorder='little') + 
base_blob_data for i in range(5)]
+                    }
+                    pa_table = pa.Table.from_pydict(data, schema=pa_schema)
+
+                    table_write.write_arrow(pa_table)
+                    commit_messages = table_write.prepare_commit()
+
+                    table_commit.commit(commit_messages)
+                    table_write.close()
+                    table_commit.close()
+
+                    write_results.append({
+                        'thread_id': thread_id,
+                        'start_id': start_id,
+                        'success': True
+                    })
+                except Exception as e:
+                    write_errors.append({
+                        'thread_id': thread_id,
+                        'error': str(e)
+                    })
+
+            # Create and start multiple threads
+            threads = []
+            num_threads = 100
+            for i in range(num_threads):
+                thread = threading.Thread(
+                    target=write_blob_data,
+                    args=(i, i * 10)
+                )
+                threads.append(thread)
+                thread.start()
+
+            # Wait for all threads to complete
+            for thread in threads:
+                thread.join()
+
+            # Verify all writes succeeded (retry mechanism should handle 
conflicts)
+            self.assertEqual(num_threads, len(write_results),
+                             f"Iteration {test_iteration}: Expected 
{num_threads} successful writes, "
+                             f"got {len(write_results)}. Errors: 
{write_errors}")
+            self.assertEqual(0, len(write_errors),
+                             f"Iteration {test_iteration}: Expected no errors, 
but got: {write_errors}")
+
+            read_builder = table.new_read_builder()
+            table_scan = read_builder.new_scan()
+            table_read = read_builder.new_read()
+            actual = 
table_read.to_arrow(table_scan.plan().splits()).sort_by('id')
+
+            # Verify data rows
+            self.assertEqual(num_threads * 5, actual.num_rows,
+                             f"Iteration {test_iteration}: Expected 
{num_threads * 5} rows")
+
+            # Verify id column
+            ids = actual.column('id').to_pylist()
+            expected_ids = []
+            for i in range(num_threads):
+                expected_ids.extend(range(i * 10, i * 10 + 5))
+            expected_ids.sort()
+
+            self.assertEqual(ids, expected_ids,
+                             f"Iteration {test_iteration}: IDs mismatch")
+
+            # Verify blob data integrity (spot check)
+            blob_data_list = actual.column('blob_data').to_pylist()
+            for i in range(0, len(blob_data_list), 100):  # Check every 100th 
blob
+                blob = blob_data_list[i]
+                self.assertGreater(len(blob), 2, f"Blob {i} should have data")
+                # Verify blob contains the pattern
+                self.assertIn(b'BLOB_PATTERN_', blob, f"Blob {i} should 
contain pattern")
+
+            # Verify snapshot count (should have num_threads snapshots)
+            snapshot_manager = SnapshotManager(table)
+            latest_snapshot = snapshot_manager.get_latest_snapshot()
+            self.assertIsNotNone(latest_snapshot,
+                                 f"Iteration {test_iteration}: Latest snapshot 
should not be None")
+            self.assertEqual(latest_snapshot.id, num_threads,
+                             f"Iteration {test_iteration}: Expected snapshot 
ID {num_threads}, "
+                             f"got {latest_snapshot.id}")
+
+            print(f"✓ Blob Table Iteration {test_iteration + 1}/{iter_num} 
completed successfully")
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py 
b/paimon-python/pypaimon/tests/reader_append_only_test.py
index 2661723919..b47f5d1f67 100644
--- a/paimon-python/pypaimon/tests/reader_append_only_test.py
+++ b/paimon-python/pypaimon/tests/reader_append_only_test.py
@@ -17,6 +17,7 @@
 
################################################################################
 
 import os
+import shutil
 import tempfile
 import time
 import unittest
@@ -53,6 +54,10 @@ class AoReaderTest(unittest.TestCase):
             'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p2'],
         }, schema=cls.pa_schema)
 
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
     def test_parquet_ao_reader(self):
         schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
         self.catalog.create_table('default.test_append_only_parquet', schema, 
False)
@@ -410,3 +415,104 @@ class AoReaderTest(unittest.TestCase):
         table_read = read_builder.new_read()
         splits = read_builder.new_scan().plan().splits()
         return table_read.to_arrow(splits)
+
+    def test_concurrent_writes_with_retry(self):
+        """Test concurrent writes to verify retry mechanism works correctly."""
+        import threading
+
+        # Run the test 10 times to verify stability
+        iter_num = 5
+        for test_iteration in range(iter_num):
+            # Create a unique table for each iteration
+            table_name = f'default.test_concurrent_writes_{test_iteration}'
+            schema = Schema.from_pyarrow_schema(self.pa_schema)
+            self.catalog.create_table(table_name, schema, False)
+            table = self.catalog.get_table(table_name)
+
+            write_results = []
+            write_errors = []
+
+            def write_data(thread_id, start_user_id):
+                """Write data in a separate thread."""
+                try:
+                    threading.current_thread().name = 
f"Iter{test_iteration}-Thread-{thread_id}"
+                    write_builder = table.new_batch_write_builder()
+                    table_write = write_builder.new_write()
+                    table_commit = write_builder.new_commit()
+
+                    # Create unique data for this thread
+                    data = {
+                        'user_id': list(range(start_user_id, start_user_id + 
5)),
+                        'item_id': [1000 + i for i in range(start_user_id, 
start_user_id + 5)],
+                        'behavior': [f'thread{thread_id}_{i}' for i in 
range(5)],
+                        'dt': ['p1' if i % 2 == 0 else 'p2' for i in range(5)],
+                    }
+                    pa_table = pa.Table.from_pydict(data, 
schema=self.pa_schema)
+
+                    table_write.write_arrow(pa_table)
+                    commit_messages = table_write.prepare_commit()
+
+                    table_commit.commit(commit_messages)
+                    table_write.close()
+                    table_commit.close()
+
+                    write_results.append({
+                        'thread_id': thread_id,
+                        'start_user_id': start_user_id,
+                        'success': True
+                    })
+                except Exception as e:
+                    write_errors.append({
+                        'thread_id': thread_id,
+                        'error': str(e)
+                    })
+
+            # Create and start multiple threads
+            threads = []
+            num_threads = 100
+            for i in range(num_threads):
+                thread = threading.Thread(
+                    target=write_data,
+                    args=(i, i * 10)
+                )
+                threads.append(thread)
+                thread.start()
+
+            # Wait for all threads to complete
+            for thread in threads:
+                thread.join()
+
+            # Verify all writes succeeded (retry mechanism should handle 
conflicts)
+            self.assertEqual(num_threads, len(write_results),
+                             f"Iteration {test_iteration}: Expected 
{num_threads} successful writes, "
+                             f"got {len(write_results)}. Errors: 
{write_errors}")
+            self.assertEqual(0, len(write_errors),
+                             f"Iteration {test_iteration}: Expected no errors, 
but got: {write_errors}")
+
+            read_builder = table.new_read_builder()
+            actual = self._read_test_table(read_builder).sort_by('user_id')
+
+            # Verify data rows
+            self.assertEqual(num_threads * 5, actual.num_rows,
+                             f"Iteration {test_iteration}: Expected 
{num_threads * 5} rows")
+
+            # Verify user_id
+            user_ids = actual.column('user_id').to_pylist()
+            expected_user_ids = []
+            for i in range(num_threads):
+                expected_user_ids.extend(range(i * 10, i * 10 + 5))
+            expected_user_ids.sort()
+
+            self.assertEqual(user_ids, expected_user_ids,
+                             f"Iteration {test_iteration}: User IDs mismatch")
+
+            # Verify snapshot count (should have num_threads snapshots)
+            snapshot_manager = SnapshotManager(table)
+            latest_snapshot = snapshot_manager.get_latest_snapshot()
+            self.assertIsNotNone(latest_snapshot,
+                                 f"Iteration {test_iteration}: Latest snapshot 
should not be None")
+            self.assertEqual(latest_snapshot.id, num_threads,
+                             f"Iteration {test_iteration}: Expected snapshot 
ID {num_threads}, "
+                             f"got {latest_snapshot.id}")
+
+            print(f"✓ Iteration {test_iteration + 1}/{iter_num} completed 
successfully")
diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py 
b/paimon-python/pypaimon/tests/reader_primary_key_test.py
index 7077b2fd44..731203385d 100644
--- a/paimon-python/pypaimon/tests/reader_primary_key_test.py
+++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py
@@ -422,3 +422,107 @@ class PkReaderTest(unittest.TestCase):
         table_read = read_builder.new_read()
         splits = read_builder.new_scan().plan().splits()
         return table_read.to_arrow(splits)
+
+    def test_concurrent_writes_with_retry(self):
+        """Test concurrent writes to verify retry mechanism works correctly 
for PK tables."""
+        import threading
+
+        # Run the test 3 times to verify stability
+        iter_num = 3
+        for test_iteration in range(iter_num):
+            # Create a unique table for each iteration
+            table_name = f'default.test_pk_concurrent_writes_{test_iteration}'
+            schema = Schema.from_pyarrow_schema(self.pa_schema,
+                                                partition_keys=['dt'],
+                                                primary_keys=['user_id', 'dt'],
+                                                options={'bucket': '2'})
+            self.catalog.create_table(table_name, schema, False)
+            table = self.catalog.get_table(table_name)
+
+            write_results = []
+            write_errors = []
+
+            def write_data(thread_id, start_user_id):
+                """Write data in a separate thread."""
+                try:
+                    threading.current_thread().name = 
f"Iter{test_iteration}-Thread-{thread_id}"
+                    write_builder = table.new_batch_write_builder()
+                    table_write = write_builder.new_write()
+                    table_commit = write_builder.new_commit()
+
+                    # Create unique data for this thread
+                    data = {
+                        'user_id': list(range(start_user_id, start_user_id + 
5)),
+                        'item_id': [1000 + i for i in range(start_user_id, 
start_user_id + 5)],
+                        'behavior': [f'thread{thread_id}_{i}' for i in 
range(5)],
+                        'dt': ['p1' if i % 2 == 0 else 'p2' for i in range(5)],
+                    }
+                    pa_table = pa.Table.from_pydict(data, 
schema=self.pa_schema)
+
+                    table_write.write_arrow(pa_table)
+                    commit_messages = table_write.prepare_commit()
+
+                    table_commit.commit(commit_messages)
+                    table_write.close()
+                    table_commit.close()
+
+                    write_results.append({
+                        'thread_id': thread_id,
+                        'start_user_id': start_user_id,
+                        'success': True
+                    })
+                except Exception as e:
+                    write_errors.append({
+                        'thread_id': thread_id,
+                        'error': str(e)
+                    })
+
+            # Create and start multiple threads
+            threads = []
+            num_threads = 100
+            for i in range(num_threads):
+                thread = threading.Thread(
+                    target=write_data,
+                    args=(i, i * 10)
+                )
+                threads.append(thread)
+                thread.start()
+
+            # Wait for all threads to complete
+            for thread in threads:
+                thread.join()
+
+            # Verify all writes succeeded (retry mechanism should handle 
conflicts)
+            self.assertEqual(num_threads, len(write_results),
+                             f"Iteration {test_iteration}: Expected 
{num_threads} successful writes, "
+                             f"got {len(write_results)}. Errors: 
{write_errors}")
+            self.assertEqual(0, len(write_errors),
+                             f"Iteration {test_iteration}: Expected no errors, 
but got: {write_errors}")
+
+            read_builder = table.new_read_builder()
+            actual = self._read_test_table(read_builder).sort_by('user_id')
+
+            # Verify data rows (PK table should have unique user_id+dt 
combinations)
+            self.assertEqual(num_threads * 5, actual.num_rows,
+                             f"Iteration {test_iteration}: Expected 
{num_threads * 5} rows")
+
+            # Verify user_id
+            user_ids = actual.column('user_id').to_pylist()
+            expected_user_ids = []
+            for i in range(num_threads):
+                expected_user_ids.extend(range(i * 10, i * 10 + 5))
+            expected_user_ids.sort()
+
+            self.assertEqual(user_ids, expected_user_ids,
+                             f"Iteration {test_iteration}: User IDs mismatch")
+
+            # Verify snapshot count (should have num_threads snapshots)
+            snapshot_manager = SnapshotManager(table)
+            latest_snapshot = snapshot_manager.get_latest_snapshot()
+            self.assertIsNotNone(latest_snapshot,
+                                 f"Iteration {test_iteration}: Latest snapshot 
should not be None")
+            self.assertEqual(latest_snapshot.id, num_threads,
+                             f"Iteration {test_iteration}: Expected snapshot 
ID {num_threads}, "
+                             f"got {latest_snapshot.id}")
+
+            print(f"✓ PK Table Iteration {test_iteration + 1}/{iter_num} 
completed successfully")
diff --git a/paimon-python/pypaimon/tests/schema_evolution_read_test.py 
b/paimon-python/pypaimon/tests/schema_evolution_read_test.py
index f5dafaae35..a67a927a5e 100644
--- a/paimon-python/pypaimon/tests/schema_evolution_read_test.py
+++ b/paimon-python/pypaimon/tests/schema_evolution_read_test.py
@@ -322,6 +322,7 @@ class SchemaEvolutionReadTest(unittest.TestCase):
 
         # write schema-0 and schema-1 to table2
         schema_manager = SchemaManager(table2.file_io, table2.table_path)
+        schema_manager.file_io.delete_quietly(table2.table_path + 
"/schema/schema-0")
         schema_manager.commit(TableSchema.from_schema(schema_id=0, 
schema=schema))
         schema_manager.commit(TableSchema.from_schema(schema_id=1, 
schema=schema2))
 
diff --git a/paimon-python/pypaimon/write/file_store_commit.py 
b/paimon-python/pypaimon/write/file_store_commit.py
index a5b9fd9693..e55e25f7c8 100644
--- a/paimon-python/pypaimon/write/file_store_commit.py
+++ b/paimon-python/pypaimon/write/file_store_commit.py
@@ -16,9 +16,11 @@
 # limitations under the License.
 
################################################################################
 
+import logging
+import random
 import time
 import uuid
-from typing import List
+from typing import List, Optional
 
 from pypaimon.common.predicate_builder import PredicateBuilder
 from pypaimon.manifest.manifest_file_manager import ManifestFileManager
@@ -35,6 +37,33 @@ from pypaimon.table.row.generic_row import GenericRow
 from pypaimon.table.row.offset_row import OffsetRow
 from pypaimon.write.commit_message import CommitMessage
 
+logger = logging.getLogger(__name__)
+
+
+class CommitResult:
+    """Base class for commit results."""
+
+    def is_success(self) -> bool:
+        """Returns True if commit was successful."""
+        raise NotImplementedError
+
+
+class SuccessResult(CommitResult):
+    """Result indicating successful commit."""
+
+    def is_success(self) -> bool:
+        return True
+
+
+class RetryResult(CommitResult):
+
+    def __init__(self, latest_snapshot, exception: Optional[Exception] = None):
+        self.latest_snapshot = latest_snapshot
+        self.exception = exception
+
+    def is_success(self) -> bool:
+        return False
+
 
 class FileStoreCommit:
     """
@@ -58,6 +87,11 @@ class FileStoreCommit:
         self.manifest_target_size = 8 * 1024 * 1024
         self.manifest_merge_min_count = 30
 
+        self.commit_max_retries = table.options.commit_max_retries()
+        self.commit_timeout = table.options.commit_timeout()
+        self.commit_min_retry_wait = table.options.commit_min_retry_wait()
+        self.commit_max_retry_wait = table.options.commit_max_retry_wait()
+
     def commit(self, commit_messages: List[CommitMessage], commit_identifier: 
int):
         """Commit the given commit messages in normal append mode."""
         if not commit_messages:
@@ -99,27 +133,81 @@ class FileStoreCommit:
                     raise RuntimeError(f"Trying to overwrite partition 
{overwrite_partition}, but the changes "
                                        f"in {msg.partition} does not belong to 
this partition")
 
-        commit_entries = []
-        current_entries = FullStartingScanner(self.table, partition_filter, 
None).plan_files()
-        for entry in current_entries:
-            entry.kind = 1
-            commit_entries.append(entry)
-        for msg in commit_messages:
-            partition = GenericRow(list(msg.partition), 
self.table.partition_keys_fields)
-            for file in msg.new_files:
-                commit_entries.append(ManifestEntry(
-                    kind=0,
-                    partition=partition,
-                    bucket=msg.bucket,
-                    total_buckets=self.table.total_buckets,
-                    file=file
-                ))
+        self._overwrite_partition_filter = partition_filter
+        self._overwrite_commit_messages = commit_messages
 
-        self._try_commit(commit_kind="OVERWRITE",
-                         commit_entries=commit_entries,
-                         commit_identifier=commit_identifier)
+        self._try_commit(
+            commit_kind="OVERWRITE",
+            commit_entries=None,  # Will be generated in _try_commit based on 
latest snapshot
+            commit_identifier=commit_identifier
+        )
 
     def _try_commit(self, commit_kind, commit_entries, commit_identifier):
+        import threading
+
+        retry_count = 0
+        retry_result = None
+        start_time_ms = int(time.time() * 1000)
+        thread_id = threading.current_thread().name
+        while True:
+            latest_snapshot = self.snapshot_manager.get_latest_snapshot()
+
+            if commit_kind == "OVERWRITE":
+                commit_entries = self._generate_overwrite_entries()
+
+            result = self._try_commit_once(
+                retry_result=retry_result,
+                commit_kind=commit_kind,
+                commit_entries=commit_entries,
+                commit_identifier=commit_identifier,
+                latest_snapshot=latest_snapshot
+            )
+
+            if result.is_success():
+                logger.warning(
+                    f"Thread {thread_id}: commit success {latest_snapshot.id + 
1 if latest_snapshot else 1} "
+                    f"after {retry_count} retries"
+                )
+                break
+
+            retry_result = result
+
+            elapsed_ms = int(time.time() * 1000) - start_time_ms
+            if elapsed_ms > self.commit_timeout or retry_count >= 
self.commit_max_retries:
+                error_msg = (
+                    f"Commit failed {latest_snapshot.id + 1 if latest_snapshot 
else 1} "
+                    f"after {elapsed_ms} millis with {retry_count} retries, "
+                    f"there maybe exist commit conflicts between multiple 
jobs."
+                )
+                if retry_result.exception:
+                    raise RuntimeError(error_msg) from retry_result.exception
+                else:
+                    raise RuntimeError(error_msg)
+
+            self._commit_retry_wait(retry_count)
+            retry_count += 1
+
+    def _try_commit_once(self, retry_result: Optional[RetryResult], 
commit_kind: str,
+                         commit_entries: List[ManifestEntry], 
commit_identifier: int,
+                         latest_snapshot: Optional[Snapshot]) -> CommitResult:
+        start_time_ms = int(time.time() * 1000)
+
+        if retry_result is not None and latest_snapshot is not None:
+            start_check_snapshot_id = 1  # Snapshot.FIRST_SNAPSHOT_ID
+            if retry_result.latest_snapshot is not None:
+                start_check_snapshot_id = retry_result.latest_snapshot.id + 1
+
+            for snapshot_id in range(start_check_snapshot_id, 
latest_snapshot.id + 2):
+                snapshot = 
self.snapshot_manager.get_snapshot_by_id(snapshot_id)
+                if (snapshot and snapshot.commit_user == self.commit_user and
+                        snapshot.commit_identifier == commit_identifier and
+                        snapshot.commit_kind == commit_kind):
+                    logger.info(
+                        f"Commit already completed (snapshot {snapshot_id}), "
+                        f"user: {self.commit_user}, identifier: 
{commit_identifier}"
+                    )
+                    return SuccessResult()
+
         unique_id = uuid.uuid4()
         base_manifest_list = f"manifest-list-{unique_id}-0"
         delta_manifest_list = f"manifest-list-{unique_id}-1"
@@ -130,7 +218,6 @@ class FileStoreCommit:
         deleted_file_count = 0
         delta_record_count = 0
         # process snapshot
-        latest_snapshot = self.snapshot_manager.get_latest_snapshot()
         new_snapshot_id = latest_snapshot.id + 1 if latest_snapshot else 1
 
         # Check if row tracking is enabled
@@ -143,7 +230,7 @@ class FileStoreCommit:
             commit_entries = self._assign_snapshot_id(new_snapshot_id, 
commit_entries)
 
             # Get the next row ID start from the latest snapshot
-            first_row_id_start = self._get_next_row_id_start()
+            first_row_id_start = self._get_next_row_id_start(latest_snapshot)
 
             # Assign row IDs to new files and get the next row ID for the 
snapshot
             commit_entries, next_row_id = 
self._assign_row_tracking_meta(first_row_id_start, commit_entries)
@@ -155,71 +242,164 @@ class FileStoreCommit:
             else:
                 deleted_file_count += 1
                 delta_record_count -= entry.file.row_count
-        self.manifest_file_manager.write(new_manifest_file, commit_entries)
-        # TODO: implement noConflictsOrFail logic
-        partition_columns = list(zip(*(entry.partition.values for entry in 
commit_entries)))
-        partition_min_stats = [min(col) for col in partition_columns]
-        partition_max_stats = [max(col) for col in partition_columns]
-        partition_null_counts = [sum(value == 0 for value in col) for col in 
partition_columns]
-        if not all(count == 0 for count in partition_null_counts):
-            raise RuntimeError("Partition value should not be null")
-        manifest_file_path = 
f"{self.manifest_file_manager.manifest_path}/{new_manifest_file}"
-        new_manifest_list = ManifestFileMeta(
-            file_name=new_manifest_file,
-            file_size=self.table.file_io.get_file_size(manifest_file_path),
-            num_added_files=added_file_count,
-            num_deleted_files=deleted_file_count,
-            partition_stats=SimpleStats(
-                min_values=GenericRow(
-                    values=partition_min_stats,
-                    fields=self.table.partition_keys_fields
-                ),
-                max_values=GenericRow(
-                    values=partition_max_stats,
-                    fields=self.table.partition_keys_fields
+
+        try:
+            self.manifest_file_manager.write(new_manifest_file, commit_entries)
+
+            # TODO: implement noConflictsOrFail logic
+            partition_columns = list(zip(*(entry.partition.values for entry in 
commit_entries)))
+            partition_min_stats = [min(col) for col in partition_columns]
+            partition_max_stats = [max(col) for col in partition_columns]
+            partition_null_counts = [sum(value == 0 for value in col) for col 
in partition_columns]
+            if not all(count == 0 for count in partition_null_counts):
+                raise RuntimeError("Partition value should not be null")
+
+            manifest_file_path = 
f"{self.manifest_file_manager.manifest_path}/{new_manifest_file}"
+            file_size = self.table.file_io.get_file_size(manifest_file_path)
+
+            new_manifest_file_meta = ManifestFileMeta(
+                file_name=new_manifest_file,
+                file_size=file_size,
+                num_added_files=added_file_count,
+                num_deleted_files=deleted_file_count,
+                partition_stats=SimpleStats(
+                    min_values=GenericRow(
+                        values=partition_min_stats,
+                        fields=self.table.partition_keys_fields
+                    ),
+                    max_values=GenericRow(
+                        values=partition_max_stats,
+                        fields=self.table.partition_keys_fields
+                    ),
+                    null_counts=partition_null_counts,
                 ),
-                null_counts=partition_null_counts,
-            ),
-            schema_id=self.table.table_schema.id,
+                schema_id=self.table.table_schema.id,
+            )
+
+            self.manifest_list_manager.write(delta_manifest_list, 
[new_manifest_file_meta])
+
+            # process existing_manifest
+            total_record_count = 0
+            if latest_snapshot:
+                existing_manifest_files = 
self.manifest_list_manager.read_all(latest_snapshot)
+                previous_record_count = latest_snapshot.total_record_count
+                if previous_record_count:
+                    total_record_count += previous_record_count
+            else:
+                existing_manifest_files = []
+
+            self.manifest_list_manager.write(base_manifest_list, 
existing_manifest_files)
+            total_record_count += delta_record_count
+            snapshot_data = Snapshot(
+                version=3,
+                id=new_snapshot_id,
+                schema_id=self.table.table_schema.id,
+                base_manifest_list=base_manifest_list,
+                delta_manifest_list=delta_manifest_list,
+                total_record_count=total_record_count,
+                delta_record_count=delta_record_count,
+                commit_user=self.commit_user,
+                commit_identifier=commit_identifier,
+                commit_kind=commit_kind,
+                time_millis=int(time.time() * 1000),
+                next_row_id=next_row_id,
+            )
+            # Generate partition statistics for the commit
+            statistics = self._generate_partition_statistics(commit_entries)
+        except Exception as e:
+            self._cleanup_preparation_failure(new_manifest_file, 
delta_manifest_list,
+                                              base_manifest_list)
+            logger.warning(f"Exception occurs when preparing snapshot: {e}", 
exc_info=True)
+            raise RuntimeError(f"Failed to prepare snapshot: {e}")
+
+        # Use SnapshotCommit for atomic commit
+        try:
+            with self.snapshot_commit:
+                success = self.snapshot_commit.commit(snapshot_data, 
self.table.current_branch(), statistics)
+                if not success:
+                    # Commit failed, clean up temporary files and retry
+                    commit_time_sec = (int(time.time() * 1000) - 
start_time_ms) / 1000
+                    logger.warning(
+                        f"Atomic commit failed for snapshot #{new_snapshot_id} 
"
+                        f"by user {self.commit_user} "
+                        f"with identifier {commit_identifier} and kind 
{commit_kind} after {commit_time_sec}s. "
+                        f"Clean up and try again."
+                    )
+                    self._cleanup_preparation_failure(new_manifest_file, 
delta_manifest_list,
+                                                      base_manifest_list)
+                    return RetryResult(latest_snapshot, None)
+        except Exception as e:
+            # Commit exception, not sure about the situation and should not 
clean up the files
+            logger.warning("Retry commit for exception")
+            return RetryResult(latest_snapshot, e)
+
+        logger.warning(
+            f"Successfully commit snapshot {new_snapshot_id} to table 
{self.table.identifier} "
+            f"for snapshot-{new_snapshot_id} by user {self.commit_user} "
+            + f"with identifier {commit_identifier} and kind {commit_kind}."
         )
-        self.manifest_list_manager.write(delta_manifest_list, 
[new_manifest_list])
-
-        # process existing_manifest
-        total_record_count = 0
-        if latest_snapshot:
-            existing_manifest_files = 
self.manifest_list_manager.read_all(latest_snapshot)
-            previous_record_count = latest_snapshot.total_record_count
-            if previous_record_count:
-                total_record_count += previous_record_count
-        else:
-            existing_manifest_files = []
-        self.manifest_list_manager.write(base_manifest_list, 
existing_manifest_files)
+        return SuccessResult()
 
-        # process snapshot
-        total_record_count += delta_record_count
-        snapshot_data = Snapshot(
-            version=3,
-            id=new_snapshot_id,
-            schema_id=self.table.table_schema.id,
-            base_manifest_list=base_manifest_list,
-            delta_manifest_list=delta_manifest_list,
-            total_record_count=total_record_count,
-            delta_record_count=delta_record_count,
-            commit_user=self.commit_user,
-            commit_identifier=commit_identifier,
-            commit_kind=commit_kind,
-            time_millis=int(time.time() * 1000),
-            next_row_id=next_row_id,
+    def _generate_overwrite_entries(self):
+        """Generate commit entries for OVERWRITE mode based on latest 
snapshot."""
+        entries = []
+        current_entries = FullStartingScanner(self.table, 
self._overwrite_partition_filter, None).plan_files()
+        for entry in current_entries:
+            entry.kind = 1  # DELETE
+            entries.append(entry)
+        for msg in self._overwrite_commit_messages:
+            partition = GenericRow(list(msg.partition), 
self.table.partition_keys_fields)
+            for file in msg.new_files:
+                entries.append(ManifestEntry(
+                    kind=0,  # ADD
+                    partition=partition,
+                    bucket=msg.bucket,
+                    total_buckets=self.table.total_buckets,
+                    file=file
+                ))
+        return entries
+
+    def _commit_retry_wait(self, retry_count: int):
+        import threading
+        thread_id = threading.get_ident()
+
+        retry_wait_ms = min(
+            self.commit_min_retry_wait * (2 ** retry_count),
+            self.commit_max_retry_wait
         )
 
-        # Generate partition statistics for the commit
-        statistics = self._generate_partition_statistics(commit_entries)
+        jitter_ms = random.randint(0, max(1, int(retry_wait_ms * 0.2)))
+        total_wait_ms = retry_wait_ms + jitter_ms
 
-        # Use SnapshotCommit for atomic commit
-        with self.snapshot_commit:
-            success = self.snapshot_commit.commit(snapshot_data, 
self.table.current_branch(), statistics)
-            if not success:
-                raise RuntimeError(f"Failed to commit snapshot 
{new_snapshot_id}")
+        logger.debug(
+            f"Thread {thread_id}: Waiting {total_wait_ms}ms before retry 
(base: {retry_wait_ms}ms, "
+            f"jitter: {jitter_ms}ms)"
+        )
+        time.sleep(total_wait_ms / 1000.0)
+
+    def _cleanup_preparation_failure(self, manifest_file: Optional[str],
+                                     delta_manifest_list: Optional[str],
+                                     base_manifest_list: Optional[str]):
+        try:
+            manifest_path = self.manifest_list_manager.manifest_path
+
+            if delta_manifest_list:
+                manifest_files = 
self.manifest_list_manager.read(delta_manifest_list)
+                for manifest_meta in manifest_files:
+                    manifest_file_path = 
f"{self.manifest_file_manager.manifest_path}/{manifest_meta.file_name}"
+                    self.table.file_io.delete_quietly(manifest_file_path)
+                delta_path = f"{manifest_path}/{delta_manifest_list}"
+                self.table.file_io.delete_quietly(delta_path)
+
+            if base_manifest_list:
+                base_path = f"{manifest_path}/{base_manifest_list}"
+                self.table.file_io.delete_quietly(base_path)
+
+            if manifest_file:
+                manifest_file_path = 
f"{self.manifest_file_manager.manifest_path}/{manifest_file}"
+                self.table.file_io.delete_quietly(manifest_file_path)
+        except Exception as e:
+            logger.warning(f"Failed to clean up temporary files during 
preparation failure: {e}", exc_info=True)
 
     def abort(self, commit_messages: List[CommitMessage]):
         """Abort commit and delete files. Uses external_path if available to 
ensure proper scheme handling."""
@@ -332,9 +512,8 @@ class FileStoreCommit:
         """Assign snapshot ID to all commit entries."""
         return [entry.assign_sequence_number(snapshot_id, snapshot_id) for 
entry in commit_entries]
 
-    def _get_next_row_id_start(self) -> int:
+    def _get_next_row_id_start(self, latest_snapshot) -> int:
         """Get the next row ID start from the latest snapshot."""
-        latest_snapshot = self.snapshot_manager.get_latest_snapshot()
         if latest_snapshot and hasattr(latest_snapshot, 'next_row_id') and 
latest_snapshot.next_row_id is not None:
             return latest_snapshot.next_row_id
         return 0

Reply via email to