ashb commented on code in PR #55068:
URL: https://github.com/apache/airflow/pull/55068#discussion_r2693803870


##########
airflow-core/tests/unit/jobs/test_triggerer_job.py:
##########


Review Comment:
   Given the size of the changes in trigger_job_runner.py I would have expected 
to see more additions to this test file.



##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -26,3 +32,76 @@
 #
 # Tests for models-specific functionality (DBDagBag, 
DagPriorityParsingRequest, etc.)
 # would remain in this file, but currently no such tests exist.
+
+
+class TestDBDagBag:

Review Comment:
   Does this not already exist sdomehwere? I'm surprised we don't have tests 
for the DBTagBag already. cc @ephraimbuddy 



##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,27 @@ class DBDagBag:
     """
 
     def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+        self._dags: dict[str, SerializedDagModel] = {}  # dag_version_id to dag
         self.load_op_links = load_op_links
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
         serdag.load_op_links = self.load_op_links
         if dag := serdag.dag:
-            self._dags[serdag.dag_version_id] = dag
+            self._dags[serdag.dag_version_id] = serdag
         return dag
 
-    def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
-        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
-        if not dag_version:
-            return None
-        if not (serdag := dag_version.serialized_dag):
-            return None
-        return self._read_dag(serdag)
+    def get_dag_model(self, version_id: str, session: Session) -> 
SerializedDagModel | None:
+        if not (serdag := self._dags.get(version_id)):
+            dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
+            if not dag_version or not (serdag := dag_version.serialized_dag):
+                return None
+        self._read_dag(serdag)
+        return serdag

Review Comment:
   Something about this function isn't quite making sense in my head.
   
   `self._dags` is now SerializedDagModel  (i.e the SQLA ORM objects), but I no 
longer understand what `_read_dag` is doing? Why do we need to eagerly 
load/deserialize the serialized dag (which is I think what L63 is doing) when 
we are asking for the dag model?



##########
airflow-core/src/airflow/serialization/definitions/mappedoperator.py:
##########
@@ -481,9 +481,9 @@ def expand_start_from_trigger(self, *, context: Context) -> 
bool:
             return False
         # TODO (GH-52141): Implement this.
         log.warning(
-            "Starting a mapped task from triggerer is currently unsupported",
-            task_id=self.task_id,
-            dag_id=self.dag_id,
+            "Starting a mapped task '%s' from dag '%s' on triggerer is 
currently unsupported",

Review Comment:
   ```suggestion
               "Starting a mapped task %r from dag %r on triggerer is currently 
unsupported",
   ```



##########
airflow-core/src/airflow/executors/workloads.py:
##########
@@ -203,6 +203,9 @@ class RunTrigger(BaseModel):
 
     type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger")
 
+    dag_data: dict | None = None

Review Comment:
   Hmmmmm. The entire serialized dag could potentially be quite large. I wonder 
if we need everything, or if we could get by with "just" the one serialized 
task here?



##########
providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py:
##########


Review Comment:
   Please avoid combining core and provider changes in 1 PR unless it is 100% 
required (it makes provider release harder when we do it)



##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1429,6 +1429,123 @@ def update_heartbeat(self):
                 .values(last_heartbeat_at=timezone.utcnow())
             )
 
+    @property
+    def start_trigger_args(self) -> StartTriggerArgs | None:
+        if self.task and self.task.start_from_trigger is True:
+            return self.task.start_trigger_args
+        return None
+
+    # TODO: We have some code duplication here and in the 
_create_ti_state_update_query_and_update_state
+    #       method of the task_instances module in the execution api when a 
TIDeferredStatePayload is being
+    #       processed. This is because of a TaskInstance being updated 
differently using SQLAlchemy.
+    #       If we use the approach from the execution api as common code in 
the DagRun schedule_tis method,
+    #       the side effect is the changes done to the task instance aren't 
picked up by the scheduler and
+    #       thus the task instance isn't processed until the scheduler is 
restarted.
+    @provide_session
+    def defer_task(self, session: Session = NEW_SESSION) -> bool:

Review Comment:
   When is `ti.defer_task` ever called? (This class is not used in the 
execution path)



##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,27 @@ class DBDagBag:
     """
 
     def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+        self._dags: dict[str, SerializedDagModel] = {}  # dag_version_id to dag
         self.load_op_links = load_op_links
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
         serdag.load_op_links = self.load_op_links
         if dag := serdag.dag:
