amoghrajesh commented on code in PR #46677:
URL: https://github.com/apache/airflow/pull/46677#discussion_r1957933941


##########
airflow/executors/workloads.py:
##########
@@ -106,4 +107,33 @@ def make(cls, ti: TIModel, dag_rel_path: Path | None = 
None) -> ExecuteTask:
         return cls(ti=ser_ti, dag_rel_path=path, token="", log_path=fname, 
bundle_info=bundle_info)
 
 
-All = Union[ExecuteTask]
+class RunTrigger(BaseModel):
+    """Execute an async "trigger" process that yields events."""
+
+    id: int
+
+    ti: TaskInstance | None
+    """
+    The task instance associated with this tirgger.

Review Comment:
   ```suggestion
       The task instance associated with this trigger.
   ```



##########
airflow/models/trigger.py:
##########
@@ -360,3 +364,82 @@ def get_sorted_triggers(cls, capacity: int, 
alive_triggerer_ids: list[int] | Sel
         # Add triggers associated to assets after triggers associated to tasks
         # It prioritizes DAGs over event driven scheduling which is fair
         return ti_triggers + asset_triggers
+
+
+@singledispatch
+def handle_event_submit(event: events.TriggerEvent, *, task_instance: 
TaskInstance, session: Session) -> None:
+    """
+    Handle the submit event for a given task instance.
+
+    This function sets the next method and next kwargs of the task instance,
+    as well as its state to scheduled. It also adds the event's payload
+    into the kwargs for the task.
+
+    :param task_instance: The task instance to handle the submit event for.
+    :param session: The session to be used for the database callback sink.
+    """
+    from airflow.utils.state import TaskInstanceState
+
+    # Get the next kwargs of the task instance, or an empty dictionary if it 
doesn't exist
+    next_kwargs = task_instance.next_kwargs or {}
+
+    # Add the event's payload into the kwargs for the task
+    next_kwargs["event"] = event.payload
+
+    # Update the next kwargs of the task instance
+    task_instance.next_kwargs = next_kwargs
+
+    # Remove ourselves as its trigger
+    task_instance.trigger_id = None
+
+    # Set the state of the task instance to scheduled
+    task_instance.state = TaskInstanceState.SCHEDULED
+    task_instance.scheduled_dttm = timezone.utcnow()
+    session.flush()

Review Comment:
   Am i missing something or we dont add it to the session?



##########
airflow/jobs/triggerer_job_runner.py:
##########
@@ -314,95 +122,222 @@ def on_kill(self):
 
         Called when there is an external kill command (via the heartbeat 
mechanism, for example).
         """
+        # TODO: signal instead.
         self.trigger_runner.stop = True
 
-    def _kill_listener(self):
-        if self.listener:
-            for h in self.listener.handlers:
-                h.close()
-            self.listener.stop()
-
     def _exit_gracefully(self, signum, frame) -> None:
         # The first time, try to exit nicely
-        if not self.trigger_runner.stop:
+        if self.trigger_runner and not self.trigger_runner.stop:
             self.log.info("Exiting gracefully upon receiving signal %s", 
signum)
             self.trigger_runner.stop = True
-            self._kill_listener()
         else:
             self.log.warning("Forcing exit due to second exit signal %s", 
signum)
+
+            self.trigger_runner.kill(signal.SIGKILL)
             sys.exit(os.EX_SOFTWARE)
 
     def _execute(self) -> int | None:
         self.log.info("Starting the triggerer")
         try:
-            # set job_id so that it can be used in log file names
-            self.trigger_runner.job_id = self.job.id
+            # Kick off runner sub-process without DB access
+            self.trigger_runner = TriggerRunnerSupervisor.start(
+                job=self.job, capacity=self.capacity, logger=log
+            )
 
-            # Kick off runner thread
-            self.trigger_runner.start()
-            # Start our own DB loop in the main thread
-            self._run_trigger_loop()
+            # Run the main DB comms loop in this process
+            self.trigger_runner.run_db_loop()
+            return self.trigger_runner._exit_code
         except Exception:
