This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c03460ed9dd common.ai: Park approval reviews in awaiting_input on 
Airflow 3.3+ (#68489)
c03460ed9dd is described below

commit c03460ed9ddf342469dccc39f7094ddae6c72226
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 16 13:01:25 2026 +0100

    common.ai: Park approval reviews in awaiting_input on Airflow 3.3+ (#68489)
    
    LLMApprovalMixin (require_approval=True on LLMOperator/AgentOperator) now
    raises TaskAwaitingInput on Airflow 3.3+ so the task parks in the
    first-class awaiting_input state -- no trigger or triggerer involved --
    matching the standard provider's HITLOperator. On older cores it falls
    back to deferring to HITLTrigger as before. The response deadline is
    enforced by the scheduler's awaiting_input timeout sweep on 3.3+.
    
    Because nothing upstream schema-validates params_input on the
    awaiting_input path (HITLTrigger did on the legacy path),
    execute_complete now enforces the string contract for reviewer-modified
    output and raises HITLTriggerEventError for non-string values.
    
    The AIRFLOW_V_3_3_PLUS flag this uses was added in
    apache-airflow-providers-common-compat 1.15.0; the dependency line is
    marked "# use next version" so the release manager bumps the floor at
    release time.
---
 providers/common/ai/pyproject.toml                 |  2 +-
 .../airflow/providers/common/ai/mixins/approval.py | 39 ++++++++++++--
 .../tests/unit/common/ai/mixins/test_approval.py   | 62 +++++++++++++++++++++-
 .../ai/tests/unit/common/ai/operators/test_llm.py  | 55 ++++++++++++++-----
 .../common/ai/operators/test_llm_file_analysis.py  | 23 ++++----
 .../tests/unit/common/ai/operators/test_llm_sql.py | 18 +++++--
 6 files changed, 165 insertions(+), 34 deletions(-)

diff --git a/providers/common/ai/pyproject.toml 
b/providers/common/ai/pyproject.toml
index db08fab7374..e569ae88f91 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -67,7 +67,7 @@ requires-python = ">=3.10"
 # After you modify the dependencies, and rebuild your Breeze CI image with 
``breeze ci-image build``
 dependencies = [
     "apache-airflow>=3.0.0",
-    "apache-airflow-providers-common-compat>=1.14.1",
+    "apache-airflow-providers-common-compat>=1.14.1",  # use next version
     "apache-airflow-providers-standard>=1.12.1",
     "pydantic-ai-slim>=1.99.0",
 ]
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py 
b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
index 07855340c4b..5ebd679efcd 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py
@@ -23,6 +23,13 @@ from typing import TYPE_CHECKING, Any, Protocol
 
 from pydantic import BaseModel
 
+from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+    # On Airflow 3.3+ the review parks the task in the first-class 
AWAITING_INPUT state instead
+    # of deferring to a trigger. On older cores this name is absent and 
defer() is used.
+    from airflow.sdk.exceptions import TaskAwaitingInput
+
 log = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
@@ -45,7 +52,8 @@ class LLMApprovalMixin:
 
     When ``require_approval=True`` on the operator, the generated output is
     presented to a human reviewer via the Airflow Human-in-the-Loop (HITL)
-    interface.  The task defers until the reviewer approves or rejects.
+    interface.  The task waits (``awaiting_input`` on Airflow 3.3+, deferred on
+    older versions) until the reviewer approves or rejects.
 
     If ``allow_modifications=True``, the reviewer can also edit the output
     before approving.  The (possibly modified) output is then returned as the
@@ -71,7 +79,11 @@ class LLMApprovalMixin:
         body: str | None = None,
     ) -> None:
         """
-        Write HITL detail, then defer to HITLTrigger for human review.
+        Write HITL detail, then pause the task for human review.
+
+        On Airflow 3.3+ the task parks in the ``awaiting_input`` state (no 
trigger or triggerer
+        involved); on older versions it defers to :class:`HITLTrigger`. Either 
way it resumes in
+        ``execute_complete`` once a response (or timeout default) arrives.
 
         :param context: Airflow task context.
         :param output: The generated output to present for review.
@@ -100,7 +112,6 @@ class LLMApprovalMixin:
             output = str(output)
 
         ti_id = context["task_instance"].id
-        timeout_datetime = utcnow() + self.approval_timeout if 
self.approval_timeout else None
 
         if subject is None:
             subject = f"Review output for task `{self.task_id}`"
@@ -128,6 +139,16 @@ class LLMApprovalMixin:
             params=hitl_params,
         )
 
+        if AIRFLOW_V_3_3_PLUS:
+            # New core (3.3+): park the task in AWAITING_INPUT -- no trigger, 
no triggerer. The
+            # task is resumed by the Core API response handler or the 
scheduler timeout sweep.
+            raise TaskAwaitingInput(
+                method_name="execute_complete",
+                kwargs={"generated_output": output},
+                timeout=self.approval_timeout,
+            )
+
+        # Fallback for cores < 3.3: defer the response check to HITLTrigger on 
the triggerer.
         self.defer(
             trigger=HITLTrigger(
                 ti_id=ti_id,
@@ -135,7 +156,7 @@ class LLMApprovalMixin:
                 defaults=None,
                 params=hitl_params,
                 multiple=False,
-                timeout_datetime=timeout_datetime,
+                timeout_datetime=utcnow() + self.approval_timeout if 
self.approval_timeout else None,
             ),
             method_name="execute_complete",
             kwargs={"generated_output": output},
@@ -182,6 +203,16 @@ class LLMApprovalMixin:
         # when allow_modifications=False, bypassing the read-only approval 
flow.
         if getattr(self, "allow_modifications", False) and params_input:
             modified = params_input.get("output")
+            if modified is not None and not isinstance(modified, str):
+                # On the awaiting_input path nothing upstream schema-validates 
params_input
+                # (HITLTrigger did on the legacy path), so enforce the string 
contract here
+                # rather than returning a non-string as the task's output.
+                raise HITLTriggerEventError(
+                    {
+                        "error": f"Modified output must be a string, got 
{type(modified).__name__}.",
+                        "error_type": "validation",
+                    }
+                )
             if modified is not None and modified != generated_output:
                 log.info("output=%s modified by the reviewer=%s ", modified, 
responded_by_user)
                 return modified
diff --git a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py 
b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
index 464dfe38986..54b675da723 100644
--- a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
+++ b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import pytest
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, 
AIRFLOW_V_3_3_PLUS
 
 if not AIRFLOW_V_3_1_PLUS:
     pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0", 
allow_module_level=True)
@@ -34,9 +34,13 @@ from airflow.providers.common.ai.mixins.approval import (
 )
 from airflow.providers.standard.exceptions import HITLRejectException, 
HITLTriggerEventError
 
+if AIRFLOW_V_3_3_PLUS:
+    from airflow.sdk.exceptions import TaskAwaitingInput
+
 HITL_TRIGGER_PATH = "airflow.providers.standard.triggers.hitl.HITLTrigger"
 UPSERT_HITL_PATH = "airflow.sdk.execution_time.hitl.upsert_hitl_detail"
 UTCNOW_PATH = "airflow.sdk.timezone.utcnow"
+AWAIT_INPUT_FLAG_PATH = 
"airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS"
 
 
 class FakeOperator(LLMApprovalMixin):
@@ -76,6 +80,9 @@ def context():
     return MagicMock(**{"__getitem__": lambda self, key: {"task_instance": 
ti}[key]})
 
 
+# The legacy trigger path is taken on cores < 3.3; pin the flag so these tests 
keep
+# exercising the defer() fallback when run against newer cores.
+@patch(AWAIT_INPUT_FLAG_PATH, False)
 class TestDeferForApproval:
     @patch(HITL_TRIGGER_PATH, autospec=True)
     @patch(UPSERT_HITL_PATH)
@@ -253,6 +260,21 @@ class TestDeferForApproval:
 
         assert result == "modified output"
 
+    def test_approved_with_non_string_modified_output_raises(self, 
approval_op_with_modifications):
+        # On the awaiting_input path nothing upstream schema-validates 
params_input
+        # (HITLTrigger did on the legacy path), so execute_complete must 
enforce the
+        # string contract instead of returning a dict as the task's output.
+        event = {
+            "chosen_options": ["Approve"],
+            "responded_by_user": "editor",
+            "params_input": {"output": {"sneaky": "dict"}},
+        }
+
+        with pytest.raises(HITLTriggerEventError, match="must be a string"):
+            approval_op_with_modifications.execute_complete(
+                {}, generated_output="original output", event=event
+            )
+
     def test_approved_with_unmodified_output(self, 
approval_op_with_modifications):
         event = {
             "chosen_options": ["Approve"],
@@ -324,3 +346,41 @@ class TestDeferForApproval:
 
         with pytest.raises(HITLRejectException, match="alice"):
             approval_op.execute_complete({}, generated_output="output", 
event=event)
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="awaiting_input path 
requires Airflow 3.3+")
+class TestAwaitInputForApproval:
+    """On Airflow 3.3+ the review parks the task in AWAITING_INPUT instead of 
deferring."""
+
+    @patch(UPSERT_HITL_PATH)
+    def test_parks_task_in_awaiting_input(self, mock_upsert, approval_op, 
context):
+        with pytest.raises(TaskAwaitingInput) as exc_info:
+            approval_op.defer_for_approval(context, "some LLM output")
+
+        assert exc_info.value.method_name == "execute_complete"
+        assert exc_info.value.kwargs == {"generated_output": "some LLM output"}
+        assert exc_info.value.timeout is None
+        mock_upsert.assert_called_once()
+        assert mock_upsert.call_args[1]["options"] == ["Approve", "Reject"]
+        approval_op.defer.assert_not_called()
+
+    @patch(UPSERT_HITL_PATH)
+    def test_approval_timeout_carried_on_await(self, mock_upsert, context):
+        timeout = timedelta(hours=2)
+        op = FakeOperator(approval_timeout=timeout)
+
+        with pytest.raises(TaskAwaitingInput) as exc_info:
+            op.defer_for_approval(context, "output")
+
+        assert exc_info.value.timeout == timeout
+
+    @patch(UPSERT_HITL_PATH)
+    def test_pydantic_output_stringified_on_await(self, mock_upsert, 
approval_op, context):
+        class Answer(BaseModel):
+            text: str
+            confidence: float
+
+        with pytest.raises(TaskAwaitingInput) as exc_info:
+            approval_op.defer_for_approval(context, Answer(text="Paris", 
confidence=0.95))
+
+        assert exc_info.value.kwargs == {"generated_output": 
'{"text":"Paris","confidence":0.95}'}
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index d5ef8228d35..f9f3bf09099 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -29,13 +29,25 @@ from airflow.providers.common.ai.mixins.approval import (
 )
 from airflow.providers.common.ai.operators.llm import LLMOperator
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, 
AIRFLOW_V_3_3_PLUS
 
 try:
     from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
 except ImportError:
     _CORE_WALKER = False
 
+from airflow.providers.common.compat.sdk import TaskDeferred
+
+if AIRFLOW_V_3_3_PLUS:
+    # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older 
cores defer
+    # to HITLTrigger. Both exceptions carry method_name/kwargs/timeout, so the 
approval
+    # tests assert against whichever pause signal the running core uses.
+    from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
+else:
+    ApprovalPauseSignal = TaskDeferred  # type: ignore[assignment, misc]
+
+AWAIT_INPUT_FLAG_PATH = 
"airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS"
+
 # Returning the Pydantic instance through XCom (rather than a dict) only 
happens
 # on cores that register declared ``output_type`` classes from the worker-side
 # DAG walk. On older cores the operator dumps to a dict, so these tests skip.
@@ -187,8 +199,6 @@ class TestLLMOperatorApproval:
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert, 
mock_trigger_cls):
         """When require_approval=True, execute() defers instead of returning 
output."""
-        from airflow.providers.common.compat.sdk import TaskDeferred
-
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("LLM 
response")
         mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
@@ -201,20 +211,43 @@ class TestLLMOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=ctx)
 
         assert exc_info.value.method_name == "execute_complete"
         assert exc_info.value.kwargs["generated_output"] == "LLM response"
         mock_upsert.assert_called_once()
 
+    @patch(AWAIT_INPUT_FLAG_PATH, False)
+    @patch("airflow.providers.standard.triggers.hitl.HITLTrigger", 
autospec=True)
+    @patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail")
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_with_approval_defers_on_legacy_core(self, mock_hook_cls, 
mock_upsert, mock_trigger_cls):
+        """On cores < 3.3 (flag pinned), execute() falls back to deferring to 
HITLTrigger."""
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_agent.run_sync.return_value = _make_mock_run_result("LLM 
response")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
+
+        op = LLMOperator(
+            task_id="legacy_approval_test",
+            prompt="Summarize this",
+            llm_conn_id="my_llm",
+            require_approval=True,
+        )
+
+        with pytest.raises(TaskDeferred) as exc_info:
+            op.execute(context=_make_context())
+
+        assert exc_info.value.method_name == "execute_complete"
+        assert exc_info.value.kwargs["generated_output"] == "LLM response"
+        mock_trigger_cls.assert_called_once()
+        mock_upsert.assert_called_once()
+
     @patch("airflow.providers.standard.triggers.hitl.HITLTrigger", 
autospec=True)
     @patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail")
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_with_approval_and_modifications(self, mock_hook_cls, 
mock_upsert, mock_trigger_cls):
         """allow_modifications=True passes an editable 'output' param."""
-        from airflow.providers.common.compat.sdk import TaskDeferred
-
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("draft 
output")
         mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
@@ -228,7 +261,7 @@ class TestLLMOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred):
+        with pytest.raises(ApprovalPauseSignal):
             op.execute(context=ctx)
 
         upsert_kwargs = mock_upsert.call_args[1]
@@ -239,8 +272,6 @@ class TestLLMOperatorApproval:
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_with_approval_and_timeout(self, mock_hook_cls, 
mock_upsert, mock_trigger_cls):
         """approval_timeout is passed to the trigger."""
-        from airflow.providers.common.compat.sdk import TaskDeferred
-
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = _make_mock_run_result("output")
         mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
@@ -255,7 +286,7 @@ class TestLLMOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=ctx)
 
         assert exc_info.value.timeout == timeout
@@ -265,8 +296,6 @@ class TestLLMOperatorApproval:
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_with_approval_structured_output(self, mock_hook_cls, 
mock_upsert, mock_trigger_cls):
         """Structured (BaseModel) output is serialized before deferring."""
-        from airflow.providers.common.compat.sdk import TaskDeferred
-
         mock_agent = MagicMock(spec=["run_sync"])
         mock_agent.run_sync.return_value = 
_make_mock_run_result(Summary(text="hello"))
         mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
@@ -280,7 +309,7 @@ class TestLLMOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=ctx)
 
         assert exc_info.value.kwargs["generated_output"] == '{"text":"hello"}'
diff --git 
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
index 7c955a160b4..9e692b420f9 100644
--- 
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
+++ 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
@@ -25,8 +25,17 @@ from pydantic import BaseModel
 
 from airflow.providers.common.ai.operators.llm_file_analysis import 
LLMFileAnalysisOperator
 from airflow.providers.common.ai.utils.file_analysis import FileAnalysisRequest
+from airflow.providers.common.compat.sdk import TaskDeferred
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, 
AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+    # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older 
cores defer to
+    # HITLTrigger. Both signals carry method_name/kwargs/timeout, so the 
approval tests assert
+    # against whichever pause signal the running core uses.
+    from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
+else:
+    ApprovalPauseSignal = TaskDeferred  # type: ignore[assignment, misc]
 
 try:
     from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
@@ -208,8 +217,6 @@ class TestLLMFileAnalysisOperatorApproval:
     def test_execute_with_approval_defers(
         self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
     ):
-        from airflow.providers.common.compat.sdk import TaskDeferred
-
         mock_build_request.return_value = FileAnalysisRequest(
             user_content="prepared prompt",
             resolved_paths=["/tmp/app.log"],
@@ -228,7 +235,7 @@ class TestLLMFileAnalysisOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=ctx)
 
         assert exc_info.value.method_name == "execute_complete"
@@ -244,8 +251,6 @@ class TestLLMFileAnalysisOperatorApproval:
     def test_execute_with_approval_defers_structured_output_as_json(
         self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
     ):
-        from airflow.providers.common.compat.sdk import TaskDeferred
-
         mock_build_request.return_value = FileAnalysisRequest(
             user_content="prepared prompt",
             resolved_paths=["/tmp/app.log"],
@@ -264,7 +269,7 @@ class TestLLMFileAnalysisOperatorApproval:
             require_approval=True,
         )
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=_make_context())
 
         assert exc_info.value.kwargs["generated_output"] == 
'{"findings":["error spike"]}'
@@ -318,8 +323,6 @@ class TestLLMFileAnalysisOperatorApproval:
     def test_execute_with_approval_timeout(
         self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
     ):
-        from airflow.providers.common.compat.sdk import TaskDeferred
-
         mock_build_request.return_value = FileAnalysisRequest(
             user_content="prepared prompt",
             resolved_paths=["/tmp/app.log"],
@@ -339,7 +342,7 @@ class TestLLMFileAnalysisOperatorApproval:
             approval_timeout=timeout,
         )
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=_make_context())
 
         assert exc_info.value.timeout == timeout
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index a994ae3d1cd..1862971c953 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -30,7 +30,15 @@ from airflow.providers.common.ai.utils.sql_validation import 
SQLSafetyError
 from airflow.providers.common.compat.sdk import TaskDeferred
 from airflow.providers.common.sql.config import DataSourceConfig
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, 
AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+    # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older 
cores defer to
+    # HITLTrigger. Both signals carry method_name/kwargs/timeout, so the 
approval tests assert
+    # against whichever pause signal the running core uses.
+    from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
+else:
+    ApprovalPauseSignal = TaskDeferred  # type: ignore[assignment, misc]
 
 
 def _make_mock_run_result(output):
@@ -475,7 +483,7 @@ class TestLLMSQLQueryOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=ctx)
 
         assert exc_info.value.method_name == "execute_complete"
@@ -521,7 +529,7 @@ class TestLLMSQLQueryOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred):
+        with pytest.raises(ApprovalPauseSignal):
             op.execute(context=ctx)
 
         upsert_kwargs = mock_upsert.call_args[1]
@@ -545,7 +553,7 @@ class TestLLMSQLQueryOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=ctx)
 
         assert exc_info.value.timeout == timeout
@@ -583,7 +591,7 @@ class TestLLMSQLQueryOperatorApproval:
         )
         ctx = _make_context()
 
-        with pytest.raises(TaskDeferred) as exc_info:
+        with pytest.raises(ApprovalPauseSignal) as exc_info:
             op.execute(context=ctx)
 
         assert exc_info.value.kwargs["generated_output"] == "SELECT 1"

Reply via email to