This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7b37a785d0 Run triggers inline with dag test (#34642)
7b37a785d0 is described below
commit 7b37a785d0b74d1e83c7ce84729febffd6e26821
Author: Daniel Standish <[email protected]>
AuthorDate: Mon Nov 27 06:48:17 2023 -0800
Run triggers inline with dag test (#34642)
No need to have trigger running -- will just run them async.
---
airflow/models/dag.py | 68 +++++++++++++---------------
airflow/models/taskinstance.py | 3 ++
tests/cli/commands/test_dag_command.py | 81 ++++++++++++++++++++--------------
tests/models/test_mappedoperator.py | 2 +-
4 files changed, 81 insertions(+), 73 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 26c83754a8..27e8258a6d 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -17,7 +17,8 @@
# under the License.
from __future__ import annotations
-import collections.abc
+import asyncio
+import collections
import copy
import functools
import itertools
@@ -82,11 +83,11 @@ from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
- AirflowSkipException,
DuplicateTaskIdFound,
FailStopDagInvalidTriggerRule,
ParamValidationError,
RemovedInAirflow3Warning,
+ TaskDeferred,
TaskNotFound,
)
from airflow.jobs.job import run_job
@@ -101,7 +102,6 @@ from airflow.models.taskinstance import (
Context,
TaskInstance,
TaskInstanceKey,
- TaskReturnCode,
clear_task_instances,
)
from airflow.secrets.local_filesystem import LocalFilesystemBackend
@@ -285,12 +285,11 @@ def get_dataset_triggered_next_run_info(
}
-class _StopDagTest(Exception):
- """
- Raise when DAG.test should stop immediately.
+def _triggerer_is_healthy():
+ from airflow.jobs.triggerer_job_runner import TriggererJobRunner
- :meta private:
- """
+ job = TriggererJobRunner.most_recent_job()
+ return job and job.is_alive()
@functools.total_ordering
@@ -2844,21 +2843,12 @@ class DAG(LoggingMixin):
if not scheduled_tis and ids_unrunnable:
self.log.warning("No tasks to run. unrunnable tasks: %s",
ids_unrunnable)
time.sleep(1)
+ triggerer_running = _triggerer_is_healthy()
for ti in scheduled_tis:
try:
add_logger_if_needed(ti)
ti.task = tasks[ti.task_id]
- ret = _run_task(ti, session=session)
- if ret is TaskReturnCode.DEFERRED:
- if not _triggerer_is_healthy():
- raise _StopDagTest(
- "Task has deferred but triggerer component is
not running. "
- "You can start the triggerer by running
`airflow triggerer` in a terminal."
- )
- except _StopDagTest:
- # Let this exception bubble out and not be swallowed by the
- # except block below.
- raise
+ _run_task(ti=ti, inline_trigger=not triggerer_running,
session=session)
except Exception:
self.log.exception("Task failed; ti=%s", ti)
if conn_file_path or variable_file_path:
@@ -3992,14 +3982,15 @@ class DagContext:
return None
-def _triggerer_is_healthy():
- from airflow.jobs.triggerer_job_runner import TriggererJobRunner
+def _run_trigger(trigger):
+ async def _run_trigger_main():
+ async for event in trigger.run():
+ return event
- job = TriggererJobRunner.most_recent_job()
- return job and job.is_alive()
+ return asyncio.run(_run_trigger_main())
-def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
+def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session:
Session):
"""
Run a single task instance, and push result to Xcom for downstream tasks.
@@ -4009,20 +4000,21 @@ def _run_task(ti: TaskInstance, session) ->
TaskReturnCode | None:
Args:
ti: TaskInstance to run
"""
- ret = None
- log.info("*****************************************************")
- if ti.map_index > 0:
- log.info("Running task %s index %d", ti.task_id, ti.map_index)
- else:
- log.info("Running task %s", ti.task_id)
- try:
- ret = ti._run_raw_task(session=session)
- session.flush()
- log.info("%s ran successfully!", ti.task_id)
- except AirflowSkipException:
- log.info("Task Skipped, continuing")
- log.info("*****************************************************")
- return ret
+ log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id,
ti.map_index)
+ while True:
+ try:
+ log.info("[DAG TEST] running task %s", ti)
+ ti._run_raw_task(session=session, raise_on_defer=inline_trigger)
+ break
+ except TaskDeferred as e:
+ log.info("[DAG TEST] running trigger in line")
+ event = _run_trigger(e.trigger)
+ ti.next_method = e.method_name
+ ti.next_kwargs = {"event": event.payload} if event else e.kwargs
+ log.info("[DAG TEST] Trigger completed")
+ session.merge(ti)
+ session.commit()
+ log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id,
ti.map_index)
def _get_or_create_dagrun(
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 95a2f5945f..f041dcf208 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2284,6 +2284,7 @@ class TaskInstance(Base, LoggingMixin):
test_mode: bool = False,
job_id: str | None = None,
pool: str | None = None,
+ raise_on_defer: bool = False,
session: Session = NEW_SESSION,
) -> TaskReturnCode | None:
"""
@@ -2338,6 +2339,8 @@ class TaskInstance(Base, LoggingMixin):
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
+ if raise_on_defer:
+ raise
self._defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s,
execution_date=%s, start_date=%s",
diff --git a/tests/cli/commands/test_dag_command.py
b/tests/cli/commands/test_dag_command.py
index 78b7fd4525..30b5c475ea 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -37,9 +37,10 @@ from airflow.decorators import task
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import _StopDagTest
+from airflow.models.dag import _run_trigger
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.triggers.temporal import TimeDeltaTrigger
+from airflow.triggers.base import TriggerEvent
+from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
@@ -824,35 +825,47 @@ class TestCliDags:
dag_command.dag_test(cli_args)
assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs
- def test_dag_test_no_triggerer(self, dag_maker):
- with dag_maker() as dag:
-
- @task
- def one():
- return 1
-
- @task
- def two(val):
- return val + 1
-
- class MyOp(BaseOperator):
- template_fields = ("tfield",)
-
- def __init__(self, tfield, **kwargs):
- self.tfield = tfield
- super().__init__(**kwargs)
-
- def execute(self, context, event=None):
- if event is None:
- print("I AM DEFERRING")
-
self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)),
method_name="execute")
- return
- print("RESUMING")
- return self.tfield + 1
-
- task_one = one()
- task_two = two(task_one)
- op = MyOp(task_id="abc", tfield=str(task_two))
- task_two >> op
- with pytest.raises(_StopDagTest, match="Task has deferred but
triggerer component is not running"):
- dag.test()
+ def test_dag_test_run_trigger(self, dag_maker):
+ now = timezone.utcnow()
+ trigger = DateTimeTrigger(moment=now)
+ e = _run_trigger(trigger)
+ assert isinstance(e, TriggerEvent)
+ assert e.payload == now
+
+ def test_dag_test_no_triggerer_running(self, dag_maker):
+ with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger)
as mock_run:
+ with dag_maker() as dag:
+
+ @task
+ def one():
+ return 1
+
+ @task
+ def two(val):
+ return val + 1
+
+ trigger = TimeDeltaTrigger(timedelta(seconds=0))
+
+ class MyOp(BaseOperator):
+ template_fields = ("tfield",)
+
+ def __init__(self, tfield, **kwargs):
+ self.tfield = tfield
+ super().__init__(**kwargs)
+
+ def execute(self, context, event=None):
+ if event is None:
+ print("I AM DEFERRING")
+ self.defer(trigger=trigger, method_name="execute")
+ return
+ print("RESUMING")
+ return self.tfield + 1
+
+ task_one = one()
+ task_two = two(task_one)
+ op = MyOp(task_id="abc", tfield=task_two)
+ task_two >> op
+ dr = dag.test()
+ assert mock_run.call_args_list[0] == ((trigger,), {})
+ tis = dr.get_task_instances()
+ assert [x for x in tis if x.task_id == "abc"][0].state == "success"
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index 7244c55774..78f0a0d271 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -95,7 +95,7 @@ def
test_task_mapping_with_dag_and_list_of_pandas_dataframe(mock_render_template
mapped =
CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values)
task1 >> mapped
dag.test()
- assert caplog.text.count("task_2 ran successfully") == 2
+ assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2
assert (
"Unable to check if the value of type 'UnrenderableClass' is False for
task 'task_2', field 'arg'"
in caplog.text