-            self._dags[serdag.dag_version_id] = dag
+            self._dags[serdag.dag_version_id] = serdag
         return dag
 
-    def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
-        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
-        if not dag_version:
-            return None
-        if not (serdag := dag_version.serialized_dag):
-            return None
-        return self._read_dag(serdag)
+    def get_dag_model(self, version_id: str, session: Session) -> 
SerializedDagModel | None:
+        if not (serdag := self._dags.get(version_id)):
+            dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
+            if not dag_version or not (serdag := dag_version.serialized_dag):
+                return None
+        self._read_dag(serdag)
+        return serdag
+
+    def get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
+        if serdag := self.get_dag_model(version_id=version_id, 
session=session):
+            return self._read_dag(serdag)

Review Comment:
   This is no longer a `serdag` -- so this variable name is wrong I'd say.



##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -1537,6 +1550,17 @@ def unmap(self, resolve: None | Mapping[str, Any]) -> 
Self:
         """
         return self
 
+    def expand_start_from_trigger(self, *, context: Context) -> bool:

Review Comment:
   Do we need this function? Given the doc string I would have expected to see 
a different implementation of this for mapped operators, but it seems to be 
working without it, so maybe this fn isn't needed?



##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -953,9 +984,19 @@ async def init_comms(self):
             raise RuntimeError(f"Required first message to be a 
messages.StartTriggerer, it was {msg}")
 
     async def create_triggers(self):
+        def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+            task = 
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+            # I need to recreate a TaskInstance from task_runner before 
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+            return RuntimeTaskInstance.model_construct(
+                **workload.ti.model_dump(exclude_unset=True),
+                task=task,
+            )

Review Comment:
   Isn't `workload.ti` already an RuntimeTaskInstance? We do we need to create 
a whole new one?
   
   If we need a copy and can't just edit it in place then: 
   ```suggestion
               return workload.ti.model_copy(update={task=task})
   ```



##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,27 @@ class DBDagBag:
     """
 
     def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+        self._dags: dict[str, SerializedDagModel] = {}  # dag_version_id to dag
         self.load_op_links = load_op_links
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
         serdag.load_op_links = self.load_op_links
         if dag := serdag.dag:
-            self._dags[serdag.dag_version_id] = dag
+            self._dags[serdag.dag_version_id] = serdag
         return dag
 
-    def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
-        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
-        if not dag_version:
-            return None
-        if not (serdag := dag_version.serialized_dag):
-            return None
-        return self._read_dag(serdag)
+    def get_dag_model(self, version_id: str, session: Session) -> 
SerializedDagModel | None:
+        if not (serdag := self._dags.get(version_id)):

Review Comment:
   Ditto here -- not a serdag



##########
airflow-core/src/airflow/triggers/base.py:
##########
@@ -32,11 +32,24 @@
     model_serializer,
 )
 
+from airflow.sdk.definitions._internal.templater import Templater
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import TaskInstanceState
 
 log = structlog.get_logger(logger_name=__name__)
 
+if TYPE_CHECKING:
+    from typing import TypeAlias
+
+    import jinja2
+
+    from airflow.models.mappedoperator import MappedOperator
+    from airflow.models.taskinstance import TaskInstance
+    from airflow.sdk.definitions.context import Context
+    from airflow.serialization.serialized_objects import SerializedBaseOperator
+
+    Operator: TypeAlias = MappedOperator | SerializedBaseOperator

Review Comment:
   We shouldn't need to define this ourself here -- such a definition already 