-            self.log.exception("Exception when executing 
TriggererJobRunner._run_trigger_loop")
+            self.log.exception("Exception when executing 
TriggerRunnerSupervisor.run_db_loop")
             raise
         finally:
             self.log.info("Waiting for triggers to clean up")
-            # Tell the subthread to stop and then wait for it.
+            # Tell the subtproc to stop and then wait for it.
             # If the user interrupts/terms again, _graceful_exit will allow 
them
             # to force-kill here.
-            self.trigger_runner.stop = True
-            self.trigger_runner.join(30)
+            self.trigger_runner.kill(escalation_delay=10, force=True)
             self.log.info("Exited trigger loop")
         return None
 
-    def _run_trigger_loop(self) -> None:
-        """Run synchronously and handle all database reads/writes; the 
main-thread trigger loop."""
-        while not self.trigger_runner.stop:
-            if not self.trigger_runner.is_alive():
-                self.log.error("Trigger runner thread has died! Exiting.")
+
+log: FilteringBoundLogger = structlog.get_logger(logger_name=__name__)
+
+
+# Using this as a simple namespace
+class messages:
+    class StartTriggerer(BaseModel):
+        """Tell the async trigger runner process to start, and where to send 
status update messages."""
+
+        requests_fd: int
+        kind: Literal["StartTriggerer"] = "StartTriggerer"
+
+    class CancelTriggers(BaseModel):
+        """Request to cancel running triggers."""
+
+        ids: Iterable[int]
+        kind: Literal["CancelTriggersMessage"] = "CancelTriggersMessage"
+
+    class TriggerStateChanges(BaseModel):
+        """Report state change about triggers back to the 
TriggerRunnerSupervisor."""
+
+        kind: Literal["TriggerStateChanges"] = "TriggerStateChanges"
+        events: Annotated[
+            list[tuple[int, events.DiscrimatedTriggerEvent]] | None,
+            # We have to specify a default here, as otherwise Pydantic 
struggles to deal with the discriminated
+            # union :shrug:
+            Field(default=None),
+        ]
+        # Format of list[str] is the exc traceback format
+        failures: list[tuple[int, list[str] | None]] | None = None
+        finished: list[int] | None = None
+
+
+ToAsyncProcess = Annotated[
+    Union[workloads.RunTrigger, messages.CancelTriggers, 
messages.StartTriggerer],
+    Field(discriminator="kind"),
+]
+
+
+ToSyncProcess = Annotated[
+    Union[messages.TriggerStateChanges],
+    Field(discriminator="kind"),
+]
+
+
[email protected](kw_only=True)
+class TriggerLoggingFactory:
+    log_path: str
+
+    bound_logger: WrappedLogger = attrs.field(init=False)
+
+    def __call__(self, processors: Iterable[structlog.typing.Processor]) -> 
WrappedLogger:
+        if hasattr(self, "bound_logger"):
+            return self.bound_logger
+
+        from airflow.sdk.log import init_log_file
+
+        log_file = init_log_file(self.log_path)
+
+        pretty_logs = False
+        if pretty_logs:
+            underlying_logger: WrappedLogger = 
structlog.WriteLogger(log_file.open("w", buffering=1))
+        else:
+            underlying_logger = structlog.BytesLogger(log_file.open("wb"))
+        logger = structlog.wrap_logger(underlying_logger, 
processors=processors).bind()
+        self.bound_logger = logger
+        return logger
+
+
[email protected](kw_only=True)
+class TriggerRunnerSupervisor(WatchedSubprocess):
+    """
+    TriggerRunnerSupervisor is responsible for monitoring the subprocess and 
marshalling DB access.
+
+    This class (which runs in the main process) is responsible for querying 
the DB, sending RunTrigger
+    workload messages to the subprocess, and collecting results and updating 
them in the DB.
+    """
+
+    job: Job
+    capacity: int
+
+    health_check_threshold = conf.getint("triggerer", 
"triggerer_health_check_threshold")
+
+    runner: TriggerRunner | None = None
+    stop: bool = False
+
+    decoder: ClassVar[TypeAdapter[ToSyncProcess]] = TypeAdapter(ToSyncProcess)
+
+    # Maps trigger IDs that we think are running in the sub process
+    running_triggers: set[int] = attrs.field(factory=set, init=False)
+
+    logger_cache: dict[int, TriggerLoggingFactory] = attrs.field(factory=dict, 
init=False)
+
+    # A list of triggers that we have told the async process to cancel. We 
keep them here until we receive the
+    # FinishedTriggers message
+    cancelling_triggers: set[int] = attrs.field(factory=set, init=False)
+
+    # Outbound queue of events
+    events: deque[tuple[int, events.TriggerEvent]] = 
attrs.field(factory=deque, init=False)
+
+    # Outbound queue of failed triggers
+    failed_triggers: deque[tuple[int, list[str] | None]] = 
attrs.field(factory=deque, init=False)
+
+    def is_alive(self) -> bool:
+        # Set by `_service_subprocess` in the loop
+        return self._exit_code is None
+
+    @classmethod
+    def start(  # type: ignore[override]
+        cls,
+        *,
+        job: Job,
+        logger=None,
+        **kwargs,
+    ):
+        proc = super().start(id=job.id, job=job, target=cls.run_in_process, 
logger=logger, **kwargs)
+
+        msg = messages.StartTriggerer(requests_fd=proc._requests_fd)
+        proc._send(msg)
+        return proc
+
+    def _handle_request(self, msg: ToSyncProcess, log: FilteringBoundLogger) 
-> None:  # type: ignore[override]
+        if isinstance(msg, messages.TriggerStateChanges):
+            log.debug("State change from async process", state=msg)
+            if msg.events:
+                self.events.extend(msg.events)
+            if msg.failures:
+                self.failed_triggers.extend(msg.failures)
+            for id in msg.finished or ():
+                self.running_triggers.discard(id)
+                self.cancelling_triggers.discard(id)
+                # TODO: Close logger? Or is deleting it enough
+                self.logger_cache.pop(id, None)

Review Comment:
   Should we close it here?



##########
task_sdk/src/airflow/sdk/api/client.py:
##########
@@ -152,14 +152,14 @@ def heartbeat(self, id: uuid.UUID, pid: int):
 
     def defer(self, id: uuid.UUID, msg):
         """Tell the API server that this TI has been deferred."""
-        body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True))
+        body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, 
exclude={"type"}))
 
         # Create a deferred state payload from msg
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
     def reschedule(self, id: uuid.UUID, msg: RescheduleTask):
         """Tell the API server that this TI has been reschduled."""
