This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0b482e39ef9 Add transient-error retry to SalesforceBulkOperator
(#64575)
0b482e39ef9 is described below
commit 0b482e39ef9a96dfb113e4e8b34a2b7c8cf1bde8
Author: nagasrisai <[email protected]>
AuthorDate: Tue Apr 7 03:50:26 2026 +0530
Add transient-error retry to SalesforceBulkOperator (#64575)
* Add tests for SalesforceBulkOperator transient-error retry
* Add transient-error retry to SalesforceBulkOperator
Introduces max_retries, retry_delay, and transient_error_codes params.
When max_retries > 0, records that fail with a transient Salesforce error
(UNABLE_TO_LOCK_ROW or API_TEMPORARILY_UNAVAILABLE by default) are
re-submitted after retry_delay seconds, up to max_retries times.
Only the failed records are re-submitted, not the entire payload.
Related to #64519
* Fix lint: remove unused pytest import and dead variable assignments
* Add input validation for max_retries, retry_delay, and
transient_error_codes
* Fix IndentationError in _validate_inputs: use consistent 8-space indent
* Rename retry_delay → bulk_retry_delay to avoid collision with
BaseOperator.retry_delay (timedelta)
* Update tests: retry_delay → bulk_retry_delay
* Fix: correct mock chain for hook conn.bulk; ruff format long dicts
* Fix: remove list() from _run_operation, add to retry call; fix ruff format
* Apply ruff format: split long lines, wrap method signature and retry call
* Apply ruff format: split long dicts in test helpers
* Fix ruff: reformat with line-length=110 (Airflow project standard)
* Fix mypy: cast(list,...) in _run_operation; fix ruff: use line-length=110
* Fix ruff: use cast("list",...) string-quoted form (TC rule)
---
.../airflow/providers/salesforce/operators/bulk.py | 125 ++++++++++---
.../unit/salesforce/operators/test_bulk_retry.py | 199 +++++++++++++++++++++
2 files changed, 300 insertions(+), 24 deletions(-)
diff --git
a/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
b/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
index 7b5d21030db..720a3c6ad9d 100644
--- a/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
+++ b/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import time
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, cast
@@ -29,6 +30,13 @@ if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
+# Salesforce error statusCode values that indicate a transient server-side
+# condition rather than a permanent data problem. Records that fail with one of
+# these codes can reasonably be re-submitted after a short delay.
+_DEFAULT_TRANSIENT_ERROR_CODES: frozenset[str] = frozenset(
+ {"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}
+)
+
class SalesforceBulkOperator(BaseOperator):
"""
@@ -46,6 +54,14 @@ class SalesforceBulkOperator(BaseOperator):
:param batch_size: number of records to assign for each batch in the job
:param use_serial: Process batches in serial mode
:param salesforce_conn_id: The :ref:`Salesforce Connection id
<howto/connection:salesforce>`.
+ :param max_retries: Number of times to re-submit records that failed with a
+ transient error code such as ``UNABLE_TO_LOCK_ROW`` or
+ ``API_TEMPORARILY_UNAVAILABLE``. Set to ``0`` (the default) to disable
+ automatic retries.
+ :param bulk_retry_delay: Seconds to wait before each retry attempt within
the Bulk API retry loop. Defaults to ``5``.
+ :param transient_error_codes: Collection of Salesforce error ``statusCode``
+ values that should trigger a retry. Defaults to
+ ``{"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}``.
"""
template_fields: Sequence[str] = ("object_name", "payload",
"external_id_field")
@@ -62,6 +78,9 @@ class SalesforceBulkOperator(BaseOperator):
batch_size: int = 10000,
use_serial: bool = False,
salesforce_conn_id: str = "salesforce_default",
+ max_retries: int = 0,
+ bulk_retry_delay: float = 5.0,
+ transient_error_codes: Iterable[str] = _DEFAULT_TRANSIENT_ERROR_CODES,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -72,9 +91,25 @@ class SalesforceBulkOperator(BaseOperator):
self.batch_size = batch_size
self.use_serial = use_serial
self.salesforce_conn_id = salesforce_conn_id
+ self.max_retries = max_retries
+ self.bulk_retry_delay = bulk_retry_delay
+ if isinstance(transient_error_codes, str):
+ raise ValueError(
+ "'transient_error_codes' must be a non-string iterable of
strings, "
+ f"got {transient_error_codes!r}. Wrap it in a list:
[{transient_error_codes!r}]"
+ )
+ self.transient_error_codes = frozenset(transient_error_codes)
self._validate_inputs()
def _validate_inputs(self) -> None:
+ if self.max_retries < 0:
+ raise ValueError(f"'max_retries' must be a non-negative integer,
got {self.max_retries!r}.")
+
+ if self.bulk_retry_delay < 0:
+ raise ValueError(
+ f"'bulk_retry_delay' must be a non-negative number, got
{self.bulk_retry_delay!r}."
+ )
+
if not self.object_name:
raise ValueError("The required parameter 'object_name' cannot have
an empty value.")
@@ -84,6 +119,68 @@ class SalesforceBulkOperator(BaseOperator):
f"Available operations are {self.available_operations}."
)
+ def _run_operation(self, bulk: SFBulkHandler, payload: list) -> list:
+ """Submit *payload* through the configured Bulk API operation and
return the result list."""
+ obj = bulk.__getattr__(self.object_name)
+ if self.operation == "upsert":
+ return cast(
+ "list",
+ obj.upsert(
+ data=payload,
+ external_id_field=self.external_id_field,
+ batch_size=self.batch_size,
+ use_serial=self.use_serial,
+ ),
+ )
+ return cast(
+ "list",
+ getattr(obj, self.operation)(
+ data=payload,
+ batch_size=self.batch_size,
+ use_serial=self.use_serial,
+ ),
+ )
+
+ def _retry_transient_failures(self, bulk: SFBulkHandler, payload: list,
result: list) -> list:
+ """
+ Re-submit records that failed with a transient error, up to
*max_retries* times.
+
+ Salesforce Bulk API results are ordered identically to the input
payload, so
+ failed records are located by index and their retry results are
written back
+ into the same positions.
+ """
+ final = list(result)
+
+ for attempt in range(1, self.max_retries + 1):
+ retry_indices = [
+ i
+ for i, r in enumerate(final)
+ if not r.get("success")
+ and {e.get("statusCode") for e in r.get("errors", [])} &
self.transient_error_codes
+ ]
+
+ if not retry_indices:
+ break
+
+ self.log.warning(
+ "Salesforce Bulk API %s on %s: retrying %d record(s) with
transient errors "
+ "(attempt %d/%d, waiting %.1f second(s)).",
+ self.operation,
+ self.object_name,
+ len(retry_indices),
+ attempt,
+ self.max_retries,
+ self.bulk_retry_delay,
+ )
+ time.sleep(self.bulk_retry_delay)
+
+ retry_result = list(self._run_operation(bulk, [payload[i] for i in
retry_indices]))
+
+ for list_pos, original_idx in enumerate(retry_indices):
+ final[original_idx] = retry_result[list_pos]
+
+ return final
+
def execute(self, context: Context):
"""
Make an HTTP request to Salesforce Bulk API.
@@ -95,30 +192,10 @@ class SalesforceBulkOperator(BaseOperator):
conn = sf_hook.get_conn()
bulk: SFBulkHandler = cast("SFBulkHandler", conn.__getattr__("bulk"))
- result: Iterable = []
- if self.operation == "insert":
- result = bulk.__getattr__(self.object_name).insert(
- data=self.payload, batch_size=self.batch_size,
use_serial=self.use_serial
- )
- elif self.operation == "update":
- result = bulk.__getattr__(self.object_name).update(
- data=self.payload, batch_size=self.batch_size,
use_serial=self.use_serial
- )
- elif self.operation == "upsert":
- result = bulk.__getattr__(self.object_name).upsert(
- data=self.payload,
- external_id_field=self.external_id_field,
- batch_size=self.batch_size,
- use_serial=self.use_serial,
- )
- elif self.operation == "delete":
- result = bulk.__getattr__(self.object_name).delete(
- data=self.payload, batch_size=self.batch_size,
use_serial=self.use_serial
- )
- elif self.operation == "hard_delete":
- result = bulk.__getattr__(self.object_name).hard_delete(
- data=self.payload, batch_size=self.batch_size,
use_serial=self.use_serial
- )
+ result = self._run_operation(bulk, self.payload)
+
+ if self.max_retries > 0:
+ result = self._retry_transient_failures(bulk, self.payload, result)
if self.do_xcom_push and result:
return result
diff --git
a/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py
b/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py
new file mode 100644
index 00000000000..d373c208a39
--- /dev/null
+++ b/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.providers.salesforce.operators.bulk import SalesforceBulkOperator
+
+
+def _make_op(**kwargs):
+ defaults = dict(
+ task_id="test_task",
+ operation="insert",
+ object_name="Contact",
+ payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
+ )
+ defaults.update(kwargs)
+ return SalesforceBulkOperator(**defaults)
+
+
+def _transient_failure(status_code="UNABLE_TO_LOCK_ROW"):
+ return {
+ "success": False,
+ "errors": [{"statusCode": status_code, "message": "locked", "fields":
[]}],
+ }
+
+
+def _permanent_failure():
+ return {
+ "success": False,
+ "errors": [
+ {
+ "statusCode": "REQUIRED_FIELD_MISSING",
+ "message": "missing",
+ "fields": ["Name"],
+ }
+ ],
+ }
+
+
+def _success():
+ return {"success": True, "errors": []}
+
+
+class TestSalesforceBulkOperatorRetry:
+ def test_no_retry_when_max_retries_zero(self):
+ op = _make_op(max_retries=0)
+ assert op.max_retries == 0
+
+ bulk_mock = mock.MagicMock()
+ bulk_mock.__getattr__("Contact").insert.return_value = [_success(),
_success()]
+
+ with
mock.patch("airflow.providers.salesforce.operators.bulk.SalesforceHook") as
hook_cls:
+ hook_cls.return_value.get_conn.return_value.bulk = bulk_mock
+ result = op.execute(context={})
+
+ assert result == [_success(), _success()]
+ assert bulk_mock.__getattr__("Contact").insert.call_count == 1
+
+ def test_transient_failure_is_retried(self):
+ op = _make_op(max_retries=2, bulk_retry_delay=0)
+
+ first_result = [_transient_failure(), _success()]
+ second_result = [_success()]
+
+ run_mock = mock.MagicMock(side_effect=[first_result, second_result])
+
+ with mock.patch.object(op, "_run_operation", run_mock):
+ with
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+ final = op._retry_transient_failures(
+ bulk=mock.MagicMock(),
+ payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
+ result=first_result,
+ )
+
+ assert final[0] == _success()
+ assert final[1] == _success()
+ assert run_mock.call_count == 2
+ retry_call = run_mock.call_args_list[1]
+ assert retry_call == mock.call(mock.ANY, [{"FirstName": "Ada"}])
+
+ def test_permanent_failure_is_not_retried(self):
+ op = _make_op(max_retries=3, bulk_retry_delay=0)
+ result = [_permanent_failure(), _success()]
+
+ run_mock = mock.MagicMock()
+
+ with mock.patch.object(op, "_run_operation", run_mock):
+ final = op._retry_transient_failures(
+ bulk=mock.MagicMock(),
+ payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
+ result=result,
+ )
+
+ run_mock.assert_not_called()
+ assert final[0] == _permanent_failure()
+
+ def test_retries_stop_after_max_retries(self):
+ op = _make_op(max_retries=2, bulk_retry_delay=0)
+
+ always_transient = [_transient_failure()]
+ run_mock = mock.MagicMock(return_value=always_transient)
+
+ with mock.patch.object(op, "_run_operation", run_mock):
+ with
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+ final = op._retry_transient_failures(
+ bulk=mock.MagicMock(),
+ payload=[{"FirstName": "Ada"}],
+ result=always_transient,
+ )
+
+ assert run_mock.call_count == 2
+ assert final[0]["success"] is False
+
+ def test_retry_delay_is_respected(self):
+ op = _make_op(max_retries=1, bulk_retry_delay=30.0)
+
+ run_mock = mock.MagicMock(return_value=[_success()])
+
+ with mock.patch.object(op, "_run_operation", run_mock):
+ with
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep") as
sleep_mock:
+ op._retry_transient_failures(
+ bulk=mock.MagicMock(),
+ payload=[{"FirstName": "Ada"}],
+ result=[_transient_failure()],
+ )
+
+ sleep_mock.assert_called_once_with(30.0)
+
+ def test_custom_transient_error_codes(self):
+ op = _make_op(max_retries=1, bulk_retry_delay=0,
transient_error_codes=["MY_CUSTOM_ERROR"])
+ assert op.transient_error_codes == frozenset({"MY_CUSTOM_ERROR"})
+
+ custom_failure = {
+ "success": False,
+ "errors": [{"statusCode": "MY_CUSTOM_ERROR", "message": "custom"}],
+ }
+ run_mock = mock.MagicMock(return_value=[_success()])
+
+ with mock.patch.object(op, "_run_operation", run_mock):
+ with
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+ final = op._retry_transient_failures(
+ bulk=mock.MagicMock(),
+ payload=[{"FirstName": "Ada"}],
+ result=[custom_failure],
+ )
+
+ run_mock.assert_called_once()
+ assert final[0] == _success()
+
+ def test_api_temporarily_unavailable_is_retried(self):
+ op = _make_op(max_retries=1, bulk_retry_delay=0)
+ failure = _transient_failure("API_TEMPORARILY_UNAVAILABLE")
+ run_mock = mock.MagicMock(return_value=[_success()])
+
+ with mock.patch.object(op, "_run_operation", run_mock):
+ with
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+ final = op._retry_transient_failures(
+ bulk=mock.MagicMock(),
+ payload=[{"FirstName": "Ada"}],
+ result=[failure],
+ )
+
+ run_mock.assert_called_once()
+ assert final[0] == _success()
+
+ def test_mixed_failures_only_retries_transient(self):
+ op = _make_op(max_retries=1, bulk_retry_delay=0)
+ payload = [{"FirstName": "A"}, {"FirstName": "B"}, {"FirstName": "C"}]
+ initial = [_transient_failure(), _permanent_failure(), _success()]
+
+ run_mock = mock.MagicMock(return_value=[_success()])
+
+ with mock.patch.object(op, "_run_operation", run_mock):
+ with
mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
+ final = op._retry_transient_failures(
+ bulk=mock.MagicMock(),
+ payload=payload,
+ result=initial,
+ )
+
+ run_mock.assert_called_once_with(mock.ANY, [{"FirstName": "A"}])
+ assert final[0] == _success()
+ assert final[1] == _permanent_failure()
+ assert final[2] == _success()