This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b482e39ef9 Add transient-error retry to SalesforceBulkOperator 
(#64575)
0b482e39ef9 is described below

commit 0b482e39ef9a96dfb113e4e8b34a2b7c8cf1bde8
Author: nagasrisai <[email protected]>
AuthorDate: Tue Apr 7 03:50:26 2026 +0530

    Add transient-error retry to SalesforceBulkOperator (#64575)
    
    * Add tests for SalesforceBulkOperator transient-error retry
    
    * Add transient-error retry to SalesforceBulkOperator
    
    Introduces max_retries, retry_delay, and transient_error_codes params.
    When max_retries > 0, records that fail with a transient Salesforce error
    (UNABLE_TO_LOCK_ROW or API_TEMPORARILY_UNAVAILABLE by default) are
    re-submitted after retry_delay seconds, up to max_retries times.
    Only the failed records are re-submitted, not the entire payload.
    
    Related to #64519
    
    * Fix lint: remove unused pytest import and dead variable assignments
    
    * Add input validation for max_retries, retry_delay, and 
transient_error_codes
    
    * Fix IndentationError in _validate_inputs: use consistent 8-space indent
    
    * Rename retry_delay → bulk_retry_delay to avoid collision with 
BaseOperator.retry_delay (timedelta)
    
    * Update tests: retry_delay → bulk_retry_delay
    
    * Fix: correct mock chain for hook conn.bulk; ruff format long dicts
    
    * Fix: remove list() from _run_operation, add to retry call; fix ruff format
    
    * Apply ruff format: split long lines, wrap method signature and retry call
    
    * Apply ruff format: split long dicts in test helpers
    
    * Fix ruff: reformat with line-length=110 (Airflow project standard)
    
    * Fix mypy: cast(list,...) in _run_operation; fix ruff: use line-length=110
    
    * Fix ruff: use cast("list",...) string-quoted form (TC rule)
---
 .../airflow/providers/salesforce/operators/bulk.py | 125 ++++++++++---
 .../unit/salesforce/operators/test_bulk_retry.py   | 199 +++++++++++++++++++++
 2 files changed, 300 insertions(+), 24 deletions(-)

diff --git 
a/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py 
b/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
index 7b5d21030db..720a3c6ad9d 100644
--- a/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
+++ b/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import time
 from collections.abc import Iterable, Sequence
 from typing import TYPE_CHECKING, cast
 
@@ -29,6 +30,13 @@ if TYPE_CHECKING:
 
     from airflow.providers.common.compat.sdk import Context
 
+# Salesforce error statusCode values that indicate a transient server-side
+# condition rather than a permanent data problem. Records that fail with one of
+# these codes can reasonably be re-submitted after a short delay.
+_DEFAULT_TRANSIENT_ERROR_CODES: frozenset[str] = frozenset(
+    {"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}
+)
+
 
 class SalesforceBulkOperator(BaseOperator):
     """
@@ -46,6 +54,14 @@ class SalesforceBulkOperator(BaseOperator):
     :param batch_size: number of records to assign for each batch in the job
     :param use_serial: Process batches in serial mode
     :param salesforce_conn_id: The :ref:`Salesforce Connection id 
<howto/connection:salesforce>`.
+    :param max_retries: Number of times to re-submit records that failed with a
+        transient error code such as ``UNABLE_TO_LOCK_ROW`` or
+        ``API_TEMPORARILY_UNAVAILABLE``.  Set to ``0`` (the default) to disable
+        automatic retries.
+    :param bulk_retry_delay: Seconds to wait before each retry attempt within 
the Bulk API retry loop. Defaults to ``5``.
+    :param transient_error_codes: Collection of Salesforce error ``statusCode``
+        values that should trigger a retry.  Defaults to
+        ``{"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}``.
     """
 
     template_fields: Sequence[str] = ("object_name", "payload", 
"external_id_field")
@@ -62,6 +78,9 @@ class SalesforceBulkOperator(BaseOperator):
         batch_size: int = 10000,
         use_serial: bool = False,
         salesforce_conn_id: str = "salesforce_default",
+        max_retries: int = 0,
+        bulk_retry_delay: float = 5.0,
+        transient_error_codes: Iterable[str] = _DEFAULT_TRANSIENT_ERROR_CODES,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -72,9 +91,25 @@ class SalesforceBulkOperator(BaseOperator):
         self.batch_size = batch_size
         self.use_serial = use_serial
         self.salesforce_conn_id = salesforce_conn_id
+        self.max_retries = max_retries
+        self.bulk_retry_delay = bulk_retry_delay
+        if isinstance(transient_error_codes, str):
+            raise ValueError(
+                "'transient_error_codes' must be a non-string iterable of 
strings, "
+                f"got {transient_error_codes!r}. Wrap it in a list: 
[{transient_error_codes!r}]"
+            )
+        self.transient_error_codes = frozenset(transient_error_codes)
         self._validate_inputs()
 
     def _validate_inputs(self) -> None:
+        if self.max_retries < 0:
+            raise ValueError(f"'max_retries' must be a non-negative integer, 
got {self.max_retries!r}.")
+
+        if self.bulk_retry_delay < 0:
+            raise ValueError(
+                f"'bulk_retry_delay' must be a non-negative number, got 
{self.bulk_retry_delay!r}."
+            )
+
         if not self.object_name:
             raise ValueError("The required parameter 'object_name' cannot have 
an empty value.")
 
@@ -84,6 +119,68 @@ class SalesforceBulkOperator(BaseOperator):
                 f"Available operations are {self.available_operations}."
             )
 
+    def _run_operation(self, bulk: SFBulkHandler, payload: list) -> list:
+        """Submit *payload* through the configured Bulk API operation and 
return the result list."""
+        obj = bulk.__getattr__(self.object_name)
+        if self.operation == "upsert":
+            return cast(
+                "list",
+                obj.upsert(
+                    data=payload,
+                    external_id_field=self.external_id_field,
+                    batch_size=self.batch_size,
+                    use_serial=self.use_serial,
+                ),
+            )
+        return cast(
+            "list",
+            getattr(obj, self.operation)(
+                data=payload,
+                batch_size=self.batch_size,
+                use_serial=self.use_serial,
+            ),
+        )
+
+    def _retry_transient_failures(self, bulk: SFBulkHandler, payload: list, 
result: list) -> list:
+        """
+        Re-submit records that failed with a transient error, up to 
*max_retries* times.
+
+        Salesforce Bulk API results are ordered identically to the input 
payload, so
+        failed records are located by index and their retry results are 
written back
+        into the same positions.
+        """
+        final = list(result)
+
+        for attempt in range(1, self.max_retries + 1):
+            retry_indices = [
+                i
+                for i, r in enumerate(final)
+                if not r.get("success")
+                and {e.get("statusCode") for e in r.get("errors", [])} & 
self.transient_error_codes
+            ]
+
+            if not retry_indices:
+                break
+
+            self.log.warning(
+                "Salesforce Bulk API %s on %s: retrying %d record(s) with 
transient errors "
+                "(attempt %d/%d, waiting %.1f second(s)).",
+                self.operation,
+                self.object_name,
+                len(retry_indices),
+                attempt,
+                self.max_retries,
+                self.bulk_retry_delay,
+            )
+            time.sleep(self.bulk_retry_delay)
+
+            retry_result = list(self._run_operation(bulk, [payload[i] for i in 
retry_indices]))
+
+            for list_pos, original_idx in enumerate(retry_indices):
+                final[original_idx] = retry_result[list_pos]
+
+        return final
+
     def execute(self, context: Context):
         """
         Make an HTTP request to Salesforce Bulk API.
@@ -95,30 +192,10 @@ class SalesforceBulkOperator(BaseOperator):
         conn = sf_hook.get_conn()
         bulk: SFBulkHandler = cast("SFBulkHandler", conn.__getattr__("bulk"))
 
-        result: Iterable = []
-        if self.operation == "insert":
-            result = bulk.__getattr__(self.object_name).insert(
-                data=self.payload, batch_size=self.batch_size, 
use_serial=self.use_serial
-            )
-        elif self.operation == "update":
-            result = bulk.__getattr__(self.object_name).update(
-                data=self.payload, batch_size=self.batch_size, 
use_serial=self.use_serial
-            )
-        elif self.operation == "upsert":
-            result = bulk.__getattr__(self.object_name).upsert(
-                data=self.payload,
-                external_id_field=self.external_id_field,
-                batch_size=self.batch_size,
-                use_serial=self.use_serial,
-            )
-        elif self.operation == "delete":
-            result = bulk.__getattr__(self.object_name).delete(
-                data=self.payload, batch_size=self.batch_size, 
use_serial=self.use_serial
-            )
-        elif self.operation == "hard_delete":
-            result = bulk.__getattr__(self.object_name).hard_delete(
-                data=self.payload, batch_size=self.batch_size, 
use_serial=self.use_serial
-            )
+        result = self._run_operation(bulk, self.payload)
+
+        if self.max_retries > 0:
+            result = self._retry_transient_failures(bulk, self.payload, result)
 
         if self.do_xcom_push and result:
             return result
diff --git 
a/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py 
b/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py
new file mode 100644
index 00000000000..d373c208a39
--- /dev/null
+++ b/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.providers.salesforce.operators.bulk import SalesforceBulkOperator
+
+
+def _make_op(**kwargs):
+    defaults = dict(
+        task_id="test_task",
+        operation="insert",
+        object_name="Contact",
+        payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
+    )
+    defaults.update(kwargs)
+    return SalesforceBulkOperator(**defaults)
+
+
+def _transient_failure(status_code="UNABLE_TO_LOCK_ROW"):
+    return {
+        "success": False,
+        "errors": [{"statusCode": status_code, "message": "locked", "fields": 
[]}],
+    }
+
+
+def _permanent_failure():
+    return {
+        "success": False,
+        "errors": [
+            {
+                "statusCode": "REQUIRED_FIELD_MISSING",
+                "message": "missing",
+                "fields": ["Name"],
+            }
+        ],
+    }
+
+
+def _success():
+    return {"success": True, "errors": []}
+
+
+class TestSalesforceBulkOperatorRetry:
+    def test_no_retry_when_max_retries_zero(self):
+        op = _make_op(max_retries=0)
+        assert op.max_retries == 0
+
+        bulk_mock = mock.MagicMock()
+        bulk_mock.__getattr__("Contact").insert.return_value = [_success(), 
_success()]
+
+        with 
mock.patch("airflow.providers.salesforce.operators.bulk.SalesforceHook") as 
hook_cls:
+            hook_cls.return_value.get_conn.return_value.bulk = bulk_mock
+            result = op.execute(context={})
+
+        assert result == [_success(), _success()]
+        assert bulk_mock.__getattr__("Contact").insert.call_count == 1
+
+    def test_transient_failure_is_retried(self):
+        op = _make_op(max_retries=2, bulk_retry_delay=0)
+
+        first_result = [_transient_failure(), _success()]
+        second_result = [_success()]
+
+        run_mock = mock.MagicMock(side_effect=[first_result, second_result])
+
+        with mock.patch.object(op, "_run_operation", run_mock):
+            with 
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+                final = op._retry_transient_failures(
+                    bulk=mock.MagicMock(),
+                    payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
+                    result=first_result,
+                )
+
+        assert final[0] == _success()
+        assert final[1] == _success()
+        assert run_mock.call_count == 2
+        retry_call = run_mock.call_args_list[1]
+        assert retry_call == mock.call(mock.ANY, [{"FirstName": "Ada"}])
+
+    def test_permanent_failure_is_not_retried(self):
+        op = _make_op(max_retries=3, bulk_retry_delay=0)
+        result = [_permanent_failure(), _success()]
+
+        run_mock = mock.MagicMock()
+
+        with mock.patch.object(op, "_run_operation", run_mock):
+            final = op._retry_transient_failures(
+                bulk=mock.MagicMock(),
+                payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
+                result=result,
+            )
+
+        run_mock.assert_not_called()
+        assert final[0] == _permanent_failure()
+
+    def test_retries_stop_after_max_retries(self):
+        op = _make_op(max_retries=2, bulk_retry_delay=0)
+
+        always_transient = [_transient_failure()]
+        run_mock = mock.MagicMock(return_value=always_transient)
+
+        with mock.patch.object(op, "_run_operation", run_mock):
+            with 
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+                final = op._retry_transient_failures(
+                    bulk=mock.MagicMock(),
+                    payload=[{"FirstName": "Ada"}],
+                    result=always_transient,
+                )
+
+        assert run_mock.call_count == 2
+        assert final[0]["success"] is False
+
+    def test_retry_delay_is_respected(self):
+        op = _make_op(max_retries=1, bulk_retry_delay=30.0)
+
+        run_mock = mock.MagicMock(return_value=[_success()])
+
+        with mock.patch.object(op, "_run_operation", run_mock):
+            with 
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep") as 
sleep_mock:
+                op._retry_transient_failures(
+                    bulk=mock.MagicMock(),
+                    payload=[{"FirstName": "Ada"}],
+                    result=[_transient_failure()],
+                )
+
+        sleep_mock.assert_called_once_with(30.0)
+
+    def test_custom_transient_error_codes(self):
+        op = _make_op(max_retries=1, bulk_retry_delay=0, 
transient_error_codes=["MY_CUSTOM_ERROR"])
+        assert op.transient_error_codes == frozenset({"MY_CUSTOM_ERROR"})
+
+        custom_failure = {
+            "success": False,
+            "errors": [{"statusCode": "MY_CUSTOM_ERROR", "message": "custom"}],
+        }
+        run_mock = mock.MagicMock(return_value=[_success()])
+
+        with mock.patch.object(op, "_run_operation", run_mock):
+            with 
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+                final = op._retry_transient_failures(
+                    bulk=mock.MagicMock(),
+                    payload=[{"FirstName": "Ada"}],
+                    result=[custom_failure],
+                )
+
+        run_mock.assert_called_once()
+        assert final[0] == _success()
+
+    def test_api_temporarily_unavailable_is_retried(self):
+        op = _make_op(max_retries=1, bulk_retry_delay=0)
+        failure = _transient_failure("API_TEMPORARILY_UNAVAILABLE")
+        run_mock = mock.MagicMock(return_value=[_success()])
+
+        with mock.patch.object(op, "_run_operation", run_mock):
+            with 
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+                final = op._retry_transient_failures(
+                    bulk=mock.MagicMock(),
+                    payload=[{"FirstName": "Ada"}],
+                    result=[failure],
+                )
+
+        run_mock.assert_called_once()
+        assert final[0] == _success()
+
+    def test_mixed_failures_only_retries_transient(self):
+        op = _make_op(max_retries=1, bulk_retry_delay=0)
+        payload = [{"FirstName": "A"}, {"FirstName": "B"}, {"FirstName": "C"}]
+        initial = [_transient_failure(), _permanent_failure(), _success()]
+
+        run_mock = mock.MagicMock(return_value=[_success()])
+
+        with mock.patch.object(op, "_run_operation", run_mock):
+            with 
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+                final = op._retry_transient_failures(
+                    bulk=mock.MagicMock(),
+                    payload=payload,
+                    result=initial,
+                )
+
+        run_mock.assert_called_once_with(mock.ANY, [{"FirstName": "A"}])
+        assert final[0] == _success()
+        assert final[1] == _permanent_failure()
+        assert final[2] == _success()

Reply via email to