This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8e355843228 Remove no-longer-needed execution interface hacks (#55681)
8e355843228 is described below
commit 8e355843228533f9743590e957efb97902a8d3a8
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Sep 17 00:19:46 2025 +0800
Remove no-longer-needed execution interface hacks (#55681)
---
airflow-core/src/airflow/api/common/mark_tasks.py | 14 ++-------
.../src/airflow/cli/commands/task_command.py | 35 ++++++++++++----------
airflow-core/src/airflow/models/taskinstance.py | 13 +++++---
.../airflow/serialization/serialized_objects.py | 7 ++---
airflow-core/src/airflow/utils/cli.py | 4 +--
airflow-core/tests/unit/models/test_cleartasks.py | 4 +--
.../tests/unit/models/test_taskinstance.py | 8 ++---
task-sdk/src/airflow/sdk/definitions/dag.py | 24 ++++++++++-----
task-sdk/src/airflow/sdk/types.py | 2 ++
9 files changed, 59 insertions(+), 52 deletions(-)
diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py
b/airflow-core/src/airflow/api/common/mark_tasks.py
index d424bab603a..5c0ed4b9f5f 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -20,7 +20,7 @@
from __future__ import annotations
from collections.abc import Collection, Iterable
-from typing import TYPE_CHECKING, TypeAlias, cast
+from typing import TYPE_CHECKING, TypeAlias
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import lazyload
@@ -228,9 +228,7 @@ def set_dag_run_state_to_success(
if not run_id:
raise ValueError(f"Invalid dag_run_id: {run_id}")
- # TODO (GH-52141): 'tasks' in scheduler needs to return scheduler types
- # instead, but currently it inherits SDK's DAG.
- tasks = cast("list[Operator]", dag.tasks)
+ tasks = dag.tasks
# Mark all task instances of the dag run to success - except for
unfinished teardown as they need to complete work.
teardown_tasks = [task for task in tasks if task.is_teardown]
@@ -312,13 +310,7 @@ def set_dag_run_state_to_failed(
task.dag = dag
return task
- # TODO (GH-52141): 'tasks' in scheduler needs to return scheduler types
- # instead, but currently it inherits SDK's DAG.
- running_tasks = [
- _set_runing_task(task)
- for task in cast("list[Operator]", dag.tasks)
- if task.task_id in task_ids_of_running_tis
- ]
+ running_tasks = [_set_runing_task(task) for task in dag.tasks if
task.task_id in task_ids_of_running_tis]
# Mark non-finished tasks as SKIPPED.
pending_tis: list[TaskInstance] = session.scalars(
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py
b/airflow-core/src/airflow/cli/commands/task_command.py
index b19ea6161a2..9b4cd4114f4 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -46,6 +46,7 @@ from airflow.utils.cli import (
get_bagged_dag,
get_dag_by_file_location,
get_dags,
+ get_db_dag,
suppress_logs_and_warning,
)
from airflow.utils.helpers import ask_yesno
@@ -82,7 +83,7 @@ def _generate_temporary_run_id() -> str:
def _get_dag_run(
*,
- dag: DAG | SerializedDAG,
+ dag: SerializedDAG,
create_if_necessary: CreateIfNecessary,
logical_date_or_run_id: str | None = None,
session: Session | None = None,
@@ -144,9 +145,8 @@ def _get_dag_run(
)
return dag_run, True
if create_if_necessary == "db":
- scheduler_dag =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) # type:
ignore[arg-type]
dag_run = get_or_create_dagrun(
- dag=scheduler_dag,
+ dag=dag,
run_id=_generate_temporary_run_id(),
logical_date=dag_run_logical_date,
data_interval=data_interval,
@@ -246,10 +246,7 @@ def task_failed_deps(args) -> None:
Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks
to have succeeded, but found 1 non-success(es).
"""
- dag = get_bagged_dag(args.bundle_name, args.dag_id)
- # TODO (GH-52141): get_task in scheduler needs to return scheduler types
- # instead, but currently it inherits SDK's DAG.
- task = cast("Operator", dag.get_task(task_id=args.task_id))
+ task = get_db_dag(args.bundle_name,
args.dag_id).get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id)
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
@@ -387,29 +384,35 @@ def task_test(args, dag: DAG | None = None) -> None:
env_vars.update(args.env_vars)
os.environ.update(env_vars)
- dag = dag or get_bagged_dag(args.bundle_name, args.dag_id)
+ if dag:
+ sdk_dag = dag
+ scheduler_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+ else:
+ sdk_dag = get_bagged_dag(args.bundle_name, args.dag_id)
+ scheduler_dag = get_db_dag(args.bundle_name, args.dag_id)
- # TODO (GH-52141): get_task in scheduler needs to return scheduler types
- # instead, but currently it inherits SDK's DAG.
- task = cast("Operator", dag.get_task(task_id=args.task_id))
+ sdk_task = sdk_dag.get_task(args.task_id)
# Add CLI provided task_params to task.params
if args.task_params:
passed_in_params = json.loads(args.task_params)
- task.params.update(passed_in_params)
+ sdk_task.params.update(passed_in_params)
- if task.params and isinstance(task.params, ParamsDict):
- task.params.validate()
+ if sdk_task.params and isinstance(sdk_task.params, ParamsDict):
+ sdk_task.params.validate()
ti, dr_created = _get_ti(
- task, args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="db"
+ scheduler_dag.get_task(args.task_id),
+ args.map_index,
+ logical_date_or_run_id=args.logical_date_or_run_id,
+ create_if_necessary="db",
)
try:
# TODO: move bulk of this logic into the SDK:
http://github.com/apache/airflow/issues/54658
from airflow.sdk._shared.secrets_masker import RedactedIO
with redirect_stdout(RedactedIO()):
- _run_task(ti=ti, task=task, run_triggerer=True)
+ _run_task(ti=ti, task=sdk_task, run_triggerer=True)
if ti.state == State.FAILED and args.post_mortem:
debugger = _guess_debugger()
debugger.set_trace()
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 72a5ab0b24b..5f3eb91f865 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -1075,8 +1075,6 @@ class TaskInstance(Base, LoggingMixin):
ti: TaskInstance = task_instance
task = task_instance.task
- if TYPE_CHECKING:
- assert isinstance(task, Operator) # TODO (GH-52141): This
shouldn't be needed.
ti.refresh_from_task(task, pool_override=pool)
ti.test_mode = test_mode
ti.refresh_from_db(session=session, lock_for_update=True)
@@ -1276,9 +1274,16 @@ class TaskInstance(Base, LoggingMixin):
log.info("[DAG TEST] Marking success for %s ", self.task_id)
return None
- taskrun_result = _run_task(ti=self, task=self.task)
- if taskrun_result is not None and taskrun_result.error:
+ # TODO (TaskSDK): This is the old ti execution path. The only usage is
+ # in TI.run(...), someone needs to analyse if it's still actually used
+ # somewhere and fix it, likely by rewriting TI.run(...) to use the same
+ # mechanism as Operator.test().
+ taskrun_result = _run_task(ti=self, task=self.task) # type:
ignore[arg-type]
+ if taskrun_result is None:
+ return None
+ if taskrun_result.error:
raise taskrun_result.error
+ self.task = taskrun_result.ti.task # type: ignore[assignment]
return None
@staticmethod
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 37c09b4a925..caa40ce93dd 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -1148,10 +1148,6 @@ class DependencyDetector:
from airflow.providers.standard.operators.trigger_dagrun import
TriggerDagRunOperator
from airflow.providers.standard.sensors.external_task import
ExternalTaskSensor
- # TODO (GH-52141): Separate MappedOperator implementation to get rid
of this.
- if TYPE_CHECKING:
- assert isinstance(task.operator_class, type)
-
deps = []
if isinstance(task, TriggerDagRunOperator):
deps.append(
@@ -1409,7 +1405,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
link = self.operator_extra_link_dict.get(name) or
self.global_operator_extra_link_dict.get(name)
if not link:
return None
- return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] #
TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives
SerializedBaseOperator
+ # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but
receives SerializedBaseOperator.
+ return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type]
@property
def operator_name(self) -> str:
diff --git a/airflow-core/src/airflow/utils/cli.py
b/airflow-core/src/airflow/utils/cli.py
index 8fef958ac0f..b6423c5af3a 100644
--- a/airflow-core/src/airflow/utils/cli.py
+++ b/airflow-core/src/airflow/utils/cli.py
@@ -300,7 +300,7 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str,
dagfile_path: str | N
)
-def _get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str |
None = None) -> SerializedDAG:
+def get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str |
None = None) -> SerializedDAG:
"""
Return DAG of a given dag_id.
@@ -321,7 +321,7 @@ def get_dags(bundle_names: list | None, dag_id: str,
use_regex: bool = False, fr
if not use_regex:
if from_db:
- return [_get_db_dag(bundle_names=bundle_names, dag_id=dag_id)]
+ return [get_db_dag(bundle_names=bundle_names, dag_id=dag_id)]
return [get_bagged_dag(bundle_names=bundle_names, dag_id=dag_id)]
def _find_dag(bundle):
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py
b/airflow-core/tests/unit/models/test_cleartasks.py
index c44cdf5635e..9a1f37c89ca 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -633,11 +633,11 @@ class TestClearTasks:
assert ti.max_tries == 1
# test dry_run
- for i in range(num_of_dags):
+ for i, dag in enumerate(dags):
ti = _get_ti(tis[i])
ti.try_number += 1
session.commit()
- ti.refresh_from_task(tis[i].task)
+ ti.refresh_from_task(dag.get_task(ti.task_id))
ti.run(session=session)
assert ti.state == State.SUCCESS
assert ti.try_number == 2
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 0df775ca41e..315d3c4cc7a 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -488,12 +488,12 @@ class TestTaskInstance:
)
def run_with_error(ti):
+ orig_task, ti.task = ti.task, task
with contextlib.suppress(AirflowException):
ti.run()
+ ti.task = orig_task
ti =
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
- ti.task = task
-
with create_session() as session:
session.get(TaskInstance, ti.id).try_number += 1
@@ -539,13 +539,13 @@ class TestTaskInstance:
)
def run_with_error(ti):
+ orig_task, ti.task = ti.task, task
with contextlib.suppress(AirflowException):
ti.run()
+ ti.task = orig_task
ti =
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
- ti.task = task
assert ti.try_number == 0
-
session.get(TaskInstance, ti.id).try_number += 1
session.commit()
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 7448b321389..6523dead760 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -66,10 +66,12 @@ if TYPE_CHECKING:
from pendulum.tz.timezone import FixedTimezone, Timezone
+ from airflow.models.taskinstance import TaskInstance as
SchedulerTaskInstance
from airflow.sdk.definitions.decorators import TaskDecoratorCollection
from airflow.sdk.definitions.edges import EdgeInfoType
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import TaskGroup
+ from airflow.sdk.execution_time.supervisor import TaskRunResult
from airflow.typing_compat import Self
Operator: TypeAlias = BaseOperator | MappedOperator
@@ -1304,7 +1306,12 @@ class DAG:
return dr
-def _run_task(*, ti, task, run_triggerer=False):
+def _run_task(
+ *,
+ ti: SchedulerTaskInstance,
+ task: Operator,
+ run_triggerer: bool = False,
+) -> TaskRunResult | None:
"""
Run a single task instance, and push result to Xcom for downstream tasks.
@@ -1314,6 +1321,7 @@ def _run_task(*, ti, task, run_triggerer=False):
from airflow.sdk.module_loading import import_string
from airflow.utils.state import State
+ taskrun_result: TaskRunResult | None
log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id,
ti.map_index)
while True:
try:
@@ -1322,6 +1330,7 @@ def _run_task(*, ti, task, run_triggerer=False):
from airflow.sdk.api.datamodels._generated import TaskInstance as
TaskInstanceSDK
from airflow.sdk.execution_time.comms import DeferTask
from airflow.sdk.execution_time.supervisor import
run_task_in_process
+ from airflow.serialization.serialized_objects import
create_scheduler_operator
# The API Server expects the task instance to be in QUEUED state
before
# it is run.
@@ -1336,14 +1345,10 @@ def _run_task(*, ti, task, run_triggerer=False):
dag_version_id=ti.dag_version_id,
)
- taskrun_result = run_task_in_process(
- ti=task_sdk_ti,
- task=task,
- )
-
+ taskrun_result = run_task_in_process(ti=task_sdk_ti, task=task)
msg = taskrun_result.msg
ti.set_state(taskrun_result.ti.state)
- ti.task = taskrun_result.ti.task
+ ti.task = create_scheduler_operator(taskrun_result.ti.task)
if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and
run_triggerer:
from airflow.utils.session import create_session
@@ -1363,16 +1368,19 @@ def _run_task(*, ti, task, run_triggerer=False):
with create_session() as session:
ti.state = State.SCHEDULED
session.add(ti)
+ continue
- return taskrun_result
+ break
except Exception:
log.exception("[DAG TEST] Error running task %s", ti)
if ti.state not in State.finished:
ti.set_state(State.FAILED)
+ taskrun_result = None
break
raise
log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id,
ti.map_index)
+ return taskrun_result
def _run_inline_trigger(trigger, task_sdk_ti):
diff --git a/task-sdk/src/airflow/sdk/types.py
b/task-sdk/src/airflow/sdk/types.py
index cfbcafe4201..c7084629085 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from pydantic import AwareDatetime
from airflow.sdk._shared.logging.types import Logger as Logger
+ from airflow.sdk.api.datamodels._generated import TaskInstanceState
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias,
AssetAliasEvent, AssetRef, BaseAssetUniqueKey
from airflow.sdk.definitions.context import Context
@@ -68,6 +69,7 @@ class RuntimeTaskInstanceProtocol(Protocol):
hostname: str | None = None
start_date: AwareDatetime
end_date: AwareDatetime | None = None
+ state: TaskInstanceState | None = None
def xcom_pull(
self,