This is an automated email from the ASF dual-hosted git repository.
dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 97959da0877 Re-enable start_from_trigger feature with rendering of
template fields (#55068)
97959da0877 is described below
commit 97959da087786dabda7c49f5512b9c5f14181735
Author: David Blain <[email protected]>
AuthorDate: Wed Mar 25 22:42:49 2026 +0100
Re-enable start_from_trigger feature with rendering of template fields
(#55068)
* Fix rendering of template fields with start from trigger
* refactor: Check if TaskInstance exists or not in BaseTrigger
* Revert "refactor: Check if TaskInstance exists or not in BaseTrigger"
This reverts commit 5f7306d287aea41e0970122bb153987b6be311b8.
* refactor: Changed return type of task_instance property in BaseTrigger
* refactor: Make sure default values for start from trigger can be
overriden in mapped operator
* refactor: Remove assert on start_date of TaskInstance
* refactor: Make sure to check if dag_data is not None in workloads before
creating the RuntimeTaskInstace
* refactor: Only pass serialized dag model to workload if trigger contains
templated fields.
* refactor: Don't invoke _read_dag twice in get_dag method of DBDagBag class
* refactor: Don't invoke _read_dag twice in get_dag method of DBDagBag class
* refactor: Make _version_from_dag_run method of DBDagBag failsafe for
legacy fallback
* refactor: Moved None check on start_state together with the task in one
type checking block to keep mypy happy
* Revert "refactor: Make _version_from_dag_run method of DBDagBag failsafe
for legacy fallback"
This reverts commit 23d7aea62e48301f2edaa15f31b8db3296c793d3.
* refactor: Fixed test_get_dag_model
* refactor: Only pass serialized Dag model data to RunTrigger if
start_from_trigger was enabled.
* refactor: Added docstrings for start_from_trigger and start_trigger_args
* refactor: Templated field must be checked on task of task instance
* refactor: Added start_from_trigger property on Trigger
* refactor: Reformatted trigger unit test
* refactor: Only the RuntimeTaskInstance has the task attribute, the
generated Pydantic one doesn't have it, we cannot do instanceof here as we fake
the typing with the models.TaskInstnace
* refactor: Reformatted test trigger
* Update
airflow-core/src/airflow/serialization/definitions/mappedoperator.py
Co-authored-by: Ash Berlin-Taylor <[email protected]>
* refactor: Removed obsolete run method from TaskInstance
* refactor: Added dag_data field to RunTrigger and made ti field optional
* refactor: Reformatted RunTrigger
* refactor: We cannot detect if a Trigger has a task associated with a task
having start_from_trigger without using DBDagBag, thus removed the check for now
* refactor: Re-added check on start_from_trigger from serialized Dag
* refactor: Fixed call to dag_bag in get_dag_for_run_or_latest_version
method due to refactor in DBDagBag
* refactor: Extracted _do_render_template_fields method into Template so it
can be re-used by AbstractOperator and BaseTrigger which is more DRY
* refactor: task_id should be an instance field instead of property
* refactor: Added tests for _do_render_template_fields method in
TestTemplater
* refactor: Fixed templater unit tests
* refactor: Raise NotImplementError in _set_context
* refactor: Reverted logging back to structlog in mappedoperator
* refactor: Refactored _create_workload in trigger job runner
* refactor: Renamed get_dag_model to get_serialized_dag_model in DBDagBag
* refactor: Refactored templater using structlog
* refactor: Added docstring to get_serialized_dag_model
* Revert "refactor: Raise NotImplementError in _set_context"
This reverts commit eff4fbdeaf598a2e367eaf19cb2e520365b7ee0c.
* refactor: Fixed typing of render_log_fname
* Revert "refactor: Refactored templater using structlog"
This reverts commit 20f7ac0437d66ff030224f8bfeb9f2aa9e3060c2.
* refactor: Reformatted files
* refactor: Removed new line in get_serialized_dag_model
* refactor: Fixed test_get_dag_returns_none_when_model_missing
* refactor: Removed default NEW_SESSION from session parameter in
_create_workload method
---------
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
airflow-core/.pre-commit-config.yaml | 1 +
.../src/airflow/api_fastapi/common/dagbag.py | 2 +-
.../src/airflow/executors/workloads/trigger.py | 5 +-
.../src/airflow/jobs/triggerer_job_runner.py | 200 ++++++++++++++-------
airflow-core/src/airflow/models/dagbag.py | 58 ++++--
airflow-core/src/airflow/models/dagrun.py | 28 +--
airflow-core/src/airflow/models/taskinstance.py | 69 ++++++-
airflow-core/src/airflow/triggers/base.py | 63 ++++++-
airflow-core/tests/unit/jobs/test_triggerer_job.py | 5 +-
airflow-core/tests/unit/models/test_dagbag.py | 79 ++++++++
.../tests/unit/models/test_taskinstance.py | 97 ++++++++++
.../tests/unit/triggers/test_base_trigger.py | 69 +++++++
devel-common/src/tests_common/pytest_plugin.py | 47 +++--
task-sdk/src/airflow/sdk/bases/operator.py | 22 +++
.../sdk/definitions/_internal/abstractoperator.py | 53 ------
.../airflow/sdk/definitions/_internal/templater.py | 94 ++++++++--
.../src/airflow/sdk/definitions/mappedoperator.py | 14 +-
task-sdk/tests/task_sdk/bases/test_operator.py | 20 +++
.../definitions/_internal/test_templater.py | 188 +++++++++++++++++++
19 files changed, 922 insertions(+), 192 deletions(-)
diff --git a/airflow-core/.pre-commit-config.yaml
b/airflow-core/.pre-commit-config.yaml
index 7573eec4e65..121b51d4e8b 100644
--- a/airflow-core/.pre-commit-config.yaml
+++ b/airflow-core/.pre-commit-config.yaml
@@ -376,6 +376,7 @@ repos:
^src/airflow/timetables/assets\.py$|
^src/airflow/timetables/base\.py$|
^src/airflow/timetables/simple\.py$|
+ ^src/airflow/triggers/base\.py$|
^src/airflow/utils/cli\.py$|
^src/airflow/utils/context\.py$|
^src/airflow/utils/dag_cycle_tester\.py$|
diff --git a/airflow-core/src/airflow/api_fastapi/common/dagbag.py
b/airflow-core/src/airflow/api_fastapi/common/dagbag.py
index 3ca4483ce87..c7630cde9f7 100644
--- a/airflow-core/src/airflow/api_fastapi/common/dagbag.py
+++ b/airflow-core/src/airflow/api_fastapi/common/dagbag.py
@@ -84,7 +84,7 @@ def get_dag_for_run_or_latest_version(
dag: SerializedDAG | None = None
if dag_run:
if dag_run.created_dag_version_id:
- dag = dag_bag._get_dag(dag_run.created_dag_version_id,
session=session)
+ dag = dag_bag.get_dag(dag_run.created_dag_version_id,
session=session)
if not dag:
dag = dag_bag.get_dag_for_run(dag_run, session=session)
elif dag_id:
diff --git a/airflow-core/src/airflow/executors/workloads/trigger.py
b/airflow-core/src/airflow/executors/workloads/trigger.py
index 25bca9ce44b..2959cde6ee3 100644
--- a/airflow-core/src/airflow/executors/workloads/trigger.py
+++ b/airflow-core/src/airflow/executors/workloads/trigger.py
@@ -35,8 +35,11 @@ class RunTrigger(BaseModel):
"""
id: int
- ti: TaskInstanceDTO | None # Could be none for asset-based triggers.
classpath: str # Dot-separated name of the module+fn to import and run
this workload.
encrypted_kwargs: str
+ ti: TaskInstanceDTO | None = None # Could be none for asset-based
triggers.
timeout_after: datetime | None = None
type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger")
+ dag_data: dict | None = (
+ None # Serialized Dag model in dict format so it can be deserialized
in trigger subprocess.
+ )
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 44c28a7a539..44f96589042 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -25,7 +25,7 @@ import signal
import sys
import time
from collections import deque
-from collections.abc import Generator, Iterable
+from collections.abc import Callable, Generator, Iterable
from contextlib import suppress
from datetime import datetime
from socket import socket
@@ -51,6 +51,7 @@ from airflow.executors import workloads
from airflow.executors.workloads.task import TaskInstanceDTO
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import perform_heartbeat
+from airflow.models.dagbag import DBDagBag
from airflow.models.trigger import Trigger
from airflow.observability.metrics import stats_utils
from airflow.sdk.api.datamodels._generated import HITLDetailResponse
@@ -84,10 +85,12 @@ from airflow.sdk.execution_time.comms import (
_RequestFrame,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess,
make_buffered_socket_reader
+from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+from airflow.serialization.serialized_objects import DagSerialization
from airflow.triggers.base import BaseEventTrigger, BaseTrigger,
DiscrimatedTriggerEvent, TriggerEvent
from airflow.utils.helpers import log_filename_template_renderer
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import create_session, provide_session
if TYPE_CHECKING:
from opentelemetry.util._decorator import _AgnosticContextManager
@@ -97,6 +100,7 @@ if TYPE_CHECKING:
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
from airflow.jobs.job import Job
from airflow.sdk.api.client import Client
+ from airflow.sdk.definitions.context import Context
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
logger = logging.getLogger(__name__)
@@ -658,6 +662,65 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
extra_tags={"hostname": self.job.hostname},
)
+ def _create_workload(
+ self,
+ trigger: Trigger,
+ dag_bag: DBDagBag,
+ render_log_fname: Callable[..., str],
+ session: Session,
+ ) -> workloads.RunTrigger | None:
+ if trigger.task_instance is None:
+ return workloads.RunTrigger(
+ id=trigger.id,
+ classpath=trigger.classpath,
+ encrypted_kwargs=trigger.encrypted_kwargs,
+ )
+
+ if not trigger.task_instance.dag_version_id:
+ # This is to handle 2 to 3 upgrade where TI.dag_version_id can be
none
+ log.warning(
+ "TaskInstance associated with Trigger has no associated Dag
Version, skipping the trigger",
+ ti_id=trigger.task_instance.id,
+ )
+ return None
+
+ log_path = render_log_fname(ti=trigger.task_instance)
+ ser_ti = TaskInstanceDTO.model_validate(trigger.task_instance,
from_attributes=True)
+
+ # When producing logs from TIs, include the job id producing the logs
to disambiguate it.
+ self.logger_cache[trigger.id] = TriggerLoggingFactory(
+ log_path=f"{log_path}.trigger.{self.job.id}.log",
+ ti=ser_ti, # type: ignore
+ )
+
+ serialized_dag_model = dag_bag.get_serialized_dag_model(
+ version_id=trigger.task_instance.dag_version_id,
+ session=session,
+ )
+
+ if serialized_dag_model:
+ task =
serialized_dag_model.dag.get_task(trigger.task_instance.task_id)
+
+ # When a TaskInstance of a Trigger contains a task with
start_from_trigger enabled,
+ # it means we need to load the SerializedDagModel so we can build
a RuntimeTaskInstance later on which
+ # will allow us to build a context on which we will render the
templated fields.
+ if task.start_from_trigger:
+ return workloads.RunTrigger(
+ id=trigger.id,
+ classpath=trigger.classpath,
+ encrypted_kwargs=trigger.encrypted_kwargs,
+ ti=ser_ti,
+ timeout_after=trigger.task_instance.trigger_timeout,
+ dag_data=serialized_dag_model.data,
+ )
+ return workloads.RunTrigger(
+ id=trigger.id,
+ classpath=trigger.classpath,
+ encrypted_kwargs=trigger.encrypted_kwargs,
+ ti=ser_ti,
+ timeout_after=trigger.task_instance.trigger_timeout,
+ )
+
def update_triggers(self, requested_trigger_ids: set[int]):
"""
Request that we update what triggers we're running.
@@ -666,8 +729,8 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
adds them to the dequeues so the subprocess can actually mutate the
running
trigger set.
"""
+ dag_bag = DBDagBag()
render_log_fname = log_filename_template_renderer()
-
known_trigger_ids = (
self.running_triggers.union(x[0] for x in self.events)
.union(self.cancelling_triggers)
@@ -678,60 +741,48 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
new_trigger_ids = requested_trigger_ids - known_trigger_ids
cancel_trigger_ids = self.running_triggers - requested_trigger_ids
# Bulk-fetch new trigger records
- new_triggers = Trigger.bulk_fetch(new_trigger_ids)
- trigger_ids_with_non_task_associations =
Trigger.fetch_trigger_ids_with_non_task_associations()
- to_create: list[workloads.RunTrigger] = []
- # Add in new triggers
- for new_id in new_trigger_ids:
- # Check it didn't vanish in the meantime
- if new_id not in new_triggers:
- log.warning("Trigger disappeared before we could start it",
id=new_id)
- continue
-
- new_trigger_orm = new_triggers[new_id]
-
- # If the trigger is not associated to a task, an asset, or a
callback, this means the TaskInstance
- # row was updated by either Trigger.submit_event or
Trigger.submit_failure
- # and can happen when a single trigger Job is being run on
multiple TriggerRunners
- # in a High-Availability setup.
- if new_trigger_orm.task_instance is None and new_id not in
trigger_ids_with_non_task_associations:
- log.info(
- (
- "TaskInstance Trigger is None. It was likely updated
by another trigger job. "
- "Skipping trigger instantiation."
- ),
- id=new_id,
- )
- continue
-
- workload = workloads.RunTrigger(
- classpath=new_trigger_orm.classpath,
- id=new_id,
- encrypted_kwargs=new_trigger_orm.encrypted_kwargs,
- ti=None,
+ with create_session() as session:
+ # Bulk-fetch new trigger records
+ new_triggers = Trigger.bulk_fetch(new_trigger_ids, session=session)
+ trigger_ids_with_non_task_associations =
Trigger.fetch_trigger_ids_with_non_task_associations(
+ session=session
)
- if new_trigger_orm.task_instance:
- log_path = render_log_fname(ti=new_trigger_orm.task_instance)
- if not new_trigger_orm.task_instance.dag_version_id:
- # This is to handle 2 to 3 upgrade where TI.dag_version_id
can be none
- log.warning(
- "TaskInstance associated with Trigger has no
associated Dag Version, skipping the trigger",
- ti_id=new_trigger_orm.task_instance.id,
- )
+ to_create: list[workloads.RunTrigger] = []
+ # Add in new triggers
+ for new_trigger_id in new_trigger_ids:
+ # Check it didn't vanish in the meantime
+ if new_trigger_id not in new_triggers:
+ log.warning("Trigger disappeared before we could start
it", id=new_trigger_id)
continue
- ser_ti =
TaskInstanceDTO.model_validate(new_trigger_orm.task_instance,
from_attributes=True)
- # When producing logs from TIs, include the job id producing
the logs to disambiguate it.
- self.logger_cache[new_id] = TriggerLoggingFactory(
- log_path=f"{log_path}.trigger.{self.job.id}.log",
- ti=ser_ti, # type: ignore
- )
- workload.ti = ser_ti
- workload.timeout_after =
new_trigger_orm.task_instance.trigger_timeout
+ new_trigger_orm = new_triggers[new_trigger_id]
+
+ # If the trigger is not associated to a task, an asset, or a
callback, this means the TaskInstance
+ # row was updated by either Trigger.submit_event or
Trigger.submit_failure
+ # and can happen when a single trigger Job is being run on
multiple TriggerRunners
+ # in a High-Availability setup.
+ if (
+ new_trigger_orm.task_instance is None
+ and new_trigger_id not in
trigger_ids_with_non_task_associations
+ ):
+ log.info(
+ (
+ "TaskInstance of Trigger is None. It was likely
updated by another trigger job. "
+ "Skipping trigger instantiation."
+ ),
+ id=new_trigger_id,
+ )
+ continue
- to_create.append(workload)
+ if workload := self._create_workload(
+ trigger=new_trigger_orm,
+ dag_bag=dag_bag,
+ render_log_fname=render_log_fname,
+ session=session,
+ ):
+ to_create.append(workload)
- self.creating_triggers.extend(to_create)
+ self.creating_triggers.extend(to_create)
if cancel_trigger_ids:
# Enqueue orphaned triggers for cancellation
@@ -986,9 +1037,19 @@ class TriggerRunner:
raise RuntimeError(f"Required first message to be a
messages.StartTriggerer, it was {msg}")
async def create_triggers(self):
+ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+ task =
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+ # I need to recreate a TaskInstance from task_runner before
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+ return RuntimeTaskInstance.model_construct(
+ **workload.ti.model_dump(exclude_unset=True),
+ task=task,
+ )
+
"""Drain the to_create queue and create all new triggers that have
been requested in the DB."""
while self.to_create:
await asyncio.sleep(0)
+ context: Context | None = None
workload = self.to_create.popleft()
trigger_id = workload.id
if trigger_id in self.triggers:
@@ -1016,24 +1077,32 @@ class TriggerRunner:
# that could cause None values in collections.
kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs)
deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for
k, v in kw.items()}
- trigger_instance = trigger_class(**deserialised_kwargs)
+
+ if ti := workload.ti:
+ trigger_name =
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID
{trigger_id})"
+ trigger_instance = trigger_class(**deserialised_kwargs)
+
+ if workload.dag_data:
+ runtime_ti = create_runtime_ti(workload.dag_data)
+ context = runtime_ti.get_template_context()
+ trigger_instance.task_instance = runtime_ti
+ else:
+ trigger_instance.task_instance = ti
+ else:
+ trigger_name = f"ID {trigger_id}"
+ trigger_instance = trigger_class(**deserialised_kwargs)
except TypeError as err:
self.log.error("Trigger failed to inflate", error=err)
self.failed_triggers.append((trigger_id, err))
continue
trigger_instance.trigger_id = trigger_id
trigger_instance.triggerer_job_id = self.job_id
- trigger_instance.task_instance = ti = workload.ti
trigger_instance.timeout_after = workload.timeout_after
- trigger_name = (
-
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID
{trigger_id})"
- if ti
- else f"ID {trigger_id}"
- )
self.triggers[trigger_id] = {
"task": asyncio.create_task(
- self.run_trigger(trigger_id, trigger_instance,
workload.timeout_after), name=trigger_name
+ self.run_trigger(trigger_id, trigger_instance,
workload.timeout_after, context),
+ name=trigger_name,
),
"is_watcher": isinstance(trigger_instance, BaseEventTrigger),
"name": trigger_name,
@@ -1200,7 +1269,13 @@ class TriggerRunner:
)
Stats.incr("triggers.blocked_main_thread")
- async def run_trigger(self, trigger_id: int, trigger: BaseTrigger,
timeout_after: datetime | None = None):
+ async def run_trigger(
+ self,
+ trigger_id: int,
+ trigger: BaseTrigger,
+ timeout_after: datetime | None = None,
+ context: Context | None = None,
+ ):
"""Run a trigger (they are async generators) and push their events
into our outbound event deque."""
if not os.environ.get("AIRFLOW_DISABLE_GREENBACK_PORTAL", "").lower()
== "true":
import greenback
@@ -1213,6 +1288,9 @@ class TriggerRunner:
self.log.info("trigger %s starting", name)
with _make_trigger_span(ti=trigger.task_instance,
trigger_id=trigger_id, name=name) as span:
try:
+ if context is not None:
+ trigger.render_template_fields(context=context)
+
async for event in trigger.run():
await self.log.ainfo(
"Trigger fired event",
name=self.triggers[trigger_id]["name"], result=event
diff --git a/airflow-core/src/airflow/models/dagbag.py
b/airflow-core/src/airflow/models/dagbag.py
index e04f77d06df..98799bbde0c 100644
--- a/airflow-core/src/airflow/models/dagbag.py
+++ b/airflow-core/src/airflow/models/dagbag.py
@@ -45,24 +45,44 @@ class DBDagBag:
"""
def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[UUID, SerializedDAG] = {} # dag_version_id to dag
+ self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to
dag
self.load_op_links = load_op_links
- def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
- serdag.load_op_links = self.load_op_links
- if dag := serdag.dag:
- self._dags[serdag.dag_version_id] = dag
+ def _read_dag(self, serialized_dag_model: SerializedDagModel) ->
SerializedDAG | None:
+ serialized_dag_model.load_op_links = self.load_op_links
+ if dag := serialized_dag_model.dag:
+ self._dags[serialized_dag_model.dag_version_id] =
serialized_dag_model
return dag
- def _get_dag(self, version_id: UUID, session: Session) -> SerializedDAG |
None:
- if dag := self._dags.get(version_id):
- return dag
- dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
- if not dag_version:
- return None
- if not (serdag := dag_version.serialized_dag):
- return None
- return self._read_dag(serdag)
+ def get_serialized_dag_model(self, version_id: UUID, session: Session) ->
SerializedDagModel | None:
+ """
+ Return the SerializedDagModel for a given dag version id.
+
+ This will first consult the in-memory cache keyed by the dag version
id. If the
+ model is not cached, the database is queried for a corresponding
:class:`DagVersion`
+ and its associated :class:`SerializedDagModel`.
+
+ :param version_id: The UUID of the dag version to look up.
+ :param session: SQLAlchemy session used to query the database.
+ :return: The serialized DAG model if found either in the cache or the
database; ``None``
+ is returned when no :class:`DagVersion` exists for the given
``version_id`` or
+ when that :class:`DagVersion` does not have an associated
:class:`SerializedDagModel`.
+ :rtype: SerializedDagModel | None
+
+ Note: If a serialized dag model is found in the database it will be
stored in the
+ internal cache (``self._dags``) before being returned.
+ """
+ if not (serialized_dag_model := self._dags.get(version_id)):
+ dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
+ if not dag_version or not (serialized_dag_model :=
dag_version.serialized_dag):
+ return None
+ self._read_dag(serialized_dag_model)
+ return serialized_dag_model
+
+ def get_dag(self, version_id: UUID, session: Session) -> SerializedDAG |
None:
+ if serialized_dag_model :=
self.get_serialized_dag_model(version_id=version_id, session=session):
+ return serialized_dag_model.dag
+ return None
@staticmethod
def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> UUID |
None:
@@ -74,24 +94,24 @@ class DBDagBag:
def get_dag_for_run(self, dag_run: DagRun, session: Session) ->
SerializedDAG | None:
if version_id := self._version_from_dag_run(dag_run=dag_run,
session=session):
- return self._get_dag(version_id=version_id, session=session)
+ return self.get_dag(version_id=version_id, session=session)
return None
def iter_all_latest_version_dags(self, *, session: Session) ->
Generator[SerializedDAG, None, None]:
"""Walk through all latest version dags available in the database."""
from airflow.models.serialized_dag import SerializedDagModel
- for sdm in session.scalars(select(SerializedDagModel)):
- if dag := self._read_dag(sdm):
+ for serialized_dag_model in
session.scalars(select(SerializedDagModel)):
+ if dag := self._read_dag(serialized_dag_model):
yield dag
def get_latest_version_of_dag(self, dag_id: str, *, session: Session) ->
SerializedDAG | None:
"""Get the latest version of a dag by its id."""
from airflow.models.serialized_dag import SerializedDagModel
- if not (serdag := SerializedDagModel.get(dag_id, session=session)):
+ if not (serialized_dag_model := SerializedDagModel.get(dag_id,
session=session)):
return None
- return self._read_dag(serdag)
+ return self._read_dag(serialized_dag_model)
def generate_md5_hash(context):
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index bbff43aad9a..c93f0ed8e1e 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1986,7 +1986,14 @@ class DagRun(Base, LoggingMixin):
debug_try_number_check = self.log.isEnabledFor(logging.DEBUG)
expected_try_number_by_ti_id: dict[UUID, tuple[int, int, str | None]]
= {}
for ti in schedulable_tis:
- if ti.is_schedulable:
+ if not ti.is_schedulable:
+ empty_ti_ids.append(ti.id)
+ # The defer_task method will check "start_trigger_args" to see
whether the operator
+ # start execution from triggerer. If so, we'll also check
"start_from_trigger"
+ # to see whether this feature is turned on and defer this task.
+ # If not, we'll add this "ti" into "schedulable_ti_ids" and later
+ # execute it to run in the worker.
+ elif not ti.defer_task(session=session):
schedulable_ti_ids.append(ti.id)
if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
reschedule_ti_ids.add(ti.id)
@@ -1998,25 +2005,6 @@ class DagRun(Base, LoggingMixin):
ti.try_number,
ti.state,
)
- # Check "start_trigger_args" to see whether the operator supports
- # start execution from triggerer. If so, we'll check
"start_from_trigger"
- # to see whether this feature is turned on and defer this task.
- # If not, we'll add this "ti" into "schedulable_ti_ids" and later
- # execute it to run in the worker.
- # TODO TaskSDK: This is disabled since we haven't figured out how
- # to render start_from_trigger in the scheduler. If we need to
- # render the value in a worker, it kind of defeats the purpose of
- # this feature (which is to save a worker process if possible).
- # elif task.start_trigger_args is not None:
- # if
task.expand_start_from_trigger(context=ti.get_template_context()):
- # ti.start_date = timezone.utcnow()
- # if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
- # ti.try_number += 1
- # ti.defer_task(exception=None, session=session)
- # else:
- # schedulable_ti_ids.append(ti.id)
- else:
- empty_ti_ids.append(ti.id)
count = 0
# Don't only check if the TI.id is in id_chunk
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index e212ca68504..4c2137a5343 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -121,7 +121,7 @@ if TYPE_CHECKING:
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.mappedoperator import Operator
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
-
+ from airflow.triggers.base import StartTriggerArgs
PAST_DEPENDS_MET = "past_depends_met"
@@ -1590,6 +1590,73 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
.values(last_heartbeat_at=timezone.utcnow())
)
+ @property
+ def start_trigger_args(self) -> StartTriggerArgs | None:
+ if self.task and self.task.start_from_trigger is True:
+ return self.task.start_trigger_args
+ return None
+
+ # TODO: We have some code duplication here and in the
_create_ti_state_update_query_and_update_state
+ # method of the task_instances module in the execution api when a
TIDeferredStatePayload is being
+ # processed. This is because of a TaskInstance being updated
differently using SQLAlchemy.
+ # If we use the approach from the execution api as common code in
the DagRun schedule_tis method,
+ # the side effect is the changes done to the task instance aren't
picked up by the scheduler and
+ # thus the task instance isn't processed until the scheduler is
restarted.
+ @provide_session
+ def defer_task(self, session: Session = NEW_SESSION) -> bool:
+ """
+ Mark the task as deferred and sets up the trigger that is needed to
resume it when TaskDeferred is raised.
+
+ :meta: private
+ """
+ from airflow.models.trigger import Trigger
+
+ if TYPE_CHECKING:
+ assert self.start_date
+ assert isinstance(self.task, Operator)
+
+ if start_trigger_args := self.start_trigger_args:
+ trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+ timeout = start_trigger_args.timeout
+
+ # Calculate timeout too if it was passed
+ if timeout is not None:
+ self.trigger_timeout = timezone.utcnow() + timeout
+ else:
+ self.trigger_timeout = None
+
+ trigger_row = Trigger(
+ classpath=start_trigger_args.trigger_cls,
+ kwargs=trigger_kwargs,
+ )
+
+ # First, make the trigger entry
+ session.add(trigger_row)
+ session.flush()
+
+ # Then, update ourselves so it matches the deferral request
+ # Keep an eye on the logic in
`check_and_change_state_before_execution()`
+ # depending on self.next_method semantics
+ self.state = TaskInstanceState.DEFERRED
+ self.trigger_id = trigger_row.id
+ self.next_method = start_trigger_args.next_method
+ self.next_kwargs = start_trigger_args.next_kwargs or {}
+
+ # If an execution_timeout is set, set the timeout to the minimum of
+ # it and the trigger timeout
+ if execution_timeout := self.task.execution_timeout:
+ if self.trigger_timeout:
+ self.trigger_timeout = min(self.start_date +
execution_timeout, self.trigger_timeout)
+ else:
+ self.trigger_timeout = self.start_date + execution_timeout
+ self.start_date = timezone.utcnow()
+ if self.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+ self.try_number += 1
+ if self.test_mode:
+ _add_log(event=self.state, task_instance=self, session=session)
+ return True
+ return False
+
@classmethod
def fetch_handle_failure_context(
cls,
diff --git a/airflow-core/src/airflow/triggers/base.py
b/airflow-core/src/airflow/triggers/base.py
index 416558242b8..7ca7ed20a74 100644
--- a/airflow-core/src/airflow/triggers/base.py
+++ b/airflow-core/src/airflow/triggers/base.py
@@ -21,7 +21,7 @@ import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import timedelta
-from typing import Annotated, Any
+from typing import TYPE_CHECKING, Annotated, Any
import structlog
from pydantic import (
@@ -32,11 +32,24 @@ from pydantic import (
model_serializer,
)
+from airflow.sdk.definitions._internal.templater import Templater
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState
log = structlog.get_logger(logger_name=__name__)
+if TYPE_CHECKING:
+ from typing import TypeAlias
+
+ import jinja2
+
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.models.taskinstance import TaskInstance
+ from airflow.sdk.definitions.context import Context
+ from airflow.serialization.serialized_objects import SerializedBaseOperator
+
+ Operator: TypeAlias = MappedOperator | SerializedBaseOperator
+
@dataclass
class StartTriggerArgs:
@@ -49,7 +62,7 @@ class StartTriggerArgs:
timeout: timedelta | None = None
-class BaseTrigger(abc.ABC, LoggingMixin):
+class BaseTrigger(abc.ABC, Templater, LoggingMixin):
"""
Base class for all triggers.
@@ -66,14 +79,56 @@ class BaseTrigger(abc.ABC, LoggingMixin):
supports_triggerer_queue: bool = True
def __init__(self, **kwargs):
+ super().__init__()
# these values are set by triggerer when preparing to run the instance
# when run, they are injected into logger record.
- self.task_instance = None
+ self._task_instance = None
self.trigger_id = None
+ self.template_fields = ()
+ self.template_ext = ()
+ self.task_id = None
def _set_context(self, context):
"""Part of LoggingMixin and used mainly for configuration of task
logging; not used for triggers."""
- raise NotImplementedError
+ pass
+
+ @property
+ def task(self) -> Operator | None:
+ # We must check if the TaskInstance is the generated Pydantic one or
the RuntimeTaskInstance
+ if self.task_instance and hasattr(self.task_instance, "task"):
+ return self.task_instance.task
+ return None
+
+ @property
+ def task_instance(self) -> TaskInstance:
+ return self._task_instance
+
+ @task_instance.setter
+ def task_instance(self, value: TaskInstance | None) -> None:
+ self._task_instance = value
+ if self.task_instance:
+ self.task_id = self.task_instance.task_id
+ if self.task:
+ self.template_fields = self.task.template_fields
+ self.template_ext = self.task.template_ext
+
+ def render_template_fields(
+ self,
+ context: Context,
+ jinja_env: jinja2.Environment | None = None,
+ ) -> None:
+ """
+ Template all attributes listed in *self.template_fields*.
+
+ This mutates the attributes in-place and is irreversible.
+
+ :param context: Context dict with values to apply on content.
+ :param jinja_env: Jinja's environment to use for rendering.
+ """
+ if not jinja_env:
+ jinja_env = self.get_template_env()
+ # We only need to render templated fields if templated fields are part
of the start_trigger_args
+ self._do_render_template_fields(self, self.template_fields, context,
jinja_env, set())
@abc.abstractmethod
def serialize(self) -> tuple[str, dict[str, Any]]:
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 3761189bfeb..503a3f4834c 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -120,9 +120,9 @@ def create_trigger_in_db(session, trigger, operator=None):
session.merge(testing_bundle)
session.flush()
- dag_model = DagModel(dag_id="test_dag", bundle_name=bundle_name)
- dag = DAG(dag_id=dag_model.dag_id, schedule="@daily",
start_date=pendulum.datetime(2023, 1, 1))
date = pendulum.datetime(2023, 1, 1)
+ dag_model = DagModel(dag_id="test_dag", bundle_name=bundle_name)
+ dag = DAG(dag_id=dag_model.dag_id, schedule="@daily", start_date=date)
run = DagRun(
dag_id=dag_model.dag_id,
run_id="test_run",
@@ -265,6 +265,7 @@ def test_trigger_lifecycle(spy_agency: SpyAgency, session,
testing_dag_bundle):
classpath=trigger.serialize()[0],
encrypted_kwargs=trigger_orm.encrypted_kwargs,
kind="RunTrigger",
+ dag_data=ANY,
)
)
# OK, now remove it from the DB
diff --git a/airflow-core/tests/unit/models/test_dagbag.py
b/airflow-core/tests/unit/models/test_dagbag.py
index 3b5b9887726..48f249205bb 100644
--- a/airflow-core/tests/unit/models/test_dagbag.py
+++ b/airflow-core/tests/unit/models/test_dagbag.py
@@ -16,8 +16,14 @@
# under the License.
from __future__ import annotations
+from unittest.mock import MagicMock, patch
+
import pytest
+from airflow.models.dagbag import DBDagBag
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.serialization.serialized_objects import SerializedDAG
+
pytestmark = pytest.mark.db_test
# This file previously contained tests for DagBag functionality, but those
tests
@@ -26,3 +32,76 @@ pytestmark = pytest.mark.db_test
#
# Tests for models-specific functionality (DBDagBag,
DagPriorityParsingRequest, etc.)
# would remain in this file, but currently no such tests exist.
+
+
+class TestDBDagBag:
+ def setup_method(self):
+ self.db_dag_bag = DBDagBag()
+ self.session = MagicMock()
+
+ def test__read_dag_stores_and_returns_dag(self):
+ """It should store the SerializedDagModel in _dags and return the
dag."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
+ mock_serdag.dag_version_id = "v1"
+
+ result = self.db_dag_bag._read_dag(mock_serdag)
+
+ assert result == mock_dag
+ assert self.db_dag_bag._dags["v1"] == mock_serdag
+ assert mock_serdag.load_op_links is True
+
+ def test__read_dag_returns_none_when_no_dag(self):
+ """It should return None and not modify _dags when no DAG is
present."""
+ mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = None
+ mock_serdag.dag_version_id = "v1"
+
+ result = self.db_dag_bag._read_dag(mock_serdag)
+
+ assert result is None
+ assert "v1" not in self.db_dag_bag._dags
+
+ def test_get_serialized_dag_model(self):
+ """It should return the cached SerializedDagModel if already loaded."""
+ mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag_version_id = "v1"
+ mock_dag_version = MagicMock()
+ mock_dag_version.serialized_dag = mock_serdag
+ self.session.get.return_value = mock_dag_version
+
+ self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
+ result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+
+ assert result == mock_serdag
+ self.session.get.assert_called_once()
+
+ def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ """It should return None if version_id not found in DB."""
+ self.session.get.return_value = None
+
+ result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+
+ assert result is None
+
+ def test_get_dag_calls_get_dag_model_and__read_dag(self):
+ """It should call get_dag_model and then _read_dag."""
+ mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag_version_id = "v1"
+ mock_dag = MagicMock(spec=SerializedDAG)
+ mock_dag_version = MagicMock()
+ mock_dag_version.serialized_dag = mock_serdag
+ mock_serdag.dag = mock_dag
+ self.session.get.return_value = mock_dag_version
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ self.session.get.assert_called_once()
+ assert result == mock_dag
+
+ def test_get_dag_returns_none_when_model_missing(self):
+ """It should return None if no SerializedDagModel found."""
+ with patch.object(self.db_dag_bag, "get_serialized_dag_model",
return_value=None):
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+ assert result is None
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index bb058d1a737..3fa09106ade 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -2653,6 +2653,103 @@ def test_refresh_from_task(pool_override,
queue_by_policy, monkeypatch):
assert ti.max_tries == expected_max_tries
+def
test_defer_task_returns_false_when_no_start_from_trigger(create_task_instance):
+ session = mock.MagicMock()
+ ti = create_task_instance(
+ dag_id="test_defer_task",
+ task_id="test_defer_task_op",
+ )
+ assert not ti.defer_task(session=session)
+
+
+def
test_defer_task_returns_false_when_no_start_trigger_args(create_task_instance):
+ session = mock.MagicMock()
+ ti = create_task_instance(
+ dag_id="test_defer_task",
+ task_id="test_defer_task",
+ start_from_trigger=True,
+ )
+ assert not ti.defer_task(session=session)
+
+
+def test_defer_task(create_task_instance):
+ from airflow.models.trigger import Trigger
+ from airflow.triggers.base import StartTriggerArgs
+
+ session = mock.MagicMock()
+ ti = create_task_instance(
+ dag_id="test_defer_task",
+ task_id="test_defer_task_op",
+ start_from_trigger=True,
+ start_trigger_args=StartTriggerArgs(
+ trigger_cls="trigger_cls",
+ next_method="next_method",
+ trigger_kwargs={"key": "value"},
+ ),
+ )
+ assert ti.defer_task(session=session)
+
+ # Check that session.add was called with a Trigger
+ assert session.add.call_count == 1
+ trigger_row = session.add.call_args[0][0]
+ assert isinstance(trigger_row, Trigger)
+ assert trigger_row.classpath == "trigger_cls"
+ assert trigger_row.kwargs == {"key": "value"}
+
+ # Check that session.flush was called
+ session.flush.assert_called_once()
+
+ # Check that TaskInstance state was updated
+ assert ti.state == TaskInstanceState.DEFERRED
+ assert ti.trigger_id == trigger_row.id
+ assert ti.next_method == "next_method"
+ assert ti.next_kwargs == {}
+
+ # Check trigger_timeout is set (should be None since no timeout provided)
+ assert ti.trigger_timeout is None
+
+
+def test_defer_task_with_trigger_timeout(create_task_instance):
+ from airflow.models.trigger import Trigger
+ from airflow.triggers.base import StartTriggerArgs
+
+ session = mock.MagicMock()
+ timeout = datetime.timedelta(hours=1)
+ ti = create_task_instance(
+ dag_id="test_defer_task_with_trigger_timeout",
+ task_id="test_defer_task_with_trigger_timeout_op",
+ start_from_trigger=True,
+ start_trigger_args=StartTriggerArgs(
+ trigger_cls="trigger_cls",
+ next_method="next_method",
+ trigger_kwargs={"key": "value"},
+ timeout=timeout,
+ ),
+ )
+
+ # Save start_date to calculate expected trigger_timeout
+ now = timezone.utcnow()
+ ti.start_date = now
+
+ ti.defer_task(session=session)
+
+ # Check session interactions
+ assert session.add.call_count == 1
+ trigger_row = session.add.call_args[0][0]
+ assert isinstance(trigger_row, Trigger)
+ session.flush.assert_called_once()
+
+ # TaskInstance fields
+ assert ti.state == TaskInstanceState.DEFERRED
+ assert ti.trigger_id == trigger_row.id
+ assert ti.next_method == "next_method"
+ assert ti.next_kwargs == {}
+
+ # Check trigger_timeout is set correctly (within a small tolerance)
+ expected_timeout = now + timeout
+ assert abs((ti.trigger_timeout - expected_timeout).total_seconds()) < 5
+
+
class TestTaskInstanceRecordTaskMapXComPush:
"""Test TI.xcom_push() correctly records return values for task-mapping."""
diff --git a/airflow-core/tests/unit/triggers/test_base_trigger.py
b/airflow-core/tests/unit/triggers/test_base_trigger.py
new file mode 100644
index 00000000000..53066c46f6a
--- /dev/null
+++ b/airflow-core/tests/unit/triggers/test_base_trigger.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.triggers.base import BaseTrigger, StartTriggerArgs
+
+
+class DummyOperator(BaseOperator):
+ template_fields = ("name",)
+
+
+class DummyTrigger(BaseTrigger):
+ def __init__(self, name: str, **kwargs):
+ super().__init__(**kwargs)
+ self.name = name
+
+ def run(self):
+ return None
+
+ def serialize(self):
+ return {"name": self.name}
+
+
[email protected]_test
+def test_render_template_fields(create_task_instance):
+ op = DummyOperator(task_id="dummy_task")
+ ti = create_task_instance(
+ task=op,
+ start_from_trigger=True,
+ start_trigger_args=StartTriggerArgs(
+
trigger_cls=f"{DummyTrigger.__module__}.{DummyTrigger.__qualname__}",
+ next_method="resume_method",
+ trigger_kwargs={"name": "Hello {{ name }}"},
+ ),
+ )
+
+ trigger = DummyTrigger(name="Hello {{ name }}")
+
+ assert not trigger.task_instance
+ assert not trigger.template_fields
+ assert not trigger.template_ext
+
+ trigger.task_instance = ti
+
+ assert trigger.task_instance == ti
+ assert "name" in trigger.template_fields
+ assert not trigger.template_ext
+
+ trigger.render_template_fields(context={"name": "world"})
+
+ assert trigger.name == "Hello world"
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index a4b590c5452..4fc35fe3514 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -53,6 +53,7 @@ if TYPE_CHECKING:
from airflow.sdk.types import DagRunProtocol, Operator
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.timetables.base import DagRunInfo, DataInterval
+ from airflow.triggers.base import StartTriggerArgs
from airflow.typing_compat import Self
from airflow.utils.state import DagRunState, TaskInstanceState
@@ -1564,6 +1565,9 @@ def create_task_instance(
hostname=None,
pid=None,
last_heartbeat_at=None,
+ task: Operator | None = None,
+ start_from_trigger: bool = False,
+ start_trigger_args: StartTriggerArgs | None = None,
**kwargs,
) -> TaskInstance:
timezone = _import_timezone()
@@ -1572,26 +1576,33 @@ def create_task_instance(
if logical_date is NOTSET:
# For now: default to having a logical date if None is not
explicitly passed.
logical_date = timezone.utcnow()
- with dag_maker(dag_id, **kwargs):
+ with dag_maker(dag_id, **kwargs) as dag:
op_kwargs = {}
op_kwargs["task_display_name"] = task_display_name
- task = EmptyOperator(
- task_id=task_id,
- max_active_tis_per_dag=max_active_tis_per_dag,
- max_active_tis_per_dagrun=max_active_tis_per_dagrun,
- executor_config=executor_config or {},
- on_success_callback=on_success_callback,
- on_execute_callback=on_execute_callback,
- on_failure_callback=on_failure_callback,
- on_retry_callback=on_retry_callback,
- on_skipped_callback=on_skipped_callback,
- inlets=inlets,
- outlets=outlets,
- email=email,
- pool=pool,
- trigger_rule=trigger_rule,
- **op_kwargs,
- )
+ if not task:
+ task = EmptyOperator(
+ task_id=task_id,
+ max_active_tis_per_dag=max_active_tis_per_dag,
+ max_active_tis_per_dagrun=max_active_tis_per_dagrun,
+ executor_config=executor_config or {},
+ on_success_callback=on_success_callback,
+ on_execute_callback=on_execute_callback,
+ on_failure_callback=on_failure_callback,
+ on_retry_callback=on_retry_callback,
+ on_skipped_callback=on_skipped_callback,
+ inlets=inlets,
+ outlets=outlets,
+ email=email,
+ pool=pool,
+ trigger_rule=trigger_rule,
+ **op_kwargs,
+ )
+ else:
+ task_id = task.task_id
+ task.dag = dag
+ task.start_from_trigger = start_from_trigger
+ task.start_trigger_args = start_trigger_args
+
if AIRFLOW_V_3_0_PLUS:
dagrun_kwargs = {
"logical_date": logical_date,
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index 6e88f0a94ad..4d5905ab73b 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -550,6 +550,11 @@ class BaseOperatorMeta(abc.ABCMeta):
# Store the args passed to init -- we need them to support
task.map serialization!
self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
+ # Validate trigger kwargs.
+ # Make sure method exists as class can depend on metaclass without
extending the BaseOperator.
+ if hasattr(self, "_validate_start_from_trigger_kwargs"):
+ self._validate_start_from_trigger_kwargs()
+
# Set upstream task defined by XComArgs passed to template fields
of the operator.
# BUT: only do this _ONCE_, not once for each class in the
hierarchy
if not instantiated_from_mapped and func ==
self.__init__.__wrapped__: # type: ignore[misc]
@@ -846,6 +851,14 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
to render templates as native Python types. If False, a Jinja
``Environment`` is used to render templates as string values.
If None (default), inherits from the DAG setting.
+ :param start_from_trigger: If True, the operator starts execution directly
in the triggerer,
+ skipping the initial worker execution phase. In this mode, templated
fields are rendered
+ inside the triggerer instead of the worker. This avoids an extra round
trip to a worker,
+ but may increase load on the triggerer, since the DAG must be
serialized in order to
+ render templated fields. Use with care for DAGs with many tasks or
heavy templating.
+ :param start_trigger_args: Used together with ``start_from_trigger`` to
explicitly specify
+ which operator fields should be passed to the trigger. This helps
limit the amount of
+ data serialized and sent to the triggerer.
"""
task_id: str
@@ -1440,6 +1453,15 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return
XComArg.apply_upstream_relationship(self, newvalue)
+ def _validate_start_from_trigger_kwargs(self):
+ if self.start_from_trigger and self.start_trigger_args and
self.start_trigger_args.trigger_kwargs:
+ for name, val in self.start_trigger_args.trigger_kwargs.items():
+ if callable(val):
+ raise ValueError(
+ f"{self.__class__.__name__} with task_id
'{self.task_id}' has a callable in trigger kwargs named "
+ f"'{name}', which is not allowed when
start_from_trigger is enabled."
+ )
+
def on_kill(self) -> None:
"""
Override this method to clean up subprocesses when a task instance
gets killed.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
index e32bd377f01..00b811146a6 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
@@ -285,59 +285,6 @@ class AbstractOperator(Templater, DAGNode):
dag = self.get_dag()
return super()._render(template, context, dag=dag)
- def _do_render_template_fields(
- self,
- parent: Any,
- template_fields: Iterable[str],
- context: Context,
- jinja_env: jinja2.Environment,
- seen_oids: set[int],
- ) -> None:
- """Override the base to use custom error logging."""
- for attr_name in template_fields:
- try:
- value = getattr(parent, attr_name)
- except AttributeError:
- raise AttributeError(
- f"{attr_name!r} is configured as a template field "
- f"but {parent.task_type} does not have this attribute."
- )
- try:
- if not value:
- continue
- except Exception:
- # This may happen if the templated field points to a class
which does not support `__bool__`,
- # such as Pandas DataFrames:
- #
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
- log.info(
- "Unable to check if the value of type '%s' is False for
task '%s', field '%s'.",
- type(value).__name__,
- self.task_id,
- attr_name,
- )
- # We may still want to render custom classes which do not
support __bool__
- pass
-
- try:
- if callable(value):
- rendered_content = value(context=context,
jinja_env=jinja_env)
- else:
- rendered_content = self.render_template(value, context,
jinja_env, seen_oids)
- except Exception:
- # Mask sensitive values in the template before logging
- from airflow.sdk._shared.secrets_masker import redact
-
- masked_value = redact(value)
- log.exception(
- "Exception rendering Jinja template for task '%s', field
'%s'. Template: %r",
- self.task_id,
- attr_name,
- masked_value,
- )
- raise
- else:
- setattr(parent, attr_name, rendered_content)
-
def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator |
MappedTaskGroup]:
"""
Return mapped nodes that are direct dependencies of the current task.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py
b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py
index f094ccd6b28..cfe4a6100e4 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import datetime
import logging
import os
-from collections.abc import Collection, Iterable, Sequence
+from collections.abc import Collection, Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
@@ -117,6 +117,48 @@ class Templater:
return dag.render_template_as_native_obj if dag else False
+ def _iter_templated_fields(
+ self,
+ parent: Any,
+ template_fields: Iterable[str],
+ ) -> Iterator[tuple[str, Any]]:
+ """
+ Iterate over template fields yielding ``(attr_name, value)`` pairs for
non-empty fields.
+
+ Fields whose value is falsy are skipped. Objects that do not support
+ ``__bool__`` (e.g. Pandas DataFrames) are still yielded.
+ """
+ for attr_name in template_fields:
+ try:
+ value = getattr(parent, attr_name)
+ except AttributeError:
+ raise AttributeError(
+ f"{attr_name!r} is configured as a template field "
+ f"but {type(parent).__name__} does not have this
attribute."
+ )
+ try:
+ if not value:
+ continue
+ except Exception:
+ # This may happen if the templated field points to a class
which does not support
+ # ``__bool__``, such as Pandas DataFrames:
+ #
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
+ if hasattr(self, "task_id"):
+ log.info(
+ "Unable to check if the value of type '%s' is False
for task '%s', field '%s'.",
+ type(value).__name__,
+ self.task_id,
+ attr_name,
+ )
+ else:
+ log.info(
+ "Unable to check if the value of type '%s' is False
for field '%s'.",
+ type(value).__name__,
+ attr_name,
+ )
+ # We may still want to render custom classes which do not
support __bool__
+ yield attr_name, value
+
def _do_render_template_fields(
self,
parent: Any,
@@ -125,15 +167,47 @@ class Templater:
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
- for attr_name in template_fields:
- value = getattr(parent, attr_name)
- rendered_content = self.render_template(
- value,
- context,
- jinja_env,
- seen_oids,
- )
- if rendered_content:
+ """
+ Render template fields on *parent* in-place.
+
+ For each non-empty field yielded by :meth:`_iter_templated_fields`,
the value is
+ rendered (or called, when it is callable) and the result is written
back via
+ ``setattr``. Rendering errors are logged with masked values before
being re-raised.
+
+ :param parent: The object whose attributes will be templated.
+ :param template_fields: Names of the attributes to render.
+ :param context: Context dict with values to apply on content.
+ :param jinja_env: Jinja2 environment to use for rendering.
+ :param seen_oids: Set of already-rendered object ids used to prevent
infinite
+ recursion on circular references.
+ """
+ for attr_name, value in self._iter_templated_fields(parent,
template_fields):
+ try:
+ if callable(value):
+ rendered_content = value(context=context,
jinja_env=jinja_env)
+ else:
+ rendered_content = self.render_template(value, context,
jinja_env, seen_oids)
+ except Exception:
+ # Mask sensitive values in the template before logging
+ from airflow.sdk._shared.secrets_masker import redact
+
+ masked_value = redact(value)
+ if hasattr(self, "task_id"):
+ log.exception(
+ "Exception rendering Jinja template for task '%s',
field '%s'. Template: %r",
+ self.task_id,
+ attr_name,
+ masked_value,
+ )
+ else:
+ log.exception(
+ "Exception rendering Jinja template for %s, field
'%s'. Template: %r",
+ type(parent).__name__,
+ attr_name,
+ masked_value,
+ )
+ raise
+ else:
setattr(parent, attr_name, rendered_content)
def _render(self, template, context, dag=None) -> Any:
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index abc4c86ed85..7c0540421d4 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -226,6 +226,16 @@ class OperatorPartial:
task_group = partial_kwargs.pop("task_group")
start_date = partial_kwargs.pop("start_date", None)
end_date = partial_kwargs.pop("end_date", None)
+ start_from_trigger = (
+ partial_kwargs["start_from_trigger"]
+ if "start_from_trigger" in partial_kwargs
+ else getattr(self.operator_class, "start_from_trigger", False)
+ )
+ start_trigger_args = (
+ partial_kwargs["start_trigger_args"]
+ if "start_trigger_args" in partial_kwargs
+ else getattr(self.operator_class, "start_trigger_args", None)
+ )
try:
operator_name = self.operator_class.custom_operator_name # type:
ignore
@@ -259,8 +269,8 @@ class OperatorPartial:
# to BaseOperator.expand() contribute to operator arguments.
expand_input_attr="expand_input",
# TODO: Move these to task SDK's BaseOperator and remove getattr
- start_trigger_args=getattr(self.operator_class,
"start_trigger_args", None),
- start_from_trigger=bool(getattr(self.operator_class,
"start_from_trigger", False)),
+ start_trigger_args=start_trigger_args,
+ start_from_trigger=start_from_trigger,
)
return op
diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py
b/task-sdk/tests/task_sdk/bases/test_operator.py
index 9e6db88d5cf..dcb5240a83d 100644
--- a/task-sdk/tests/task_sdk/bases/test_operator.py
+++ b/task-sdk/tests/task_sdk/bases/test_operator.py
@@ -41,6 +41,7 @@ from airflow.sdk.bases.operator import (
)
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.definitions.template import literal
+from airflow.triggers.base import StartTriggerArgs
DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
@@ -108,9 +109,18 @@ class MockOperator(BaseOperator):
super().__init__(**kwargs)
self.arg1 = arg1
self.arg2 = arg2
+ if self.start_from_trigger:
+ self.start_trigger_args = StartTriggerArgs(
+ trigger_cls="trigger_cls",
+ next_method="next_method",
+ trigger_kwargs={"arg1": arg1, "arg2": arg2},
+ )
class TestBaseOperator:
+ def setup_method(self, method):
+ MockOperator.start_from_trigger = False
+
# Since we have a custom metaclass, lets double check the behaviour of
# passing args in the wrong way (args etc)
def test_kwargs_only(self):
@@ -800,6 +810,16 @@ class TestBaseOperator:
task.render_template_fields(context={"foo": "whatever", "bar":
"whatever"})
assert mock_jinja_env.call_count == 1
+ def test_validate_start_from_trigger_kwargs(self):
+ MockOperator.start_from_trigger = True
+
+ with pytest.raises(
+ ValueError,
+ match="MockOperator with task_id 'one' has a callable in trigger
kwargs named "
+ "'arg2', which is not allowed when start_from_trigger is enabled.",
+ ):
+ MockOperator(task_id="one", arg1="{{ foo }}", arg2=lambda context,
jinja_env: "bar")
+
def test_params_source(self):
# Test bug when copying an operator attached to a Dag
with DAG(
diff --git a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py
b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py
index bcce3c89547..fccdfe8664c 100644
--- a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py
+++ b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from datetime import datetime, timezone
+from unittest.mock import MagicMock, NonCallableMagicMock
import jinja2
import pytest
@@ -111,6 +112,193 @@ class TestTemplater:
assert rendered_content == "template_file.txt"
+ def test_do_render_template_fields_basic(self):
+ """Test that _do_render_template_fields renders a simple string
template field in-place."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["greeting"])
+ parent.greeting = "Hello {{ name }}"
+
+ context = {"name": "world"}
+ jinja_env = templater.get_template_env()
+
+ templater._do_render_template_fields(parent, ["greeting"], context,
jinja_env, set())
+
+ assert parent.greeting == "Hello world"
+
+ def test_do_render_template_fields_multiple_fields(self):
+ """Test rendering multiple template fields at once."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["first", "second"])
+ parent.first = "Hello {{ name }}"
+ parent.second = "Date: {{ ds }}"
+
+ context = {"name": "world", "ds": "2024-01-01"}
+ jinja_env = templater.get_template_env()
+
+ templater._do_render_template_fields(parent, ["first", "second"],
context, jinja_env, set())
+
+ assert parent.first == "Hello world"
+ assert parent.second == "Date: 2024-01-01"
+
+ def test_do_render_template_fields_callable_value(self):
+ """Test that callable field values are called with context and
jinja_env."""
+ templater = Templater()
+ templater.template_ext = []
+
+ callback = MagicMock(spec=lambda context, jinja_env: None,
return_value="resolved")
+ parent = MagicMock(spec=["my_field"])
+ parent.my_field = callback
+
+ context = {"key": "value"}
+ jinja_env = templater.get_template_env()
+
+ templater._do_render_template_fields(parent, ["my_field"], context,
jinja_env, set())
+
+ callback.assert_called_once_with(context=context, jinja_env=jinja_env)
+ assert parent.my_field == "resolved"
+
+ def test_do_render_template_fields_skips_falsy_values(self):
+ """Test that falsy field values (empty string, None, 0) are skipped."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["empty_str", "none_val"])
+ parent.empty_str = ""
+ parent.none_val = None
+
+ context = {"name": "world"}
+ jinja_env = templater.get_template_env()
+
+ templater._do_render_template_fields(parent, ["empty_str",
"none_val"], context, jinja_env, set())
+
+ # Falsy values should not be touched
+ assert parent.empty_str == ""
+ assert parent.none_val is None
+
+ def test_do_render_template_fields_missing_attribute(self):
+ """Test that a missing attribute on parent raises AttributeError."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["existing"])
+ parent.existing = "value"
+
+ context = {}
+ jinja_env = templater.get_template_env()
+
+ with pytest.raises(
+ AttributeError,
+ match="'nonexistent' is configured as a template field",
+ ):
+ templater._do_render_template_fields(parent, ["nonexistent"],
context, jinja_env, set())
+
+ def test_do_render_template_fields_exception_logged_with_task_id(self,
caplog):
+ """Test that rendering errors are logged with task_id when available
and re-raised."""
+ templater = Templater()
+ templater.template_ext = []
+ templater.task_id = "my_task"
+
+ parent = MagicMock(spec=["bad_field"])
+ parent.bad_field = "{{ undefined_var }}"
+
+ context = {}
+ jinja_env = SandboxedEnvironment(undefined=jinja2.StrictUndefined,
cache_size=0)
+
+ with pytest.raises(jinja2.UndefinedError):
+ templater._do_render_template_fields(parent, ["bad_field"],
context, jinja_env, set())
+
+ assert "Exception rendering Jinja template for task 'my_task', field
'bad_field'" in caplog.text
+
+ def test_do_render_template_fields_exception_logged_without_task_id(self,
caplog):
+ """Test that rendering errors are logged with parent type name when no
task_id."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["bad_field"])
+ parent.bad_field = "{{ undefined_var }}"
+
+ context = {}
+ jinja_env = SandboxedEnvironment(undefined=jinja2.StrictUndefined,
cache_size=0)
+
+ with pytest.raises(jinja2.UndefinedError):
+ templater._do_render_template_fields(parent, ["bad_field"],
context, jinja_env, set())
+
+ assert "Exception rendering Jinja template for MagicMock, field
'bad_field'" in caplog.text
+
+ def test_do_render_template_fields_nested_template_fields(self):
+ """Test rendering nested objects that have their own
template_fields."""
+ templater = Templater()
+ templater.template_ext = []
+
+ inner = NonCallableMagicMock(spec=["template_fields", "message"])
+ inner.template_fields = ["message"]
+ inner.message = "Hello {{ name }}"
+
+ parent = MagicMock(spec=["nested"])
+ parent.nested = inner
+
+ context = {"name": "world"}
+ jinja_env = templater.get_template_env()
+
+ templater._do_render_template_fields(parent, ["nested"], context,
jinja_env, set())
+
+ assert inner.message == "Hello world"
+
+ def test_do_render_template_fields_seen_oids_prevents_reprocessing(self):
+ """Test that already-seen objects (by id) are not re-rendered."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["greeting"])
+ parent.greeting = "Hello {{ name }}"
+
+ context = {"name": "world"}
+ jinja_env = templater.get_template_env()
+
+ # Pre-populate seen_oids with the parent's greeting value id
+ seen_oids = {id(parent.greeting)}
+
+ templater._do_render_template_fields(parent, ["greeting"], context,
jinja_env, seen_oids)
+
+ # The value should NOT be rendered because render_template checks
+ # `id(value) in seen_oids` and short-circuits, returning the original
+ # unrendered string.
+ assert parent.greeting == "Hello {{ name }}"
+
+ def test_do_render_template_fields_renders_dict_values(self):
+ """Test that dict field values have their inner templates rendered."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["params"])
+ parent.params = {"key": "{{ value }}"}
+
+ context = {"value": "rendered"}
+ jinja_env = templater.get_template_env()
+
+ templater._do_render_template_fields(parent, ["params"], context,
jinja_env, set())
+
+ assert parent.params == {"key": "rendered"}
+
+ def test_do_render_template_fields_renders_list_values(self):
+ """Test that list field values have their inner templates rendered."""
+ templater = Templater()
+ templater.template_ext = []
+
+ parent = MagicMock(spec=["items"])
+ parent.items = ["{{ a }}", "{{ b }}"]
+
+ context = {"a": "first", "b": "second"}
+ jinja_env = templater.get_template_env()
+
+ templater._do_render_template_fields(parent, ["items"], context,
jinja_env, set())
+
+ assert parent.items == ["first", "second"]
+
@pytest.fixture
def env():