shahar1 commented on code in PR #55068:
URL: https://github.com/apache/airflow/pull/55068#discussion_r2772514215


##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1443,6 +1443,123 @@ def update_heartbeat(self):
                 .values(last_heartbeat_at=timezone.utcnow())
             )
 
+    @property
+    def start_trigger_args(self) -> StartTriggerArgs | None:
+        if self.task and self.task.start_from_trigger is True:
+            return self.task.start_trigger_args
+        return None
+
+    # TODO: We have some code duplication here and in the 
_create_ti_state_update_query_and_update_state
+    #       method of the task_instances module in the execution api when a 
TIDeferredStatePayload is being
+    #       processed. This is because of a TaskInstance being updated 
differently using SQLAlchemy.
+    #       If we use the approach from the execution api as common code in 
the DagRun schedule_tis method,
+    #       the side effect is the changes done to the task instance aren't 
picked up by the scheduler and
+    #       thus the task instance isn't processed until the scheduler is 
restarted.
+    @provide_session
+    def defer_task(self, session: Session = NEW_SESSION) -> bool:
+        """
+        Mark the task as deferred and sets up the trigger that is needed to 
resume it when TaskDeferred is raised.
+
+        :meta: private
+        """
+        from airflow.models.trigger import Trigger
+
+        if TYPE_CHECKING:
+            assert isinstance(self.task, Operator)
+
+        if start_trigger_args := self.start_trigger_args:
+            trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+            timeout = start_trigger_args.timeout
+
+            # Calculate timeout too if it was passed
+            if timeout is not None:
+                self.trigger_timeout = timezone.utcnow() + timeout
+            else:
+                self.trigger_timeout = None
+
+            trigger_row = Trigger(
+                classpath=start_trigger_args.trigger_cls,
+                kwargs=trigger_kwargs,
+            )
+
+            # First, make the trigger entry
+            session.add(trigger_row)
+            session.flush()
+
+            # Then, update ourselves so it matches the deferral request
+            # Keep an eye on the logic in 
`check_and_change_state_before_execution()`
+            # depending on self.next_method semantics
+            self.state = TaskInstanceState.DEFERRED
+            self.trigger_id = trigger_row.id
+            self.next_method = start_trigger_args.next_method
+            self.next_kwargs = start_trigger_args.next_kwargs or {}
+
+            # If an execution_timeout is set, set the timeout to the minimum of
+            # it and the trigger timeout
+            if execution_timeout := self.task.execution_timeout:
+                if TYPE_CHECKING:
+                    assert self.start_date

Review Comment:
   1. Why is it safe to assume at this point that `self.start_date` is not 
`None`? 
   2. Why do we need to overwrite it in L1506?



##########
airflow-core/src/airflow/triggers/base.py:
##########
@@ -64,14 +77,67 @@ class BaseTrigger(abc.ABC, LoggingMixin):
     """
 
     def __init__(self, **kwargs):
+        super().__init__()
         # these values are set by triggerer when preparing to run the instance
         # when run, they are injected into logger record.
-        self.task_instance = None
+        self._task_instance = None
         self.trigger_id = None
+        self.template_fields = ()
+        self.template_ext = ()
 
     def _set_context(self, context):
         """Part of LoggingMixin and used mainly for configuration of task 
logging; not used for triggers."""
-        raise NotImplementedError
+        pass
+
+    @property
+    def task(self) -> Operator | None:
+        if self.task_instance:
+            return self.task_instance.task
+        return None
+
+    @property
+    def task_instance(self) -> TaskInstance:

Review Comment:
   ```suggestion
       def task_instance(self) -> TaskInstance | None:
   ```



##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -986,24 +1027,28 @@ async def create_triggers(self):
                 # that could cause None values in collections.
                 kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs)
                 deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for 
k, v in kw.items()}
-                trigger_instance = trigger_class(**deserialised_kwargs)
+
+                if ti := workload.ti:
+                    runtime_ti = create_runtime_ti(workload.dag_data)

Review Comment:
   If `workload.dag_data` is `None` (L681), won't we get a `TypeError`?



##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,27 @@ class DBDagBag:
     """
 
     def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+        self._dags: dict[str, SerializedDagModel] = {}  # dag_version_id to dag
         self.load_op_links = load_op_links
 
-    def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
-        serdag.load_op_links = self.load_op_links
-        if dag := serdag.dag:
-            self._dags[serdag.dag_version_id] = dag
+    def _read_dag(self, serialized_dag_model: SerializedDagModel) -> 
SerializedDAG | None:
+        serialized_dag_model.load_op_links = self.load_op_links
+        if dag := serialized_dag_model.dag:
+            self._dags[serialized_dag_model.dag_version_id] = 
serialized_dag_model
         return dag
 
-    def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
-        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
-        if not dag_version:
-            return None
-        if not (serdag := dag_version.serialized_dag):
-            return None
-        return self._read_dag(serdag)
+    def get_dag_model(self, version_id: str, session: Session) -> 
SerializedDagModel | None:
+        if not (serialized_dag_model := self._dags.get(version_id)):
+            dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
+            if not dag_version or not (serialized_dag_model := 
dag_version.serialized_dag):
+                return None
+        self._read_dag(serialized_dag_model)
+        return serialized_dag_model
+
+    def get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
+        if serialized_dag_model := self.get_dag_model(version_id=version_id, 
session=session):
+            return self._read_dag(serialized_dag_model)

Review Comment:
   You call `self._read_dag` twice - once inside `self.get_dag_model` and 
second time here.
   Could this be optimized?



##########
task-sdk/src/airflow/sdk/definitions/mappedoperator.py:
##########
@@ -226,6 +226,13 @@ def _expand(self, expand_input: ExpandInput, *, strict: 
bool) -> MappedOperator:
         task_group = partial_kwargs.pop("task_group")
         start_date = partial_kwargs.pop("start_date", None)
         end_date = partial_kwargs.pop("end_date", None)
+        start_from_trigger = bool(
+            partial_kwargs.get("start_from_trigger", False)
+            or getattr(self.operator_class, "start_from_trigger", False)
+        )
+        start_trigger_args = partial_kwargs.get("start_trigger_args", None) or 
getattr(
+            self.operator_class, "start_trigger_args", None
+        )

Review Comment:
   The `or` operator evaluates to the first truthy value, so if the operator 
class's default is `True`, and you want to override it to `False` with 
`partial_kwargs` - you won't be able to do it according to current logic.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to