ashb commented on code in PR #55068:
URL: https://github.com/apache/airflow/pull/55068#discussion_r2693803870
##########
airflow-core/tests/unit/jobs/test_triggerer_job.py:
##########
Review Comment:
Given the size of the changes in trigger_job_runner.py I would have expected
to see more additions to this test file.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -26,3 +32,76 @@
#
# Tests for models-specific functionality (DBDagBag,
DagPriorityParsingRequest, etc.)
# would remain in this file, but currently no such tests exist.
+
+
+class TestDBDagBag:
Review Comment:
Does this not already exist sdomehwere? I'm surprised we don't have tests
for the DBTagBag already. cc @ephraimbuddy
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,27 @@ class DBDagBag:
"""
def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[str, SerializedDAG] = {} # dag_version_id to dag
+ self._dags: dict[str, SerializedDagModel] = {} # dag_version_id to dag
self.load_op_links = load_op_links
def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
serdag.load_op_links = self.load_op_links
if dag := serdag.dag:
- self._dags[serdag.dag_version_id] = dag
+ self._dags[serdag.dag_version_id] = serdag
return dag
- def _get_dag(self, version_id: str, session: Session) -> SerializedDAG |
None:
- if dag := self._dags.get(version_id):
- return dag
- dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
- if not dag_version:
- return None
- if not (serdag := dag_version.serialized_dag):
- return None
- return self._read_dag(serdag)
+ def get_dag_model(self, version_id: str, session: Session) ->
SerializedDagModel | None:
+ if not (serdag := self._dags.get(version_id)):
+ dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
+ if not dag_version or not (serdag := dag_version.serialized_dag):
+ return None
+ self._read_dag(serdag)
+ return serdag
Review Comment:
Something about this function isn't quite making sense in my head.
`self._dags` is now SerializedDagModel (i.e the SQLA ORM objects), but I no
longer understand what `_read_dag` is doing? Why do we need to eagerly
load/deserialize the serialized dag (which is I think what L63 is doing) when
we are asking for the dag model?
##########
airflow-core/src/airflow/serialization/definitions/mappedoperator.py:
##########
@@ -481,9 +481,9 @@ def expand_start_from_trigger(self, *, context: Context) ->
bool:
return False
# TODO (GH-52141): Implement this.
log.warning(
- "Starting a mapped task from triggerer is currently unsupported",
- task_id=self.task_id,
- dag_id=self.dag_id,
+ "Starting a mapped task '%s' from dag '%s' on triggerer is
currently unsupported",
Review Comment:
```suggestion
"Starting a mapped task %r from dag %r on triggerer is currently
unsupported",
```
##########
airflow-core/src/airflow/executors/workloads.py:
##########
@@ -203,6 +203,9 @@ class RunTrigger(BaseModel):
type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger")
+ dag_data: dict | None = None
Review Comment:
Hmmmmm. The entire serialized dag could potentially be quite large. I wonder
if we need everything, or if we could get by with "just" the one serialized
task here?
##########
providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py:
##########
Review Comment:
Please avoid combining core and provider changes in 1 PR unless it is 100%
required (it makes provider release harder when we do it)
##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1429,6 +1429,123 @@ def update_heartbeat(self):
.values(last_heartbeat_at=timezone.utcnow())
)
+ @property
+ def start_trigger_args(self) -> StartTriggerArgs | None:
+ if self.task and self.task.start_from_trigger is True:
+ return self.task.start_trigger_args
+ return None
+
+ # TODO: We have some code duplication here and in the
_create_ti_state_update_query_and_update_state
+ # method of the task_instances module in the execution api when a
TIDeferredStatePayload is being
+ # processed. This is because of a TaskInstance being updated
differently using SQLAlchemy.
+ # If we use the approach from the execution api as common code in
the DagRun schedule_tis method,
+ # the side effect is the changes done to the task instance aren't
picked up by the scheduler and
+ # thus the task instance isn't processed until the scheduler is
restarted.
+ @provide_session
+ def defer_task(self, session: Session = NEW_SESSION) -> bool:
Review Comment:
When is `ti.defer_task` ever called? (This class is not used in the
execution path)
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,27 @@ class DBDagBag:
"""
def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[str, SerializedDAG] = {} # dag_version_id to dag
+ self._dags: dict[str, SerializedDagModel] = {} # dag_version_id to dag
self.load_op_links = load_op_links
def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
serdag.load_op_links = self.load_op_links
if dag := serdag.dag:
- self._dags[serdag.dag_version_id] = dag
+ self._dags[serdag.dag_version_id] = serdag
return dag
- def _get_dag(self, version_id: str, session: Session) -> SerializedDAG |
None:
- if dag := self._dags.get(version_id):
- return dag
- dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
- if not dag_version:
- return None
- if not (serdag := dag_version.serialized_dag):
- return None
- return self._read_dag(serdag)
+ def get_dag_model(self, version_id: str, session: Session) ->
SerializedDagModel | None:
+ if not (serdag := self._dags.get(version_id)):
+ dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
+ if not dag_version or not (serdag := dag_version.serialized_dag):
+ return None
+ self._read_dag(serdag)
+ return serdag
+
+ def get_dag(self, version_id: str, session: Session) -> SerializedDAG |
None:
+ if serdag := self.get_dag_model(version_id=version_id,
session=session):
+ return self._read_dag(serdag)
Review Comment:
This is no longer a `serdag` -- so this variable name is wrong I'd say.
##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -1537,6 +1550,17 @@ def unmap(self, resolve: None | Mapping[str, Any]) ->
Self:
"""
return self
+ def expand_start_from_trigger(self, *, context: Context) -> bool:
Review Comment:
Do we need this function? Given the doc string I would have expected to see
a different implementation of this for mapped operators, but it seems to be
working without it, so maybe this fn isn't needed?
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -953,9 +984,19 @@ async def init_comms(self):
raise RuntimeError(f"Required first message to be a
messages.StartTriggerer, it was {msg}")
async def create_triggers(self):
+ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+ task =
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+ # I need to recreate a TaskInstance from task_runner before
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+ return RuntimeTaskInstance.model_construct(
+ **workload.ti.model_dump(exclude_unset=True),
+ task=task,
+ )
Review Comment:
Isn't `workload.ti` already an RuntimeTaskInstance? We do we need to create
a whole new one?
If we need a copy and can't just edit it in place then:
```suggestion
return workload.ti.model_copy(update={task=task})
```
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,27 @@ class DBDagBag:
"""
def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[str, SerializedDAG] = {} # dag_version_id to dag
+ self._dags: dict[str, SerializedDagModel] = {} # dag_version_id to dag
self.load_op_links = load_op_links
def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
serdag.load_op_links = self.load_op_links
if dag := serdag.dag:
- self._dags[serdag.dag_version_id] = dag
+ self._dags[serdag.dag_version_id] = serdag
return dag
- def _get_dag(self, version_id: str, session: Session) -> SerializedDAG |
None:
- if dag := self._dags.get(version_id):
- return dag
- dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
- if not dag_version:
- return None
- if not (serdag := dag_version.serialized_dag):
- return None
- return self._read_dag(serdag)
+ def get_dag_model(self, version_id: str, session: Session) ->
SerializedDagModel | None:
+ if not (serdag := self._dags.get(version_id)):
Review Comment:
Ditto here -- not a serdag
##########
airflow-core/src/airflow/triggers/base.py:
##########
@@ -32,11 +32,24 @@
model_serializer,
)
+from airflow.sdk.definitions._internal.templater import Templater
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState
log = structlog.get_logger(logger_name=__name__)
+if TYPE_CHECKING:
+ from typing import TypeAlias
+
+ import jinja2
+
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.models.taskinstance import TaskInstance
+ from airflow.sdk.definitions.context import Context
+ from airflow.serialization.serialized_objects import SerializedBaseOperator
+
+ Operator: TypeAlias = MappedOperator | SerializedBaseOperator
Review Comment:
We shouldn't need to define this ourself here -- such a definition already
exists, right @uranusjr?
##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -269,6 +276,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
+ """
+ Get the task instance for the current task.
+
+ :param session: Sqlalchemy session
+ """
+ if not self.task_instance:
Review Comment:
Isn't this going to fail on Airflow 2.x? Or maybe on Airflow 3.1? It feels
like it's going to fail sometime when it should still be supported.
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -627,6 +631,51 @@ def emit_metrics(self):
}
)
+ @provide_session
+ def create_workload(
+ self,
+ trigger: Trigger,
+ dag_bag: DBDagBag,
+ render_log_fname=log_filename_template_renderer(),
+ session: Session = NEW_SESSION,
+ ) -> workloads.RunTrigger | None:
+ if trigger.task_instance is None:
+ return workloads.RunTrigger(
+ id=trigger.id,
+ classpath=trigger.classpath,
+ encrypted_kwargs=trigger.encrypted_kwargs,
+ )
+
+ if not trigger.task_instance.dag_version_id:
+ # This is to handle 2 to 3 upgrade where TI.dag_version_id can be
none
+ log.warning(
+ "TaskInstance associated with Trigger has no associated Dag
Version, skipping the trigger",
+ ti_id=trigger.task_instance.id,
+ )
+ return None
Review Comment:
What state does this leave the TI in in this case? (I'm just worried about
it being permanently in limbo)
##########
airflow-core/src/airflow/triggers/base.py:
##########
@@ -32,11 +32,24 @@
model_serializer,
)
+from airflow.sdk.definitions._internal.templater import Templater
Review Comment:
cc @amoghrajesh I think we were removing these, right?
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -953,9 +984,19 @@ async def init_comms(self):
raise RuntimeError(f"Required first message to be a
messages.StartTriggerer, it was {msg}")
async def create_triggers(self):
+ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+ task =
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+ # I need to recreate a TaskInstance from task_runner before
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+ return RuntimeTaskInstance.model_construct(
+ **workload.ti.model_dump(exclude_unset=True),
+ task=task,
+ )
Review Comment:
Looking at this, I _think_ we could just set `workload.ti.task = task`
though, we don't need to create a new copy I don't think
##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -526,6 +526,10 @@ def apply_defaults(self: BaseOperator, *args: Any,
**kwargs: Any) -> Any:
# Store the args passed to init -- we need them to support
task.map serialization!
self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
+ # Validate trigger kwargs
+ if hasattr(self, "_validate_start_from_trigger_kwargs"):
Review Comment:
Given we define a `_validate_start_from_trigger_kwargs` on L1415 how can
this ever be False?
##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1429,6 +1429,123 @@ def update_heartbeat(self):
.values(last_heartbeat_at=timezone.utcnow())
)
+ @property
+ def start_trigger_args(self) -> StartTriggerArgs | None:
+ if self.task and self.task.start_from_trigger is True:
+ return self.task.start_trigger_args
+ return None
+
+ # TODO: We have some code duplication here and in the
_create_ti_state_update_query_and_update_state
+ # method of the task_instances module in the execution api when a
TIDeferredStatePayload is being
+ # processed. This is because of a TaskInstance being updated
differently using SQLAlchemy.
+ # If we use the approach from the execution api as common code in
the DagRun schedule_tis method,
+ # the side effect is the changes done to the task instance aren't
picked up by the scheduler and
+ # thus the task instance isn't processed until the scheduler is
restarted.
+ @provide_session
+ def defer_task(self, session: Session = NEW_SESSION) -> bool:
+ """
+ Mark the task as deferred and sets up the trigger that is needed to
resume it when TaskDeferred is raised.
+
+ :meta: private
+ """
+ from airflow.models.trigger import Trigger
+
+ if TYPE_CHECKING:
+ assert isinstance(self.task, Operator)
+
+ if start_trigger_args := self.start_trigger_args:
+ trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+ timeout = start_trigger_args.timeout
+
+ # Calculate timeout too if it was passed
+ if timeout is not None:
+ self.trigger_timeout = timezone.utcnow() + timeout
+ else:
+ self.trigger_timeout = None
+
+ trigger_row = Trigger(
+ classpath=start_trigger_args.trigger_cls,
+ kwargs=trigger_kwargs,
+ )
+
+ # First, make the trigger entry
+ session.add(trigger_row)
+ session.flush()
+
+ # Then, update ourselves so it matches the deferral request
+ # Keep an eye on the logic in
`check_and_change_state_before_execution()`
+ # depending on self.next_method semantics
+ self.state = TaskInstanceState.DEFERRED
+ self.trigger_id = trigger_row.id
+ self.next_method = start_trigger_args.next_method
+ self.next_kwargs = start_trigger_args.next_kwargs or {}
+
+ # If an execution_timeout is set, set the timeout to the minimum of
+ # it and the trigger timeout
+ if execution_timeout := self.task.execution_timeout:
+ if TYPE_CHECKING:
+ assert self.start_date
+ if self.trigger_timeout:
+ self.trigger_timeout = min(self.start_date +
execution_timeout, self.trigger_timeout)
+ else:
+ self.trigger_timeout = self.start_date + execution_timeout
+ self.start_date = timezone.utcnow()
+ if self.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+ self.try_number += 1
+ if self.test_mode:
+ _add_log(event=self.state, task_instance=self, session=session)
+ return True
+ return False
+
+ @provide_session
+ def run(
+ self,
+ verbose: bool = True,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ wait_for_past_depends_before_skipping: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ pool: str | None = None,
+ session: Session = NEW_SESSION,
+ raise_on_defer: bool = False,
+ ) -> None:
+ """Run TaskInstance (only kept for tests)."""
+ # This method is only used in ti.run and dag.test and task.test.
+ # So doing the s10n/de-s10n dance to operator on Serialized task for
the scheduler dep check part.
+ from airflow.serialization.definitions.dag import SerializedDAG
+ from airflow.serialization.serialized_objects import DagSerialization
+
+ original_task = self.task
+ if TYPE_CHECKING:
+ assert original_task is not None
+ assert original_task.dag is not None
+
+ # We don't set up all tests well...
+ if not isinstance(original_task.dag, SerializedDAG):
+ serialized_dag =
DagSerialization.from_dict(DagSerialization.to_dict(original_task.dag))
+ self.task = serialized_dag.get_task(original_task.task_id)
+
+ res = self.check_and_change_state_before_execution(
+ verbose=verbose,
+ ignore_all_deps=ignore_all_deps,
+ ignore_depends_on_past=ignore_depends_on_past,
+
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+ ignore_task_deps=ignore_task_deps,
+ ignore_ti_state=ignore_ti_state,
+ mark_success=mark_success,
+ test_mode=test_mode,
+ pool=pool,
+ session=session,
+ )
+ self.task = original_task
+ if not res:
+ return
+
+ self._run_raw_task(mark_success=mark_success)
Review Comment:
We shouldn't add this back. It was removed because model TIs are not
runnable.
##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -153,23 +156,27 @@ def get_task_instance(self, session: Session) ->
TaskInstance:
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
- task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
- dag_id=self.task_instance.dag_id,
- task_ids=[self.task_instance.task_id],
- run_ids=[self.task_instance.run_id],
- map_index=self.task_instance.map_index,
- )
- try:
- task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
- except Exception:
- raise AirflowException(
- "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ if not self.task_instance:
+ raise AirflowException(f"TaskInstance not set on
{self.__class__.__name__}!")
+
+ if not isinstance(self.task_instance, RuntimeTaskInstance):
Review Comment:
This is for for 3.0..3.1 right? Might be worth a comment saying when we
could hit this case.
##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -150,23 +153,27 @@ def get_task_instance(self, session: Session) ->
TaskInstance:
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
- task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
- dag_id=self.task_instance.dag_id,
- task_ids=[self.task_instance.task_id],
- run_ids=[self.task_instance.run_id],
- map_index=self.task_instance.map_index,
- )
- try:
- task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
- except Exception:
- raise AirflowException(
- "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ if not self.task_instance:
+ raise AirflowException(f"TaskInstance not set on
{self.__class__.__name__}!")
Review Comment:
Wei is correct. We don't want to use `AirflowExecption` for new use cases.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]