This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dc4ce652b67 AIP-72: Port _validate_inlet_outlet_assets_activeness into 
Task SDK (#46020)
dc4ce652b67 is described below

commit dc4ce652b67dc7930976bd9765e31bd0af6f923e
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jan 30 12:04:39 2025 +0530

    AIP-72: Port _validate_inlet_outlet_assets_activeness into Task SDK (#46020)
---
 .../execution_api/datamodels/taskinstance.py       |  7 +++
 .../execution_api/routes/task_instances.py         | 42 +++++++++++++++
 airflow/models/taskinstance.py                     | 21 +++++---
 task_sdk/src/airflow/sdk/api/client.py             | 16 +++++-
 .../src/airflow/sdk/api/datamodels/_generated.py   |  9 ++++
 .../src/airflow/sdk/definitions/asset/__init__.py  |  9 ++++
 task_sdk/src/airflow/sdk/execution_time/comms.py   | 12 +++++
 .../src/airflow/sdk/execution_time/supervisor.py   |  4 ++
 .../src/airflow/sdk/execution_time/task_runner.py  | 52 +++++++++++-------
 task_sdk/tests/execution_time/test_supervisor.py   | 23 +++++++-
 task_sdk/tests/execution_time/test_task_runner.py  | 39 ++++++++++++++
 .../execution_api/routes/test_task_instances.py    | 61 ++++++++++++++++++++++
 12 files changed, 266 insertions(+), 29 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 6cc82259cf7..e427cac5f3d 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -249,3 +249,10 @@ class PrevSuccessfulDagRunResponse(BaseModel):
     data_interval_end: UtcDateTime | None = None
     start_date: UtcDateTime | None = None
     end_date: UtcDateTime | None = None
+
+
+class TIRuntimeCheckPayload(BaseModel):
+    """Payload for performing Runtime checks on the TaskInstance model as 
requested by the SDK."""
+
+    inlets: list[AssetProfile] | None = None
+    outlets: list[AssetProfile] | None = None
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 899017e612d..155f96f861a 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -37,10 +37,12 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     TIHeartbeatInfo,
     TIRescheduleStatePayload,
     TIRunContext,
+    TIRuntimeCheckPayload,
     TIStateUpdate,
     TISuccessStatePayload,
     TITerminalStatePayload,
 )
+from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
 from airflow.models.dagrun import DagRun as DR
 from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
 from airflow.models.taskreschedule import TaskReschedule
@@ -442,6 +444,46 @@ def get_previous_successful_dagrun(
     return PrevSuccessfulDagRunResponse.model_validate(dag_run)
 
 
[email protected](
+    "/{task_instance_id}/runtime-checks",
+    status_code=status.HTTP_204_NO_CONTENT,
+    # TODO: Add description to the operation
+    # TODO: Add Operation ID to control the function name in the OpenAPI spec
+    # TODO: Do we need to use create_openapi_http_exception_doc here?
+    responses={
+        status.HTTP_400_BAD_REQUEST: {"description": "Task Instance failed the 
runtime checks."},
+        status.HTTP_409_CONFLICT: {
+            "description": "Task Instance isn't in a running state. Cannot 
perform runtime checks."
+        },
+        status.HTTP_422_UNPROCESSABLE_ENTITY: {
+            "description": "Invalid payload for requested runtime checks on 
the Task Instance."
+        },
+    },
+)
+def ti_runtime_checks(
+    task_instance_id: UUID,
+    payload: TIRuntimeCheckPayload,
+    session: SessionDep,
+):
+    ti_id_str = str(task_instance_id)
+    task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
+    if task_instance.state != State.RUNNING:
+        raise HTTPException(status_code=status.HTTP_409_CONFLICT)
+
+    try:
+        TI.validate_inlet_outlet_assets_activeness(payload.inlets, 
payload.outlets, session)  # type: ignore
+    except AirflowInactiveAssetInInletOrOutletException as e:
+        log.error("Task Instance %s fails the runtime checks.", ti_id_str)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail={
+                "reason": "validation_failed",
+                "message": "Task Instance fails the runtime checks",
+                "error": str(e),
+            },
+        )
+
+
 def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
     """Is task instance is eligible for retry."""
     if state == State.RESTARTING:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index f1aa3a8236e..69b6d147ead 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -260,7 +260,10 @@ def _run_raw_task(
         context = ti.get_template_context(ignore_param_exceptions=False, 
session=session)
 
         try:
-            ti._validate_inlet_outlet_assets_activeness(session=session)
+            if ti.task:
+                inlets = [asset.asprofile() for asset in ti.task.inlets if 
isinstance(asset, Asset)]
+                outlets = [asset.asprofile() for asset in ti.task.outlets if 
isinstance(asset, Asset)]
+                TaskInstance.validate_inlet_outlet_assets_activeness(inlets, 
outlets, session=session)
             if not mark_success:
                 TaskInstance._execute_task_with_callbacks(
                     self=ti,  # type: ignore[arg-type]
@@ -3715,16 +3718,20 @@ class TaskInstance(Base, LoggingMixin):
             }
         )
 
-    def _validate_inlet_outlet_assets_activeness(self, session: Session) -> 
None:
-        if not self.task or not (self.task.outlets or self.task.inlets):
+    @staticmethod
+    def validate_inlet_outlet_assets_activeness(
+        inlets: list[AssetProfile], outlets: list[AssetProfile], session: 
Session
+    ) -> None:
+        if not (inlets or outlets):
             return
 
         all_asset_unique_keys = {
-            AssetUniqueKey.from_asset(inlet_or_outlet)
-            for inlet_or_outlet in itertools.chain(self.task.inlets, 
self.task.outlets)
-            if isinstance(inlet_or_outlet, Asset)
+            AssetUniqueKey.from_asset(inlet_or_outlet)  # type: ignore
+            for inlet_or_outlet in itertools.chain(inlets, outlets)
         }
-        inactive_asset_unique_keys = 
self._get_inactive_asset_unique_keys(all_asset_unique_keys, session)
+        inactive_asset_unique_keys = 
TaskInstance._get_inactive_asset_unique_keys(
+            all_asset_unique_keys, session
+        )
         if inactive_asset_unique_keys:
             raise 
AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys)
 
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index 443256e3a67..821e589ad52 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -32,6 +32,7 @@ from retryhttp import retry, wait_retry_after
 from tenacity import before_log, wait_random_exponential
 from uuid6 import uuid7
 
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TIRuntimeCheckPayload
 from airflow.sdk import __version__
 from airflow.sdk.api.datamodels._generated import (
     AssetResponse,
@@ -52,7 +53,7 @@ from airflow.sdk.api.datamodels._generated import (
     XComResponse,
 )
 from airflow.sdk.exceptions import ErrorType
-from airflow.sdk.execution_time.comms import ErrorResponse
+from airflow.sdk.execution_time.comms import ErrorResponse, OKResponse, 
RuntimeCheckOnTask
 from airflow.utils.net import get_hostname
 from airflow.utils.platform import getuser
 
@@ -177,6 +178,19 @@ class TaskInstanceOperations:
         resp = 
self.client.get(f"task-instances/{id}/previous-successful-dagrun")
         return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())
 
+    def runtime_checks(self, id: uuid.UUID, msg: RuntimeCheckOnTask) -> 
OKResponse:
+        body = TIRuntimeCheckPayload(**msg.model_dump(exclude_unset=True))
+        try:
+            self.client.post(f"task-instances/{id}/runtime-checks", 
content=body.model_dump_json())
+            return OKResponse(ok=True)
+        except ServerResponseError as e:
+            if e.response.status_code == 400:
+                return OKResponse(ok=False)
+            elif e.response.status_code == 409:
+                # The TI isn't in the right state to perform the check, but we 
shouldn't fail the task for that
+                return OKResponse(ok=True)
+            raise
+
 
 class ConnectionOperations:
     __slots__ = ("client",)
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index 3383e61d1c3..1d6d0eb4156 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -148,6 +148,15 @@ class TIRescheduleStatePayload(BaseModel):
     end_date: Annotated[datetime, Field(title="End Date")]
 
 
+class TIRuntimeCheckPayload(BaseModel):
+    """
+    Payload for performing Runtime checks on the TaskInstance model as 
requested by the SDK.
+    """
+
+    inlets: Annotated[list[AssetProfile] | None, Field(title="Inlets")] = None
+    outlets: Annotated[list[AssetProfile] | None, Field(title="Outlets")] = 
None
+
+
 class TISuccessStatePayload(BaseModel):
     """
     Schema for updating TaskInstance to success state.
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index 91ebc4ab6cb..b976bb8c156 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, 
Union, overload
 
 import attrs
 
+from airflow.sdk.api.datamodels._generated import AssetProfile
 from airflow.serialization.dag_dependency import DagDependency
 
 if TYPE_CHECKING:
@@ -428,6 +429,14 @@ class Asset(os.PathLike, BaseAsset):
             dependency_id=self.name,
         )
 
+    def asprofile(self) -> AssetProfile:
+        """
+        Profiles Asset to AssetProfile.
+
+        :meta private:
+        """
+        return AssetProfile(name=self.name or None, uri=self.uri or None, 
asset_type=Asset.__name__)
+
 
 class AssetRef(BaseAsset, AttrsInstance):
     """
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 3ab8addc8bb..93ea133f489 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -60,6 +60,7 @@ from airflow.sdk.api.datamodels._generated import (
     TIDeferredStatePayload,
     TIRescheduleStatePayload,
     TIRunContext,
+    TIRuntimeCheckPayload,
     TISuccessStatePayload,
     VariableResponse,
     XComResponse,
@@ -169,6 +170,11 @@ class ErrorResponse(BaseModel):
     type: Literal["ErrorResponse"] = "ErrorResponse"
 
 
+class OKResponse(BaseModel):
+    ok: bool
+    type: Literal["OKResponse"] = "OKResponse"
+
+
 ToTask = Annotated[
     Union[
         AssetResult,
@@ -178,6 +184,7 @@ ToTask = Annotated[
         StartupDetails,
         VariableResult,
         XComResult,
+        OKResponse,
     ],
     Field(discriminator="type"),
 ]
@@ -220,6 +227,10 @@ class RescheduleTask(TIRescheduleStatePayload):
     type: Literal["RescheduleTask"] = "RescheduleTask"
 
 
+class RuntimeCheckOnTask(TIRuntimeCheckPayload):
+    type: Literal["RuntimeCheckOnTask"] = "RuntimeCheckOnTask"
+
+
 class GetXCom(BaseModel):
     key: str
     dag_id: str
@@ -317,6 +328,7 @@ ToSupervisor = Annotated[
         SetRenderedFields,
         SetXCom,
         TaskState,
+        RuntimeCheckOnTask,
     ],
     Field(discriminator="type"),
 ]
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 569855016cf..30050c0b955 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -73,6 +73,7 @@ from airflow.sdk.execution_time.comms import (
     PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
+    RuntimeCheckOnTask,
     SetRenderedFields,
     SetXCom,
     StartupDetails,
@@ -767,6 +768,9 @@ class ActivitySubprocess(WatchedSubprocess):
         if isinstance(msg, TaskState):
             self._terminal_state = msg.state
             self._task_end_time_monotonic = time.monotonic()
+        elif isinstance(msg, RuntimeCheckOnTask):
+            runtime_check_resp = 
self.client.task_instances.runtime_checks(id=self.id, msg=msg)
+            resp = runtime_check_resp.model_dump_json().encode()
         elif isinstance(msg, SucceedTask):
             self._terminal_state = msg.state
             self.client.task_instances.succeed(
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index c2d2c51b630..715dbf75dd7 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -40,7 +40,9 @@ from airflow.sdk.definitions.baseoperator import BaseOperator
 from airflow.sdk.execution_time.comms import (
     DeferTask,
     GetXCom,
+    OKResponse,
     RescheduleTask,
+    RuntimeCheckOnTask,
     SetRenderedFields,
     SetXCom,
     StartupDetails,
@@ -501,26 +503,36 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         # TODO: Get a real context object
         ti.hostname = get_hostname()
         ti.task = ti.task.prepare_for_execution()
-        context = ti.get_template_context()
-        with set_current_context(context):
-            jinja_env = ti.task.dag.get_template_env()
-            ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
-            result = _execute_task(context, ti.task)
-
-        _push_xcom_if_needed(result, ti)
-
-        task_outlets, outlet_events = _process_outlets(context, 
ti.task.outlets)
-
-        # TODO: Get things from _execute_task_with_callbacks
-        #   - Clearing XCom
-        #   - Update RTIF
-        #   - Pre Execute
-        #   etc
-        msg = SucceedTask(
-            end_date=datetime.now(tz=timezone.utc),
-            task_outlets=task_outlets,
-            outlet_events=outlet_events,
-        )
+        if ti.task.inlets or ti.task.outlets:
+            inlets = [asset.asprofile() for asset in ti.task.inlets if 
isinstance(asset, Asset)]
+            outlets = [asset.asprofile() for asset in ti.task.outlets if 
isinstance(asset, Asset)]
+            
SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, 
outlets=outlets), log=log)  # type: ignore
+            msg = SUPERVISOR_COMMS.get_message()  # type: ignore
+
+        if isinstance(msg, OKResponse) and not msg.ok:
+            log.info("Runtime checks failed for task, marking task as 
failed..")
+            msg = TaskState(
+                state=TerminalTIState.FAILED,
+                end_date=datetime.now(tz=timezone.utc),
+            )
+        else:
+            context = ti.get_template_context()
+            with set_current_context(context):
+                jinja_env = ti.task.dag.get_template_env()
+                ti.task = ti.render_templates(context=context, 
jinja_env=jinja_env)
+                # TODO: Get things from _execute_task_with_callbacks
+                #   - Pre Execute
+                #   etc
+                result = _execute_task(context, ti.task)
+
+            _push_xcom_if_needed(result, ti)
+
+            task_outlets, outlet_events = _process_outlets(context, 
ti.task.outlets)
+            msg = SucceedTask(
+                end_date=datetime.now(tz=timezone.utc),
+                task_outlets=task_outlets,
+                outlet_events=outlet_events,
+            )
     except TaskDeferred as defer:
         # TODO: Should we use structlog.bind_contextvars here for dag_id, 
task_id & run_id?
         log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, 
task_id=ti.task_id, run_id=ti.run_id)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 2bd631fc080..4bc8febc67c 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -39,7 +39,7 @@ from uuid6 import uuid7
 from airflow.executors.workloads import BundleInfo
 from airflow.sdk.api import client as sdk_client
 from airflow.sdk.api.client import ServerResponseError
-from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
+from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, 
TerminalTIState
 from airflow.sdk.execution_time.comms import (
     AssetResult,
     ConnectionResult,
@@ -50,9 +50,11 @@ from airflow.sdk.execution_time.comms import (
     GetPrevSuccessfulDagRun,
     GetVariable,
     GetXCom,
+    OKResponse,
     PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
+    RuntimeCheckOnTask,
     SetRenderedFields,
     SetXCom,
     SucceedTask,
@@ -1011,6 +1013,25 @@ class TestHandleRequest:
                 ),
                 id="get_prev_successful_dagrun",
             ),
+            pytest.param(
+                RuntimeCheckOnTask(
+                    inlets=[AssetProfile(name="alias", uri="alias", 
asset_type="asset")],
+                    outlets=[AssetProfile(name="alias", uri="alias", 
asset_type="asset")],
+                ),
+                b'{"ok":true,"type":"OKResponse"}\n',
+                "task_instances.runtime_checks",
+                (),
+                {
+                    "id": TI_ID,
+                    "msg": RuntimeCheckOnTask(
+                        inlets=[AssetProfile(name="alias", uri="alias", 
asset_type="asset")],  # type: ignore
+                        outlets=[AssetProfile(name="alias", uri="alias", 
asset_type="asset")],  # type: ignore
+                        type="RuntimeCheckOnTask",
+                    ),
+                },
+                OKResponse(ok=True),
+                id="runtime_check_on_task",
+            ),
         ],
     )
     def test_handle_requests(
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 250b7e765a0..d9aa675242c 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -47,7 +47,9 @@ from airflow.sdk.execution_time.comms import (
     GetConnection,
     GetVariable,
     GetXCom,
+    OKResponse,
     PrevSuccessfulDagRunResult,
+    RuntimeCheckOnTask,
     SetRenderedFields,
     StartupDetails,
     SucceedTask,
@@ -651,6 +653,43 @@ def test_run_with_asset_outlets(
     mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, 
log=mock.ANY)
 
 
+def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):
+    """Test running a basic tasks with inlets and outlets."""
+    from airflow.providers.standard.operators.bash import BashOperator
+
+    task = BashOperator(
+        outlets=[
+            Asset(name="name", uri="s3://bucket/my-task"),
+            Asset(name="new-name", uri="s3://bucket/my-task"),
+        ],
+        inlets=[
+            Asset(name="name", uri="s3://bucket/my-task"),
+            Asset(name="new-name", uri="s3://bucket/my-task"),
+        ],
+        task_id="inlets-and-outlets",
+        bash_command="echo 'hi'",
+    )
+
+    ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets")
+    mock_supervisor_comms.get_message.return_value = OKResponse(
+        ok=True,
+    )
+
+    run(ti, log=mock.MagicMock())
+
+    expected = RuntimeCheckOnTask(
+        inlets=[
+            AssetProfile(name="name", uri="s3://bucket/my-task", 
asset_type="Asset"),
+            AssetProfile(name="new-name", uri="s3://bucket/my-task", 
asset_type="Asset"),
+        ],
+        outlets=[
+            AssetProfile(name="name", uri="s3://bucket/my-task", 
asset_type="Asset"),
+            AssetProfile(name="new-name", uri="s3://bucket/my-task", 
asset_type="Asset"),
+        ],
+    )
+    mock_supervisor_comms.send_request.assert_any_call(msg=expected, 
log=mock.ANY)
+
+
 class TestRuntimeTaskInstance:
     def test_get_context_without_ti_context_from_server(self, mocked_parse, 
make_ti_context):
         """Test get_template_context without ti_context_from_server."""
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 9ccd1b0d088..8a7b21699f2 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -25,9 +25,11 @@ import uuid6
 from sqlalchemy import select, update
 from sqlalchemy.exc import SQLAlchemyError
 
+from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
 from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
 from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, 
AssetModel
 from airflow.models.taskinstance import TaskInstance
+from airflow.sdk.definitions.asset import AssetUniqueKey
 from airflow.utils import timezone
 from airflow.utils.state import State, TaskInstanceState, TerminalTIState
 
@@ -655,6 +657,65 @@ class TestTIUpdateState:
         assert ti.next_kwargs is None
         assert ti.duration == 3600.00
 
+    @pytest.mark.parametrize(
+        ("state", "expected_status_code"),
+        [
+            (State.RUNNING, 204),
+            (State.SUCCESS, 409),
+            (State.QUEUED, 409),
+            (State.FAILED, 409),
+        ],
+    )
+    def test_ti_runtime_checks_success(
+        self, client, session, create_task_instance, state, 
expected_status_code
+    ):
+        ti = create_task_instance(
+            task_id="test_ti_runtime_checks",
+            state=state,
+        )
+        session.commit()
+
+        with mock.patch(
+            
"airflow.models.taskinstance.TaskInstance.validate_inlet_outlet_assets_activeness"
+        ) as mock_validate_inlet_outlet_assets_activeness:
+            mock_validate_inlet_outlet_assets_activeness.return_value = None
+            response = client.post(
+                f"/execution/task-instances/{ti.id}/runtime-checks",
+                json={
+                    "inlets": [],
+                    "outlets": [],
+                },
+            )
+
+            assert response.status_code == expected_status_code
+
+        session.expire_all()
+
+    def test_ti_runtime_checks_failure(self, client, session, 
create_task_instance):
+        ti = create_task_instance(
+            task_id="test_ti_runtime_checks_failure",
+            state=State.RUNNING,
+        )
+        session.commit()
+
+        with mock.patch(
+            
"airflow.models.taskinstance.TaskInstance.validate_inlet_outlet_assets_activeness"
+        ) as mock_validate_inlet_outlet_assets_activeness:
+            mock_validate_inlet_outlet_assets_activeness.side_effect = (
+                
AirflowInactiveAssetInInletOrOutletException([AssetUniqueKey(name="abc", 
uri="something")])
+            )
+            response = client.post(
+                f"/execution/task-instances/{ti.id}/runtime-checks",
+                json={
+                    "inlets": [],
+                    "outlets": [],
+                },
+            )
+
+            assert response.status_code == 400
+
+        session.expire_all()
+
 
 class TestTIHealthEndpoint:
     def setup_method(self):

Reply via email to