This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-0-test by this push:
     new 036a2c60840 Add back invalid inlet and outlet check before running 
tasks (#50773)
036a2c60840 is described below

commit 036a2c60840eee24f6d93da168d3884034a0ef33
Author: Wei Lee <[email protected]>
AuthorDate: Tue Jun 3 16:35:27 2025 +0800

    Add back invalid inlet and outlet check before running tasks (#50773)
    
    * feat(task-sdk): check invalid inlets or outlets before running tasks
    
    * test(pytest_plugin): extend mock_supervisor_comms to ignore invalid 
assets in inlets and outlets
    
    * test(task_instances): add test cases TestInvalidInletsAndOutlets
    
    * Revert "test(pytest_plugin): extend mock_supervisor_comms to ignore 
invalid assets in inlets and outlets"
    
    This reverts commit 5f6956d89cf4300b8bc374c999af19c02292289a.
    
    * feat(task_runner): early inlets and outlets check
    
    * test(task_runner): fix asset inlet outlets tests
    
    * test(asset): add test cases to AssetUniqueKey
    
    * fix(task_instances): guard AirflowInactiveAssetInInletOrOutletException 
in ti_update_state
    
    As we already check before scheduling, it's should normally not happen. 
Unless the asset become invalid after task succeeded,
    which is not something expected to happen
    
    * refactor: replace invalid with inactive
    
    * refactor(task_instance): update exception
    
    * test(task_runner): improve mocking check
    
    * test(supervisor): improve test_handle_requests
    
    (cherry picked from commit 083e03a909f923527d9d2f8d978962ddfb6e5b7a)
---
 .../execution_api/datamodels/taskinstance.py       |  6 ++
 .../execution_api/routes/task_instances.py         | 85 ++++++++++++++++++++--
 airflow-core/src/airflow/models/dagrun.py          |  1 +
 .../versions/head/test_task_instances.py           | 53 +++++++++++++-
 task-sdk/src/airflow/sdk/api/client.py             |  6 ++
 .../src/airflow/sdk/api/datamodels/_generated.py   |  8 ++
 .../src/airflow/sdk/definitions/asset/__init__.py  | 12 +++
 task-sdk/src/airflow/sdk/execution_time/comms.py   | 26 +++++++
 .../src/airflow/sdk/execution_time/supervisor.py   |  6 ++
 .../src/airflow/sdk/execution_time/task_runner.py  | 21 ++++++
 task-sdk/tests/task_sdk/definitions/test_asset.py  | 35 +++++++++
 .../task_sdk/execution_time/test_supervisor.py     | 15 ++++
 .../task_sdk/execution_time/test_task_runner.py    |  7 +-
 13 files changed, 273 insertions(+), 8 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index b83d731a54e..c43c931f3e2 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -345,3 +345,9 @@ class TaskStatesResponse(BaseModel):
     """Response for task states with run_id, task and state."""
 
     task_states: dict[str, Any]
+
+
+class InactiveAssetsResponse(BaseModel):
+    """Response for inactive assets."""
+
+    inactive_assets: Annotated[list[AssetProfile], Field(default_factory=list)]
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index ac1d1602460..15cdb0a40ca 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -17,12 +17,15 @@
 
 from __future__ import annotations
 
+import contextlib
+import itertools
 import json
 from collections import defaultdict
 from collections.abc import Iterator
 from typing import TYPE_CHECKING, Annotated, Any
 from uuid import UUID
 
+import attrs
 import structlog
 from cadwyn import VersionedAPIRouter
 from fastapi import Body, HTTPException, Query, status
@@ -37,6 +40,7 @@ from airflow.api_fastapi.common.dagbag import DagBagDep
 from airflow.api_fastapi.common.db.common import SessionDep
 from airflow.api_fastapi.common.types import UtcDateTime
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+    InactiveAssetsResponse,
     PrevSuccessfulDagRunResponse,
     TaskStatesResponse,
     TIDeferredStatePayload,
@@ -51,6 +55,8 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     TITerminalStatePayload,
 )
 from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
+from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, 
TaskNotFound
+from airflow.models.asset import AssetActive
 from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun as DR
 from airflow.models.taskinstance import TaskInstance as TI, 
_stop_remaining_tasks
@@ -58,6 +64,7 @@ from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
 from airflow.models.xcom import XComModel
 from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
+from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
 from airflow.sdk.definitions.taskgroup import MappedTaskGroup
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -400,12 +407,16 @@ def ti_update_state(
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
         updated_state = ti_patch_payload.state
         task_instance = session.get(TI, ti_id_str)
-        TI.register_asset_changes_in_db(
-            task_instance,
-            ti_patch_payload.task_outlets,  # type: ignore
-            ti_patch_payload.outlet_events,
-            session,
-        )
+        try:
+            TI.register_asset_changes_in_db(
+                task_instance,
+                ti_patch_payload.task_outlets,  # type: ignore
+                ti_patch_payload.outlet_events,
+                session,
+            )
+        except AirflowInactiveAssetInInletOrOutletException as err:
+            log.error("Asset registration failed due to conflicting asset: 
%s", err)
+
         query = query.values(state=updated_state)
     elif isinstance(ti_patch_payload, TIDeferredStatePayload):
         # Calculate timeout if it was passed
@@ -840,5 +851,67 @@ def _get_group_tasks(dag_id: str, task_group_id: str, 
session: SessionDep, logic
     return group_tasks
 
 
+@ti_id_router.get(
+    "/{task_instance_id}/validate-inlets-and-outlets",
+    status_code=status.HTTP_200_OK,
+    responses={
+        status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+    },
+)
+def validate_inlets_and_outlets(
+    task_instance_id: UUID,
+    session: SessionDep,
+    dag_bag: DagBagDep,
+) -> InactiveAssetsResponse:
+    """Validate whether there're inactive assets in inlets and outlets of a 
given task instance."""
+    ti_id_str = str(task_instance_id)
+    bind_contextvars(ti_id=ti_id_str)
+
+    ti = session.scalar(select(TI).where(TI.id == ti_id_str))
+    if not ti or not ti.logical_date:
+        log.error("Task Instance not found")
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail={
+                "reason": "not_found",
+                "message": "Task Instance not found",
+            },
+        )
+
+    if not ti.task:
+        dag = dag_bag.get_dag(ti.dag_id)
+        if dag:
+            with contextlib.suppress(TaskNotFound):
+                ti.task = dag.get_task(ti.task_id)
+
+    inlets = [asset.asprofile() for asset in ti.task.inlets if 
isinstance(asset, Asset)]
+    outlets = [asset.asprofile() for asset in ti.task.outlets if 
isinstance(asset, Asset)]
+    if not (inlets or outlets):
+        return InactiveAssetsResponse(inactive_assets=[])
+
+    all_asset_unique_keys: set[AssetUniqueKey] = {
+        AssetUniqueKey.from_asset(inlet_or_outlet)  # type: ignore
+        for inlet_or_outlet in itertools.chain(inlets, outlets)
+    }
+    active_asset_unique_keys = {
+        AssetUniqueKey(name, uri)
+        for name, uri in session.execute(
+            select(AssetActive.name, AssetActive.uri).where(
+                tuple_(AssetActive.name, AssetActive.uri).in_(
+                    attrs.astuple(key) for key in all_asset_unique_keys
+                )
+            )
+        )
+    }
+    different = all_asset_unique_keys - active_asset_unique_keys
+
+    return InactiveAssetsResponse(
+        inactive_assets=[
+            asset_unique_key.to_asset().asprofile()  # type: ignore
+            for asset_unique_key in different
+        ]
+    )
+
+
 # This line should be at the end of the file to ensure all routes are 
registered
 router.include_router(ti_id_router)
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index 11a65f87055..3feb0f8794d 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1887,6 +1887,7 @@ class DagRun(Base, LoggingMixin):
                 and not ti.task.on_execute_callback
                 and not ti.task.on_success_callback
                 and not ti.task.outlets
+                and not ti.task.inlets
             ):
                 empty_ti_ids.append(ti.id)
             # check "start_trigger_args" to see whether the operator supports 
start execution from triggerer
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 21c733a5cce..49a8717c36a 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -33,7 +33,7 @@ from airflow.models.asset import AssetActive, 
AssetAliasModel, AssetEvent, Asset
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import TaskGroup, task, task_group
+from airflow.sdk import Asset, TaskGroup, task, task_group
 from airflow.utils import timezone
 from airflow.utils.state import State, TaskInstanceState, TerminalTIState
 
@@ -2139,3 +2139,54 @@ class TestGetTaskStates:
         response = client.get("/execution/task-instances/states", 
params={"dag_id": dr.dag_id, **params})
         assert response.status_code == 200
         assert response.json() == {"task_states": {dr.run_id: expected}}
+
+
+class TestInvactiveInletsAndOutlets:
+    def test_ti_inactive_inlets_and_outlets(self, client, dag_maker):
+        """Test the inactive assets in inlets and outlets can be found."""
+        with dag_maker("test_inlets_and_outlets"):
+            EmptyOperator(
+                task_id="task1",
+                inlets=[Asset(name="inlet-name"), Asset(name="inlet-name", 
uri="but-different-uri")],
+                outlets=[
+                    Asset(name="outlet-name", uri="uri"),
+                    Asset(name="outlet-name", uri="second-different-uri"),
+                ],
+            )
+
+        dr = dag_maker.create_dagrun()
+
+        task1_ti = dr.get_task_instance("task1")
+        response = 
client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
+        assert response.status_code == 200
+        inactive_assets = response.json()["inactive_assets"]
+        expected_inactive_assets = (
+            {
+                "name": "inlet-name",
+                "type": "Asset",
+                "uri": "but-different-uri",
+            },
+            {
+                "name": "outlet-name",
+                "type": "Asset",
+                "uri": "second-different-uri",
+            },
+        )
+        for asset in expected_inactive_assets:
+            assert asset in inactive_assets
+
+    def test_ti_inactive_inlets_and_outlets_without_inactive_assets(self, 
client, dag_maker):
+        """Test the task without inactive assets in its inlets or outlets 
returns empty list."""
+        with dag_maker("test_inlets_and_outlets_inactive"):
+            EmptyOperator(
+                task_id="inactive_task1",
+                inlets=[Asset(name="inlet-name")],
+                outlets=[Asset(name="outlet-name", uri="uri")],
+            )
+
+        dr = dag_maker.create_dagrun()
+
+        task1_ti = dr.get_task_instance("inactive_task1")
+        response = 
client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
+        assert response.status_code == 200
+        assert response.json() == {"inactive_assets": []}
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 919e11ac5b3..7ef0a68a6c0 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -40,6 +40,7 @@ from airflow.sdk.api.datamodels._generated import (
     ConnectionResponse,
     DagRunStateResponse,
     DagRunType,
+    InactiveAssetsResponse,
     PrevSuccessfulDagRunResponse,
     TaskInstanceState,
     TaskStatesResponse,
@@ -274,6 +275,11 @@ class TaskInstanceOperations:
         resp = self.client.get("task-instances/states", params=params)
         return TaskStatesResponse.model_validate_json(resp.read())
 
+    def validate_inlets_and_outlets(self, id: uuid.UUID) -> 
InactiveAssetsResponse:
+        """Validate whether there're inactive assets in inlets and outlets of 
a given task instance."""
+        resp = 
self.client.get(f"task-instances/{id}/validate-inlets-and-outlets")
+        return InactiveAssetsResponse.model_validate_json(resp.read())
+
 
 class ConnectionOperations:
     __slots__ = ("client",)
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index f6b1c907ef5..ac1e51d5e55 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -154,6 +154,14 @@ class DagRunType(str, Enum):
     ASSET_TRIGGERED = "asset_triggered"
 
 
+class InactiveAssetsResponse(BaseModel):
+    """
+    Response for inactive assets.
+    """
+
+    inactive_assets: Annotated[list[AssetProfile] | None, 
Field(title="Inactive Assets")] = None
+
+
 class IntermediateTIState(str, Enum):
     """
     States that a Task Instance can be in that indicate it is not yet in a 
terminal or running state.
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
index c81732cf404..9cb913807ee 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -86,6 +86,18 @@ class AssetUniqueKey(attrs.AttrsInstance):
     def to_str(self) -> str:
         return json.dumps(attrs.asdict(self))
 
+    @staticmethod
+    def from_profile(profile: AssetProfile) -> AssetUniqueKey:
+        if profile.name and profile.uri:
+            return AssetUniqueKey(name=profile.name, uri=profile.uri)
+
+        if name := profile.name:
+            return AssetUniqueKey(name=name, uri=name)
+        if uri := profile.uri:
+            return AssetUniqueKey(name=uri, uri=uri)
+
+        raise ValueError("name and uri cannot both be empty")
+
 
 @attrs.define(frozen=True)
 class AssetAliasUniqueKey:
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index ecc34852252..d0622cf6ffc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -61,6 +61,7 @@ from airflow.sdk.api.datamodels._generated import (
     BundleInfo,
     ConnectionResponse,
     DagRunStateResponse,
+    InactiveAssetsResponse,
     PrevSuccessfulDagRunResponse,
     TaskInstance,
     TaskInstanceState,
@@ -208,6 +209,24 @@ class 
AssetEventDagRunReferenceResult(AssetEventDagRunReference):
         )
 
 
+class InactiveAssetsResult(InactiveAssetsResponse):
+    """Response of InactiveAssets requests."""
+
+    type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult"
+
+    @classmethod
+    def from_inactive_assets_response(
+        cls, inactive_assets_response: InactiveAssetsResponse
+    ) -> InactiveAssetsResult:
+        """
+        Get InactiveAssetsResponse from InactiveAssetsResult.
+
+        InactiveAssetsResponse is autogenerated from the API schema, so we 
need to convert it to InactiveAssetsResult
+        for communication between the Supervisor and the task process.
+        """
+        return 
cls(**inactive_assets_response.model_dump(exclude_defaults=True), 
type="InactiveAssetsResult")
+
+
 class XComResult(XComResponse):
     """Response to ReadXCom request."""
 
@@ -376,6 +395,7 @@ ToTask = Annotated[
         XComResult,
         XComSequenceIndexResult,
         XComSequenceSliceResult,
+        InactiveAssetsResult,
         OKResponse,
     ],
     Field(discriminator="type"),
@@ -590,6 +610,11 @@ class GetAssetEventByAssetAlias(BaseModel):
     type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias"
 
 
+class ValidateInletsAndOutlets(BaseModel):
+    ti_id: UUID
+    type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets"
+
+
 class GetPrevSuccessfulDagRun(BaseModel):
     ti_id: UUID
     type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
@@ -657,6 +682,7 @@ ToSupervisor = Annotated[
         SetXCom,
         SkipDownstreamTasks,
         SucceedTask,
+        ValidateInletsAndOutlets,
         TaskState,
         TriggerDagRun,
         DeleteVariable,
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 74dec987c5f..b2dd76d85b1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -88,6 +88,7 @@ from airflow.sdk.execution_time.comms import (
     GetXComCount,
     GetXComSequenceItem,
     GetXComSequenceSlice,
+    InactiveAssetsResult,
     PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
@@ -101,6 +102,7 @@ from airflow.sdk.execution_time.comms import (
     TaskStatesResult,
     ToSupervisor,
     TriggerDagRun,
+    ValidateInletsAndOutlets,
     VariableResult,
     XComCountResponse,
     XComResult,
@@ -1162,6 +1164,10 @@ class ActivitySubprocess(WatchedSubprocess):
             )
         elif isinstance(msg, DeleteVariable):
             resp = self.client.variables.delete(msg.key)
+        elif isinstance(msg, ValidateInletsAndOutlets):
+            inactive_assets_resp = 
self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
+            resp = 
InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp)
+            dump_opts = {"exclude_unset": True}
         else:
             log.error("Unhandled request", msg=msg)
             return
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index f7faea7489c..1f7219a7a8e 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -40,6 +40,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, 
Field, JsonValue, Typ
 
 from airflow.dag_processing.bundles.base import BaseDagBundle, 
BundleVersionLock
 from airflow.dag_processing.bundles.manager import DagBundlesManager
+from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
 from airflow.listeners.listener import get_listener_manager
 from airflow.sdk.api.datamodels._generated import (
     AssetProfile,
@@ -66,6 +67,7 @@ from airflow.sdk.execution_time.comms import (
     GetTaskRescheduleStartDate,
     GetTaskStates,
     GetTICount,
+    InactiveAssetsResult,
     RescheduleTask,
     RetryTask,
     SetRenderedFields,
@@ -79,6 +81,7 @@ from airflow.sdk.execution_time.comms import (
     ToSupervisor,
     ToTask,
     TriggerDagRun,
+    ValidateInletsAndOutlets,
 )
 from airflow.sdk.execution_time.context import (
     ConnectionAccessor,
@@ -763,6 +766,8 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: 
Context) -> ToSuperv
         # so that we do not call the API unnecessarily
         SUPERVISOR_COMMS.send_request(log=log, 
msg=SetRenderedFields(rendered_fields=rendered_fields))
 
+    _validate_task_inlets_and_outlets(ti=ti, log=log)
+
     try:
         # TODO: Call pre execute etc.
         get_listener_manager().hook.on_task_instance_running(
@@ -775,6 +780,22 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, 
context: Context) -> ToSuperv
     return None
 
 
+def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) 
-> None:
+    if not ti.task.inlets and not ti.task.outlets:
+        return
+
+    SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), 
log=log)
+    inactive_assets_resp = SUPERVISOR_COMMS.get_message()
+    if TYPE_CHECKING:
+        assert isinstance(inactive_assets_resp, InactiveAssetsResult)
+    if inactive_assets := inactive_assets_resp.inactive_assets:
+        raise AirflowInactiveAssetInInletOrOutletException(
+            inactive_asset_keys=[
+                AssetUniqueKey.from_profile(asset_profile) for asset_profile 
in inactive_assets
+            ]
+        )
+
+
 def _defer_task(
     defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
 ) -> tuple[ToSupervisor, TaskInstanceState]:
diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py 
b/task-sdk/tests/task_sdk/definitions/test_asset.py
index 8328b061811..2a25c0907c7 100644
--- a/task-sdk/tests/task_sdk/definitions/test_asset.py
+++ b/task-sdk/tests/task_sdk/definitions/test_asset.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import json
 import os
 from typing import Callable
 from unittest import mock
@@ -24,6 +25,7 @@ from unittest import mock
 import pytest
 
 from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk.api.datamodels._generated import AssetProfile
 from airflow.sdk.definitions.asset import (
     Asset,
     AssetAlias,
@@ -384,6 +386,39 @@ def 
test_normalize_uri_valid_uri(mock_get_normalized_scheme):
     assert asset.normalized_uri == "valid_aip60_uri"
 
 
+class TestAssetUniqueKey:
+    def test_from_asset(self):
+        asset = Asset(name="test", uri="test://test/")
+
+        assert AssetUniqueKey.from_asset(asset) == AssetUniqueKey(name="test", 
uri="test://test/")
+
+    def test_to_asset(self):
+        assert AssetUniqueKey(name="test", uri="test://test/").to_asset() == 
Asset(
+            name="test", uri="test://test/"
+        )
+
+    def test_from_str(self):
+        json_str = json.dumps({"name": "test", "uri": "test://test/"})
+        assert AssetUniqueKey.from_str(json_str) == 
AssetUniqueKey(name="test", uri="test://test/")
+
+    def test_to_str(self):
+        assert AssetUniqueKey(name="test", uri="test://test/").to_str() == 
json.dumps(
+            {"name": "test", "uri": "test://test/"}
+        )
+
+    @pytest.mark.parametrize(
+        "name, uri, expected_asset_unique_key",
+        [
+            ("test", None, AssetUniqueKey(name="test", uri="test")),
+            (None, "test://test/", AssetUniqueKey(name="test://test/", 
uri="test://test/")),
+            ("test", "test://test/", AssetUniqueKey(name="test", 
uri="test://test/")),
+        ],
+    )
+    def test_from_profile(self, name, uri, expected_asset_unique_key):
+        profile = AssetProfile(name=name, uri=uri, type="Asset")
+        assert AssetUniqueKey.from_profile(profile) == 
expected_asset_unique_key
+
+
 class TestAssetAlias:
     def test_as_expression(self):
         alias = AssetAlias(name="test_name", group="test")
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index b4fc1340064..4f6341c89da 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -45,6 +45,7 @@ from airflow.sdk.api import client as sdk_client
 from airflow.sdk.api.client import ServerResponseError
 from airflow.sdk.api.datamodels._generated import (
     AssetEventResponse,
+    AssetProfile,
     AssetResponse,
     DagRunState,
     TaskInstance,
@@ -77,6 +78,7 @@ from airflow.sdk.execution_time.comms import (
     GetXCom,
     GetXComSequenceItem,
     GetXComSequenceSlice,
+    InactiveAssetsResult,
     OKResponse,
     PrevSuccessfulDagRunResult,
     PutVariable,
@@ -90,6 +92,7 @@ from airflow.sdk.execution_time.comms import (
     TaskStatesResult,
     TICount,
     TriggerDagRun,
+    ValidateInletsAndOutlets,
     VariableResult,
     XComResult,
     XComSequenceIndexResult,
@@ -1442,6 +1445,18 @@ class TestHandleRequest:
                 None,
                 id="get_asset_events_by_asset_alias",
             ),
+            pytest.param(
+                ValidateInletsAndOutlets(ti_id=TI_ID),
+                
b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n',
+                "task_instances.validate_inlets_and_outlets",
+                (TI_ID,),
+                {},
+                InactiveAssetsResult(
+                    inactive_assets=[AssetProfile(name="asset_name", 
uri="asset_uri", type="asset")]
+                ),
+                None,
+                id="validate_inlets_and_outlets",
+            ),
             pytest.param(
                 SucceedTask(
                     end_date=timezone.parse("2024-10-31T12:00:00Z"), 
rendered_map_index="test success task"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 3077a47f41b..09d4d9709e7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -945,7 +945,12 @@ def test_run_with_asset_outlets(
     instant = timezone.datetime(2024, 12, 3, 10, 0)
     time_machine.move_to(instant, tick=False)
 
-    run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+    with mock.patch(
+        
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
+    ) as validate_mock:
+        run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+
+    validate_mock.assert_called_once()
 
     mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, 
log=mock.ANY)
 

Reply via email to