-        body = TIRescheduleStatePayload(**msg.model_dump(exclude_unset=True))
+        body = TIRescheduleStatePayload(**msg.model_dump(exclude_unset=True, 
exclude={"type"}))

Review Comment:
   Thanks :)



##########
task_sdk/src/airflow/sdk/api/client.py:
##########
@@ -152,14 +152,14 @@ def heartbeat(self, id: uuid.UUID, pid: int):
 
     def defer(self, id: uuid.UUID, msg):
         """Tell the API server that this TI has been deferred."""
-        body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True))
+        body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, 
exclude={"type"}))

Review Comment:
   Oh, so this was the issue.



##########
airflow/jobs/triggerer_job_runner.py:
##########
@@ -314,95 +122,222 @@ def on_kill(self):
 
         Called when there is an external kill command (via the heartbeat 
mechanism, for example).
         """
+        # TODO: signal instead.
         self.trigger_runner.stop = True
 
-    def _kill_listener(self):
-        if self.listener:
-            for h in self.listener.handlers:
-                h.close()
-            self.listener.stop()
-
     def _exit_gracefully(self, signum, frame) -> None:
         # The first time, try to exit nicely
-        if not self.trigger_runner.stop:
+        if self.trigger_runner and not self.trigger_runner.stop:
             self.log.info("Exiting gracefully upon receiving signal %s", 
signum)
             self.trigger_runner.stop = True
-            self._kill_listener()
         else:
             self.log.warning("Forcing exit due to second exit signal %s", 
signum)
+
+            self.trigger_runner.kill(signal.SIGKILL)
             sys.exit(os.EX_SOFTWARE)
 
     def _execute(self) -> int | None:
         self.log.info("Starting the triggerer")
         try:
-            # set job_id so that it can be used in log file names
-            self.trigger_runner.job_id = self.job.id
+            # Kick off runner sub-process without DB access
+            self.trigger_runner = TriggerRunnerSupervisor.start(
+                job=self.job, capacity=self.capacity, logger=log
+            )
 
-            # Kick off runner thread
-            self.trigger_runner.start()
-            # Start our own DB loop in the main thread
-            self._run_trigger_loop()
+            # Run the main DB comms loop in this process
+            self.trigger_runner.run_db_loop()
+            return self.trigger_runner._exit_code
         except Exception:
-            self.log.exception("Exception when executing 
TriggererJobRunner._run_trigger_loop")
+            self.log.exception("Exception when executing 
TriggerRunnerSupervisor.run_db_loop")
             raise
         finally:
             self.log.info("Waiting for triggers to clean up")
-            # Tell the subthread to stop and then wait for it.
+            # Tell the subtproc to stop and then wait for it.
             # If the user interrupts/terms again, _graceful_exit will allow 
them
             # to force-kill here.
-            self.trigger_runner.stop = True
-            self.trigger_runner.join(30)
+            self.trigger_runner.kill(escalation_delay=10, force=True)
             self.log.info("Exited trigger loop")
         return None
 
-    def _run_trigger_loop(self) -> None:
-        """Run synchronously and handle all database reads/writes; the 
main-thread trigger loop."""
-        while not self.trigger_runner.stop:
-            if not self.trigger_runner.is_alive():
-                self.log.error("Trigger runner thread has died! Exiting.")
+
+log: FilteringBoundLogger = structlog.get_logger(logger_name=__name__)
+
+
+# Using this as a simple namespace
+class messages:
+    class StartTriggerer(BaseModel):
+        """Tell the async trigger runner process to start, and where to send 
status update messages."""
+
+        requests_fd: int
+        kind: Literal["StartTriggerer"] = "StartTriggerer"
+
+    class CancelTriggers(BaseModel):
+        """Request to cancel running triggers."""
+
+        ids: Iterable[int]
+        kind: Literal["CancelTriggersMessage"] = "CancelTriggersMessage"
+
+    class TriggerStateChanges(BaseModel):
+        """Report state change about triggers back to the 
TriggerRunnerSupervisor."""
+
+        kind: Literal["TriggerStateChanges"] = "TriggerStateChanges"
+        events: Annotated[
+            list[tuple[int, events.DiscrimatedTriggerEvent]] | None,
+            # We have to specify a default here, as otherwise Pydantic 
struggles to deal with the discriminated
+            # union :shrug:
+            Field(default=None),
+        ]
+        # Format of list[str] is the exc traceback format
+        failures: list[tuple[int, list[str] | None]] | None = None
+        finished: list[int] | None = None
+
+
+ToAsyncProcess = Annotated[
+    Union[workloads.RunTrigger, messages.CancelTriggers, 
messages.StartTriggerer],
+    Field(discriminator="kind"),
+]
+
+
+ToSyncProcess = Annotated[
+    Union[messages.TriggerStateChanges],
+    Field(discriminator="kind"),
+]

Review Comment:
   This is cool but can we be in line with the `ToParent` or `ToTriggeProcess` 
/ `ToTriggerProcessSupervisor`?
   Or anything you think makes more sense.



##########
airflow/jobs/triggerer_job_runner.py:
##########
@@ -415,33 +350,177 @@ def handle_failed_triggers(self):
 
         Task Instances that depend on them need failing.
         """
