This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-0-test by this push:
new 036a2c60840 Add back invalid inlet and outlet check before running
tasks (#50773)
036a2c60840 is described below
commit 036a2c60840eee24f6d93da168d3884034a0ef33
Author: Wei Lee <[email protected]>
AuthorDate: Tue Jun 3 16:35:27 2025 +0800
Add back invalid inlet and outlet check before running tasks (#50773)
* feat(task-sdk): check invalid inlets or outlets before running tasks
* test(pytest_plugin): extend mock_supervisor_comms to ignore invalid
assets in inlets and outlets
* test(task_instances): add test cases TestInvalidInletsAndOutlets
* Revert "test(pytest_plugin): extend mock_supervisor_comms to ignore
invalid assets in inlets and outlets"
This reverts commit 5f6956d89cf4300b8bc374c999af19c02292289a.
* feat(task_runner): early inlets and outlets check
* test(task_runner): fix asset inlet outlets tests
* test(asset): add test cases to AssetUniqueKey
* fix(task_instances): guard AirflowInactiveAssetInInletOrOutletException
in ti_update_state
As we already check before scheduling, it's should normally not happen.
Unless the asset become invalid after task succeeded,
which is not something expected to happen
* refactor: replace invalid with inactive
* refactor(task_instance): update exception
* test(task_runner): improve mocking check
* test(supervisor): improve test_handle_requests
(cherry picked from commit 083e03a909f923527d9d2f8d978962ddfb6e5b7a)
---
.../execution_api/datamodels/taskinstance.py | 6 ++
.../execution_api/routes/task_instances.py | 85 ++++++++++++++++++++--
airflow-core/src/airflow/models/dagrun.py | 1 +
.../versions/head/test_task_instances.py | 53 +++++++++++++-
task-sdk/src/airflow/sdk/api/client.py | 6 ++
.../src/airflow/sdk/api/datamodels/_generated.py | 8 ++
.../src/airflow/sdk/definitions/asset/__init__.py | 12 +++
task-sdk/src/airflow/sdk/execution_time/comms.py | 26 +++++++
.../src/airflow/sdk/execution_time/supervisor.py | 6 ++
.../src/airflow/sdk/execution_time/task_runner.py | 21 ++++++
task-sdk/tests/task_sdk/definitions/test_asset.py | 35 +++++++++
.../task_sdk/execution_time/test_supervisor.py | 15 ++++
.../task_sdk/execution_time/test_task_runner.py | 7 +-
13 files changed, 273 insertions(+), 8 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index b83d731a54e..c43c931f3e2 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -345,3 +345,9 @@ class TaskStatesResponse(BaseModel):
"""Response for task states with run_id, task and state."""
task_states: dict[str, Any]
+
+
+class InactiveAssetsResponse(BaseModel):
+ """Response for inactive assets."""
+
+ inactive_assets: Annotated[list[AssetProfile], Field(default_factory=list)]
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index ac1d1602460..15cdb0a40ca 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -17,12 +17,15 @@
from __future__ import annotations
+import contextlib
+import itertools
import json
from collections import defaultdict
from collections.abc import Iterator
from typing import TYPE_CHECKING, Annotated, Any
from uuid import UUID
+import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, status
@@ -37,6 +40,7 @@ from airflow.api_fastapi.common.dagbag import DagBagDep
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+ InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
TaskStatesResponse,
TIDeferredStatePayload,
@@ -51,6 +55,8 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
+from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException,
TaskNotFound
+from airflow.models.asset import AssetActive
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI,
_stop_remaining_tasks
@@ -58,6 +64,7 @@ from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
+from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -400,12 +407,16 @@ def ti_update_state(
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
updated_state = ti_patch_payload.state
task_instance = session.get(TI, ti_id_str)
- TI.register_asset_changes_in_db(
- task_instance,
- ti_patch_payload.task_outlets, # type: ignore
- ti_patch_payload.outlet_events,
- session,
- )
+ try:
+ TI.register_asset_changes_in_db(
+ task_instance,
+ ti_patch_payload.task_outlets, # type: ignore
+ ti_patch_payload.outlet_events,
+ session,
+ )
+ except AirflowInactiveAssetInInletOrOutletException as err:
+ log.error("Asset registration failed due to conflicting asset:
%s", err)
+
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
@@ -840,5 +851,67 @@ def _get_group_tasks(dag_id: str, task_group_id: str,
session: SessionDep, logic
return group_tasks
+@ti_id_router.get(
+ "/{task_instance_id}/validate-inlets-and-outlets",
+ status_code=status.HTTP_200_OK,
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+ },
+)
+def validate_inlets_and_outlets(
+ task_instance_id: UUID,
+ session: SessionDep,
+ dag_bag: DagBagDep,
+) -> InactiveAssetsResponse:
+ """Validate whether there're inactive assets in inlets and outlets of a
given task instance."""
+ ti_id_str = str(task_instance_id)
+ bind_contextvars(ti_id=ti_id_str)
+
+ ti = session.scalar(select(TI).where(TI.id == ti_id_str))
+ if not ti or not ti.logical_date:
+ log.error("Task Instance not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail={
+ "reason": "not_found",
+ "message": "Task Instance not found",
+ },
+ )
+
+ if not ti.task:
+ dag = dag_bag.get_dag(ti.dag_id)
+ if dag:
+ with contextlib.suppress(TaskNotFound):
+ ti.task = dag.get_task(ti.task_id)
+
+ inlets = [asset.asprofile() for asset in ti.task.inlets if
isinstance(asset, Asset)]
+ outlets = [asset.asprofile() for asset in ti.task.outlets if
isinstance(asset, Asset)]
+ if not (inlets or outlets):
+ return InactiveAssetsResponse(inactive_assets=[])
+
+ all_asset_unique_keys: set[AssetUniqueKey] = {
+ AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore
+ for inlet_or_outlet in itertools.chain(inlets, outlets)
+ }
+ active_asset_unique_keys = {
+ AssetUniqueKey(name, uri)
+ for name, uri in session.execute(
+ select(AssetActive.name, AssetActive.uri).where(
+ tuple_(AssetActive.name, AssetActive.uri).in_(
+ attrs.astuple(key) for key in all_asset_unique_keys
+ )
+ )
+ )
+ }
+ different = all_asset_unique_keys - active_asset_unique_keys
+
+ return InactiveAssetsResponse(
+ inactive_assets=[
+ asset_unique_key.to_asset().asprofile() # type: ignore
+ for asset_unique_key in different
+ ]
+ )
+
+
# This line should be at the end of the file to ensure all routes are
registered
router.include_router(ti_id_router)
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 11a65f87055..3feb0f8794d 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1887,6 +1887,7 @@ class DagRun(Base, LoggingMixin):
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
and not ti.task.outlets
+ and not ti.task.inlets
):
empty_ti_ids.append(ti.id)
# check "start_trigger_args" to see whether the operator supports
start execution from triggerer
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 21c733a5cce..49a8717c36a 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -33,7 +33,7 @@ from airflow.models.asset import AssetActive,
AssetAliasModel, AssetEvent, Asset
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import TaskGroup, task, task_group
+from airflow.sdk import Asset, TaskGroup, task, task_group
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState
@@ -2139,3 +2139,54 @@ class TestGetTaskStates:
response = client.get("/execution/task-instances/states",
params={"dag_id": dr.dag_id, **params})
assert response.status_code == 200
assert response.json() == {"task_states": {dr.run_id: expected}}
+
+
+class TestInvactiveInletsAndOutlets:
+ def test_ti_inactive_inlets_and_outlets(self, client, dag_maker):
+ """Test the inactive assets in inlets and outlets can be found."""
+ with dag_maker("test_inlets_and_outlets"):
+ EmptyOperator(
+ task_id="task1",
+ inlets=[Asset(name="inlet-name"), Asset(name="inlet-name",
uri="but-different-uri")],
+ outlets=[
+ Asset(name="outlet-name", uri="uri"),
+ Asset(name="outlet-name", uri="second-different-uri"),
+ ],
+ )
+
+ dr = dag_maker.create_dagrun()
+
+ task1_ti = dr.get_task_instance("task1")
+ response =
client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
+ assert response.status_code == 200
+ inactive_assets = response.json()["inactive_assets"]
+ expected_inactive_assets = (
+ {
+ "name": "inlet-name",
+ "type": "Asset",
+ "uri": "but-different-uri",
+ },
+ {
+ "name": "outlet-name",
+ "type": "Asset",
+ "uri": "second-different-uri",
+ },
+ )
+ for asset in expected_inactive_assets:
+ assert asset in inactive_assets
+
+ def test_ti_inactive_inlets_and_outlets_without_inactive_assets(self,
client, dag_maker):
+ """Test the task without inactive assets in its inlets or outlets
returns empty list."""
+ with dag_maker("test_inlets_and_outlets_inactive"):
+ EmptyOperator(
+ task_id="inactive_task1",
+ inlets=[Asset(name="inlet-name")],
+ outlets=[Asset(name="outlet-name", uri="uri")],
+ )
+
+ dr = dag_maker.create_dagrun()
+
+ task1_ti = dr.get_task_instance("inactive_task1")
+ response =
client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
+ assert response.status_code == 200
+ assert response.json() == {"inactive_assets": []}
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 919e11ac5b3..7ef0a68a6c0 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -40,6 +40,7 @@ from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
DagRunStateResponse,
DagRunType,
+ InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
TaskInstanceState,
TaskStatesResponse,
@@ -274,6 +275,11 @@ class TaskInstanceOperations:
resp = self.client.get("task-instances/states", params=params)
return TaskStatesResponse.model_validate_json(resp.read())
+ def validate_inlets_and_outlets(self, id: uuid.UUID) ->
InactiveAssetsResponse:
+ """Validate whether there're inactive assets in inlets and outlets of
a given task instance."""
+ resp =
self.client.get(f"task-instances/{id}/validate-inlets-and-outlets")
+ return InactiveAssetsResponse.model_validate_json(resp.read())
+
class ConnectionOperations:
__slots__ = ("client",)
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index f6b1c907ef5..ac1e51d5e55 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -154,6 +154,14 @@ class DagRunType(str, Enum):
ASSET_TRIGGERED = "asset_triggered"
+class InactiveAssetsResponse(BaseModel):
+ """
+ Response for inactive assets.
+ """
+
+ inactive_assets: Annotated[list[AssetProfile] | None,
Field(title="Inactive Assets")] = None
+
+
class IntermediateTIState(str, Enum):
"""
States that a Task Instance can be in that indicate it is not yet in a
terminal or running state.
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
index c81732cf404..9cb913807ee 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -86,6 +86,18 @@ class AssetUniqueKey(attrs.AttrsInstance):
def to_str(self) -> str:
return json.dumps(attrs.asdict(self))
+ @staticmethod
+ def from_profile(profile: AssetProfile) -> AssetUniqueKey:
+ if profile.name and profile.uri:
+ return AssetUniqueKey(name=profile.name, uri=profile.uri)
+
+ if name := profile.name:
+ return AssetUniqueKey(name=name, uri=name)
+ if uri := profile.uri:
+ return AssetUniqueKey(name=uri, uri=uri)
+
+ raise ValueError("name and uri cannot both be empty")
+
@attrs.define(frozen=True)
class AssetAliasUniqueKey:
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index ecc34852252..d0622cf6ffc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -61,6 +61,7 @@ from airflow.sdk.api.datamodels._generated import (
BundleInfo,
ConnectionResponse,
DagRunStateResponse,
+ InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
TaskInstance,
TaskInstanceState,
@@ -208,6 +209,24 @@ class
AssetEventDagRunReferenceResult(AssetEventDagRunReference):
)
+class InactiveAssetsResult(InactiveAssetsResponse):
+ """Response of InactiveAssets requests."""
+
+ type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult"
+
+ @classmethod
+ def from_inactive_assets_response(
+ cls, inactive_assets_response: InactiveAssetsResponse
+ ) -> InactiveAssetsResult:
+ """
+ Get InactiveAssetsResponse from InactiveAssetsResult.
+
+ InactiveAssetsResponse is autogenerated from the API schema, so we
need to convert it to InactiveAssetsResult
+ for communication between the Supervisor and the task process.
+ """
+ return
cls(**inactive_assets_response.model_dump(exclude_defaults=True),
type="InactiveAssetsResult")
+
+
class XComResult(XComResponse):
"""Response to ReadXCom request."""
@@ -376,6 +395,7 @@ ToTask = Annotated[
XComResult,
XComSequenceIndexResult,
XComSequenceSliceResult,
+ InactiveAssetsResult,
OKResponse,
],
Field(discriminator="type"),
@@ -590,6 +610,11 @@ class GetAssetEventByAssetAlias(BaseModel):
type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias"
+class ValidateInletsAndOutlets(BaseModel):
+ ti_id: UUID
+ type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets"
+
+
class GetPrevSuccessfulDagRun(BaseModel):
ti_id: UUID
type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
@@ -657,6 +682,7 @@ ToSupervisor = Annotated[
SetXCom,
SkipDownstreamTasks,
SucceedTask,
+ ValidateInletsAndOutlets,
TaskState,
TriggerDagRun,
DeleteVariable,
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 74dec987c5f..b2dd76d85b1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -88,6 +88,7 @@ from airflow.sdk.execution_time.comms import (
GetXComCount,
GetXComSequenceItem,
GetXComSequenceSlice,
+ InactiveAssetsResult,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
@@ -101,6 +102,7 @@ from airflow.sdk.execution_time.comms import (
TaskStatesResult,
ToSupervisor,
TriggerDagRun,
+ ValidateInletsAndOutlets,
VariableResult,
XComCountResponse,
XComResult,
@@ -1162,6 +1164,10 @@ class ActivitySubprocess(WatchedSubprocess):
)
elif isinstance(msg, DeleteVariable):
resp = self.client.variables.delete(msg.key)
+ elif isinstance(msg, ValidateInletsAndOutlets):
+ inactive_assets_resp =
self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
+ resp =
InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp)
+ dump_opts = {"exclude_unset": True}
else:
log.error("Unhandled request", msg=msg)
return
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index f7faea7489c..1f7219a7a8e 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -40,6 +40,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict,
Field, JsonValue, Typ
from airflow.dag_processing.bundles.base import BaseDagBundle,
BundleVersionLock
from airflow.dag_processing.bundles.manager import DagBundlesManager
+from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
from airflow.listeners.listener import get_listener_manager
from airflow.sdk.api.datamodels._generated import (
AssetProfile,
@@ -66,6 +67,7 @@ from airflow.sdk.execution_time.comms import (
GetTaskRescheduleStartDate,
GetTaskStates,
GetTICount,
+ InactiveAssetsResult,
RescheduleTask,
RetryTask,
SetRenderedFields,
@@ -79,6 +81,7 @@ from airflow.sdk.execution_time.comms import (
ToSupervisor,
ToTask,
TriggerDagRun,
+ ValidateInletsAndOutlets,
)
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
@@ -763,6 +766,8 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context:
Context) -> ToSuperv
# so that we do not call the API unnecessarily
SUPERVISOR_COMMS.send_request(log=log,
msg=SetRenderedFields(rendered_fields=rendered_fields))
+ _validate_task_inlets_and_outlets(ti=ti, log=log)
+
try:
# TODO: Call pre execute etc.
get_listener_manager().hook.on_task_instance_running(
@@ -775,6 +780,22 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger,
context: Context) -> ToSuperv
return None
+def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger)
-> None:
+ if not ti.task.inlets and not ti.task.outlets:
+ return
+
+ SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id),
log=log)
+ inactive_assets_resp = SUPERVISOR_COMMS.get_message()
+ if TYPE_CHECKING:
+ assert isinstance(inactive_assets_resp, InactiveAssetsResult)
+ if inactive_assets := inactive_assets_resp.inactive_assets:
+ raise AirflowInactiveAssetInInletOrOutletException(
+ inactive_asset_keys=[
+ AssetUniqueKey.from_profile(asset_profile) for asset_profile
in inactive_assets
+ ]
+ )
+
+
def _defer_task(
defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
) -> tuple[ToSupervisor, TaskInstanceState]:
diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py
b/task-sdk/tests/task_sdk/definitions/test_asset.py
index 8328b061811..2a25c0907c7 100644
--- a/task-sdk/tests/task_sdk/definitions/test_asset.py
+++ b/task-sdk/tests/task_sdk/definitions/test_asset.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+import json
import os
from typing import Callable
from unittest import mock
@@ -24,6 +25,7 @@ from unittest import mock
import pytest
from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
@@ -384,6 +386,39 @@ def
test_normalize_uri_valid_uri(mock_get_normalized_scheme):
assert asset.normalized_uri == "valid_aip60_uri"
+class TestAssetUniqueKey:
+ def test_from_asset(self):
+ asset = Asset(name="test", uri="test://test/")
+
+ assert AssetUniqueKey.from_asset(asset) == AssetUniqueKey(name="test",
uri="test://test/")
+
+ def test_to_asset(self):
+ assert AssetUniqueKey(name="test", uri="test://test/").to_asset() ==
Asset(
+ name="test", uri="test://test/"
+ )
+
+ def test_from_str(self):
+ json_str = json.dumps({"name": "test", "uri": "test://test/"})
+ assert AssetUniqueKey.from_str(json_str) ==
AssetUniqueKey(name="test", uri="test://test/")
+
+ def test_to_str(self):
+ assert AssetUniqueKey(name="test", uri="test://test/").to_str() ==
json.dumps(
+ {"name": "test", "uri": "test://test/"}
+ )
+
+ @pytest.mark.parametrize(
+ "name, uri, expected_asset_unique_key",
+ [
+ ("test", None, AssetUniqueKey(name="test", uri="test")),
+ (None, "test://test/", AssetUniqueKey(name="test://test/",
uri="test://test/")),
+ ("test", "test://test/", AssetUniqueKey(name="test",
uri="test://test/")),
+ ],
+ )
+ def test_from_profile(self, name, uri, expected_asset_unique_key):
+ profile = AssetProfile(name=name, uri=uri, type="Asset")
+ assert AssetUniqueKey.from_profile(profile) ==
expected_asset_unique_key
+
+
class TestAssetAlias:
def test_as_expression(self):
alias = AssetAlias(name="test_name", group="test")
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index b4fc1340064..4f6341c89da 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -45,6 +45,7 @@ from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import (
AssetEventResponse,
+ AssetProfile,
AssetResponse,
DagRunState,
TaskInstance,
@@ -77,6 +78,7 @@ from airflow.sdk.execution_time.comms import (
GetXCom,
GetXComSequenceItem,
GetXComSequenceSlice,
+ InactiveAssetsResult,
OKResponse,
PrevSuccessfulDagRunResult,
PutVariable,
@@ -90,6 +92,7 @@ from airflow.sdk.execution_time.comms import (
TaskStatesResult,
TICount,
TriggerDagRun,
+ ValidateInletsAndOutlets,
VariableResult,
XComResult,
XComSequenceIndexResult,
@@ -1442,6 +1445,18 @@ class TestHandleRequest:
None,
id="get_asset_events_by_asset_alias",
),
+ pytest.param(
+ ValidateInletsAndOutlets(ti_id=TI_ID),
+
b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n',
+ "task_instances.validate_inlets_and_outlets",
+ (TI_ID,),
+ {},
+ InactiveAssetsResult(
+ inactive_assets=[AssetProfile(name="asset_name",
uri="asset_uri", type="asset")]
+ ),
+ None,
+ id="validate_inlets_and_outlets",
+ ),
pytest.param(
SucceedTask(
end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test success task"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 3077a47f41b..09d4d9709e7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -945,7 +945,12 @@ def test_run_with_asset_outlets(
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
- run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+ with mock.patch(
+
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
+ ) as validate_mock:
+ run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+
+ validate_mock.assert_called_once()
mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg,
log=mock.ANY)