This is an automated email from the ASF dual-hosted git repository.

dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 97959da0877 Re-enable start_from_trigger feature with rendering of 
template fields (#55068)
97959da0877 is described below

commit 97959da087786dabda7c49f5512b9c5f14181735
Author: David Blain <[email protected]>
AuthorDate: Wed Mar 25 22:42:49 2026 +0100

    Re-enable start_from_trigger feature with rendering of template fields 
(#55068)
    
    * Fix rendering of template fields with start from trigger
    
    * refactor: Check if TaskInstance exists or not in BaseTrigger
    
    * Revert "refactor: Check if TaskInstance exists or not in BaseTrigger"
    
    This reverts commit 5f7306d287aea41e0970122bb153987b6be311b8.
    
    * refactor: Changed return type of task_instance property in BaseTrigger
    
    * refactor: Make sure default values for start from trigger can be 
overriden in mapped operator
    
    * refactor: Remove assert on start_date of TaskInstance
    
    * refactor: Make sure to check if dag_data is not None in workloads before 
creating the RuntimeTaskInstace
    
    * refactor: Only pass serialized dag model to workload if trigger contains 
templated fields.
    
    * refactor: Don't invoke _read_dag twice in get_dag method of DBDagBag class
    
    * refactor: Don't invoke _read_dag twice in get_dag method of DBDagBag class
    
    * refactor: Make _version_from_dag_run method of DBDagBag failsafe for 
legacy fallback
    
    * refactor: Moved None check on start_state together with the task in one 
type checking block to keep mypy happy
    
    * Revert "refactor: Make _version_from_dag_run method of DBDagBag failsafe 
for legacy fallback"
    
    This reverts commit 23d7aea62e48301f2edaa15f31b8db3296c793d3.
    
    * refactor: Fixed test_get_dag_model
    
    * refactor: Only pass serialized Dag model data to RunTrigger if 
start_from_trigger was enabled.
    
    * refactor: Added docstrings for start_from_trigger and start_trigger_args
    
    * refactor: Templated field must be checked on task of task instance
    
    * refactor: Added start_from_trigger property on Trigger
    
    * refactor: Reformatted trigger unit test
    
    * refactor: Only the RuntimeTaskInstance has the task attribute, the 
generated Pydantic one doesn't have it, we cannot do instanceof here as we fake 
the typing with the models.TaskInstnace
    
    * refactor: Reformatted test trigger
    
    * Update 
airflow-core/src/airflow/serialization/definitions/mappedoperator.py
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
    
    * refactor: Removed obsolete run method from TaskInstance
    
    * refactor: Added dag_data field to RunTrigger and made ti field optional
    
    * refactor: Reformatted RunTrigger
    
    * refactor: We cannot detect if a Trigger has a task associated with a task 
having start_from_trigger without using DBDagBag, thus removed the check for now
    
    * refactor: Re-added check on start_from_trigger from serialized Dag
    
    * refactor: Fixed call to dag_bag in get_dag_for_run_or_latest_version 
method due to refactor in DBDagBag
    
    * refactor: Extracted _do_render_template_fields method into Template so it 
can be re-used by AbstractOperator and BaseTrigger which is more DRY
    
    * refactor: task_id should be an instance field instead of property
    
    * refactor: Added tests for _do_render_template_fields method in 
TestTemplater
    
    * refactor: Fixed templater unit tests
    
    * refactor: Raise NotImplementError in _set_context
    
    * refactor: Reverted logging back to structlog in mappedoperator
    
    * refactor: Refactored _create_workload in trigger job runner
    
    * refactor: Renamed get_dag_model to get_serialized_dag_model in DBDagBag
    
    * refactor: Refactored templater using structlog
    
    * refactor: Added docstring to get_serialized_dag_model
    
    * Revert "refactor: Raise NotImplementError in _set_context"
    
    This reverts commit eff4fbdeaf598a2e367eaf19cb2e520365b7ee0c.
    
    * refactor: Fixed typing of render_log_fname
    
    * Revert "refactor: Refactored templater using structlog"
    
    This reverts commit 20f7ac0437d66ff030224f8bfeb9f2aa9e3060c2.
    
    * refactor: Reformatted files
    
    * refactor: Removed new line in get_serialized_dag_model
    
    * refactor: Fixed test_get_dag_returns_none_when_model_missing
    
    * refactor: Removed default NEW_SESSION from session parameter in 
_create_workload method
    
    ---------
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
 airflow-core/.pre-commit-config.yaml               |   1 +
 .../src/airflow/api_fastapi/common/dagbag.py       |   2 +-
 .../src/airflow/executors/workloads/trigger.py     |   5 +-
 .../src/airflow/jobs/triggerer_job_runner.py       | 200 ++++++++++++++-------
 airflow-core/src/airflow/models/dagbag.py          |  58 ++++--
 airflow-core/src/airflow/models/dagrun.py          |  28 +--
 airflow-core/src/airflow/models/taskinstance.py    |  69 ++++++-
 airflow-core/src/airflow/triggers/base.py          |  63 ++++++-
 airflow-core/tests/unit/jobs/test_triggerer_job.py |   5 +-
 airflow-core/tests/unit/models/test_dagbag.py      |  79 ++++++++
 .../tests/unit/models/test_taskinstance.py         |  97 ++++++++++
 .../tests/unit/triggers/test_base_trigger.py       |  69 +++++++
 devel-common/src/tests_common/pytest_plugin.py     |  47 +++--
 task-sdk/src/airflow/sdk/bases/operator.py         |  22 +++
 .../sdk/definitions/_internal/abstractoperator.py  |  53 ------
 .../airflow/sdk/definitions/_internal/templater.py |  94 ++++++++--
 .../src/airflow/sdk/definitions/mappedoperator.py  |  14 +-
 task-sdk/tests/task_sdk/bases/test_operator.py     |  20 +++
 .../definitions/_internal/test_templater.py        | 188 +++++++++++++++++++
 19 files changed, 922 insertions(+), 192 deletions(-)

diff --git a/airflow-core/.pre-commit-config.yaml 
b/airflow-core/.pre-commit-config.yaml
index 7573eec4e65..121b51d4e8b 100644
--- a/airflow-core/.pre-commit-config.yaml
+++ b/airflow-core/.pre-commit-config.yaml
@@ -376,6 +376,7 @@ repos:
           ^src/airflow/timetables/assets\.py$|
           ^src/airflow/timetables/base\.py$|
           ^src/airflow/timetables/simple\.py$|
+          ^src/airflow/triggers/base\.py$|
           ^src/airflow/utils/cli\.py$|
           ^src/airflow/utils/context\.py$|
           ^src/airflow/utils/dag_cycle_tester\.py$|
diff --git a/airflow-core/src/airflow/api_fastapi/common/dagbag.py 
b/airflow-core/src/airflow/api_fastapi/common/dagbag.py
index 3ca4483ce87..c7630cde9f7 100644
--- a/airflow-core/src/airflow/api_fastapi/common/dagbag.py
+++ b/airflow-core/src/airflow/api_fastapi/common/dagbag.py
@@ -84,7 +84,7 @@ def get_dag_for_run_or_latest_version(
     dag: SerializedDAG | None = None
     if dag_run:
         if dag_run.created_dag_version_id:
-            dag = dag_bag._get_dag(dag_run.created_dag_version_id, 
session=session)
+            dag = dag_bag.get_dag(dag_run.created_dag_version_id, 
session=session)
         if not dag:
             dag = dag_bag.get_dag_for_run(dag_run, session=session)
     elif dag_id:
diff --git a/airflow-core/src/airflow/executors/workloads/trigger.py 
b/airflow-core/src/airflow/executors/workloads/trigger.py
index 25bca9ce44b..2959cde6ee3 100644
--- a/airflow-core/src/airflow/executors/workloads/trigger.py
+++ b/airflow-core/src/airflow/executors/workloads/trigger.py
@@ -35,8 +35,11 @@ class RunTrigger(BaseModel):
     """
 
     id: int
-    ti: TaskInstanceDTO | None  # Could be none for asset-based triggers.
     classpath: str  # Dot-separated name of the module+fn to import and run 
this workload.
     encrypted_kwargs: str
+    ti: TaskInstanceDTO | None = None  # Could be none for asset-based 
triggers.
     timeout_after: datetime | None = None
     type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger")
+    dag_data: dict | None = (
+        None  # Serialized Dag model in dict format so it can be deserialized 
in trigger subprocess.
+    )
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 44c28a7a539..44f96589042 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -25,7 +25,7 @@ import signal
 import sys
 import time
 from collections import deque
-from collections.abc import Generator, Iterable
+from collections.abc import Callable, Generator, Iterable
 from contextlib import suppress
 from datetime import datetime
 from socket import socket
@@ -51,6 +51,7 @@ from airflow.executors import workloads
 from airflow.executors.workloads.task import TaskInstanceDTO
 from airflow.jobs.base_job_runner import BaseJobRunner
 from airflow.jobs.job import perform_heartbeat
+from airflow.models.dagbag import DBDagBag
 from airflow.models.trigger import Trigger
 from airflow.observability.metrics import stats_utils
 from airflow.sdk.api.datamodels._generated import HITLDetailResponse
@@ -84,10 +85,12 @@ from airflow.sdk.execution_time.comms import (
     _RequestFrame,
 )
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess, 
make_buffered_socket_reader
+from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+from airflow.serialization.serialized_objects import DagSerialization
 from airflow.triggers.base import BaseEventTrigger, BaseTrigger, 
DiscrimatedTriggerEvent, TriggerEvent
 from airflow.utils.helpers import log_filename_template_renderer
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import create_session, provide_session
 
 if TYPE_CHECKING:
     from opentelemetry.util._decorator import _AgnosticContextManager
@@ -97,6 +100,7 @@ if TYPE_CHECKING:
     from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
     from airflow.jobs.job import Job
     from airflow.sdk.api.client import Client
+    from airflow.sdk.definitions.context import Context
     from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
 
 logger = logging.getLogger(__name__)
@@ -658,6 +662,65 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
             extra_tags={"hostname": self.job.hostname},
         )
 
+    def _create_workload(
+        self,
+        trigger: Trigger,
+        dag_bag: DBDagBag,
+        render_log_fname: Callable[..., str],
+        session: Session,
+    ) -> workloads.RunTrigger | None:
+        if trigger.task_instance is None:
+            return workloads.RunTrigger(
+                id=trigger.id,
+                classpath=trigger.classpath,
+                encrypted_kwargs=trigger.encrypted_kwargs,
+            )
+
+        if not trigger.task_instance.dag_version_id:
+            # This is to handle 2 to 3 upgrade where TI.dag_version_id can be 
none
+            log.warning(
+                "TaskInstance associated with Trigger has no associated Dag 
Version, skipping the trigger",
+                ti_id=trigger.task_instance.id,
+            )
+            return None
+
+        log_path = render_log_fname(ti=trigger.task_instance)
+        ser_ti = TaskInstanceDTO.model_validate(trigger.task_instance, 
from_attributes=True)
+
+        # When producing logs from TIs, include the job id producing the logs 
to disambiguate it.
+        self.logger_cache[trigger.id] = TriggerLoggingFactory(
+            log_path=f"{log_path}.trigger.{self.job.id}.log",
+            ti=ser_ti,  # type: ignore
+        )
+
+        serialized_dag_model = dag_bag.get_serialized_dag_model(
+            version_id=trigger.task_instance.dag_version_id,
+            session=session,
+        )
+
+        if serialized_dag_model:
+            task = 
serialized_dag_model.dag.get_task(trigger.task_instance.task_id)
+
+            # When a TaskInstance of a Trigger contains a task with 
start_from_trigger enabled,
+            # it means we need to load the SerializedDagModel so we can build 
a RuntimeTaskInstance later on which
+            # will allow us to build a context on which we will render the 
templated fields.
+            if task.start_from_trigger:
+                return workloads.RunTrigger(
+                    id=trigger.id,
+                    classpath=trigger.classpath,
+                    encrypted_kwargs=trigger.encrypted_kwargs,
+                    ti=ser_ti,
+                    timeout_after=trigger.task_instance.trigger_timeout,
+                    dag_data=serialized_dag_model.data,
+                )
+        return workloads.RunTrigger(
+            id=trigger.id,
+            classpath=trigger.classpath,
+            encrypted_kwargs=trigger.encrypted_kwargs,
+            ti=ser_ti,
+            timeout_after=trigger.task_instance.trigger_timeout,
+        )
+
     def update_triggers(self, requested_trigger_ids: set[int]):
         """
         Request that we update what triggers we're running.
@@ -666,8 +729,8 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
         adds them to the dequeues so the subprocess can actually mutate the 
running
         trigger set.
         """
+        dag_bag = DBDagBag()
         render_log_fname = log_filename_template_renderer()
-
         known_trigger_ids = (
             self.running_triggers.union(x[0] for x in self.events)
             .union(self.cancelling_triggers)
@@ -678,60 +741,48 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
         new_trigger_ids = requested_trigger_ids - known_trigger_ids
         cancel_trigger_ids = self.running_triggers - requested_trigger_ids
         # Bulk-fetch new trigger records
-        new_triggers = Trigger.bulk_fetch(new_trigger_ids)
-        trigger_ids_with_non_task_associations = 
Trigger.fetch_trigger_ids_with_non_task_associations()
-        to_create: list[workloads.RunTrigger] = []
-        # Add in new triggers
-        for new_id in new_trigger_ids:
-            # Check it didn't vanish in the meantime
-            if new_id not in new_triggers:
-                log.warning("Trigger disappeared before we could start it", 
id=new_id)
-                continue
-
-            new_trigger_orm = new_triggers[new_id]
-
-            # If the trigger is not associated to a task, an asset, or a 
callback, this means the TaskInstance
-            # row was updated by either Trigger.submit_event or 
Trigger.submit_failure
-            # and can happen when a single trigger Job is being run on 
multiple TriggerRunners
-            # in a High-Availability setup.
-            if new_trigger_orm.task_instance is None and new_id not in 
trigger_ids_with_non_task_associations:
-                log.info(
-                    (
-                        "TaskInstance Trigger is None. It was likely updated 
by another trigger job. "
-                        "Skipping trigger instantiation."
-                    ),
-                    id=new_id,
-                )
-                continue
-
-            workload = workloads.RunTrigger(
-                classpath=new_trigger_orm.classpath,
-                id=new_id,
-                encrypted_kwargs=new_trigger_orm.encrypted_kwargs,
-                ti=None,
+        with create_session() as session:
+            # Bulk-fetch new trigger records
+            new_triggers = Trigger.bulk_fetch(new_trigger_ids, session=session)
+            trigger_ids_with_non_task_associations = 
Trigger.fetch_trigger_ids_with_non_task_associations(
+                session=session
             )
-            if new_trigger_orm.task_instance:
-                log_path = render_log_fname(ti=new_trigger_orm.task_instance)
-                if not new_trigger_orm.task_instance.dag_version_id:
-                    # This is to handle 2 to 3 upgrade where TI.dag_version_id 
can be none
-                    log.warning(
-                        "TaskInstance associated with Trigger has no 
associated Dag Version, skipping the trigger",
-                        ti_id=new_trigger_orm.task_instance.id,
-                    )
+            to_create: list[workloads.RunTrigger] = []
+            # Add in new triggers
+            for new_trigger_id in new_trigger_ids:
+                # Check it didn't vanish in the meantime
+                if new_trigger_id not in new_triggers:
+                    log.warning("Trigger disappeared before we could start 
it", id=new_trigger_id)
                     continue
-                ser_ti = 
TaskInstanceDTO.model_validate(new_trigger_orm.task_instance, 
from_attributes=True)
-                # When producing logs from TIs, include the job id producing 
the logs to disambiguate it.
-                self.logger_cache[new_id] = TriggerLoggingFactory(
-                    log_path=f"{log_path}.trigger.{self.job.id}.log",
-                    ti=ser_ti,  # type: ignore
-                )
 
-                workload.ti = ser_ti
-                workload.timeout_after = 
new_trigger_orm.task_instance.trigger_timeout
+                new_trigger_orm = new_triggers[new_trigger_id]
+
+                # If the trigger is not associated to a task, an asset, or a 
callback, this means the TaskInstance
+                # row was updated by either Trigger.submit_event or 
Trigger.submit_failure
+                # and can happen when a single trigger Job is being run on 
multiple TriggerRunners
+                # in a High-Availability setup.
+                if (
+                    new_trigger_orm.task_instance is None
+                    and new_trigger_id not in 
trigger_ids_with_non_task_associations
+                ):
+                    log.info(
+                        (
+                            "TaskInstance of Trigger is None. It was likely 
updated by another trigger job. "
+                            "Skipping trigger instantiation."
+                        ),
+                        id=new_trigger_id,
+                    )
+                    continue
 
-            to_create.append(workload)
+                if workload := self._create_workload(
+                    trigger=new_trigger_orm,
+                    dag_bag=dag_bag,
+                    render_log_fname=render_log_fname,
+                    session=session,
+                ):
+                    to_create.append(workload)
 
-        self.creating_triggers.extend(to_create)
+            self.creating_triggers.extend(to_create)
 
         if cancel_trigger_ids:
             # Enqueue orphaned triggers for cancellation
@@ -986,9 +1037,19 @@ class TriggerRunner:
             raise RuntimeError(f"Required first message to be a 
messages.StartTriggerer, it was {msg}")
 
     async def create_triggers(self):
+        def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+            task = 
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+            # I need to recreate a TaskInstance from task_runner before 
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+            return RuntimeTaskInstance.model_construct(
+                **workload.ti.model_dump(exclude_unset=True),
+                task=task,
+            )
+
         """Drain the to_create queue and create all new triggers that have 
been requested in the DB."""
         while self.to_create:
             await asyncio.sleep(0)
+            context: Context | None = None
             workload = self.to_create.popleft()
             trigger_id = workload.id
             if trigger_id in self.triggers:
@@ -1016,24 +1077,32 @@ class TriggerRunner:
                 # that could cause None values in collections.
                 kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs)
                 deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for 
k, v in kw.items()}
-                trigger_instance = trigger_class(**deserialised_kwargs)
+
+                if ti := workload.ti:
+                    trigger_name = 
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID 
{trigger_id})"
+                    trigger_instance = trigger_class(**deserialised_kwargs)
+
+                    if workload.dag_data:
+                        runtime_ti = create_runtime_ti(workload.dag_data)
+                        context = runtime_ti.get_template_context()
+                        trigger_instance.task_instance = runtime_ti
+                    else:
+                        trigger_instance.task_instance = ti
+                else:
+                    trigger_name = f"ID {trigger_id}"
+                    trigger_instance = trigger_class(**deserialised_kwargs)
             except TypeError as err:
                 self.log.error("Trigger failed to inflate", error=err)
                 self.failed_triggers.append((trigger_id, err))
                 continue
             trigger_instance.trigger_id = trigger_id
             trigger_instance.triggerer_job_id = self.job_id
-            trigger_instance.task_instance = ti = workload.ti
             trigger_instance.timeout_after = workload.timeout_after
 
-            trigger_name = (
-                
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID 
{trigger_id})"
-                if ti
-                else f"ID {trigger_id}"
-            )
             self.triggers[trigger_id] = {
                 "task": asyncio.create_task(
-                    self.run_trigger(trigger_id, trigger_instance, 
workload.timeout_after), name=trigger_name
+                    self.run_trigger(trigger_id, trigger_instance, 
workload.timeout_after, context),
+                    name=trigger_name,
                 ),
                 "is_watcher": isinstance(trigger_instance, BaseEventTrigger),
                 "name": trigger_name,
@@ -1200,7 +1269,13 @@ class TriggerRunner:
                 )
                 Stats.incr("triggers.blocked_main_thread")
 
-    async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, 
timeout_after: datetime | None = None):
+    async def run_trigger(
+        self,
+        trigger_id: int,
+        trigger: BaseTrigger,
+        timeout_after: datetime | None = None,
+        context: Context | None = None,
+    ):
         """Run a trigger (they are async generators) and push their events 
into our outbound event deque."""
         if not os.environ.get("AIRFLOW_DISABLE_GREENBACK_PORTAL", "").lower() 
== "true":
             import greenback
@@ -1213,6 +1288,9 @@ class TriggerRunner:
         self.log.info("trigger %s starting", name)
         with _make_trigger_span(ti=trigger.task_instance, 
trigger_id=trigger_id, name=name) as span:
             try:
+                if context is not None:
+                    trigger.render_template_fields(context=context)
+
                 async for event in trigger.run():
                     await self.log.ainfo(
                         "Trigger fired event", 
name=self.triggers[trigger_id]["name"], result=event
diff --git a/airflow-core/src/airflow/models/dagbag.py 
b/airflow-core/src/airflow/models/dagbag.py
index e04f77d06df..98799bbde0c 100644
--- a/airflow-core/src/airflow/models/dagbag.py
+++ b/airflow-core/src/airflow/models/dagbag.py
@@ -45,24 +45,44 @@ class DBDagBag:
     """
 
     def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[UUID, SerializedDAG] = {}  # dag_version_id to dag
+        self._dags: dict[UUID, SerializedDagModel] = {}  # dag_version_id to 
dag
         self.load_op_links = load_op_links
 
-    def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
-        serdag.load_op_links = self.load_op_links
-        if dag := serdag.dag:
-            self._dags[serdag.dag_version_id] = dag
+    def _read_dag(self, serialized_dag_model: SerializedDagModel) -> 
SerializedDAG | None:
+        serialized_dag_model.load_op_links = self.load_op_links
+        if dag := serialized_dag_model.dag:
+            self._dags[serialized_dag_model.dag_version_id] = 
serialized_dag_model
         return dag
 
-    def _get_dag(self, version_id: UUID, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
-        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
-        if not dag_version:
-            return None
-        if not (serdag := dag_version.serialized_dag):
-            return None
-        return self._read_dag(serdag)
+    def get_serialized_dag_model(self, version_id: UUID, session: Session) -> 
SerializedDagModel | None:
+        """
+        Return the SerializedDagModel for a given dag version id.
+
+        This will first consult the in-memory cache keyed by the dag version 
id. If the
+        model is not cached, the database is queried for a corresponding 
:class:`DagVersion`
+        and its associated :class:`SerializedDagModel`.
+
+        :param version_id: The UUID of the dag version to look up.
+        :param session: SQLAlchemy session used to query the database.
+        :return: The serialized DAG model if found either in the cache or the 
database; ``None``
+                 is returned when no :class:`DagVersion` exists for the given 
``version_id`` or
+                 when that :class:`DagVersion` does not have an associated 
:class:`SerializedDagModel`.
+        :rtype: SerializedDagModel | None
+
+        Note: If a serialized dag model is found in the database it will be 
stored in the
+        internal cache (``self._dags``) before being returned.
+        """
+        if not (serialized_dag_model := self._dags.get(version_id)):
+            dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
+            if not dag_version or not (serialized_dag_model := 
dag_version.serialized_dag):
+                return None
+            self._read_dag(serialized_dag_model)
+        return serialized_dag_model
+
+    def get_dag(self, version_id: UUID, session: Session) -> SerializedDAG | 
None:
+        if serialized_dag_model := 
self.get_serialized_dag_model(version_id=version_id, session=session):
+            return serialized_dag_model.dag
+        return None
 
     @staticmethod
     def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> UUID | 
None:
@@ -74,24 +94,24 @@ class DBDagBag:
 
     def get_dag_for_run(self, dag_run: DagRun, session: Session) -> 
SerializedDAG | None:
         if version_id := self._version_from_dag_run(dag_run=dag_run, 
session=session):
-            return self._get_dag(version_id=version_id, session=session)
+            return self.get_dag(version_id=version_id, session=session)
         return None
 
     def iter_all_latest_version_dags(self, *, session: Session) -> 
Generator[SerializedDAG, None, None]:
         """Walk through all latest version dags available in the database."""
         from airflow.models.serialized_dag import SerializedDagModel
 
-        for sdm in session.scalars(select(SerializedDagModel)):
-            if dag := self._read_dag(sdm):
+        for serialized_dag_model in 
session.scalars(select(SerializedDagModel)):
+            if dag := self._read_dag(serialized_dag_model):
                 yield dag
 
     def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> 
SerializedDAG | None:
         """Get the latest version of a dag by its id."""
         from airflow.models.serialized_dag import SerializedDagModel
 
-        if not (serdag := SerializedDagModel.get(dag_id, session=session)):
+        if not (serialized_dag_model := SerializedDagModel.get(dag_id, 
session=session)):
             return None
-        return self._read_dag(serdag)
+        return self._read_dag(serialized_dag_model)
 
 
 def generate_md5_hash(context):
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index bbff43aad9a..c93f0ed8e1e 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1986,7 +1986,14 @@ class DagRun(Base, LoggingMixin):
         debug_try_number_check = self.log.isEnabledFor(logging.DEBUG)
         expected_try_number_by_ti_id: dict[UUID, tuple[int, int, str | None]] 
= {}
         for ti in schedulable_tis:
-            if ti.is_schedulable:
+            if not ti.is_schedulable:
+                empty_ti_ids.append(ti.id)
+            # The defer_task method will check "start_trigger_args" to see 
whether the operator
+            # start execution from triggerer. If so, we'll also check 
"start_from_trigger"
+            # to see whether this feature is turned on and defer this task.
+            # If not, we'll add this "ti" into "schedulable_ti_ids" and later
+            # execute it to run in the worker.
+            elif not ti.defer_task(session=session):
                 schedulable_ti_ids.append(ti.id)
                 if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
                     reschedule_ti_ids.add(ti.id)
@@ -1998,25 +2005,6 @@ class DagRun(Base, LoggingMixin):
                         ti.try_number,
                         ti.state,
                     )
-            # Check "start_trigger_args" to see whether the operator supports
-            # start execution from triggerer. If so, we'll check 
"start_from_trigger"
-            # to see whether this feature is turned on and defer this task.
-            # If not, we'll add this "ti" into "schedulable_ti_ids" and later
-            # execute it to run in the worker.
-            # TODO TaskSDK: This is disabled since we haven't figured out how
-            # to render start_from_trigger in the scheduler. If we need to
-            # render the value in a worker, it kind of defeats the purpose of
-            # this feature (which is to save a worker process if possible).
-            # elif task.start_trigger_args is not None:
-            #     if 
task.expand_start_from_trigger(context=ti.get_template_context()):
-            #         ti.start_date = timezone.utcnow()
-            #         if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
-            #             ti.try_number += 1
-            #         ti.defer_task(exception=None, session=session)
-            #     else:
-            #         schedulable_ti_ids.append(ti.id)
-            else:
-                empty_ti_ids.append(ti.id)
 
         count = 0
         # Don't only check if the TI.id is in id_chunk
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index e212ca68504..4c2137a5343 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -121,7 +121,7 @@ if TYPE_CHECKING:
     from airflow.serialization.definitions.dag import SerializedDAG
     from airflow.serialization.definitions.mappedoperator import Operator
     from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
-
+    from airflow.triggers.base import StartTriggerArgs
 
 PAST_DEPENDS_MET = "past_depends_met"
 
@@ -1590,6 +1590,73 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
                 .values(last_heartbeat_at=timezone.utcnow())
             )
 
+    @property
+    def start_trigger_args(self) -> StartTriggerArgs | None:
+        if self.task and self.task.start_from_trigger is True:
+            return self.task.start_trigger_args
+        return None
+
+    # TODO: We have some code duplication here and in the 
_create_ti_state_update_query_and_update_state
+    #       method of the task_instances module in the execution api when a 
TIDeferredStatePayload is being
+    #       processed. This is because of a TaskInstance being updated 
differently using SQLAlchemy.
+    #       If we use the approach from the execution api as common code in 
the DagRun schedule_tis method,
+    #       the side effect is the changes done to the task instance aren't 
picked up by the scheduler and
+    #       thus the task instance isn't processed until the scheduler is 
restarted.
+    @provide_session
+    def defer_task(self, session: Session = NEW_SESSION) -> bool:
+        """
+        Mark the task as deferred and sets up the trigger that is needed to 
resume it when TaskDeferred is raised.
+
+        :meta: private
+        """
+        from airflow.models.trigger import Trigger
+
+        if TYPE_CHECKING:
+            assert self.start_date
+            assert isinstance(self.task, Operator)
+
+        if start_trigger_args := self.start_trigger_args:
+            trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+            timeout = start_trigger_args.timeout
+
+            # Calculate timeout too if it was passed
+            if timeout is not None:
+                self.trigger_timeout = timezone.utcnow() + timeout
+            else:
+                self.trigger_timeout = None
+
+            trigger_row = Trigger(
+                classpath=start_trigger_args.trigger_cls,
+                kwargs=trigger_kwargs,
+            )
+
+            # First, make the trigger entry
+            session.add(trigger_row)
+            session.flush()
+
+            # Then, update ourselves so it matches the deferral request
+            # Keep an eye on the logic in 
`check_and_change_state_before_execution()`
+            # depending on self.next_method semantics
+            self.state = TaskInstanceState.DEFERRED
+            self.trigger_id = trigger_row.id
+            self.next_method = start_trigger_args.next_method
+            self.next_kwargs = start_trigger_args.next_kwargs or {}
+
+            # If an execution_timeout is set, set the timeout to the minimum of
+            # it and the trigger timeout
+            if execution_timeout := self.task.execution_timeout:
+                if self.trigger_timeout:
+                    self.trigger_timeout = min(self.start_date + 
execution_timeout, self.trigger_timeout)
+                else:
+                    self.trigger_timeout = self.start_date + execution_timeout
+            self.start_date = timezone.utcnow()
+            if self.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+                self.try_number += 1
+            if self.test_mode:
+                _add_log(event=self.state, task_instance=self, session=session)
+            return True
+        return False
+
     @classmethod
     def fetch_handle_failure_context(
         cls,
diff --git a/airflow-core/src/airflow/triggers/base.py 
b/airflow-core/src/airflow/triggers/base.py
index 416558242b8..7ca7ed20a74 100644
--- a/airflow-core/src/airflow/triggers/base.py
+++ b/airflow-core/src/airflow/triggers/base.py
@@ -21,7 +21,7 @@ import json
 from collections.abc import AsyncIterator
 from dataclasses import dataclass
 from datetime import timedelta
-from typing import Annotated, Any
+from typing import TYPE_CHECKING, Annotated, Any
 
 import structlog
 from pydantic import (
@@ -32,11 +32,24 @@ from pydantic import (
     model_serializer,
 )
 
+from airflow.sdk.definitions._internal.templater import Templater
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import TaskInstanceState
 
 log = structlog.get_logger(logger_name=__name__)
 
+if TYPE_CHECKING:
+    from typing import TypeAlias
+
+    import jinja2
+
+    from airflow.models.mappedoperator import MappedOperator
+    from airflow.models.taskinstance import TaskInstance
+    from airflow.sdk.definitions.context import Context
+    from airflow.serialization.serialized_objects import SerializedBaseOperator
+
+    Operator: TypeAlias = MappedOperator | SerializedBaseOperator
+
 
 @dataclass
 class StartTriggerArgs:
@@ -49,7 +62,7 @@ class StartTriggerArgs:
     timeout: timedelta | None = None
 
 
-class BaseTrigger(abc.ABC, LoggingMixin):
+class BaseTrigger(abc.ABC, Templater, LoggingMixin):
     """
     Base class for all triggers.
 
@@ -66,14 +79,56 @@ class BaseTrigger(abc.ABC, LoggingMixin):
     supports_triggerer_queue: bool = True
 
     def __init__(self, **kwargs):
+        super().__init__()
         # these values are set by triggerer when preparing to run the instance
         # when run, they are injected into logger record.
-        self.task_instance = None
+        self._task_instance = None
         self.trigger_id = None
+        self.template_fields = ()
+        self.template_ext = ()
+        self.task_id = None
 
     def _set_context(self, context):
         """Part of LoggingMixin and used mainly for configuration of task 
logging; not used for triggers."""
-        raise NotImplementedError
+        pass
+
+    @property
+    def task(self) -> Operator | None:
+        # We must check if the TaskInstance is the generated Pydantic one or 
the RuntimeTaskInstance
+        if self.task_instance and hasattr(self.task_instance, "task"):
+            return self.task_instance.task
+        return None
+
+    @property
+    def task_instance(self) -> TaskInstance:
+        return self._task_instance
+
+    @task_instance.setter
+    def task_instance(self, value: TaskInstance | None) -> None:
+        self._task_instance = value
+        if self.task_instance:
+            self.task_id = self.task_instance.task_id
+        if self.task:
+            self.template_fields = self.task.template_fields
+            self.template_ext = self.task.template_ext
+
+    def render_template_fields(
+        self,
+        context: Context,
+        jinja_env: jinja2.Environment | None = None,
+    ) -> None:
+        """
+        Template all attributes listed in *self.template_fields*.
+
+        This mutates the attributes in-place and is irreversible.
+
+        :param context: Context dict with values to apply on content.
+        :param jinja_env: Jinja's environment to use for rendering.
+        """
+        if not jinja_env:
+            jinja_env = self.get_template_env()
+        # We only need to render templated fields if templated fields are part 
of the start_trigger_args
+        self._do_render_template_fields(self, self.template_fields, context, 
jinja_env, set())
 
     @abc.abstractmethod
     def serialize(self) -> tuple[str, dict[str, Any]]:
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py 
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 3761189bfeb..503a3f4834c 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -120,9 +120,9 @@ def create_trigger_in_db(session, trigger, operator=None):
     session.merge(testing_bundle)
     session.flush()
 
-    dag_model = DagModel(dag_id="test_dag", bundle_name=bundle_name)
-    dag = DAG(dag_id=dag_model.dag_id, schedule="@daily", 
start_date=pendulum.datetime(2023, 1, 1))
     date = pendulum.datetime(2023, 1, 1)
+    dag_model = DagModel(dag_id="test_dag", bundle_name=bundle_name)
+    dag = DAG(dag_id=dag_model.dag_id, schedule="@daily", start_date=date)
     run = DagRun(
         dag_id=dag_model.dag_id,
         run_id="test_run",
@@ -265,6 +265,7 @@ def test_trigger_lifecycle(spy_agency: SpyAgency, session, 
testing_dag_bundle):
                 classpath=trigger.serialize()[0],
                 encrypted_kwargs=trigger_orm.encrypted_kwargs,
                 kind="RunTrigger",
+                dag_data=ANY,
             )
         )
         # OK, now remove it from the DB
diff --git a/airflow-core/tests/unit/models/test_dagbag.py 
b/airflow-core/tests/unit/models/test_dagbag.py
index 3b5b9887726..48f249205bb 100644
--- a/airflow-core/tests/unit/models/test_dagbag.py
+++ b/airflow-core/tests/unit/models/test_dagbag.py
@@ -16,8 +16,14 @@
 # under the License.
 from __future__ import annotations
 
+from unittest.mock import MagicMock, patch
+
 import pytest
 
+from airflow.models.dagbag import DBDagBag
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.serialization.serialized_objects import SerializedDAG
+
 pytestmark = pytest.mark.db_test
 
 # This file previously contained tests for DagBag functionality, but those 
tests
@@ -26,3 +32,76 @@ pytestmark = pytest.mark.db_test
 #
 # Tests for models-specific functionality (DBDagBag, 
DagPriorityParsingRequest, etc.)
 # would remain in this file, but currently no such tests exist.
+
+
+class TestDBDagBag:
+    def setup_method(self):
+        self.db_dag_bag = DBDagBag()
+        self.session = MagicMock()
+
+    def test__read_dag_stores_and_returns_dag(self):
+        """It should store the SerializedDagModel in _dags and return the 
dag."""
+        mock_dag = MagicMock(spec=SerializedDAG)
+        mock_serdag = MagicMock(spec=SerializedDagModel)
+        mock_serdag.dag = mock_dag
+        mock_serdag.dag_version_id = "v1"
+
+        result = self.db_dag_bag._read_dag(mock_serdag)
+
+        assert result == mock_dag
+        assert self.db_dag_bag._dags["v1"] == mock_serdag
+        assert mock_serdag.load_op_links is True
+
+    def test__read_dag_returns_none_when_no_dag(self):
+        """It should return None and not modify _dags when no DAG is 
present."""
+        mock_serdag = MagicMock(spec=SerializedDagModel)
+        mock_serdag.dag = None
+        mock_serdag.dag_version_id = "v1"
+
+        result = self.db_dag_bag._read_dag(mock_serdag)
+
+        assert result is None
+        assert "v1" not in self.db_dag_bag._dags
+
+    def test_get_serialized_dag_model(self):
+        """It should return the cached SerializedDagModel if already loaded."""
+        mock_serdag = MagicMock(spec=SerializedDagModel)
+        mock_serdag.dag_version_id = "v1"
+        mock_dag_version = MagicMock()
+        mock_dag_version.serialized_dag = mock_serdag
+        self.session.get.return_value = mock_dag_version
+
+        self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
+        result = self.db_dag_bag.get_serialized_dag_model("v1", 
session=self.session)
+
+        assert result == mock_serdag
+        self.session.get.assert_called_once()
+
+    def test_get_serialized_dag_model_returns_none_when_not_found(self):
+        """It should return None if version_id not found in DB."""
+        self.session.get.return_value = None
+
+        result = self.db_dag_bag.get_serialized_dag_model("v1", 
session=self.session)
+
+        assert result is None
+
+    def test_get_dag_calls_get_dag_model_and__read_dag(self):
+        """It should call get_dag_model and then _read_dag."""
+        mock_serdag = MagicMock(spec=SerializedDagModel)
+        mock_serdag.dag_version_id = "v1"
+        mock_dag = MagicMock(spec=SerializedDAG)
+        mock_dag_version = MagicMock()
+        mock_dag_version.serialized_dag = mock_serdag
+        mock_serdag.dag = mock_dag
+        self.session.get.return_value = mock_dag_version
+
+        result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+        self.session.get.assert_called_once()
+        assert result == mock_dag
+
+    def test_get_dag_returns_none_when_model_missing(self):
+        """It should return None if no SerializedDagModel found."""
+        with patch.object(self.db_dag_bag, "get_serialized_dag_model", 
return_value=None):
+            result = self.db_dag_bag.get_dag("v1", session=self.session)
+        assert result is None
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index bb058d1a737..3fa09106ade 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -2653,6 +2653,103 @@ def test_refresh_from_task(pool_override, 
queue_by_policy, monkeypatch):
     assert ti.max_tries == expected_max_tries
 
 
+def 
test_defer_task_returns_false_when_no_start_from_trigger(create_task_instance):
+    session = mock.MagicMock()
+    ti = create_task_instance(
+        dag_id="test_defer_task",
+        task_id="test_defer_task_op",
+    )
+    assert not ti.defer_task(session=session)
+
+
+def 
test_defer_task_returns_false_when_no_start_trigger_args(create_task_instance):
+    session = mock.MagicMock()
+    ti = create_task_instance(
+        dag_id="test_defer_task",
+        task_id="test_defer_task",
+        start_from_trigger=True,
+    )
+    assert not ti.defer_task(session=session)
+
+
+def test_defer_task(create_task_instance):
+    from airflow.models.trigger import Trigger
+    from airflow.triggers.base import StartTriggerArgs
+
+    session = mock.MagicMock()
+    ti = create_task_instance(
+        dag_id="test_defer_task",
+        task_id="test_defer_task_op",
+        start_from_trigger=True,
+        start_trigger_args=StartTriggerArgs(
+            trigger_cls="trigger_cls",
+            next_method="next_method",
+            trigger_kwargs={"key": "value"},
+        ),
+    )
+    assert ti.defer_task(session=session)
+
+    # Check that session.add was called with a Trigger
+    assert session.add.call_count == 1
+    trigger_row = session.add.call_args[0][0]
+    assert isinstance(trigger_row, Trigger)
+    assert trigger_row.classpath == "trigger_cls"
+    assert trigger_row.kwargs == {"key": "value"}
+
+    # Check that session.flush was called
+    session.flush.assert_called_once()
+
+    # Check that TaskInstance state was updated
+    assert ti.state == TaskInstanceState.DEFERRED
+    assert ti.trigger_id == trigger_row.id
+    assert ti.next_method == "next_method"
+    assert ti.next_kwargs == {}
+
+    # Check trigger_timeout is set (should be None since no timeout provided)
+    assert ti.trigger_timeout is None
+
+
+def test_defer_task_with_trigger_timeout(create_task_instance):
+    from airflow.models.trigger import Trigger
+    from airflow.triggers.base import StartTriggerArgs
+
+    session = mock.MagicMock()
+    timeout = datetime.timedelta(hours=1)
+    ti = create_task_instance(
+        dag_id="test_defer_task_with_trigger_timeout",
+        task_id="test_defer_task_with_trigger_timeout_op",
+        start_from_trigger=True,
+        start_trigger_args=StartTriggerArgs(
+            trigger_cls="trigger_cls",
+            next_method="next_method",
+            trigger_kwargs={"key": "value"},
+            timeout=timeout,
+        ),
+    )
+
+    # Save start_date to calculate expected trigger_timeout
+    now = timezone.utcnow()
+    ti.start_date = now
+
+    ti.defer_task(session=session)
+
+    # Check session interactions
+    assert session.add.call_count == 1
+    trigger_row = session.add.call_args[0][0]
+    assert isinstance(trigger_row, Trigger)
+    session.flush.assert_called_once()
+
+    # TaskInstance fields
+    assert ti.state == TaskInstanceState.DEFERRED
+    assert ti.trigger_id == trigger_row.id
+    assert ti.next_method == "next_method"
+    assert ti.next_kwargs == {}
+
+    # Check trigger_timeout is set correctly (within a small tolerance)
+    expected_timeout = now + timeout
+    assert abs((ti.trigger_timeout - expected_timeout).total_seconds()) < 5
+
+
 class TestTaskInstanceRecordTaskMapXComPush:
     """Test TI.xcom_push() correctly records return values for task-mapping."""
 
diff --git a/airflow-core/tests/unit/triggers/test_base_trigger.py 
b/airflow-core/tests/unit/triggers/test_base_trigger.py
new file mode 100644
index 00000000000..53066c46f6a
--- /dev/null
+++ b/airflow-core/tests/unit/triggers/test_base_trigger.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.triggers.base import BaseTrigger, StartTriggerArgs
+
+
+class DummyOperator(BaseOperator):
+    template_fields = ("name",)
+
+
+class DummyTrigger(BaseTrigger):
+    def __init__(self, name: str, **kwargs):
+        super().__init__(**kwargs)
+        self.name = name
+
+    def run(self):
+        return None
+
+    def serialize(self):
+        return {"name": self.name}
+
+
[email protected]_test
+def test_render_template_fields(create_task_instance):
+    op = DummyOperator(task_id="dummy_task")
+    ti = create_task_instance(
+        task=op,
+        start_from_trigger=True,
+        start_trigger_args=StartTriggerArgs(
+            
trigger_cls=f"{DummyTrigger.__module__}.{DummyTrigger.__qualname__}",
+            next_method="resume_method",
+            trigger_kwargs={"name": "Hello {{ name }}"},
+        ),
+    )
+
+    trigger = DummyTrigger(name="Hello {{ name }}")
+
+    assert not trigger.task_instance
+    assert not trigger.template_fields
+    assert not trigger.template_ext
+
+    trigger.task_instance = ti
+
+    assert trigger.task_instance == ti
+    assert "name" in trigger.template_fields
+    assert not trigger.template_ext
+
+    trigger.render_template_fields(context={"name": "world"})
+
+    assert trigger.name == "Hello world"
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index a4b590c5452..4fc35fe3514 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -53,6 +53,7 @@ if TYPE_CHECKING:
     from airflow.sdk.types import DagRunProtocol, Operator
     from airflow.serialization.definitions.dag import SerializedDAG
     from airflow.timetables.base import DagRunInfo, DataInterval
+    from airflow.triggers.base import StartTriggerArgs
     from airflow.typing_compat import Self
     from airflow.utils.state import DagRunState, TaskInstanceState
 
@@ -1564,6 +1565,9 @@ def create_task_instance(
         hostname=None,
         pid=None,
         last_heartbeat_at=None,
+        task: Operator | None = None,
+        start_from_trigger: bool = False,
+        start_trigger_args: StartTriggerArgs | None = None,
         **kwargs,
     ) -> TaskInstance:
         timezone = _import_timezone()
@@ -1572,26 +1576,33 @@ def create_task_instance(
         if logical_date is NOTSET:
             # For now: default to having a logical date if None is not 
explicitly passed.
             logical_date = timezone.utcnow()
-        with dag_maker(dag_id, **kwargs):
+        with dag_maker(dag_id, **kwargs) as dag:
             op_kwargs = {}
             op_kwargs["task_display_name"] = task_display_name
-            task = EmptyOperator(
-                task_id=task_id,
-                max_active_tis_per_dag=max_active_tis_per_dag,
-                max_active_tis_per_dagrun=max_active_tis_per_dagrun,
-                executor_config=executor_config or {},
-                on_success_callback=on_success_callback,
-                on_execute_callback=on_execute_callback,
-                on_failure_callback=on_failure_callback,
-                on_retry_callback=on_retry_callback,
-                on_skipped_callback=on_skipped_callback,
-                inlets=inlets,
-                outlets=outlets,
-                email=email,
-                pool=pool,
-                trigger_rule=trigger_rule,
-                **op_kwargs,
-            )
+            if not task:
+                task = EmptyOperator(
+                    task_id=task_id,
+                    max_active_tis_per_dag=max_active_tis_per_dag,
+                    max_active_tis_per_dagrun=max_active_tis_per_dagrun,
+                    executor_config=executor_config or {},
+                    on_success_callback=on_success_callback,
+                    on_execute_callback=on_execute_callback,
+                    on_failure_callback=on_failure_callback,
+                    on_retry_callback=on_retry_callback,
+                    on_skipped_callback=on_skipped_callback,
+                    inlets=inlets,
+                    outlets=outlets,
+                    email=email,
+                    pool=pool,
+                    trigger_rule=trigger_rule,
+                    **op_kwargs,
+                )
+            else:
+                task_id = task.task_id
+                task.dag = dag
+            task.start_from_trigger = start_from_trigger
+            task.start_trigger_args = start_trigger_args
+
         if AIRFLOW_V_3_0_PLUS:
             dagrun_kwargs = {
                 "logical_date": logical_date,
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index 6e88f0a94ad..4d5905ab73b 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -550,6 +550,11 @@ class BaseOperatorMeta(abc.ABCMeta):
             # Store the args passed to init -- we need them to support 
task.map serialization!
             self._BaseOperator__init_kwargs.update(kwargs)  # type: ignore
 
+            # Validate trigger kwargs.
+            # Make sure method exists as class can depend on metaclass without 
extending the BaseOperator.
+            if hasattr(self, "_validate_start_from_trigger_kwargs"):
+                self._validate_start_from_trigger_kwargs()
+
             # Set upstream task defined by XComArgs passed to template fields 
of the operator.
             # BUT: only do this _ONCE_, not once for each class in the 
hierarchy
             if not instantiated_from_mapped and func == 
self.__init__.__wrapped__:  # type: ignore[misc]
@@ -846,6 +851,14 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         to render templates as native Python types. If False, a Jinja
         ``Environment`` is used to render templates as string values.
         If None (default), inherits from the DAG setting.
+    :param start_from_trigger: If True, the operator starts execution directly 
in the triggerer,
+        skipping the initial worker execution phase. In this mode, templated 
fields are rendered
+        inside the triggerer instead of the worker. This avoids an extra round 
trip to a worker,
+        but may increase load on the triggerer, since the DAG must be 
serialized in order to
+        render templated fields. Use with care for DAGs with many tasks or 
heavy templating.
+    :param start_trigger_args: Used together with ``start_from_trigger`` to 
explicitly specify
+        which operator fields should be passed to the trigger. This helps 
limit the amount of
+        data serialized and sent to the triggerer.
     """
 
     task_id: str
@@ -1440,6 +1453,15 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             return
         XComArg.apply_upstream_relationship(self, newvalue)
 
+    def _validate_start_from_trigger_kwargs(self):
+        if self.start_from_trigger and self.start_trigger_args and 
self.start_trigger_args.trigger_kwargs:
+            for name, val in self.start_trigger_args.trigger_kwargs.items():
+                if callable(val):
+                    raise ValueError(
+                        f"{self.__class__.__name__} with task_id 
'{self.task_id}' has a callable in trigger kwargs named "
+                        f"'{name}', which is not allowed when 
start_from_trigger is enabled."
+                    )
+
     def on_kill(self) -> None:
         """
         Override this method to clean up subprocesses when a task instance 
gets killed.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py 
b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
index e32bd377f01..00b811146a6 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
@@ -285,59 +285,6 @@ class AbstractOperator(Templater, DAGNode):
             dag = self.get_dag()
         return super()._render(template, context, dag=dag)
 
-    def _do_render_template_fields(
-        self,
-        parent: Any,
-        template_fields: Iterable[str],
-        context: Context,
-        jinja_env: jinja2.Environment,
-        seen_oids: set[int],
-    ) -> None:
-        """Override the base to use custom error logging."""
-        for attr_name in template_fields:
-            try:
-                value = getattr(parent, attr_name)
-            except AttributeError:
-                raise AttributeError(
-                    f"{attr_name!r} is configured as a template field "
-                    f"but {parent.task_type} does not have this attribute."
-                )
-            try:
-                if not value:
-                    continue
-            except Exception:
-                # This may happen if the templated field points to a class 
which does not support `__bool__`,
-                # such as Pandas DataFrames:
-                # 
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
-                log.info(
-                    "Unable to check if the value of type '%s' is False for 
task '%s', field '%s'.",
-                    type(value).__name__,
-                    self.task_id,
-                    attr_name,
-                )
-                # We may still want to render custom classes which do not 
support __bool__
-                pass
-
-            try:
-                if callable(value):
-                    rendered_content = value(context=context, 
jinja_env=jinja_env)
-                else:
-                    rendered_content = self.render_template(value, context, 
jinja_env, seen_oids)
-            except Exception:
-                # Mask sensitive values in the template before logging
-                from airflow.sdk._shared.secrets_masker import redact
-
-                masked_value = redact(value)
-                log.exception(
-                    "Exception rendering Jinja template for task '%s', field 
'%s'. Template: %r",
-                    self.task_id,
-                    attr_name,
-                    masked_value,
-                )
-                raise
-            else:
-                setattr(parent, attr_name, rendered_content)
-
     def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | 
MappedTaskGroup]:
         """
         Return mapped nodes that are direct dependencies of the current task.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py 
b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py
index f094ccd6b28..cfe4a6100e4 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import datetime
 import logging
 import os
-from collections.abc import Collection, Iterable, Sequence
+from collections.abc import Collection, Iterable, Iterator, Sequence
 from dataclasses import dataclass
 from typing import TYPE_CHECKING, Any
 
@@ -117,6 +117,48 @@ class Templater:
 
         return dag.render_template_as_native_obj if dag else False
 
+    def _iter_templated_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+    ) -> Iterator[tuple[str, Any]]:
+        """
+        Iterate over template fields yielding ``(attr_name, value)`` pairs for 
non-empty fields.
+
+        Fields whose value is falsy are skipped.  Objects that do not support
+        ``__bool__`` (e.g. Pandas DataFrames) are still yielded.
+        """
+        for attr_name in template_fields:
+            try:
+                value = getattr(parent, attr_name)
+            except AttributeError:
+                raise AttributeError(
+                    f"{attr_name!r} is configured as a template field "
+                    f"but {type(parent).__name__} does not have this 
attribute."
+                )
+            try:
+                if not value:
+                    continue
+            except Exception:
+                # This may happen if the templated field points to a class 
which does not support
+                # ``__bool__``, such as Pandas DataFrames:
+                # 
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
+                if hasattr(self, "task_id"):
+                    log.info(
+                        "Unable to check if the value of type '%s' is False 
for task '%s', field '%s'.",
+                        type(value).__name__,
+                        self.task_id,
+                        attr_name,
+                    )
+                else:
+                    log.info(
+                        "Unable to check if the value of type '%s' is False 
for field '%s'.",
+                        type(value).__name__,
+                        attr_name,
+                    )
+                # We may still want to render custom classes which do not 
support __bool__
+            yield attr_name, value
+
     def _do_render_template_fields(
         self,
         parent: Any,
@@ -125,15 +167,47 @@ class Templater:
         jinja_env: jinja2.Environment,
         seen_oids: set[int],
     ) -> None:
-        for attr_name in template_fields:
-            value = getattr(parent, attr_name)
-            rendered_content = self.render_template(
-                value,
-                context,
-                jinja_env,
-                seen_oids,
-            )
-            if rendered_content:
+        """
+        Render template fields on *parent* in-place.
+
+        For each non-empty field yielded by :meth:`_iter_templated_fields`, 
the value is
+        rendered (or called, when it is callable) and the result is written 
back via
+        ``setattr``.  Rendering errors are logged with masked values before 
being re-raised.
+
+        :param parent: The object whose attributes will be templated.
+        :param template_fields: Names of the attributes to render.
+        :param context: Context dict with values to apply on content.
+        :param jinja_env: Jinja2 environment to use for rendering.
+        :param seen_oids: Set of already-rendered object ids used to prevent 
infinite
+            recursion on circular references.
+        """
+        for attr_name, value in self._iter_templated_fields(parent, 
template_fields):
+            try:
+                if callable(value):
+                    rendered_content = value(context=context, 
jinja_env=jinja_env)
+                else:
+                    rendered_content = self.render_template(value, context, 
jinja_env, seen_oids)
+            except Exception:
+                # Mask sensitive values in the template before logging
+                from airflow.sdk._shared.secrets_masker import redact
+
+                masked_value = redact(value)
+                if hasattr(self, "task_id"):
+                    log.exception(
+                        "Exception rendering Jinja template for task '%s', 
field '%s'. Template: %r",
+                        self.task_id,
+                        attr_name,
+                        masked_value,
+                    )
+                else:
+                    log.exception(
+                        "Exception rendering Jinja template for %s, field 
'%s'. Template: %r",
+                        type(parent).__name__,
+                        attr_name,
+                        masked_value,
+                    )
+                raise
+            else:
                 setattr(parent, attr_name, rendered_content)
 
     def _render(self, template, context, dag=None) -> Any:
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index abc4c86ed85..7c0540421d4 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -226,6 +226,16 @@ class OperatorPartial:
         task_group = partial_kwargs.pop("task_group")
         start_date = partial_kwargs.pop("start_date", None)
         end_date = partial_kwargs.pop("end_date", None)
+        start_from_trigger = (
+            partial_kwargs["start_from_trigger"]
+            if "start_from_trigger" in partial_kwargs
+            else getattr(self.operator_class, "start_from_trigger", False)
+        )
+        start_trigger_args = (
+            partial_kwargs["start_trigger_args"]
+            if "start_trigger_args" in partial_kwargs
+            else getattr(self.operator_class, "start_trigger_args", None)
+        )
 
         try:
             operator_name = self.operator_class.custom_operator_name  # type: 
ignore
@@ -259,8 +269,8 @@ class OperatorPartial:
             # to BaseOperator.expand() contribute to operator arguments.
             expand_input_attr="expand_input",
             # TODO: Move these to task SDK's BaseOperator and remove getattr
-            start_trigger_args=getattr(self.operator_class, 
"start_trigger_args", None),
-            start_from_trigger=bool(getattr(self.operator_class, 
"start_from_trigger", False)),
+            start_trigger_args=start_trigger_args,
+            start_from_trigger=start_from_trigger,
         )
         return op
 
diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py 
b/task-sdk/tests/task_sdk/bases/test_operator.py
index 9e6db88d5cf..dcb5240a83d 100644
--- a/task-sdk/tests/task_sdk/bases/test_operator.py
+++ b/task-sdk/tests/task_sdk/bases/test_operator.py
@@ -41,6 +41,7 @@ from airflow.sdk.bases.operator import (
 )
 from airflow.sdk.definitions.param import ParamsDict
 from airflow.sdk.definitions.template import literal
+from airflow.triggers.base import StartTriggerArgs
 
 DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
 
@@ -108,9 +109,18 @@ class MockOperator(BaseOperator):
         super().__init__(**kwargs)
         self.arg1 = arg1
         self.arg2 = arg2
+        if self.start_from_trigger:
+            self.start_trigger_args = StartTriggerArgs(
+                trigger_cls="trigger_cls",
+                next_method="next_method",
+                trigger_kwargs={"arg1": arg1, "arg2": arg2},
+            )
 
 
 class TestBaseOperator:
+    def setup_method(self, method):
+        MockOperator.start_from_trigger = False
+
     # Since we have a custom metaclass, lets double check the behaviour of
     # passing args in the wrong way (args etc)
     def test_kwargs_only(self):
@@ -800,6 +810,16 @@ class TestBaseOperator:
         task.render_template_fields(context={"foo": "whatever", "bar": 
"whatever"})
         assert mock_jinja_env.call_count == 1
 
+    def test_validate_start_from_trigger_kwargs(self):
+        MockOperator.start_from_trigger = True
+
+        with pytest.raises(
+            ValueError,
+            match="MockOperator with task_id 'one' has a callable in trigger 
kwargs named "
+            "'arg2', which is not allowed when start_from_trigger is enabled.",
+        ):
+            MockOperator(task_id="one", arg1="{{ foo }}", arg2=lambda context, 
jinja_env: "bar")
+
     def test_params_source(self):
         # Test bug when copying an operator attached to a Dag
         with DAG(
diff --git a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py 
b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py
index bcce3c89547..fccdfe8664c 100644
--- a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py
+++ b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from datetime import datetime, timezone
+from unittest.mock import MagicMock, NonCallableMagicMock
 
 import jinja2
 import pytest
@@ -111,6 +112,193 @@ class TestTemplater:
 
         assert rendered_content == "template_file.txt"
 
+    def test_do_render_template_fields_basic(self):
+        """Test that _do_render_template_fields renders a simple string 
template field in-place."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["greeting"])
+        parent.greeting = "Hello {{ name }}"
+
+        context = {"name": "world"}
+        jinja_env = templater.get_template_env()
+
+        templater._do_render_template_fields(parent, ["greeting"], context, 
jinja_env, set())
+
+        assert parent.greeting == "Hello world"
+
+    def test_do_render_template_fields_multiple_fields(self):
+        """Test rendering multiple template fields at once."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["first", "second"])
+        parent.first = "Hello {{ name }}"
+        parent.second = "Date: {{ ds }}"
+
+        context = {"name": "world", "ds": "2024-01-01"}
+        jinja_env = templater.get_template_env()
+
+        templater._do_render_template_fields(parent, ["first", "second"], 
context, jinja_env, set())
+
+        assert parent.first == "Hello world"
+        assert parent.second == "Date: 2024-01-01"
+
+    def test_do_render_template_fields_callable_value(self):
+        """Test that callable field values are called with context and 
jinja_env."""
+        templater = Templater()
+        templater.template_ext = []
+
+        callback = MagicMock(spec=lambda context, jinja_env: None, 
return_value="resolved")
+        parent = MagicMock(spec=["my_field"])
+        parent.my_field = callback
+
+        context = {"key": "value"}
+        jinja_env = templater.get_template_env()
+
+        templater._do_render_template_fields(parent, ["my_field"], context, 
jinja_env, set())
+
+        callback.assert_called_once_with(context=context, jinja_env=jinja_env)
+        assert parent.my_field == "resolved"
+
+    def test_do_render_template_fields_skips_falsy_values(self):
+        """Test that falsy field values (empty string, None, 0) are skipped."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["empty_str", "none_val"])
+        parent.empty_str = ""
+        parent.none_val = None
+
+        context = {"name": "world"}
+        jinja_env = templater.get_template_env()
+
+        templater._do_render_template_fields(parent, ["empty_str", 
"none_val"], context, jinja_env, set())
+
+        # Falsy values should not be touched
+        assert parent.empty_str == ""
+        assert parent.none_val is None
+
+    def test_do_render_template_fields_missing_attribute(self):
+        """Test that a missing attribute on parent raises AttributeError."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["existing"])
+        parent.existing = "value"
+
+        context = {}
+        jinja_env = templater.get_template_env()
+
+        with pytest.raises(
+            AttributeError,
+            match="'nonexistent' is configured as a template field",
+        ):
+            templater._do_render_template_fields(parent, ["nonexistent"], 
context, jinja_env, set())
+
+    def test_do_render_template_fields_exception_logged_with_task_id(self, 
caplog):
+        """Test that rendering errors are logged with task_id when available 
and re-raised."""
+        templater = Templater()
+        templater.template_ext = []
+        templater.task_id = "my_task"
+
+        parent = MagicMock(spec=["bad_field"])
+        parent.bad_field = "{{ undefined_var }}"
+
+        context = {}
+        jinja_env = SandboxedEnvironment(undefined=jinja2.StrictUndefined, 
cache_size=0)
+
+        with pytest.raises(jinja2.UndefinedError):
+            templater._do_render_template_fields(parent, ["bad_field"], 
context, jinja_env, set())
+
+        assert "Exception rendering Jinja template for task 'my_task', field 
'bad_field'" in caplog.text
+
+    def test_do_render_template_fields_exception_logged_without_task_id(self, 
caplog):
+        """Test that rendering errors are logged with parent type name when no 
task_id."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["bad_field"])
+        parent.bad_field = "{{ undefined_var }}"
+
+        context = {}
+        jinja_env = SandboxedEnvironment(undefined=jinja2.StrictUndefined, 
cache_size=0)
+
+        with pytest.raises(jinja2.UndefinedError):
+            templater._do_render_template_fields(parent, ["bad_field"], 
context, jinja_env, set())
+
+        assert "Exception rendering Jinja template for MagicMock, field 
'bad_field'" in caplog.text
+
+    def test_do_render_template_fields_nested_template_fields(self):
+        """Test rendering nested objects that have their own 
template_fields."""
+        templater = Templater()
+        templater.template_ext = []
+
+        inner = NonCallableMagicMock(spec=["template_fields", "message"])
+        inner.template_fields = ["message"]
+        inner.message = "Hello {{ name }}"
+
+        parent = MagicMock(spec=["nested"])
+        parent.nested = inner
+
+        context = {"name": "world"}
+        jinja_env = templater.get_template_env()
+
+        templater._do_render_template_fields(parent, ["nested"], context, 
jinja_env, set())
+
+        assert inner.message == "Hello world"
+
+    def test_do_render_template_fields_seen_oids_prevents_reprocessing(self):
+        """Test that already-seen objects (by id) are not re-rendered."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["greeting"])
+        parent.greeting = "Hello {{ name }}"
+
+        context = {"name": "world"}
+        jinja_env = templater.get_template_env()
+
+        # Pre-populate seen_oids with the parent's greeting value id
+        seen_oids = {id(parent.greeting)}
+
+        templater._do_render_template_fields(parent, ["greeting"], context, 
jinja_env, seen_oids)
+
+        # The value should NOT be rendered because render_template checks
+        # `id(value) in seen_oids` and short-circuits, returning the original
+        # unrendered string.
+        assert parent.greeting == "Hello {{ name }}"
+
+    def test_do_render_template_fields_renders_dict_values(self):
+        """Test that dict field values have their inner templates rendered."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["params"])
+        parent.params = {"key": "{{ value }}"}
+
+        context = {"value": "rendered"}
+        jinja_env = templater.get_template_env()
+
+        templater._do_render_template_fields(parent, ["params"], context, 
jinja_env, set())
+
+        assert parent.params == {"key": "rendered"}
+
+    def test_do_render_template_fields_renders_list_values(self):
+        """Test that list field values have their inner templates rendered."""
+        templater = Templater()
+        templater.template_ext = []
+
+        parent = MagicMock(spec=["items"])
+        parent.items = ["{{ a }}", "{{ b }}"]
+
+        context = {"a": "first", "b": "second"}
+        jinja_env = templater.get_template_env()
+
+        templater._do_render_template_fields(parent, ["items"], context, 
jinja_env, set())
+
+        assert parent.items == ["first", "second"]
+
 
 @pytest.fixture
 def env():

Reply via email to