-        while self.trigger_runner.failed_triggers:
+        while self.failed_triggers:
             # Tell the model to fail this trigger's deps
-            trigger_id, saved_exc = 
self.trigger_runner.failed_triggers.popleft()
+            trigger_id, saved_exc = self.failed_triggers.popleft()
             Trigger.submit_failure(trigger_id=trigger_id, exc=saved_exc)
             # Emit stat event
             Stats.incr("triggers.failed")
 
-    @add_span
     def emit_metrics(self):
-        Stats.gauge(f"triggers.running.{self.job.hostname}", 
len(self.trigger_runner.triggers))
-        Stats.gauge(
-            "triggers.running", len(self.trigger_runner.triggers), 
tags={"hostname": self.job.hostname}
-        )
+        Stats.gauge(f"triggers.running.{self.job.hostname}", 
len(self.running_triggers))
+        Stats.gauge("triggers.running", len(self.running_triggers), 
tags={"hostname": self.job.hostname})
 
-        capacity_left = self.capacity - len(self.trigger_runner.triggers)
+        capacity_left = self.capacity - len(self.running_triggers)
         Stats.gauge(f"triggerer.capacity_left.{self.job.hostname}", 
capacity_left)
         Stats.gauge("triggerer.capacity_left", capacity_left, 
tags={"hostname": self.job.hostname})
 
         span = Trace.get_current_span()
         span.set_attributes(
             {
                 "trigger host": self.job.hostname,
-                "triggers running": len(self.trigger_runner.triggers),
+                "triggers running": len(self.running_triggers),
                 "capacity left": capacity_left,
             }
         )
 
