amoghrajesh commented on code in PR #50300:
URL: https://github.com/apache/airflow/pull/50300#discussion_r2081377947
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -1316,6 +1316,29 @@ def notify_dagrun_state_changed(self, msg: str = ""):
# we can't get all the state changes on SchedulerJob,
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
+ def handle_dag_callback(self, dag, success=True, reason="success"):
Review Comment:
Can we add typing here?
##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -70,7 +70,7 @@
ti_id_router = VersionedAPIRouter(
dependencies=[
# This checks that the UUID in the url matches the one in the token
for us.
Review Comment:
Lets remove this
##########
task-sdk/src/airflow/sdk/definitions/dag.py:
##########
@@ -1014,6 +1016,281 @@ def _validate_owner_links(self, _, owner_links):
f"Bad formatted links are: {wrong_links}"
)
+ def test(
+ self,
+ run_after: datetime | None = None,
+ logical_date: datetime | None = None,
+ run_conf: dict[str, Any] | None = None,
+ conn_file_path: str | None = None,
+ variable_file_path: str | None = None,
+ use_executor: bool = False,
+ mark_success_pattern: Pattern | str | None = None,
+ ):
+ """
+ Execute one single DagRun for a given DAG and logical date.
+
+ :param run_after: the datetime before which to Dag cannot run.
+ :param logical_date: logical date for the DAG run
+ :param run_conf: configuration to pass to newly created dagrun
+ :param conn_file_path: file path to a connection file in either yaml
or json
+ :param variable_file_path: file path to a variable file in either yaml
or json
+ :param use_executor: if set, uses an executor to test the DAG
+ :param mark_success_pattern: regex of task_ids to mark as success
instead of running
+ """
+ import re
+ import time
+ from contextlib import ExitStack
+
+ from airflow import settings
+ from airflow.configuration import secrets_backend_list
+ from airflow.models.dag import DAG as SchedulerDAG,
_get_or_create_dagrun
+ from airflow.models.dagrun import DagRun
+ from airflow.secrets.local_filesystem import LocalFilesystemBackend
+ from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.utils import timezone
+ from airflow.utils.state import DagRunState, State, TaskInstanceState
+ from airflow.utils.types import DagRunTriggeredByType, DagRunType
+
+ if TYPE_CHECKING:
+ from airflow.models.taskinstance import TaskInstance
+
+ def add_logger_if_needed(ti: TaskInstance):
+ """
+ Add a formatted logger to the task instance.
+
+ This allows all logs to surface to the command line, instead of
into
+ a task file. Since this is a local test run, it is much better for
+ the user to see logs in the command line, rather than needing to
+ search for a log file.
+
+ :param ti: The task instance that will receive a logger.
+ """
+ format = logging.Formatter("[%(asctime)s]
{%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
+ handler = logging.StreamHandler(sys.stdout)
+ handler.level = logging.INFO
+ handler.setFormatter(format)
+ # only add log handler once
+ if not any(isinstance(h, logging.StreamHandler) for h in
ti.log.handlers):
+ log.debug("Adding Streamhandler to taskinstance %s",
ti.task_id)
+ ti.log.addHandler(handler)
+
+ exit_stack = ExitStack()
+
+ if conn_file_path or variable_file_path:
+ local_secrets = LocalFilesystemBackend(
+ variables_file_path=variable_file_path,
connections_file_path=conn_file_path
+ )
+ secrets_backend_list.insert(0, local_secrets)
+ exit_stack.callback(lambda: secrets_backend_list.pop(0))
+
+ session = settings.Session()
+
+ with exit_stack:
+ self.validate()
+ log.debug("Clearing existing task instances for logical date %s",
logical_date)
+ # TODO: Replace with calling client.dag_run.clear in Execution API
at some point
+ SchedulerDAG.clear_dags(
+ dags=[self],
+ start_date=logical_date,
+ end_date=logical_date,
+ dag_run_state=False, # type: ignore
+ )
+
+ log.debug("Getting dagrun for dag %s", self.dag_id)
+ logical_date = timezone.coerce_datetime(logical_date)
+ run_after = timezone.coerce_datetime(run_after) or
timezone.coerce_datetime(timezone.utcnow())
+ data_interval = (
+
self.timetable.infer_manual_data_interval(run_after=logical_date) if
logical_date else None
+ )
+ scheduler_dag =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self)) # type:
ignore[arg-type]
+
+ dr: DagRun = _get_or_create_dagrun(
+ dag=scheduler_dag,
+ start_date=logical_date or run_after,
+ logical_date=logical_date,
+ data_interval=data_interval,
+ run_after=run_after,
+ run_id=DagRun.generate_run_id(
+ run_type=DagRunType.MANUAL,
+ logical_date=logical_date,
+ run_after=run_after,
+ ),
+ session=session,
+ conf=run_conf,
+ triggered_by=DagRunTriggeredByType.TEST,
+ )
+ # Start a mock span so that one is present and not started
downstream. We
+ # don't care about otel in dag.test and starting the span during
dagrun update
+ # is not functioning properly in this context anyway.
+ dr.start_dr_spans_if_needed(tis=[])
+ dr.dag = self # type: ignore[assignment]
+
+ tasks = self.task_dict
+ log.debug("starting dagrun")
+ # Instead of starting a scheduler, we run the minimal loop
possible to check
+ # for task readiness and dependency management.
+ # Instead of starting a scheduler, we run the minimal loop
possible to check
+ # for task readiness and dependency management.
+
+ # ``Dag.test()`` works in two different modes depending on
``use_executor``:
+ # - if ``use_executor`` is False, runs the task locally with no
executor using ``_run_task``
+ # - if ``use_executor`` is True, sends the task instances to the
executor with
+ # ``BaseExecutor.queue_task_instance``
+ if use_executor:
+ from pathlib import Path
+
+ from airflow.executors.base_executor import ExecutorLoader
+ from airflow.executors.workloads import BundleInfo
+
+ executor = ExecutorLoader.get_default_executor()
+ executor.start()
+
+ while dr.state == DagRunState.RUNNING:
+ session.expire_all()
+ schedulable_tis, _ = dr.update_state(session=session)
+ for s in schedulable_tis:
+ if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+ s.try_number += 1
+ s.state = TaskInstanceState.SCHEDULED
+ s.scheduled_dttm = timezone.utcnow()
+ session.commit()
+ # triggerer may mark tasks scheduled so we read from DB
+ all_tis = set(dr.get_task_instances(session=session))
+ scheduled_tis = {x for x in all_tis if x.state ==
TaskInstanceState.SCHEDULED}
+ ids_unrunnable = {x for x in all_tis if x.state not in
State.finished} - scheduled_tis
+ if not scheduled_tis and ids_unrunnable:
+ log.warning("No tasks to run. unrunnable tasks: %s",
ids_unrunnable)
+ time.sleep(1)
+
+ for ti in scheduled_tis:
+ ti.task = tasks[ti.task_id]
+
+ mark_success = (
+ re.compile(mark_success_pattern).fullmatch(ti.task_id)
is not None
+ if mark_success_pattern is not None
+ else False
+ )
+
+ if use_executor:
+ if executor.has_task(ti):
+ continue
+ # TODO: Task-SDK: This check is transitionary. Remove
once all executors are ported over.
+ from airflow.executors import workloads
+ from airflow.executors.base_executor import
BaseExecutor
+
+ if executor.queue_workload.__func__ is not
BaseExecutor.queue_workload: # type: ignore[attr-defined]
+ workload = workloads.ExecuteTask.make(
+ ti,
+ dag_rel_path=Path(self.fileloc),
+ generator=executor.jwt_generator,
+ # For the system test/debug purpose, we use
the default bundle which uses
+ # local file system. If it turns out to be a
feature people want, we could
+ # plumb the Bundle to use as a parameter to
dag.test
+ bundle_info=BundleInfo(name="dags-folder"),
+ )
+ executor.queue_workload(workload, session=session)
+ ti.state = TaskInstanceState.QUEUED
+ session.commit()
+ else:
+ # Send the task to the executor
+ executor.queue_task_instance(ti,
ignore_ti_state=True)
Review Comment:
Task SDK is not supposed to run with AF2 right? Do we need this check?
##########
airflow-core/src/airflow/models/dag.py:
##########
@@ -777,89 +766,6 @@ def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"}
- @staticmethod
- @provide_session
- def fetch_callback(
- dag: DAG,
- run_id: str,
- success: bool = True,
- reason: str | None = None,
- *,
- session: Session = NEW_SESSION,
- ) -> tuple[list[TaskStateChangeCallback], Context] | None:
- """
- Fetch the appropriate callbacks depending on the value of success.
-
- This method gets the context of a single TaskInstance part of this
DagRun and returns it along
- the list of callbacks.
-
- :param dag: DAG object
- :param run_id: The DAG run ID
- :param success: Flag to specify if failure or success callback should
be called
- :param reason: Completion reason
- :param session: Database session
- """
- callbacks = dag.on_success_callback if success else
dag.on_failure_callback
- if callbacks:
- dagrun = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=run_id,
session=session)
- callbacks = callbacks if isinstance(callbacks, list) else
[callbacks]
- tis = dagrun.get_task_instances(session=session)
- # tis from a dagrun may not be a part of dag.partial_subset,
- # since dag.partial_subset is a subset of the dag.
- # This ensures that we will only use the accessible TI
- # context for the callback.
- if dag.partial:
- tis = [ti for ti in tis if not ti.state == State.NONE]
- # filter out removed tasks
- tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED]
- ti = tis[-1] # get first TaskInstance of DagRun
- ti.task = dag.get_task(ti.task_id)
- context = ti.get_template_context(session=session)
- context["reason"] = reason
- return callbacks, context
- return None
-
- @provide_session
- def handle_callback(self, dagrun: DagRun, success=True, reason=None,
session=NEW_SESSION):
- """
- Triggers on_failure_callback or on_success_callback as appropriate.
-
- This method gets the context of a single TaskInstance part of this
DagRun
- and passes that to the callable along with a 'reason', primarily to
- differentiate DagRun failures.
-
- .. note: The logs end up in
- ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log``
-
- :param dagrun: DagRun object
- :param success: Flag to specify if failure or success callback should
be called
- :param reason: Completion reason
- :param session: Database session
- """
- callbacks, context = DAG.fetch_callback(
- dag=self, run_id=dagrun.run_id, success=success, reason=reason,
session=session
- ) or (None, None)
-
- DAG.execute_callback(callbacks, context, self.dag_id)
-
- @classmethod
- def execute_callback(cls, callbacks: list[Callable] | None, context:
Context | None, dag_id: str):
- """
- Triggers the callbacks with the given context.
-
- :param callbacks: List of callbacks to call
- :param context: Context to pass to all callbacks
- :param dag_id: The dag_id of the DAG to find.
- """
- if callbacks and context:
- for callback in callbacks:
- cls.logger().info("Executing dag callback function: %s",
callback)
- try:
- callback(context)
- except Exception:
- cls.logger().exception("failed to invoke dag state update
callback")
- Stats.incr("dag.callback_exceptions", tags={"dag_id":
dag_id})
-
Review Comment:
Ah the cleanup 😄
##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -837,6 +839,15 @@ def wait(self) -> int:
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1
+ self.update_task_state_if_needed()
+
+ # Now at the last possible moment, when all logs and comms with the
subprocess has finished, lets
+ # upload the remote logs
+ self._upload_logs()
Review Comment:
Can we do this in finalize instead?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]