This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new af80c491a8c Refactor timeout handling in DatabricksSqlHook to use 
explicit signaling (#62623)
af80c491a8c is described below

commit af80c491a8cabbc9d3697b7ebf243cecb48816a7
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Mar 10 18:35:13 2026 +0000

    Refactor timeout handling in DatabricksSqlHook to use explicit signaling 
(#62623)
    
    Replace implicit timeout detection based on Timer.is_alive() with explicit
    timeout signaling via threading.Event. Timeout classification now checks an
    explicit signal set by the timeout callback instead of inferring state from
    thread lifecycle behavior.
    
    Preserves existing cancellation semantics and exception types. Unit tests 
have been adjusted accordingly.
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../providers/databricks/hooks/databricks_sql.py   | 48 ++++++++++++++--------
 .../unit/databricks/hooks/test_databricks_sql.py   | 39 +++++++++++-------
 2 files changed, 55 insertions(+), 32 deletions(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 127f6b71c70..2c2164bd9c7 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -52,14 +52,23 @@ if TYPE_CHECKING:
 T = TypeVar("T")
 
 
-def create_timeout_thread(cur, execution_timeout: timedelta | None) -> 
threading.Timer | None:
-    if execution_timeout is not None:
-        seconds_to_timeout = execution_timeout.total_seconds()
-        t = threading.Timer(seconds_to_timeout, cur.connection.cancel)
-    else:
-        t = None
+def create_timeout_thread(
+    cur, execution_timeout: timedelta | None
+) -> tuple[threading.Timer | None, threading.Event | None]:
+    """Create a timeout timer that cancels the connection and sets a timeout 
flag."""
+    if not execution_timeout:
+        return None, None
 
-    return t
+    timeout_event = threading.Event()
+
+    def _cancel():
+        timeout_event.set()
+        cur.connection.cancel()
+
+    timer = threading.Timer(execution_timeout.total_seconds(), _cancel)
+    timer.start()
+
+    return timer, timeout_event
 
 
 class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
@@ -290,22 +299,25 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
                 self.set_autocommit(conn, autocommit)
 
                 with closing(conn.cursor()) as cur:
-                    t = create_timeout_thread(cur, execution_timeout)
+                    timer, timeout_event = create_timeout_thread(cur, 
execution_timeout)
 
-                    # TODO: adjust this to make testing easier
                     try:
                         self._run_command(cur, sql_statement, parameters)
+
                     except Exception as e:
-                        if t is None or t.is_alive():
-                            raise DatabricksSqlExecutionError(
-                                f"Error running SQL statement: 
{sql_statement}. {str(e)}"
-                            )
-                        raise DatabricksSqlExecutionTimeout(
-                            f"Timeout threshold exceeded for SQL statement: 
{sql_statement} was cancelled."
-                        )
+                        if timeout_event and timeout_event.is_set():
+                            raise DatabricksSqlExecutionTimeout(
+                                f"Timeout threshold exceeded for SQL 
statement: "
+                                f"{sql_statement} was cancelled."
+                            ) from e
+
+                        raise DatabricksSqlExecutionError(
+                            f"Error running SQL statement: {sql_statement}. 
{str(e)}"
+                        ) from e
+
                     finally:
-                        if t is not None:
-                            t.cancel()
+                        if timer:
+                            timer.cancel()
 
                     if query_id := cur.query_id:
                         self.log.info("Databricks query id: %s", query_id)
diff --git 
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py 
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index 98ea7e1d347..d661f5b0714 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -18,7 +18,6 @@
 #
 from __future__ import annotations
 
-import threading
 from collections import namedtuple
 from datetime import timedelta
 from unittest import mock
@@ -509,8 +508,12 @@ def test_execution_timeout_exceeded(
             description=get_cursor_descriptions(cursor_descriptions),
         )
 
-        # Simulate a timeout
-        mock_create_timeout_thread.return_value = threading.Timer(cur, 
execution_timeout)
+        mock_event = mock.MagicMock()
+        mock_event.is_set.return_value = True  # simulate timeout
+
+        mock_timer = mock.MagicMock()
+
+        mock_create_timeout_thread.return_value = (mock_timer, mock_event)
 
         mock_run_command.side_effect = Exception("Mocked exception")
 
@@ -532,20 +535,22 @@ def test_execution_timeout_exceeded(
     "cursor_descriptions",
     [(("id", "value"),)],
 )
-def test_create_timeout_thread(
-    mock_get_conn,
-    mock_get_requests,
-    mock_timer,
-    cursor_descriptions,
-):
+def test_create_timeout_thread(mock_get_conn, mock_get_requests, 
cursor_descriptions):
+
     cur = mock.MagicMock(
         rowcount=1,
         description=get_cursor_descriptions(cursor_descriptions),
     )
+
     timeout = timedelta(seconds=1)
-    thread = create_timeout_thread(cur=cur, execution_timeout=timeout)
-    mock_timer.assert_called_once_with(timeout.total_seconds(), 
cur.connection.cancel)
-    assert thread is not None
+
+    timer, event = create_timeout_thread(cur=cur, execution_timeout=timeout)
+
+    assert timer is not None
+    assert event is not None
+    assert not event.is_set()
+
+    timer.cancel()
 
 
 @pytest.mark.parametrize(
@@ -562,9 +567,15 @@ def test_create_timeout_thread_no_timeout(
         rowcount=1,
         description=get_cursor_descriptions(cursor_descriptions),
     )
-    thread = create_timeout_thread(cur=cur, execution_timeout=None)
+
+    timer, timeout_event = create_timeout_thread(
+        cur=cur,
+        execution_timeout=None,
+    )
+
     mock_timer.assert_not_called()
-    assert thread is None
+    assert timer is None
+    assert timeout_event is None
 
 
 def test_get_openlineage_default_schema_with_no_schema_set():

Reply via email to