Lee-W commented on code in PR #40084: URL: https://github.com/apache/airflow/pull/40084#discussion_r1637896854
########## airflow/exceptions.py: ########## @@ -372,14 +372,20 @@ class TaskDeferred(BaseException): Signal an operator moving to deferred state. Special exception raised to signal that the operator it was raised from - wishes to defer until a trigger fires. + wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance + directly. If the trigger will end the task instance itself, ``method_name`` should be + None; otherwise, provide the name of the method that should be used when + resuming execution in the task. """ + TRIGGER_EXIT = "__trigger_exit__" Review Comment: I'm not sure whether we really need this value. Do you think `from airflow.utils.types import NOTSET` works? ########## airflow/dag_processing/processor.py: ########## @@ -795,8 +817,17 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe if task: ti.refresh_from_task(task) - ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session) - self.log.info("Executed failure callback for %s in state %s", ti, ti.state) + if callback_type is TaskInstanceState.SUCCESS: + context = ti.get_template_context(session=session) + if callback_type is TaskInstanceState.SUCCESS: Review Comment: Why do we want to check `if callback_type is TaskInstanceState.SUCCESS:` again here? are we expecting `context = ti.get_template_context(session=session)` to change the state? ########## airflow/triggers/base.py: ########## @@ -135,3 +148,105 @@ def __eq__(self, other): if isinstance(other, TriggerEvent): return other.payload == self.payload return False + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Handle the submit event for a given task instance. + + This function sets the next method and next kwargs of the task instance, + as well as its state to scheduled. It also adds the event's payload + into the kwargs for the task. + + :param task_instance: The task instance to handle the submit event for. + :param session: The session to be used for the database callback sink. + """ + # Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist + next_kwargs = task_instance.next_kwargs or {} + + # Add the event's payload into the kwargs for the task + next_kwargs["event"] = self.payload + + # Update the next kwargs of the task instance + task_instance.next_kwargs = next_kwargs + + # Remove ourselves as its trigger + task_instance.trigger_id = None + + # Set the state of the task instance to scheduled + task_instance.state = TaskInstanceState.SCHEDULED + + +class BaseTaskEndEvent(TriggerEvent): + """Base event class to end the task without resuming on worker.""" + + task_instance_state: TaskInstanceState + + def __init__(self, *, xcoms: dict[str, Any] | None = None, **kwargs) -> None: + """ + Initialize the class with the specified parameters. + + :param xcoms: A dictionary of XComs or None. + :param kwargs: Additional keyword arguments. + """ + if "payload" in kwargs: + raise ValueError("Param 'payload' not supported for this class.") + super().__init__(payload=self.task_instance_state) + self.xcoms = xcoms + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): Review Comment: ```suggestion def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION) -> None: ``` ########## airflow/triggers/base.py: ########## @@ -135,3 +148,105 @@ def __eq__(self, other): if isinstance(other, TriggerEvent): return other.payload == self.payload return False + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Handle the submit event for a given task instance. + + This function sets the next method and next kwargs of the task instance, + as well as its state to scheduled. It also adds the event's payload + into the kwargs for the task. + + :param task_instance: The task instance to handle the submit event for. + :param session: The session to be used for the database callback sink. + """ + # Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist + next_kwargs = task_instance.next_kwargs or {} + + # Add the event's payload into the kwargs for the task + next_kwargs["event"] = self.payload + + # Update the next kwargs of the task instance + task_instance.next_kwargs = next_kwargs + + # Remove ourselves as its trigger + task_instance.trigger_id = None + + # Set the state of the task instance to scheduled + task_instance.state = TaskInstanceState.SCHEDULED + + +class BaseTaskEndEvent(TriggerEvent): + """Base event class to end the task without resuming on worker.""" + + task_instance_state: TaskInstanceState + + def __init__(self, *, xcoms: dict[str, Any] | None = None, **kwargs) -> None: + """ + Initialize the class with the specified parameters. + + :param xcoms: A dictionary of XComs or None. + :param kwargs: Additional keyword arguments. + """ + if "payload" in kwargs: + raise ValueError("Param 'payload' not supported for this class.") + super().__init__(payload=self.task_instance_state) + self.xcoms = xcoms + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Submit event for the given task instance. + + Marks the task with the state `task_instance_state` and optionally pushes xcom if applicable. + + :param task_instance: The task instance to be submitted. + :param session: The session to be used for the database callback sink. + """ + # Mark the task with terminal state and prevent it from resuming on worker + task_instance.trigger_id = None + task_instance.state = self.task_instance_state + + self._submit_callback_if_necessary(task_instance=task_instance, session=session) + self._push_xcoms_if_necessary(task_instance=task_instance) + + def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session): + """Submit a callback request if the task state is SUCCESS or FAILED.""" + is_failure = self.task_instance_state == TaskInstanceState.FAILED + if self.task_instance_state in [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED]: + request = TaskCallbackRequest( + full_filepath=task_instance.dag_model.fileloc, + simple_task_instance=SimpleTaskInstance.from_ti(task_instance), + is_failure_callback=is_failure, + task_callback_type=self.task_instance_state, + ) + log.warning("Sending callback: %s", request) + try: + DatabaseCallbackSink().send(callback=request, session=session) + except Exception as e: + log.error("Failed to send callback: %s", e) + + def _push_xcoms_if_necessary(self, *, task_instance: TaskInstance): Review Comment: ```suggestion def _push_xcoms_if_necessary(self, *, task_instance: TaskInstance) -> None: ``` ########## airflow/dag_processing/manager.py: ########## @@ -574,6 +576,7 @@ def _run_parsing_loop(self): pass elif isinstance(agent_signal, CallbackRequest): self._add_callback_to_queue(agent_signal) + self.log.warning("_add_callback_to_queue; agent signal; %s", agent_signal) Review Comment: Why is this a warning? Should it be info instead? ########## airflow/dag_processing/processor.py: ########## @@ -762,8 +762,29 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se if callbacks and context: DAG.execute_callback(callbacks, context, dag.dag_id) - def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session): - if not request.is_failure_callback: + def _execute_task_callbacks( + self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session + ) -> None: + """ + Execute the task callbacks. + + :param dagbag: the DagBag to use to get the task instance + :param request: the task callback request + :param session: the session to use + """ + try: + callback_type = TaskInstanceState(request.task_callback_type) + except Exception: + callback_type = None + is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED) + + # previously we ignored any request besides failures. now if given callback type directly, + # then we respect it and execute it. additionally because in this scenario the callback + # is submitted remotely, we assume there is no need to mess with state; we simply run + # the callback + + if not is_remote and not request.is_failure_callback: + self.log.debug("not failure callback: %s", request) Review Comment: I'm a bit confused by this log. What does this log want to tell us? ########## airflow/dag_processing/processor.py: ########## @@ -762,8 +762,29 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se if callbacks and context: DAG.execute_callback(callbacks, context, dag.dag_id) - def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session): - if not request.is_failure_callback: + def _execute_task_callbacks( + self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session + ) -> None: + """ + Execute the task callbacks. + + :param dagbag: the DagBag to use to get the task instance + :param request: the task callback request + :param session: the session to use + """ + try: + callback_type = TaskInstanceState(request.task_callback_type) + except Exception: Review Comment: Should it be ValueError instead? Or are we expecting other errors as well. ########## airflow/triggers/base.py: ########## @@ -135,3 +148,105 @@ def __eq__(self, other): if isinstance(other, TriggerEvent): return other.payload == self.payload return False + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Handle the submit event for a given task instance. + + This function sets the next method and next kwargs of the task instance, + as well as its state to scheduled. It also adds the event's payload + into the kwargs for the task. + + :param task_instance: The task instance to handle the submit event for. + :param session: The session to be used for the database callback sink. + """ + # Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist + next_kwargs = task_instance.next_kwargs or {} + + # Add the event's payload into the kwargs for the task + next_kwargs["event"] = self.payload + + # Update the next kwargs of the task instance + task_instance.next_kwargs = next_kwargs + + # Remove ourselves as its trigger + task_instance.trigger_id = None + + # Set the state of the task instance to scheduled + task_instance.state = TaskInstanceState.SCHEDULED + + +class BaseTaskEndEvent(TriggerEvent): + """Base event class to end the task without resuming on worker.""" + + task_instance_state: TaskInstanceState + + def __init__(self, *, xcoms: dict[str, Any] | None = None, **kwargs) -> None: + """ + Initialize the class with the specified parameters. + + :param xcoms: A dictionary of XComs or None. + :param kwargs: Additional keyword arguments. + """ + if "payload" in kwargs: + raise ValueError("Param 'payload' not supported for this class.") + super().__init__(payload=self.task_instance_state) + self.xcoms = xcoms + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Submit event for the given task instance. + + Marks the task with the state `task_instance_state` and optionally pushes xcom if applicable. + + :param task_instance: The task instance to be submitted. + :param session: The session to be used for the database callback sink. + """ + # Mark the task with terminal state and prevent it from resuming on worker + task_instance.trigger_id = None + task_instance.state = self.task_instance_state + + self._submit_callback_if_necessary(task_instance=task_instance, session=session) + self._push_xcoms_if_necessary(task_instance=task_instance) + + def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session): + """Submit a callback request if the task state is SUCCESS or FAILED.""" + is_failure = self.task_instance_state == TaskInstanceState.FAILED + if self.task_instance_state in [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED]: + request = TaskCallbackRequest( + full_filepath=task_instance.dag_model.fileloc, + simple_task_instance=SimpleTaskInstance.from_ti(task_instance), + is_failure_callback=is_failure, + task_callback_type=self.task_instance_state, + ) + log.warning("Sending callback: %s", request) Review Comment: ```suggestion log.info("Sending callback: %s", request) ``` ########## airflow/models/baseoperator.py: ########## @@ -1704,15 +1704,18 @@ def defer( self, *, trigger: BaseTrigger, - method_name: str, + method_name: str = TaskDeferred.TRIGGER_EXIT, kwargs: dict[str, Any] | None = None, timeout: timedelta | None = None, ) -> NoReturn: """ Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. This is achieved by raising a special exception (TaskDeferred) - which is caught in the main _execute_task wrapper. + which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end Review Comment: ```suggestion which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end ``` ########## airflow/triggers/base.py: ########## @@ -135,3 +148,105 @@ def __eq__(self, other): if isinstance(other, TriggerEvent): return other.payload == self.payload return False + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): Review Comment: ```suggestion def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION) -> None: ``` ########## airflow/triggers/temporal.py: ########## @@ -34,9 +34,13 @@ class DateTimeTrigger(BaseTrigger): a few seconds. The provided datetime MUST be in UTC. + + :param moment: when to yield event + :param end_task: whether the trigger should mark the task successful after time condition + reached or resume the task after time condition reached. """ - def __init__(self, moment: datetime.datetime): + def __init__(self, moment: datetime.datetime, *, end_task=False): Review Comment: Just want to confirm it. Is `end_task` a convention we should follow if we want to use this feature? ########## tests/sensors/test_time_sensor.py: ########## @@ -63,7 +63,6 @@ def test_task_is_deferred(self): assert isinstance(exc_info.value.trigger, DateTimeTrigger) assert exc_info.value.trigger.moment == timezone.datetime(2020, 7, 7, 10) - assert exc_info.value.method_name == "execute_complete" Review Comment: Why do we remove this `method_name` check? ########## airflow/triggers/base.py: ########## @@ -135,3 +148,105 @@ def __eq__(self, other): if isinstance(other, TriggerEvent): return other.payload == self.payload return False + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Handle the submit event for a given task instance. + + This function sets the next method and next kwargs of the task instance, + as well as its state to scheduled. It also adds the event's payload + into the kwargs for the task. + + :param task_instance: The task instance to handle the submit event for. + :param session: The session to be used for the database callback sink. + """ + # Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist + next_kwargs = task_instance.next_kwargs or {} + + # Add the event's payload into the kwargs for the task + next_kwargs["event"] = self.payload + + # Update the next kwargs of the task instance + task_instance.next_kwargs = next_kwargs + + # Remove ourselves as its trigger + task_instance.trigger_id = None + + # Set the state of the task instance to scheduled + task_instance.state = TaskInstanceState.SCHEDULED + + +class BaseTaskEndEvent(TriggerEvent): + """Base event class to end the task without resuming on worker.""" + + task_instance_state: TaskInstanceState + + def __init__(self, *, xcoms: dict[str, Any] | None = None, **kwargs) -> None: + """ + Initialize the class with the specified parameters. + + :param xcoms: A dictionary of XComs or None. + :param kwargs: Additional keyword arguments. + """ + if "payload" in kwargs: + raise ValueError("Param 'payload' not supported for this class.") + super().__init__(payload=self.task_instance_state) + self.xcoms = xcoms + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Submit event for the given task instance. + + Marks the task with the state `task_instance_state` and optionally pushes xcom if applicable. + + :param task_instance: The task instance to be submitted. + :param session: The session to be used for the database callback sink. + """ + # Mark the task with terminal state and prevent it from resuming on worker + task_instance.trigger_id = None + task_instance.state = self.task_instance_state + + self._submit_callback_if_necessary(task_instance=task_instance, session=session) + self._push_xcoms_if_necessary(task_instance=task_instance) + + def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session): Review Comment: ```suggestion def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session: Session) -> None: ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org