exists, right @uranusjr?



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -269,6 +276,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
 
         @provide_session
         def get_task_instance(self, session: Session) -> TaskInstance:
+            """
+            Get the task instance for the current task.
+
+            :param session: Sqlalchemy session
+            """
+            if not self.task_instance:

Review Comment:
   Isn't this going to fail on Airflow 2.x? Or maybe on Airflow 3.1? It feels 
like it's going to fail sometime when it should still be supported.



##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -627,6 +631,51 @@ def emit_metrics(self):
             }
         )
 
+    @provide_session
+    def create_workload(
+        self,
+        trigger: Trigger,
+        dag_bag: DBDagBag,
+        render_log_fname=log_filename_template_renderer(),
+        session: Session = NEW_SESSION,
+    ) -> workloads.RunTrigger | None:
+        if trigger.task_instance is None:
+            return workloads.RunTrigger(
+                id=trigger.id,
+                classpath=trigger.classpath,
+                encrypted_kwargs=trigger.encrypted_kwargs,
+            )
+
+        if not trigger.task_instance.dag_version_id:
+            # This is to handle 2 to 3 upgrade where TI.dag_version_id can be 
none
+            log.warning(
+                "TaskInstance associated with Trigger has no associated Dag 
Version, skipping the trigger",
+                ti_id=trigger.task_instance.id,
+            )
+            return None

Review Comment:
   What state does this leave the TI in in this case? (I'm just worried about 
it being permanently in limbo)



##########
airflow-core/src/airflow/triggers/base.py:
##########
@@ -32,11 +32,24 @@
     model_serializer,
 )
 
+from airflow.sdk.definitions._internal.templater import Templater

Review Comment:
   cc @amoghrajesh I think we were removing these, right?



##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -953,9 +984,19 @@ async def init_comms(self):
             raise RuntimeError(f"Required first message to be a 
messages.StartTriggerer, it was {msg}")
 
     async def create_triggers(self):
+        def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+            task = 
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+            # I need to recreate a TaskInstance from task_runner before 
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+            return RuntimeTaskInstance.model_construct(
+                **workload.ti.model_dump(exclude_unset=True),
+                task=task,
+            )

Review Comment:
   Looking at this, I _think_ we could just set `workload.ti.task = task` 
though, we don't need to create a new copy I don't think



##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -526,6 +526,10 @@ def apply_defaults(self: BaseOperator, *args: Any, 
**kwargs: Any) -> Any:
             # Store the args passed to init -- we need them to support 
task.map serialization!
             self._BaseOperator__init_kwargs.update(kwargs)  # type: ignore
 
+            # Validate trigger kwargs
+            if hasattr(self, "_validate_start_from_trigger_kwargs"):

Review Comment:
   Given we define a `_validate_start_from_trigger_kwargs` on L1415 how can 
this ever be False?



##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1429,6 +1429,123 @@ def update_heartbeat(self):
                 .values(last_heartbeat_at=timezone.utcnow())
             )
 