+    def _send(self, msg: BaseModel):
+        self.stdin.write(msg.model_dump_json().encode("utf-8") + b"\n")
+
+    def update_triggers(self, requested_trigger_ids: set[int]):
+        """
+        Request that we update what triggers we're running.
+
+        Works out the differences - ones to add, and ones to remove - then
+        adds them to the deques so the subprocess can actually mutate the 
running
+        trigger set.
+        """
+        render_log_fname = log_filename_template_renderer()
+
+        known_trigger_ids = (
+            self.running_triggers.union(x[0] for x in self.events)
+            .union(self.cancelling_triggers)
+            # .union(x.id for x in self.to_create)
+            .union(trigger[0] for trigger in self.failed_triggers)
+        )
+        # Work out the two difference sets
+        new_trigger_ids = requested_trigger_ids - known_trigger_ids
+        cancel_trigger_ids = self.running_triggers - requested_trigger_ids
+        # Bulk-fetch new trigger records
+        new_triggers = Trigger.bulk_fetch(new_trigger_ids)
+        triggers_with_assets = Trigger.fetch_trigger_ids_with_asset()
+        to_create: list[workloads.RunTrigger] = []
+        # Add in new triggers
+        for new_id in new_trigger_ids:
+            # Check it didn't vanish in the meantime
+            if new_id not in new_triggers:
+                log.warning("Trigger disappeared before we could start it", 
id=new_id)
+                continue
+
+            new_trigger_orm = new_triggers[new_id]
+
+            # If the trigger is not associated to a task or an asset, this 
means the TaskInstance
+            # row was updated by either Trigger.submit_event or 
Trigger.submit_failure
+            # and can happen when a single trigger Job is being run on 
multiple TriggerRunners
+            # in a High-Availability setup.
+            if new_trigger_orm.task_instance is None and new_id not in 
triggers_with_assets:
+                log.info(
+                    (
+                        "TaskInstance Trigger is None. It was likely updated 
by another trigger job. "
+                        "Skipping trigger instantiation."
+                    ),
+                    id=new_id,
+                )
+                continue
+
+            workload = workloads.RunTrigger(
+                classpath=new_trigger_orm.classpath,
+                id=new_id,
+                encrypted_kwargs=new_trigger_orm.encrypted_kwargs,

Review Comment:
   Is this a change made now or do we always encrypt and store in DB?



##########
airflow/jobs/triggerer_job_runner.py:
##########
@@ -503,46 +588,136 @@ async def arun(self):
         Actual triggers run in their own separate coroutines.
         """
         watchdog = asyncio.create_task(self.block_watchdog())
-        last_status = time.time()
+        ready_event = asyncio.Event()
+        read_workloads = asyncio.create_task(self.read_workloads(ready_event))
+
+        await ready_event.wait()
+        last_status = time.monotonic()
         try:
             while not self.stop:
+                # Raise exceptions from the tasks
+                if read_workloads.done():
+                    read_workloads.result()
+                if watchdog.done():
+                    watchdog.result()
+
                 # Run core logic
                 await self.create_triggers()
                 await self.cancel_triggers()
-                await self.cleanup_finished_triggers()
+                finished_ids = await self.cleanup_finished_triggers()
+                await self.sync_state_to_supervisor(finished_ids)
                 # Sleep for a bit
                 await asyncio.sleep(1)
                 # Every minute, log status
-                if time.time() - last_status >= 60:
+                if (now := time.monotonic()) - last_status >= 60:
                     count = len(self.triggers)
                     self.log.info("%i triggers currently running", count)
-                    last_status = time.time()
+                    last_status = now
+
         except Exception:
+            log.exception("Trigger runner failed")
             self.stop = True
             raise
-        # Wait for watchdog to complete
+        read_workloads.cancel()
+        # Wait for supporting tasks to complete
         await watchdog
+        await read_workloads
+
+    async def read_workloads(self, ready_event: asyncio.Event):
+        """
+        Read the triggers to run on stdin.
+
+        This reads-and-decodes the JSON lines send by the 
TriggerRunnerSupervisor to us on our stdint
+        """
+        loop = asyncio.get_event_loop()
+
+        task = asyncio.current_task(loop=loop)
+        if TYPE_CHECKING:
+            assert task
+        # Set the event on done callback so that this FN fails the arun wakes 
up and we catch the exception
+        task.add_done_callback(lambda _: ready_event.set())
+
+        async def connect_stdin() -> asyncio.StreamReader:
+            reader = asyncio.StreamReader()
+            protocol = asyncio.StreamReaderProtocol(reader)
+            await loop.connect_read_pipe(lambda: protocol, sys.stdin)
+            return reader
+
+        stdin = await connect_stdin()
+
+        # The first message must be this type, else we can't operate
+        line = await stdin.readline()
+
+        decoder = TypeAdapter[ToAsyncProcess](ToAsyncProcess)
+        msg = decoder.validate_json(line)
+        if not isinstance(msg, messages.StartTriggerer) or msg.requests_fd <= 
0:
+            raise RuntimeError(f"First message to triggerer must be 
{messages.StartTriggerer.__name__}")
+
+        writer_transport, writer_protocol = await loop.connect_write_pipe(
+            lambda: asyncio.streams.FlowControlMixin(loop=loop),
+            os.fdopen(msg.requests_fd, "wb"),
+        )
+        self.requests_sock = asyncio.streams.StreamWriter(writer_transport, 
writer_protocol, None, loop)
+
+        # Tell `arun` it can start the main loop now
+        ready_event.set()
+
+        async for line in stdin:
+            msg = decoder.validate_json(line)
+
+            if isinstance(msg, workloads.RunTrigger):
+                self.to_create.append(msg)
+            elif isinstance(msg, messages.CancelTriggers):
+                self.to_cancel.extend(msg.ids)
+            else:
+                raise ValueError(f"Unknown workload type {type(msg)}")
 
     async def create_triggers(self):
         """Drain the to_create queue and create all new triggers that have 
been requested in the DB."""
         while self.to_create:
-            trigger_id, trigger_instance = self.to_create.popleft()
-            if trigger_id not in self.triggers:
-                ti: TaskInstance | None = trigger_instance.task_instance
-                trigger_name = (
-                    
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID 
{trigger_id})"
-                    if ti
-                    else f"ID {trigger_id}"
-                )
-                self.triggers[trigger_id] = {
-                    "task": asyncio.create_task(self.run_trigger(trigger_id, 
trigger_instance)),
-                    "name": trigger_name,
-                    "events": 0,
-                }
-            else:
+            await asyncio.sleep(0)
+            workload = self.to_create.popleft()
+            trigger_id = workload.id
+            if trigger_id in self.triggers:
                 self.log.warning("Trigger %s had insertion attempted twice", 
trigger_id)
+                continue
+
+            try:
+                trigger_class = 
self.get_trigger_by_classpath(workload.classpath)
+            except BaseException as e:
+                # Either the trigger code or the path to it is bad. Fail the 
trigger.
+                self.log.error("Trigger failed to load code", error=e, 
classpath=workload.classpath)
+                self.failed_triggers.append((trigger_id, e))
+                continue
+
+            # Loading the trigger class could have been expensive. Lets give 
other things a chance to run!
             await asyncio.sleep(0)
 
+            try:
+                kwargs = Trigger._decrypt_kwargs(workload.encrypted_kwargs)

Review Comment:
   Ah ok seems like we have been storing it that way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to