This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 745898a17e5 Fail fast for non-serializable retry_args in deferrable
operators and triggers (#64960)
745898a17e5 is described below
commit 745898a17e5faadeb929b5cdee8deba0c0cea819
Author: kosiew <[email protected]>
AuthorDate: Tue Jun 2 19:57:54 2026 +0800
Fail fast for non-serializable retry_args in deferrable operators and
triggers (#64960)
* Add validation for non-serializable retry_args
Implement a shared validation guard to reject
non-serializable databricks_retry_args before
deferrable Databricks tasks cross the trigger boundary.
Enforce this check for deferrable operators and SQL
sensor in databricks.py. Add regression tests to cover
failure modes for both in test_databricks.py.
* Refactor validation to utility module and enhance tests
Move validation logic to retry.py for better cohesion. Enforce
validation in both trigger constructors within databricks.py.
Add direct trigger regression tests in test_databricks.py and
update sensor test setup to maintain deferrable branch coverage.
* Add parameterized validation tests for Tenacity shapes
Enhance operators, sensors, and triggers tests to cover two
unsupported Tenacity shapes. Tests are now parameterized for
{"wait": wait_incrementing(...)} and {"stop":
stop_after_attempt(...)} scenarios.
* Refactor retry argument tests and reduce duplication
Extract shared invalid retry-arg test data and pytest.raises
assertion into _retry_test_utils.py. Remove duplicated
UNSUPPORTED_RETRY_ARGS definitions from operator, sensor, and
trigger test files. Simplify setup in operator and sensor
negative tests with local helpers for the running deferrable
path. Combine two trigger-construction negative tests into
one shared parametrized test in test_databricks.py.
* Tighten API and update retry tests
Require owner explicitly in retry.py's private helper.
Define an UNSUPPORTED_RETRY_ARGS constant in
_retry_test_utils.py and update operator, sensor, and
trigger tests to parametrize directly from it in
test_databricks.py.
* Update retry logic and Databricks tests
Refactor retry.py to catch ValueErrors and clarify
retry_args/databricks_retry_args messages. Adjust
validation in databricks.py to use owner=caller. Update
tests in operators, sensors, and triggers for
Databricks. Fix test-helper import to follow repo style.
* Refactor retry.py to use stdlib JSON serialization
Replace SDK serde import with stdlib JSON serialization
in retry.py. Update validation call to use json.dumps()
instead of serde_serialize() to improve simplicity and
reduce dependencies.
* Add unit test for validate_deferrable_databricks_retry_args
Implement tests for the retry validation function in the
Databricks provider. Handle cases for `None` and valid
JSON-serializable primitive retry configurations, while
ensuring unsupported Tenacity retry arguments are rejected.
* Add dev/databricks_retry_args_repro.py
* rm dev/databricks_retry_args_repro.py
* trigger ci
* feat(tests): refactor retry test utilities and improve assertions
- Removed unnecessary retry test utility file.
- Moved retry test constants into a more appropriate location.
- Inlined retry error assertions in sensor and trigger tests for clarity.
- Added explicit assertions for success validation tests to enhance
reliability.
- Added comments to trigger constructor for better understanding of
serialization-boundary fail-fast validation.
* chore(databricks): update retry utility to use airflow.sdk.serde.serialize
- Modified retry.py to utilize airflow.sdk.serde.serialize
- Retained wrapping of AttributeError, RecursionError, TypeError, and
ValueError
- Updated error message to indicate "Airflow-serializable"
- Enhanced test_retry.py with datetime coverage for serde-supported
non-JSON values
- Kept Tenacity rejection tests unchanged
* trigger ci
* fix: replace datetime.UTC with datetime.timezone.utc in test_retry.py
* feat(databricks): add compat import fallback for serialization module in
retry.py
* feat(databricks): enhance retry utility with improved serialization and
cleanup
- Removed duplicate fallback import names for clarity.
- Added `get_serde_serialize()` function utilizing `import_module(...)`.
- Updated validator to call `get_serde_serialize()(retry_args)`.
---
.../providers/databricks/triggers/databricks.py | 7 +++
.../airflow/providers/databricks/utils/retry.py | 43 ++++++++++++++++
.../unit/databricks/operators/test_databricks.py | 28 ++++++++++
.../unit/databricks/sensors/test_databricks.py | 29 +++++++++++
.../unit/databricks/triggers/test_databricks.py | 39 +++++++++++++-
.../tests/unit/databricks/utils/test_retry.py | 60 ++++++++++++++++++++++
6 files changed, 204 insertions(+), 2 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py
b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py
index 25cade7fc80..2bb626b9114 100644
---
a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py
+++
b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py
@@ -23,6 +23,7 @@ from typing import Any
from airflow.providers.databricks.hooks.databricks import DatabricksHook
from airflow.providers.databricks.utils.databricks import
extract_failed_task_errors_async
+from airflow.providers.databricks.utils.retry import
validate_deferrable_databricks_retry_args
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -55,6 +56,9 @@ class DatabricksExecutionTrigger(BaseTrigger):
caller: str = "DatabricksExecutionTrigger",
) -> None:
super().__init__()
+ # Trigger kwargs cross Airflow's serialization boundary, so fail
before storing invalid
+ # trigger state or surfacing a generic serializer error without
Databricks-specific guidance.
+ validate_deferrable_databricks_retry_args(retry_args, owner=caller)
self.run_id = run_id
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
@@ -151,6 +155,9 @@ class DatabricksSQLStatementExecutionTrigger(BaseTrigger):
caller: str = "DatabricksSQLStatementExecutionTrigger",
) -> None:
super().__init__()
+ # Trigger kwargs cross Airflow's serialization boundary, so fail
before storing invalid
+ # trigger state or surfacing a generic serializer error without
Databricks-specific guidance.
+ validate_deferrable_databricks_retry_args(retry_args, owner=caller)
self.statement_id = statement_id
self.databricks_conn_id = databricks_conn_id
self.end_time = end_time
diff --git
a/providers/databricks/src/airflow/providers/databricks/utils/retry.py
b/providers/databricks/src/airflow/providers/databricks/utils/retry.py
new file mode 100644
index 00000000000..508f4b71628
--- /dev/null
+++ b/providers/databricks/src/airflow/providers/databricks/utils/retry.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Callable, Mapping
+from importlib import import_module
+from typing import Any
+
+
+def get_serde_serialize() -> Callable[[Any], Any]:
+ try:
+ return import_module("airflow.sdk.serde").serialize
+ except ImportError:
+ return import_module("airflow.serialization.serde").serialize
+
+
+def validate_deferrable_databricks_retry_args(retry_args: Mapping[str, Any] |
None, *, owner: str) -> None:
+ """Validate retry args that need to cross the trigger serialization
boundary."""
+ if retry_args is None:
+ return
+
+ try:
+ get_serde_serialize()(retry_args)
+ except (AttributeError, RecursionError, TypeError, ValueError) as err:
+ raise ValueError(
+ f"{owner} does not support non-serializable
retry_args/databricks_retry_args "
+ "when deferrable=True. "
+ "Use Airflow-serializable values, remove callable retry
strategies, or disable deferrable mode."
+ ) from err
diff --git
a/providers/databricks/tests/unit/databricks/operators/test_databricks.py
b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
index 986cf51f1cf..a1b7b4f11b3 100644
--- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
@@ -24,6 +24,7 @@ from unittest import mock
from unittest.mock import MagicMock, call
import pytest
+from tenacity import stop_after_attempt, wait_incrementing
# Do not run the tests when FAB / Flask is not installed
pytest.importorskip("flask_session")
@@ -95,6 +96,13 @@ TAGS = {
"cost-center": "engineering",
"team": "jobs",
}
+INVALID_RETRY_ARGS_PATTERN = (
+ "does not support non-serializable retry_args/databricks_retry_args when
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+ pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)},
id="wait_incrementing"),
+ pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
TASKS = [
{
"task_key": "Sessionize",
@@ -666,6 +674,11 @@ class TestDatabricksCreateJobsOperator:
class TestDatabricksSubmitRunOperator:
+ @staticmethod
+ def _configure_running_deferrable_hook(db_mock):
+ db_mock.submit_run.return_value = RUN_ID
+ db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")
+
def test_init_with_notebook_task_named_parameters(self):
"""
Test the initializer with the named parameters.
@@ -1089,6 +1102,21 @@ class TestDatabricksSubmitRunOperator:
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
assert op.run_id == RUN_ID
+ @pytest.mark.parametrize("retry_args", UNSUPPORTED_RETRY_ARGS)
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_execute_task_deferred_rejects_non_serializable_retry_args(self,
db_mock_class, retry_args):
+ op = DatabricksSubmitRunOperator(
+ deferrable=True,
+ task_id=TASK_ID,
+ json={"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK},
+ databricks_retry_args=retry_args,
+ )
+ db_mock = db_mock_class.return_value
+ self._configure_running_deferrable_hook(db_mock)
+
+ with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+ op.execute(None)
+
def test_execute_complete_success(self):
"""
Test `execute_complete` function in case the Trigger has returned a
successful completion event.
diff --git
a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
index 615364f89ed..8f94274e7a4 100644
--- a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
@@ -20,6 +20,7 @@ from __future__ import annotations
from unittest import mock
import pytest
+from tenacity import stop_after_attempt, wait_incrementing
from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
from airflow.providers.databricks.hooks.databricks import SQLStatementState
@@ -31,6 +32,13 @@ STATEMENT = "select * from test.test;"
STATEMENT_ID = "statement_id"
TASK_ID = "task_id"
WAREHOUSE_ID = "warehouse_id"
+INVALID_RETRY_ARGS_PATTERN = (
+ "does not support non-serializable retry_args/databricks_retry_args when
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+ pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)},
id="wait_incrementing"),
+ pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
class TestDatabricksSQLStatementsSensor:
@@ -39,6 +47,11 @@ class TestDatabricksSQLStatementsSensor:
from the DatabricksSQLStatementOperator, meaning that much of the testing
logic is also reused.
"""
+ @staticmethod
+ def _configure_running_deferrable_hook(db_mock):
+ db_mock.post_sql_statement.return_value = STATEMENT_ID
+ db_mock.get_sql_statement_state.return_value =
SQLStatementState("RUNNING")
+
def test_init_statement(self):
"""Test initialization for traditional use-case (statement)."""
op = DatabricksSQLStatementsSensor(task_id=TASK_ID,
statement=STATEMENT, warehouse_id=WAREHOUSE_ID)
@@ -167,6 +180,22 @@ class TestDatabricksSQLStatementsSensor:
assert isinstance(exc.value.trigger,
DatabricksSQLStatementExecutionTrigger)
assert exc.value.method_name == "execute_complete"
+ @pytest.mark.parametrize("retry_args", UNSUPPORTED_RETRY_ARGS)
+
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+ def test_execute_task_deferred_rejects_non_serializable_retry_args(self,
db_mock_class, retry_args):
+ op = DatabricksSQLStatementsSensor(
+ task_id=TASK_ID,
+ statement=STATEMENT,
+ warehouse_id=WAREHOUSE_ID,
+ deferrable=True,
+ databricks_retry_args=retry_args,
+ )
+ db_mock = db_mock_class.return_value
+ self._configure_running_deferrable_hook(db_mock)
+
+ with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+ op.execute(None)
+
def test_execute_complete_success(self):
"""
Test the execute_complete function in case the Trigger has returned a
successful completion event.
diff --git
a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py
b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py
index 903173774b7..8854eb03fb5 100644
--- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py
@@ -17,10 +17,10 @@
# under the License.
from __future__ import annotations
-import time
from unittest import mock
import pytest
+from tenacity import stop_after_attempt, wait_incrementing
from airflow.models import Connection
from airflow.providers.databricks.hooks.databricks import RunState,
SQLStatementState
@@ -42,6 +42,7 @@ RETRY_DELAY = 10
RETRY_LIMIT = 3
RUN_ID = 1
STATEMENT_ID = "statement_id"
+STATEMENT_END_TIME = 9999999999.0
TASK_RUN_ID1 = 11
TASK_RUN_ID1_KEY = "first_task"
TASK_RUN_ID2 = 22
@@ -53,6 +54,13 @@ RUN_PAGE_URL =
"https://XX.cloud.databricks.com/#jobs/1/runs/1"
CALLER = "DatabricksSubmitRunOperator"
ERROR_MESSAGE = "error message from databricks API"
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE,
"notebook_output": {}}
+INVALID_RETRY_ARGS_PATTERN = (
+ "does not support non-serializable retry_args/databricks_retry_args when
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+ pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)},
id="wait_incrementing"),
+ pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED",
"SKIPPED", "INTERNAL_ERROR"]
@@ -119,6 +127,33 @@ GET_RUN_RESPONSE_TERMINATED_WITH_FAILED = {
],
}
+TRIGGER_INIT_CASES = [
+ pytest.param(
+ DatabricksExecutionTrigger,
+ {
+ "run_id": RUN_ID,
+ "databricks_conn_id": DEFAULT_CONN_ID,
+ },
+ id="execution_trigger",
+ ),
+ pytest.param(
+ DatabricksSQLStatementExecutionTrigger,
+ {
+ "statement_id": STATEMENT_ID,
+ "databricks_conn_id": DEFAULT_CONN_ID,
+ "end_time": 1234567890.0,
+ },
+ id="sql_statement_trigger",
+ ),
+]
+
+
[email protected]("retry_args", UNSUPPORTED_RETRY_ARGS)
[email protected](("trigger_cls", "trigger_kwargs"), TRIGGER_INIT_CASES)
+def test_trigger_init_rejects_non_serializable_retry_args(trigger_cls,
trigger_kwargs, retry_args):
+ with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+ trigger_cls(**trigger_kwargs, retry_args=retry_args)
+
class TestDatabricksExecutionTrigger:
@pytest.fixture(autouse=True)
@@ -281,7 +316,7 @@ class TestDatabricksExecutionTrigger:
class TestDatabricksSQLStatementExecutionTrigger:
@pytest.fixture(autouse=True)
def setup_connections(self, create_connection_without_db):
- self.end_time = time.time() + 60
+ self.end_time = STATEMENT_END_TIME
create_connection_without_db(
Connection(
conn_id=DEFAULT_CONN_ID,
diff --git a/providers/databricks/tests/unit/databricks/utils/test_retry.py
b/providers/databricks/tests/unit/databricks/utils/test_retry.py
new file mode 100644
index 00000000000..7ff46a71f7d
--- /dev/null
+++ b/providers/databricks/tests/unit/databricks/utils/test_retry.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime
+
+import pytest
+from tenacity import stop_after_attempt, wait_incrementing
+
+from airflow.providers.databricks.utils.retry import
validate_deferrable_databricks_retry_args
+
+INVALID_RETRY_ARGS_PATTERN = (
+ "does not support non-serializable retry_args/databricks_retry_args when
deferrable=True"
+)
+UNSUPPORTED_RETRY_ARGS = [
+ pytest.param({"wait": wait_incrementing(start=1, increment=1, max=3)},
id="wait_incrementing"),
+ pytest.param({"stop": stop_after_attempt(3)}, id="stop_after_attempt"),
+]
+
+
+def test_validate_deferrable_databricks_retry_args_accepts_none():
+ assert validate_deferrable_databricks_retry_args(None, owner="test-owner")
is None
+
+
[email protected](
+ "retry_args",
+ [
+ {},
+ {"retry_limit": 3, "retry_delay": 10},
+ {"retry_limit": 3, "retry_delay": 10.5, "retry_enabled": True,
"retry_codes": ["429", "500"]},
+ ],
+)
+def
test_validate_deferrable_databricks_retry_args_accepts_serde_serializable_values(retry_args):
+ assert validate_deferrable_databricks_retry_args(retry_args,
owner="test-owner") is None
+
+
+def
test_validate_deferrable_databricks_retry_args_accepts_airflow_serde_serializable_values():
+ retry_args = {"deadline": datetime.datetime(2026, 5, 29, 12, 30,
tzinfo=datetime.timezone.utc)}
+
+ assert validate_deferrable_databricks_retry_args(retry_args,
owner="test-owner") is None
+
+
[email protected]("retry_args", UNSUPPORTED_RETRY_ARGS)
+def
test_validate_deferrable_databricks_retry_args_rejects_non_serializable_values(retry_args):
+ with pytest.raises(ValueError, match=INVALID_RETRY_ARGS_PATTERN):
+ validate_deferrable_databricks_retry_args(retry_args,
owner="test-owner")