+    @property
+    def start_trigger_args(self) -> StartTriggerArgs | None:
+        if self.task and self.task.start_from_trigger is True:
+            return self.task.start_trigger_args
+        return None
+
+    # TODO: We have some code duplication here and in the 
_create_ti_state_update_query_and_update_state
+    #       method of the task_instances module in the execution api when a 
TIDeferredStatePayload is being
+    #       processed. This is because of a TaskInstance being updated 
differently using SQLAlchemy.
+    #       If we use the approach from the execution api as common code in 
the DagRun schedule_tis method,
+    #       the side effect is the changes done to the task instance aren't 
picked up by the scheduler and
+    #       thus the task instance isn't processed until the scheduler is 
restarted.
+    @provide_session
+    def defer_task(self, session: Session = NEW_SESSION) -> bool:
+        """
+        Mark the task as deferred and sets up the trigger that is needed to 
resume it when TaskDeferred is raised.
+
+        :meta: private
+        """
+        from airflow.models.trigger import Trigger
+
+        if TYPE_CHECKING:
+            assert isinstance(self.task, Operator)
+
+        if start_trigger_args := self.start_trigger_args:
+            trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+            timeout = start_trigger_args.timeout
+
+            # Calculate timeout too if it was passed
+            if timeout is not None:
+                self.trigger_timeout = timezone.utcnow() + timeout
+            else:
+                self.trigger_timeout = None
+
+            trigger_row = Trigger(
+                classpath=start_trigger_args.trigger_cls,
+                kwargs=trigger_kwargs,
+            )
+
+            # First, make the trigger entry
+            session.add(trigger_row)
+            session.flush()
+
+            # Then, update ourselves so it matches the deferral request
+            # Keep an eye on the logic in 
`check_and_change_state_before_execution()`
+            # depending on self.next_method semantics
+            self.state = TaskInstanceState.DEFERRED
+            self.trigger_id = trigger_row.id
+            self.next_method = start_trigger_args.next_method
+            self.next_kwargs = start_trigger_args.next_kwargs or {}
+
+            # If an execution_timeout is set, set the timeout to the minimum of
+            # it and the trigger timeout
+            if execution_timeout := self.task.execution_timeout:
+                if TYPE_CHECKING:
+                    assert self.start_date
+                if self.trigger_timeout:
+                    self.trigger_timeout = min(self.start_date + 
execution_timeout, self.trigger_timeout)
+                else:
+                    self.trigger_timeout = self.start_date + execution_timeout
+            self.start_date = timezone.utcnow()
+            if self.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+                self.try_number += 1
+            if self.test_mode:
+                _add_log(event=self.state, task_instance=self, session=session)
+            return True
+        return False
+
+    @provide_session
+    def run(
+        self,
+        verbose: bool = True,
+        ignore_all_deps: bool = False,
+        ignore_depends_on_past: bool = False,
+        wait_for_past_depends_before_skipping: bool = False,
+        ignore_task_deps: bool = False,
+        ignore_ti_state: bool = False,
+        mark_success: bool = False,
+        test_mode: bool = False,
+        pool: str | None = None,
+        session: Session = NEW_SESSION,
+        raise_on_defer: bool = False,
+    ) -> None:
+        """Run TaskInstance (only kept for tests)."""
+        # This method is only used in ti.run and dag.test and task.test.
+        # So doing the s10n/de-s10n dance to operator on Serialized task for 
the scheduler dep check part.
+        from airflow.serialization.definitions.dag import SerializedDAG
+        from airflow.serialization.serialized_objects import DagSerialization
+
+        original_task = self.task
+        if TYPE_CHECKING:
+            assert original_task is not None
+            assert original_task.dag is not None
+
+        # We don't set up all tests well...
+        if not isinstance(original_task.dag, SerializedDAG):
+            serialized_dag = 
DagSerialization.from_dict(DagSerialization.to_dict(original_task.dag))
+            self.task = serialized_dag.get_task(original_task.task_id)
+
+        res = self.check_and_change_state_before_execution(
+            verbose=verbose,
+            ignore_all_deps=ignore_all_deps,
+            ignore_depends_on_past=ignore_depends_on_past,
+            
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+            ignore_task_deps=ignore_task_deps,
+            ignore_ti_state=ignore_ti_state,
+            mark_success=mark_success,
+            test_mode=test_mode,
+            pool=pool,
+            session=session,
+        )
+        self.task = original_task
+        if not res:
+            return
+
+        self._run_raw_task(mark_success=mark_success)

Review Comment:
   We shouldn't add this back. It was removed because model TIs are not 
runnable.



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -153,23 +156,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")
+
+        if not isinstance(self.task_instance, RuntimeTaskInstance):

Review Comment:
   This is for for 3.0..3.1 right? Might be worth a comment saying when we 
could hit this case.



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -150,23 +153,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")

Review Comment:
   Wei is correct. We don't want to use `AirflowExecption` for new use cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to