This is an automated email from the ASF dual-hosted git repository.
Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bf3c3515895 Share one poll loop across sibling event triggers (#66584)
bf3c3515895 is described below
commit bf3c3515895dde6ac64feebe39bf356c19bcea89
Author: Wei Lee <[email protected]>
AuthorDate: Mon May 25 18:27:28 2026 +0800
Share one poll loop across sibling event triggers (#66584)
---
.../authoring-and-scheduling/event-scheduling.rst | 140 +++++
airflow-core/newsfragments/66584.feature.rst | 1 +
.../src/airflow/config_templates/config.yml | 11 +
.../example_dags/example_asset_with_watchers.py | 44 +-
.../src/airflow/jobs/triggerer_job_runner.py | 57 +-
airflow-core/src/airflow/triggers/base.py | 122 +++-
airflow-core/src/airflow/triggers/shared_stream.py | 387 ++++++++++++
airflow-core/tests/unit/jobs/test_triggerer_job.py | 64 ++
.../tests/unit/triggers/test_base_trigger.py | 96 ++-
.../tests/unit/triggers/test_shared_stream.py | 685 +++++++++++++++++++++
.../airflow/providers/standard/triggers/file.py | 114 +++-
.../tests/unit/standard/triggers/test_file.py | 211 ++++++-
12 files changed, 1920 insertions(+), 12 deletions(-)
diff --git a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst
b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst
index 4cd72edc63b..b0118bf0d8b 100644
--- a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst
+++ b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst
@@ -64,6 +64,146 @@ event-driven scheduling, then a new trigger must be created.
This new trigger must inherit ``BaseEventTrigger`` and ensure it properly
works with event-driven scheduling.
It might inherit from the existing trigger as well if both triggers share some
common code.
+Sharing one poll across sibling triggers
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 3.3
+
+When several ``AssetWatcher`` instances on different assets back triggers that
read from the **same upstream resource**
+— a directory of flag files, a polling REST endpoint, and similar idempotent or
+subscriber-side-effect sources — the triggerer would otherwise spin up one
independent poll loop per trigger. For a
+shared source with twenty subscribers that means twenty poll loops, twenty
connections, twenty sets of API calls per
+cadence. See "Suitable upstreams" below for the precise scope.
+
+``BaseEventTrigger`` supports an opt-in path so that sibling triggers share a
single underlying poll, while each
+trigger keeps its own DB row, its own ``run_trigger`` task, and its own
per-instance filtering. To participate, a
+subclass overrides three hooks:
+
+* :py:meth:`~airflow.triggers.base.BaseEventTrigger.shared_stream_key` —
return a key identifying the shared
+ upstream (typically a tuple of strings). Triggers whose key compares equal
will share one poll. Returning ``None``
+ (the default) opts out — the trigger runs its own independent ``run()``
loop, exactly as before. The return value
+ is read **once** when the triggerer starts this trigger; changing it
mid-lifetime has no effect on group
+ membership, so siblings that should share a poll must return the same key
from the outset.
+ The key must be deterministic — derive it from configuration fields, never
from per-call values such as
+ ``time.time()`` or ``uuid.uuid4()``, because the comparison must be stable
across the lifetime of the group.
+
+* :py:meth:`~airflow.triggers.base.BaseEventTrigger.open_shared_stream` — a
``@classmethod`` coroutine the triggerer
+ drives **once per shared-stream group** to yield raw events from the
upstream. Because the triggerer reuses one
+ trigger's kwargs to drive the shared poll, only rely on fields whose values
participate in ``shared_stream_key``.
+
+* :py:meth:`~airflow.triggers.base.BaseEventTrigger.filter_shared_stream` — an
instance method that consumes the
+ broadcast raw stream and yields the ``TriggerEvent`` instances this trigger
should fire. Per-trigger filtering
+ (e.g. only events matching this instance's ``filename``) lives here.
+
+Example: a ``DirectoryFileDeleteTrigger`` that fires when a per-asset flag
file appears in a shared inbox directory:
+
+.. code-block:: python
+
+ from collections.abc import AsyncIterator, Hashable
+ from typing import Any
+
+ from airflow.triggers.base import BaseEventTrigger, TriggerEvent
+
+
+ class DirectoryFileDeleteTrigger(BaseEventTrigger):
+ def __init__(self, *, directory, filename, poke_interval=5.0):
+ super().__init__()
+ self.directory = directory
+ self.filename = filename
+ self.poke_interval = poke_interval
+
+ def shared_stream_key(self) -> Hashable | None:
+ # All triggers on the same directory + cadence share one scan.
+ return ("directory-scan", self.directory, self.poke_interval)
+
+ @classmethod
+ async def open_shared_stream(cls, kwargs: dict[str, Any]) ->
AsyncIterator[Any]:
+ # Drives one directory listing loop per group.
+ ...
+
+ async def filter_shared_stream(self, shared_stream:
AsyncIterator[Any]) -> AsyncIterator[TriggerEvent]:
+ # Each instance fires only for its own filename.
+ async for snapshot in shared_stream:
+ if self.filename in snapshot["names"]:
+ yield TriggerEvent(...)
+ return
+
+A complete example using this trigger ships in
+``airflow.example_dags.example_asset_with_watchers``, where two sibling
+``DirectoryFileDeleteTrigger`` watchers share one directory scan alongside
+a standalone ``FileDeleteTrigger`` watcher in the same Dag.
+
+What is and isn't shared
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+The sharing is narrower than the name might suggest:
+
+* **Shared** (one per ``shared_stream_key``): the ``open_shared_stream`` async
generator and its upstream I/O — for
+ example, the actual ``iterdir`` calls on the directory or polling REST API
calls.
+
+* **Not shared** (one per trigger): the ``Trigger`` DB row, the trigger
instance, the ``run_trigger``
+ asyncio task, and the ``filter_shared_stream`` async generator. Each
``AssetWatcher`` still appears as its own
+ trigger in the UI and in the metadata database.
+
+In other words, the savings is at the poll-loop and upstream-I/O layer, not at
the persistence or scheduling layer.
+
+Suitable upstreams
+^^^^^^^^^^^^^^^^^^
+
+The shared-stream channel is **one-way** today: events flow from
+``open_shared_stream`` out to each subscriber's ``filter_shared_stream``,
+and there is no way for a subscriber to tell the producer "I accepted /
+dropped / committed this event". That restricts the pattern to upstreams
+whose consumption does **not** depend on a side effect on a handle that
+only the producer holds. Good fits:
+
+* Idempotent / read-only reads — directory scans, polling REST APIs.
+* Subscriber-side-effect cleanup, where the trigger's per-event action
+ (``unlink``, local marking, …) goes through APIs the subscriber owns
+ independently of the shared producer handle.
+
+Currently **not** in scope: Kafka consumers (regardless of commit mode),
+SQS with delete-on-process or visibility extension, and any source where
+progress on the producer's handle is tied to the subscriber's accept /
+reject decision. These sources need a way for the subscriber to signal
+acceptance back to the producer, which the current shared-stream API does
+not provide.
+
+Verifying that sharing is active
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The triggerer logs the creation of each shared-stream group, and names the
poll task after its key:
+
+.. code-block:: text
+
+ Shared stream group started key=('directory-scan', '/tmp/region-flags',
5.0)
+
+.. code-block:: text
+
+ asyncio task name: shared-stream-poll[('directory-scan',
'/tmp/region-flags', 5.0)]
+
+If sharing is active you should see exactly one ``Shared stream group
started`` line per distinct key, regardless of
+how many subscribers join it. If you see one log line per subscriber instead,
the keys probably do not compare equal
+— verify that ``shared_stream_key`` returns identical values across the
siblings.
+
+Slow-subscriber overflow
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+Each subscriber in a shared-stream group has a bounded in-memory queue. If the
poll loop
+produces events faster than a subscriber's ``filter_shared_stream`` can
consume them, the
+queue fills and that trigger is failed with ``_SubscriberOverflow`` — a
deliberate fail-fast
+rather than unbounded memory growth.
+
+If subscribers repeatedly overflow, there are two ways to address this:
+
+* Raise ``[triggerer] shared_stream_subscriber_queue_size`` to give the
+ filter more slack before the overflow threshold is reached.
+* Redesign
:py:meth:`~airflow.triggers.base.BaseEventTrigger.shared_stream_key` so fewer
+ sibling triggers share a single group — a narrower group reduces the rate at
which any
+ one subscriber needs to consume events.
+
+Both reduce the mismatch between producer throughput and per-subscriber
consume rate.
+
Avoid infinite scheduling
~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/airflow-core/newsfragments/66584.feature.rst
b/airflow-core/newsfragments/66584.feature.rst
new file mode 100644
index 00000000000..e8a547d10ba
--- /dev/null
+++ b/airflow-core/newsfragments/66584.feature.rst
@@ -0,0 +1 @@
+Sibling ``BaseEventTrigger`` instances on different ``AssetWatcher`` s can now
share a single underlying poll loop in the triggerer by overriding
``shared_stream_key``, ``open_shared_stream``, and ``filter_shared_stream``.
Triggers that opt out (the default) keep their existing independent ``run()``
loop behavior.
diff --git a/airflow-core/src/airflow/config_templates/config.yml
b/airflow-core/src/airflow/config_templates/config.yml
index f7f958bb84b..a6f3de25e21 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -2773,6 +2773,17 @@ triggerer:
type: boolean
example: ~
default: "False"
+ shared_stream_subscriber_queue_size:
+ description: |
+ Per-subscriber buffer size for shared-stream triggers (triggers that
opt into a shared poll loop
+ via ``BaseEventTrigger.shared_stream_key``). Each subscribing trigger
keeps an in-memory queue of
+ raw events produced by the shared poll; if a slow subscriber fills its
queue, only that subscriber
+ fails, sibling subscribers are unaffected. Increase if a slow
subscriber must tolerate bursts from
+ a fast upstream.
+ version_added: 3.3.0
+ type: integer
+ example: ~
+ default: "1024"
kerberos:
description: ~
options:
diff --git
a/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py
b/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py
index 228b437dc53..49166c01e35 100644
--- a/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py
+++ b/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py
@@ -15,23 +15,55 @@
# specific language governing permissions and limitations
# under the License.
"""
-Example DAG for demonstrating the usage of event driven scheduling using
assets and triggers.
+Example Dag for event-driven scheduling using Assets and AssetWatchers.
+
+Three watchers demonstrate the two trigger patterns in one place:
+
+* The first watcher uses ``FileDeleteTrigger`` for a single specific path —
+ one watcher, one independent poll loop in the triggerer.
+* The other two use ``DirectoryFileDeleteTrigger`` with a matching
+ ``shared_stream_key`` of ``("directory-scan", directory, poke_interval)``;
+ the triggerer runs **one** directory listing loop for the pair and
+ broadcasts the result to both. Each still fires only for its own filename.
+
+The Dag runs when any of the three watchers' assets is updated. Touch
+``/tmp/test``, ``/tmp/region-flags/us.flag``, or ``/tmp/region-flags/eu.flag``
+to trigger a run.
"""
from __future__ import annotations
-from airflow.providers.standard.triggers.file import FileDeleteTrigger
+from airflow.providers.standard.triggers.file import (
+ DirectoryFileDeleteTrigger,
+ FileDeleteTrigger,
+)
from airflow.sdk import DAG, Asset, AssetWatcher, chain, task
-file_path = "/tmp/test"
+# Independent single-file watcher — has its own poll loop in the triggerer.
+single_file_trigger = FileDeleteTrigger(filepath="/tmp/test")
+single_file_asset = Asset(
+ "example_asset",
+ watchers=[AssetWatcher(name="test_asset_watcher",
trigger=single_file_trigger)],
+)
-trigger = FileDeleteTrigger(filepath=file_path)
-asset = Asset("example_asset",
watchers=[AssetWatcher(name="test_asset_watcher", trigger=trigger)])
+# Shared-stream watchers — same directory + poke interval, so the triggerer
+# runs one scan for both. Each watcher's ``filter_shared_stream`` matches on
+# its own filename and ``unlink``s the flag file as a subscriber-side effect.
+us_trigger = DirectoryFileDeleteTrigger(directory="/tmp/region-flags",
filename="us.flag", poke_interval=5.0)
+eu_trigger = DirectoryFileDeleteTrigger(directory="/tmp/region-flags",
filename="eu.flag", poke_interval=5.0)
+us_asset = Asset(
+ "region_us_flag",
+ watchers=[AssetWatcher(name="us_flag_watcher", trigger=us_trigger)],
+)
+eu_asset = Asset(
+ "region_eu_flag",
+ watchers=[AssetWatcher(name="eu_flag_watcher", trigger=eu_trigger)],
+)
with DAG(
dag_id="example_asset_with_watchers",
- schedule=[asset],
+ schedule=[single_file_asset, us_asset, eu_asset],
catchup=False,
tags=["example"],
):
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 6f8f7baae84..c73dfd857bd 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -27,7 +27,7 @@ import sys
import threading
import time
from collections import deque
-from collections.abc import Callable, Generator, Iterable, Iterator
+from collections.abc import Callable, Generator, Hashable, Iterable, Iterator
from contextlib import contextmanager, suppress
from datetime import datetime
from socket import socket
@@ -109,6 +109,7 @@ from airflow.sdk.execution_time.supervisor import
WatchedSubprocess, make_buffer
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.serialization.serialized_objects import DagSerialization
from airflow.triggers.base import BaseEventTrigger, BaseTrigger,
DiscrimatedTriggerEvent, TriggerEvent
+from airflow.triggers.shared_stream import SharedStreamManager
from airflow.utils.helpers import log_filename_template_renderer
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import create_session, provide_session
@@ -1069,6 +1070,10 @@ class TriggerRunner:
self.failed_triggers = deque()
self.job_id = None
self._stop_event = None
+ self._shared_streams = SharedStreamManager(
+ log=self.log,
+ max_subscriber_queue=conf.getint("triggerer",
"shared_stream_subscriber_queue_size"),
+ )
self.blocked_main_thread_warning_threshold = conf.getfloat(
"triggerer", "blocked_main_thread_warning_threshold"
)
@@ -1136,6 +1141,12 @@ class TriggerRunner:
reader_task.cancel()
with suppress(asyncio.CancelledError):
await reader_task
+ # Safety net: cancel any shared-stream poll tasks whose group
+ # survived per-trigger cleanup. The normal eviction path is
+ # ``SharedStreamManager.unsubscribe`` in ``run_trigger``'s
+ # finally; this call only matters when that path was bypassed
+ # (e.g. the unsubscribe coroutine raised and was swallowed).
+ await self._shared_streams.stop_all()
# Wait for supporting tasks to complete
await watchdog
@@ -1437,12 +1448,43 @@ class TriggerRunner:
name = self.triggers[trigger_id]["name"]
self.log.info("trigger %s starting", name)
+
+ # Triggers that opt into a shared underlying I/O stream
+ # (BaseEventTrigger.shared_stream_key returns non-None) consume a
+ # broadcast stream produced by SharedStreamManager and convert it
+ # via filter_shared_stream(). Everything else stays on the original
+ # standalone-run() path. The key is computed after
+ # render_template_fields so any templated attributes are already
+ # resolved when the key is constructed.
+ event_trigger: BaseEventTrigger | None = None
+ if isinstance(trigger, BaseEventTrigger):
+ event_trigger = trigger
+ shared_key: Hashable | None = None
+
with _make_trigger_span(ti=trigger.task_instance,
trigger_id=trigger_id, name=name) as span:
try:
if context is not None:
trigger.render_template_fields(context=context)
- async for event in trigger.run():
+ if event_trigger is not None:
+ try:
+ shared_key = event_trigger.shared_stream_key()
+ except Exception:
+ self.log.exception(
+ "shared_stream_key() raised; falling back to
standalone run",
+ trigger_id=trigger_id,
+ )
+ shared_key = None
+
+ if shared_key is not None and event_trigger is not None:
+ shared_stream = self._shared_streams.subscribe(
+ trigger_id=trigger_id, trigger=event_trigger,
key=shared_key
+ )
+ event_stream =
event_trigger.filter_shared_stream(shared_stream)
+ else:
+ event_stream = trigger.run()
+
+ async for event in event_stream:
await self.log.ainfo(
"Trigger fired event",
name=self.triggers[trigger_id]["name"], result=event
)
@@ -1486,6 +1528,17 @@ class TriggerRunner:
# fine, the cleanup process will understand that, but we want
to
# allow triggers a chance to cleanup, either in that case or if
# they exit cleanly. Exception from cleanup methods are
ignored.
+ if shared_key is not None:
+ try:
+ await self._shared_streams.unsubscribe(trigger_id,
shared_key)
+ except Exception:
+ # Best-effort cleanup, but log so we don't lose
+ # cancel-propagation or _handle_poll_terminate bugs.
+ self.log.exception(
+ "Failed to unsubscribe trigger from shared stream",
+ trigger_id=trigger_id,
+ key=shared_key,
+ )
with suppress(Exception):
await trigger.cleanup()
diff --git a/airflow-core/src/airflow/triggers/base.py
b/airflow-core/src/airflow/triggers/base.py
index f39b62facf7..7cda8cc4272 100644
--- a/airflow-core/src/airflow/triggers/base.py
+++ b/airflow-core/src/airflow/triggers/base.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import abc
import json
-from collections.abc import AsyncIterator
+from collections.abc import AsyncIterator, Hashable
from dataclasses import dataclass
from datetime import timedelta
from typing import TYPE_CHECKING, Annotated, Any
@@ -251,6 +251,48 @@ class BaseEventTrigger(BaseTrigger):
``BaseEventTrigger`` is a subclass of ``BaseTrigger`` designed to identify
triggers compatible with
event-driven scheduling.
+
+ **Sharing an underlying I/O stream between triggers**
+
+ A subclass that polls an upstream resource which can be safely consumed
+ by multiple sibling triggers (e.g. a directory scan, a polling REST API)
+ may opt in to having the triggerer run a single underlying poll loop
+ and fan its raw events out to every trigger in the group. To do so,
+ override:
+
+ * :meth:`shared_stream_key` — return a key identifying the
+ shared stream (a ``tuple`` of strings is a common choice). Triggers
+ whose key compares equal share one poll.
+ * :meth:`open_shared_stream` — open the shared stream and yield raw
+ events. Called once per group in the triggerer.
+ * :meth:`filter_shared_stream` — convert the shared raw stream into this
+ trigger's own ``TriggerEvent`` instances, applying any per-trigger
+ filtering or transformation.
+
+ Triggers whose ``shared_stream_key`` returns ``None`` (the default)
+ keep the existing behavior: each trigger gets its own poll loop via
+ :meth:`run`.
+
+ **Suitable upstreams**
+
+ The shared-stream channel is **one-way** today: events flow from the
+ producer (``open_shared_stream``) to each subscriber's
+ ``filter_shared_stream``, with no path back to tell the producer that a
+ subscriber accepted, dropped, or finished processing an event. That
+ restricts the pattern to upstreams whose consumption does **not** depend
+ on a side effect on a handle that only the producer holds:
+
+ * Idempotent / read-only reads (filesystem listings, polling REST APIs).
+ * Subscriber-side-effect cleanup, where the trigger's per-event action
+ (``unlink``, local marking, …) operates through APIs the subscriber
+ already owns, independent of the shared producer handle.
+
+ Upstreams **not** in scope include Kafka consumers (regardless of
+ commit mode), SQS with delete-on-process or visibility extension,
+ and any source where progress on the producer's handle is tied to
+ the subscriber's accept / reject decision. These sources need a way
+ for the subscriber to signal acceptance back to the producer, which
+ the current shared-stream API does not provide.
"""
supports_triggerer_queue: bool = False
@@ -269,6 +311,84 @@ class BaseEventTrigger(BaseTrigger):
normalized = encode_trigger({"classpath": classpath, "kwargs":
kwargs})["kwargs"]
return hash((classpath,
json.dumps(BaseSerialization.serialize(normalized)).encode("utf-8")))
+ def shared_stream_key(self) -> Hashable | None:
+ """
+ Identify an underlying I/O stream that can be shared with sibling
triggers.
+
+ Two trigger instances whose ``shared_stream_key()`` return values
+ compare equal (and are not ``None``) will share a single underlying
+ poll loop in the triggerer. Each instance still receives the events
+ it cares about through its own :meth:`filter_shared_stream` call.
+
+ Returning ``None`` (the default) opts out of sharing — the trigger
+ runs its own independent poll loop via :meth:`run`, exactly as today.
+
+ The return value is read **once** when ``run_trigger`` first starts
+ this trigger; any change to the key afterwards has no effect on
+ group membership for this instance. To share one poll across a set
+ of sibling triggers, ensure every trigger in the set returns the
+ same key from the outset.
+
+ The key must be deterministic — derive it from configuration fields,
+ never from per-call values such as ``time.time()`` or ``uuid.uuid4()``,
+ because the comparison must be stable across the lifetime of the group.
+
+ .. note::
+
+ This method is called **after** :meth:`render_template_fields`,
+ so any templated attribute (for example a ``directory`` derived
+ from a Jinja expression) is already resolved when the key is
+ constructed. Two sibling triggers that render to the same path
+ will correctly share their poll.
+ """
+ return None
+
+ @classmethod
+ async def open_shared_stream(cls, kwargs: dict[str, Any]) ->
AsyncIterator[Any]:
+ """
+ Open the shared underlying stream and yield raw events.
+
+ Called **once per shared-stream group** in the triggerer. ``kwargs``
+ is taken from one trigger in the group; implementations should rely
+ only on fields whose values participate in :meth:`shared_stream_key`,
+ because other fields may differ between siblings in the group.
+
+ Implementations are expected to run for the lifetime of the group —
+ the triggerer drives the iterator from a single task and cancels it
+ when the last subscriber leaves. Returning without raising (e.g.
+ because the upstream resource closed) is treated as an error and
+ propagated to every subscriber, so the contract is "yield forever, or
+ raise". If an upstream EOF is a meaningful end-of-life condition,
+ raise an exception that conveys it.
+
+ Declared as a classmethod (not staticmethod) so subclasses can
+ compose via ``super().open_shared_stream(kwargs)`` and reach
+ ``cls`` for class-scoped state or diagnostics.
+
+ Required only when :meth:`shared_stream_key` returns non-``None``.
+ """
+ raise NotImplementedError(
+ f"{cls.__name__} declares a shared_stream_key but does not
implement open_shared_stream"
+ )
+ yield # pragma: no cover - convince mypy this is an async iterator
+
+ async def filter_shared_stream(self, shared_stream: AsyncIterator[Any]) ->
AsyncIterator[TriggerEvent]:
+ """
+ Transform the shared raw event stream into this trigger's events.
+
+ The triggerer calls this method (instead of :meth:`run`) when this
+ trigger participates in a shared-stream group. Iterate
+ ``shared_stream`` to receive raw events from the shared poll, and
+ ``yield`` a :class:`TriggerEvent` for each one that should fire this
+ trigger.
+
+ Required only when :meth:`shared_stream_key` returns non-``None``.
+ """
+ raise NotImplementedError(
+ f"{type(self).__name__} declares a shared_stream_key but does not
implement filter_shared_stream"
+ )
+ yield # pragma: no cover - convince mypy this is an async iterator
+
class TriggerEvent(BaseModel):
"""
diff --git a/airflow-core/src/airflow/triggers/shared_stream.py
b/airflow-core/src/airflow/triggers/shared_stream.py
new file mode 100644
index 00000000000..f347a265524
--- /dev/null
+++ b/airflow-core/src/airflow/triggers/shared_stream.py
@@ -0,0 +1,387 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Shared underlying I/O between :class:`BaseEventTrigger` instances in the
triggerer.
+
+When multiple triggers declare the same non-``None``
+:meth:`~airflow.triggers.base.BaseEventTrigger.shared_stream_key`, the
+triggerer routes them through :class:`SharedStreamManager` so that one
+underlying poll loop produces raw events that are broadcast to every
+participating trigger. Each trigger then runs
+:meth:`~airflow.triggers.base.BaseEventTrigger.filter_shared_stream` to
+convert the broadcast into its own :class:`~airflow.triggers.base.TriggerEvent`
+instances. Triggers that opt out (the default) keep their independent
+``run()``-based poll loops untouched.
+
+Scope and the missing ack channel
+---------------------------------
+
+The shared-stream channel is **one-way**: events flow from
+``open_shared_stream`` out to each subscriber's ``filter_shared_stream``,
+with no path back. Subscribers cannot tell the producer "I accepted this
+event; please advance / commit / ack". The pattern is therefore only safe
+for upstreams whose consumption does not need a producer-side side effect
+tied to a subscriber's accept / reject decision:
+
+* Idempotent / read-only reads (filesystem listings, polling REST APIs).
+* Auto-commit Kafka consumers (``enable.auto.commit=true``).
+* Subscriber-side-effect cleanup (``unlink``, local marking, …) where the
+ per-event action goes through APIs the subscriber owns independently.
+
+Kafka manual-commit consumers, SQS delete-on-process / visibility
+extension, and similar message-broker patterns where progress is per-message
+and tied to the subscriber's decision are **not** in scope here today. A
+producer-side ack channel to cover them is a follow-up that should be
+designed against a concrete Kafka or SQS consumer rather than against an
+abstract API. See :class:`~airflow.triggers.base.BaseEventTrigger` for the
+matching subclass-facing notes.
+
+Lifecycle invariants
+--------------------
+
+The manager and groups cooperate to keep a single invariant true at every
+``await``-point:
+
+ A key is present in :attr:`SharedStreamManager._groups` only while its
+ group's poll task is alive and accepting new subscribers.
+
+This rules out the late-subscriber races that the naive design admits — a
+new subscriber for a key whose poll has died or is in the middle of being
+torn down always falls through to "create a fresh group" rather than
+attaching to a dead one and hanging on an empty queue. The invariant is
+maintained synchronously:
+
+* When ``_poll`` ends for any reason other than cancellation (the upstream
+ iterator raised, or returned), the group's ``finally`` block evicts the
+ key from ``_groups`` and broadcasts a terminal sentinel to current
+ subscribers — all without yielding, so no other coroutine can interleave.
+* When the last subscriber leaves, :meth:`SharedStreamManager.unsubscribe`
+ evicts the key from ``_groups`` *before* awaiting ``group.stop()``, so a
+ new subscriber arriving while we wait for cancellation creates a fresh
+ group.
+* :meth:`SharedStreamManager.stop_all` clears ``_groups`` in one synchronous
+ step before awaiting any stop, applying the same rule to shutdown.
+"""
+
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncGenerator, AsyncIterator, Callable, Hashable
+from contextlib import suppress
+from typing import TYPE_CHECKING, Any
+
+import structlog
+
+if TYPE_CHECKING:
+ from structlog.stdlib import BoundLogger
+
+ from airflow.triggers.base import BaseEventTrigger
+
+log = structlog.get_logger(__name__)
+
+DEFAULT_SUBSCRIBER_QUEUE_MAX = 1024
+"""Default per-subscriber queue size for shared streams.
+
+The :class:`SharedStreamManager` admits up to this many unconsumed raw events
+per subscriber before treating the subscriber as too slow to keep up — at
+which point the subscriber's trigger is failed with
+:class:`_SubscriberOverflow` rather than the queue growing without bound.
+
+Used as the fallback when no value is passed to ``SharedStreamManager``;
+in the triggerer this is overridden from the
+``[triggerer] shared_stream_subscriber_queue_size`` config option.
+"""
+
+
+class _PollTerminated(Exception):
+ """
+ Raised inside subscribers when ``open_shared_stream`` returns without
yielding more events.
+
+ Implementations are expected to run for the lifetime of the group; an
+ early return would otherwise leave subscribers waiting forever on an
+ empty queue.
+ """
+
+
+class _SubscriberOverflow(Exception):
+ """
+ Raised in a subscriber whose queue exceeded its maxsize.
+
+ Surfaces the slow subscriber loudly through the standard trigger-failure
+ path (rather than silently dropping events) so Airflow's retry / failure
+ semantics apply. Other subscribers in the same group are unaffected.
+ """
+
+
+class _PollFailure:
+ """Sentinel propagated through subscriber queues when the shared poll
ends."""
+
+ __slots__ = ("exc",)
+
+ def __init__(self, exc: BaseException) -> None:
+ self.exc = exc
+
+
+async def _drain(queue: asyncio.Queue) -> AsyncGenerator[Any, None]:
+ """
+ Yield items from ``queue`` until a poll termination sentinel arrives.
+
+ Subscribers exit either by their consuming task being cancelled
+ (Airflow's standard idiom — :class:`CancelledError` propagates through
+ ``queue.get()``) or by the shared poll ending, in which case the
+ :class:`_PollFailure` sentinel re-raises here.
+ """
+ while True:
+ item = await queue.get()
+ if isinstance(item, _PollFailure):
+ raise item.exc
+ yield item
+
+
+class _SharedStreamGroup:
+ """One shared poll loop broadcasting raw events to N subscriber queues."""
+
+ def __init__(
+ self,
+ *,
+ key: Hashable,
+ trigger_class: type[BaseEventTrigger],
+ kwargs: dict[str, Any],
+ on_poll_terminate: Callable[[_SharedStreamGroup], None],
+ max_subscriber_queue: int,
+ log: BoundLogger,
+ ) -> None:
+ self.key = key
+ self.trigger_class = trigger_class
+ self.kwargs = kwargs
+ self.log = log
+ self._on_poll_terminate = on_poll_terminate
+ self._max_subscriber_queue = max_subscriber_queue
+ self._subscribers: dict[int, asyncio.Queue] = {}
+ self._overflowed: set[int] = set()
+ self._poll_task: asyncio.Task | None = None
+
+ def start(self) -> None:
+ """Start the underlying poll loop. Call exactly once per group."""
+ if self._poll_task is not None:
+ raise RuntimeError(f"Shared stream group {self.key!r} already
started")
+ self._poll_task = asyncio.create_task(
+ self._poll(),
+ name=f"shared-stream-poll[{self.key!r}]",
+ )
+
+ async def _poll(self) -> None:
+ terminal_exc: BaseException | None = None
+ try:
+ async for raw_event in
self.trigger_class.open_shared_stream(self.kwargs):
+ for trigger_id, queue in self._subscribers.items():
+ if trigger_id in self._overflowed:
+ # Subscriber has been force-failed on a previous
+ # overflow; the failure sentinel is already in its
+ # queue and unsubscribe will drop it on next pass.
+ continue
+ try:
+ queue.put_nowait(raw_event)
+ except asyncio.QueueFull:
+ self._fail_overflowed_subscriber(trigger_id, queue)
+ terminal_exc = _PollTerminated(
+ f"open_shared_stream for {self.key!r} returned without
raising; "
+ "shared streams are expected to run for the lifetime of the
group"
+ )
+ except asyncio.CancelledError:
+ # ``stop()`` initiated this; the manager has already evicted the
+ # group and is awaiting our exit. Do not run the terminate path.
+ raise
+ except Exception as exc:
+ terminal_exc = exc
+ self.log.exception("Shared stream poll failed; propagating to
subscribers", key=self.key)
+ finally:
+ if terminal_exc is not None:
+ # Synchronous: evict from the manager and broadcast the
+ # sentinel before returning to the loop, so no coroutine can
+ # observe ``_groups[key]`` pointing at a dead poll.
+ self._on_poll_terminate(self)
+ failure = _PollFailure(terminal_exc)
+ for queue in self._subscribers.values():
+ # Drain stale events then put the failure sentinel so every
+ # subscriber wakes up even if its queue was at capacity.
+ self._drain_and_offer_failure(queue, failure)
+
+ def subscribe(self, trigger_id: int) -> AsyncIterator[Any]:
+ """Register ``trigger_id`` as a subscriber and return its raw event
stream."""
+ if trigger_id in self._subscribers:
+ raise RuntimeError(f"Trigger {trigger_id} already subscribed to
shared stream {self.key!r}")
+ queue: asyncio.Queue =
asyncio.Queue(maxsize=self._max_subscriber_queue)
+ self._subscribers[trigger_id] = queue
+ return _drain(queue)
+
+ def unsubscribe(self, trigger_id: int) -> None:
+ # Active subscribers exit through their consuming task being cancelled
+ # (Airflow's standard idiom); dropping the queue is enough here.
+ self._subscribers.pop(trigger_id, None)
+ self._overflowed.discard(trigger_id)
+
+ def _fail_overflowed_subscriber(self, trigger_id: int, queue:
asyncio.Queue) -> None:
+ """
+ Force a slow subscriber to fail with :class:`_SubscriberOverflow`.
+
+ The broadcast hit ``QueueFull`` for this subscriber's queue, which
+ means the subscriber's :meth:`filter_shared_stream` is falling behind
+ the upstream cadence. Rather than dropping events silently — which
+ would invisibly violate Asset event-driven semantics — we drain
+ whatever stale events are pending and replace them with a
+ :class:`_PollFailure` so the subscriber's ``run_trigger`` sees the
+ error on its next ``__anext__``. Other subscribers in the same group
+ are unaffected.
+ """
+ self.log.warning(
+ "Shared stream subscriber overflowed; failing this trigger",
+ key=self.key,
+ trigger_id=trigger_id,
+ queue_maxsize=queue.maxsize,
+ )
+ self._drain_and_offer_failure(
+ queue,
+ _PollFailure(
+ _SubscriberOverflow(
+ f"shared stream {self.key!r} fell behind for trigger
{trigger_id}: "
+ f"subscriber queue exceeded maxsize={queue.maxsize}"
+ )
+ ),
+ )
+ self._overflowed.add(trigger_id)
+
+ def _drain_and_offer_failure(self, queue: asyncio.Queue, failure:
_PollFailure) -> None:
+ """
+ Drain ``queue`` and put ``failure`` so the subscriber wakes on the
failure.
+
+ The drain releases capacity so the subsequent ``put_nowait`` cannot
raise
+ ``QueueFull``; this is the single point that both the
terminal-broadcast
+ and the per-subscriber overflow path go through.
+ """
+ while not queue.empty():
+ try:
+ queue.get_nowait()
+ except asyncio.QueueEmpty:
+ break
+ queue.put_nowait(failure)
+
+ def is_empty(self) -> bool:
+ return not self._subscribers
+
+ async def stop(self) -> None:
+ """Cancel the poll task if it is still running and wait for it to
exit."""
+ if self._poll_task is None or self._poll_task.done():
+ return
+ self._poll_task.cancel()
+ with suppress(asyncio.CancelledError):
+ await self._poll_task
+
+
+class SharedStreamManager:
+ """
+ Coordinate :class:`BaseEventTrigger` instances that share underlying I/O.
+
+ The manager owns one :class:`_SharedStreamGroup` per distinct
+ ``shared_stream_key``. Each group runs a single async task that drives
+ ``open_shared_stream``; subscribers receive raw events through their own
+ asyncio queues and convert them to :class:`TriggerEvent` instances
+ independently.
+
+ The manager is single-event-loop and not thread-safe. The triggerer's
+ ``TriggerRunner`` is its sole owner.
+ """
+
+ def __init__(
+ self,
+ *,
+ log: BoundLogger | None = None,
+ max_subscriber_queue: int = DEFAULT_SUBSCRIBER_QUEUE_MAX,
+ ) -> None:
+ self.log = log or structlog.get_logger(__name__)
+ self._max_subscriber_queue = max_subscriber_queue
+ self._groups: dict[Hashable, _SharedStreamGroup] = {}
+
+ def subscribe(
+ self,
+ *,
+ trigger_id: int,
+ trigger: BaseEventTrigger,
+ key: Hashable,
+ ) -> AsyncIterator[Any]:
+ """
+ Subscribe a trigger to the shared stream identified by ``key``.
+
+ On first subscriber for a given key the group is created and the
+ underlying poll loop is started. Returns an async iterator of raw
+ events the trigger should feed into ``filter_shared_stream``.
+ """
+ if key is None:
+ raise ValueError("shared stream key must not be None")
+ if (group := self._groups.get(key)) is None:
+ _, kwargs = trigger.serialize()
+ group = _SharedStreamGroup(
+ key=key,
+ trigger_class=type(trigger),
+ kwargs=kwargs,
+ on_poll_terminate=self._handle_poll_terminate,
+ max_subscriber_queue=self._max_subscriber_queue,
+ log=self.log,
+ )
+ self._groups[key] = group
+ group.start()
+ self.log.debug("Shared stream group started", key=key)
+ return group.subscribe(trigger_id)
+
+ async def unsubscribe(self, trigger_id: int, key: Hashable) -> None:
+ """
+ Remove a subscriber.
+
+ When the last subscriber for ``key`` leaves, the key is evicted from
+ ``_groups`` synchronously and the underlying poll task is cancelled.
+ Eviction happens *before* awaiting ``stop()`` so that a new subscriber
+ arriving while we wait for cancellation builds a fresh group rather
+ than attaching to the dying one.
+ """
+ group = self._groups.get(key)
+ if group is None:
+ return
+ group.unsubscribe(trigger_id)
+ if group.is_empty():
+ del self._groups[key]
+ await group.stop()
+ self.log.debug("Shared stream group stopped", key=key)
+
+ async def stop_all(self) -> None:
+ """Cancel every active group; used during triggerer shutdown."""
+ groups = list(self._groups.values())
+ self._groups.clear()
+ for group in groups:
+ await group.stop()
+
+ def _handle_poll_terminate(self, group: _SharedStreamGroup) -> None:
+ """
+ Evict a group synchronously when its poll task ends on its own.
+
+ Invoked from ``_SharedStreamGroup._poll``'s ``finally`` before any
+ ``await`` hands control to another coroutine, so the eviction races no
+ ``subscribe`` call. The ``is`` check is defensive — under normal flow
+ a group only enters this path while it is still the live entry for
+ its key.
+ """
+ if self._groups.get(group.key) is group:
+ del self._groups[group.key]
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 0501783b992..73643ab3867 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -776,6 +776,70 @@ class TestTriggerRunner:
mock_trigger.on_kill.assert_awaited_once()
mock_trigger.cleanup.assert_awaited_once()
+ def test_run_trigger_routes_shared_stream_trigger_through_manager(self,
session) -> None:
+ """A BaseEventTrigger that opts into a shared stream consumes
filter_shared_stream()."""
+ from airflow.triggers.base import BaseEventTrigger, TriggerEvent
+
+ class _SharedTrigger(BaseEventTrigger):
+ def __init__(self, queue_url: str, region: str | None = None):
+ super().__init__()
+ self.queue_url = queue_url
+ self.region = region
+
+ def serialize(self):
+ return (
+ f"{type(self).__module__}.{type(self).__qualname__}",
+ {"queue_url": self.queue_url, "region": self.region},
+ )
+
+ def shared_stream_key(self):
+ return ("queue", self.queue_url)
+
+ @classmethod
+ async def open_shared_stream(cls, kwargs):
+ yield {"region": "us"}
+ yield {"region": "eu"}
+ # Stay alive so the manager can tear us down on unsubscribe.
+ await asyncio.Event().wait()
+
+ async def filter_shared_stream(self, shared_stream):
+ async for raw in shared_stream:
+ if self.region is None or raw["region"] == self.region:
+ yield TriggerEvent(raw)
+
+ async def run(self): # pragma: no cover - replaced by
filter_shared_stream
+ yield TriggerEvent({})
+
+ trigger_runner = TriggerRunner()
+ trigger_runner.triggers = {
+ 1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": True,
"name": "us", "events": 0}
+ }
+ trigger = _SharedTrigger(queue_url="https://q", region="us")
+ trigger.task_instance = MagicMock()
+ trigger.task_instance.map_index = -1
+
+ async def _drive():
+ run_task = asyncio.create_task(trigger_runner.run_trigger(1,
trigger))
+ # Wait until the "us" event has been pushed onto the outbound
queue,
+ # then cancel the trigger so the test can exit deterministically.
+ for _ in range(100):
+ await asyncio.sleep(0.01)
+ if trigger_runner.events:
+ break
+ run_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await run_task
+
+ asyncio.run(_drive())
+
+ events = list(trigger_runner.events)
+ assert len(events) == 1
+ trigger_id, event = events[0]
+ assert trigger_id == 1
+ assert event.payload == {"region": "us"}
+ # Group is torn down on unsubscribe.
+ assert trigger_runner._shared_streams._groups == {}
+
def test_run_trigger_on_kill_timeout_does_not_block_cleanup(self, session)
-> None:
"""A hanging on_kill() is interrupted after the timeout and cleanup
still runs."""
trigger_runner = TriggerRunner()
diff --git a/airflow-core/tests/unit/triggers/test_base_trigger.py
b/airflow-core/tests/unit/triggers/test_base_trigger.py
index d9e38385a20..8e5690a3ef9 100644
--- a/airflow-core/tests/unit/triggers/test_base_trigger.py
+++ b/airflow-core/tests/unit/triggers/test_base_trigger.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import pytest
from airflow.sdk.bases.operator import BaseOperator
-from airflow.triggers.base import BaseTrigger, StartTriggerArgs
+from airflow.triggers.base import BaseEventTrigger, BaseTrigger,
StartTriggerArgs, TriggerEvent
class DummyOperator(BaseOperator):
@@ -138,3 +138,97 @@ def
test_render_template_fields_empty_when_no_trigger_kwargs(create_task_instanc
# Rendering with empty template_fields is a no-op
trigger.render_template_fields(context={"name": "world"})
assert trigger.name == "Hello {{ name }}"
+
+
+class _PlainEventTrigger(BaseEventTrigger):
+ """A BaseEventTrigger that does not opt into shared streams."""
+
+ def __init__(self, name: str = "plain"):
+ super().__init__()
+ self.name = name
+
+ def serialize(self):
+ return (f"{type(self).__module__}.{type(self).__qualname__}", {"name":
self.name})
+
+ async def run(self):
+ yield TriggerEvent({"name": self.name})
+
+
+class _SharedQueueTrigger(BaseEventTrigger):
+ """A BaseEventTrigger that opts into shared streams."""
+
+ def __init__(self, queue_url: str, region: str | None = None):
+ super().__init__()
+ self.queue_url = queue_url
+ self.region = region
+
+ def serialize(self):
+ return (
+ f"{type(self).__module__}.{type(self).__qualname__}",
+ {"queue_url": self.queue_url, "region": self.region},
+ )
+
+ def shared_stream_key(self):
+ return ("shared-queue", self.queue_url)
+
+ @classmethod
+ async def open_shared_stream(cls, kwargs):
+ for region in ("us", "eu", "us"):
+ yield {"queue_url": kwargs["queue_url"], "region": region}
+
+ async def filter_shared_stream(self, shared_stream):
+ async for raw in shared_stream:
+ if self.region is None or raw["region"] == self.region:
+ yield TriggerEvent(raw)
+
+ async def run(self): # pragma: no cover - replaced by filter_shared_stream
+ yield TriggerEvent({})
+
+
+def test_base_event_trigger_defaults_no_sharing():
+ trigger = _PlainEventTrigger()
+ assert trigger.shared_stream_key() is None
+
+
+async def _drain_async_iter(it):
+ async for _ in it:
+ pass
+
+
[email protected]
+async def test_base_event_trigger_default_open_shared_stream_raises():
+ with pytest.raises(NotImplementedError, match="open_shared_stream"):
+ await _drain_async_iter(_PlainEventTrigger.open_shared_stream({}))
+
+
[email protected]
+async def test_base_event_trigger_default_filter_shared_stream_raises():
+ trigger = _PlainEventTrigger()
+
+ async def empty_stream():
+ if False:
+ yield # pragma: no cover
+
+ with pytest.raises(NotImplementedError, match="filter_shared_stream"):
+ await _drain_async_iter(trigger.filter_shared_stream(empty_stream()))
+
+
+def test_subclass_can_declare_shared_stream_key():
+ a = _SharedQueueTrigger(queue_url="https://q", region="us")
+ b = _SharedQueueTrigger(queue_url="https://q", region="eu")
+ c = _SharedQueueTrigger(queue_url="https://other", region="us")
+
+ assert a.shared_stream_key() == b.shared_stream_key()
+ assert a.shared_stream_key() != c.shared_stream_key()
+
+
[email protected]
+async def test_subclass_filter_shared_stream_applies_per_instance_match():
+ us = _SharedQueueTrigger(queue_url="https://q", region="us")
+
+ async def stream():
+ for region in ("us", "eu", "us"):
+ yield {"queue_url": "https://q", "region": region}
+
+ payloads = [event.payload async for event in
us.filter_shared_stream(stream())]
+ assert [p["region"] for p in payloads] == ["us", "us"]
diff --git a/airflow-core/tests/unit/triggers/test_shared_stream.py
b/airflow-core/tests/unit/triggers/test_shared_stream.py
new file mode 100644
index 00000000000..28577ccbb52
--- /dev/null
+++ b/airflow-core/tests/unit/triggers/test_shared_stream.py
@@ -0,0 +1,685 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from contextlib import suppress
+
+import pytest
+
+from airflow.triggers.base import BaseEventTrigger, TriggerEvent
+from airflow.triggers.shared_stream import (
+ SharedStreamManager,
+ _PollFailure,
+ _SharedStreamGroup,
+ _SubscriberOverflow,
+)
+
+
+class _ProgrammableSharedStreamTrigger(BaseEventTrigger):
+ """
+ Test helper trigger whose shared poll yields whatever the test class attr
says.
+
+ Subclass per test so each scenario gets its own ``open_shared_stream``
+ behavior without leaking state between tests.
+ """
+
+ queue_url: str = "https://q"
+
+ def __init__(self, queue_url: str = "https://q", region: str | None =
None):
+ super().__init__()
+ self.queue_url = queue_url
+ self.region = region
+
+ def serialize(self):
+ return (
+ f"{type(self).__module__}.{type(self).__qualname__}",
+ {"queue_url": self.queue_url, "region": self.region},
+ )
+
+ def shared_stream_key(self):
+ return ("queue", self.queue_url)
+
+ async def filter_shared_stream(self, shared_stream):
+ async for raw in shared_stream:
+ if self.region is None or raw["region"] == self.region:
+ yield TriggerEvent(raw)
+
+ async def run(self): # pragma: no cover - replaced by filter_shared_stream
+ yield TriggerEvent({})
+
+
+def _events_then_block(events: list[dict]):
+ async def _open_shared_stream(cls, kwargs):
+ for event in events:
+ yield event
+ # Stay alive forever so tests can observe broadcast then tear down.
+ await asyncio.Event().wait()
+
+ return classmethod(_open_shared_stream)
+
+
+def _make_trigger_class(open_shared_stream):
+ """Return a fresh subclass with the given open_shared_stream
classmethod."""
+
+ class _Trigger(_ProgrammableSharedStreamTrigger):
+ pass
+
+ _Trigger.open_shared_stream = open_shared_stream
+ return _Trigger
+
+
+async def _collect(stream, *, n: int, timeout: float = 1.0) -> list:
+ """Pull ``n`` items off an async iterator with a per-item timeout."""
+ out = []
+ it = stream.__aiter__()
+ for _ in range(n):
+ out.append(await asyncio.wait_for(it.__anext__(), timeout=timeout))
+ return out
+
+
[email protected]
+async def test_single_subscriber_receives_broadcast_events():
+ cls = _make_trigger_class(
+ _events_then_block(
+ [
+ {"region": "us"},
+ {"region": "eu"},
+ ]
+ )
+ )
+ trigger = cls(region="us")
+ manager = SharedStreamManager()
+ try:
+ stream = manager.subscribe(trigger_id=1, trigger=trigger,
key=trigger.shared_stream_key())
+ events = await _collect(trigger.filter_shared_stream(stream), n=1)
+ assert [e.payload["region"] for e in events] == ["us"]
+ finally:
+ await manager.unsubscribe(1, trigger.shared_stream_key())
+
+
[email protected]
+async def test_two_subscribers_share_one_poll_and_filter_independently():
+ cls = _make_trigger_class(
+ _events_then_block(
+ [
+ {"region": "us"},
+ {"region": "eu"},
+ {"region": "us"},
+ ]
+ )
+ )
+ us, eu = cls(region="us"), cls(region="eu")
+ key = us.shared_stream_key()
+ assert key == eu.shared_stream_key()
+
+ manager = SharedStreamManager()
+ try:
+ us_stream = manager.subscribe(trigger_id=1, trigger=us, key=key)
+ eu_stream = manager.subscribe(trigger_id=2, trigger=eu, key=key)
+
+ # The shared group is created exactly once.
+ assert len(manager._groups) == 1
+
+ us_events, eu_events = await asyncio.gather(
+ _collect(us.filter_shared_stream(us_stream), n=2),
+ _collect(eu.filter_shared_stream(eu_stream), n=1),
+ )
+ assert [e.payload["region"] for e in us_events] == ["us", "us"]
+ assert [e.payload["region"] for e in eu_events] == ["eu"]
+ finally:
+ await manager.unsubscribe(1, key)
+ await manager.unsubscribe(2, key)
+
+
[email protected]
+async def test_group_is_torn_down_when_last_subscriber_leaves():
+ cls = _make_trigger_class(_events_then_block([{"region": "us"}]))
+ trigger = cls(region="us")
+ manager = SharedStreamManager()
+ key = trigger.shared_stream_key()
+
+ manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+ assert key in manager._groups
+
+ await manager.unsubscribe(1, key)
+ assert key not in manager._groups
+
+
[email protected]
+async def test_independent_keys_use_independent_groups():
+ cls = _make_trigger_class(_events_then_block([{"region": "us"}]))
+ a = cls(queue_url="https://a")
+ b = cls(queue_url="https://b")
+ manager = SharedStreamManager()
+
+ manager.subscribe(trigger_id=1, trigger=a, key=a.shared_stream_key())
+ manager.subscribe(trigger_id=2, trigger=b, key=b.shared_stream_key())
+ try:
+ assert set(manager._groups) == {a.shared_stream_key(),
b.shared_stream_key()}
+ finally:
+ await manager.unsubscribe(1, a.shared_stream_key())
+ await manager.unsubscribe(2, b.shared_stream_key())
+
+
[email protected]
+async def test_poll_failure_propagates_to_subscribers_and_evicts_group():
+ async def _open_shared_stream(cls, kwargs):
+ raise RuntimeError("boom")
+ yield # pragma: no cover
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ trigger = cls()
+ manager = SharedStreamManager()
+ key = trigger.shared_stream_key()
+ try:
+ stream = manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+ with pytest.raises(RuntimeError, match="boom"):
+ await
asyncio.wait_for(_collect(trigger.filter_shared_stream(stream), n=1),
timeout=1.0)
+ # The failing poll evicts its own group from the manager in _poll's
+ # finally, before any subscriber resumes — so by the time the
+ # subscriber observes "boom" the manager already has no group for
+ # this key. A late subscriber arriving here would create a fresh
+ # group rather than attaching to a dead one.
+ assert key not in manager._groups
+ finally:
+ await manager.unsubscribe(1, key)
+
+
[email protected]
+async def test_subscribe_rejects_none_key():
+ cls = _make_trigger_class(_events_then_block([]))
+ trigger = cls()
+ manager = SharedStreamManager()
+ with pytest.raises(ValueError, match="must not be None"):
+ manager.subscribe(trigger_id=1, trigger=trigger, key=None)
+
+
[email protected]
+async def test_double_subscribe_same_id_is_rejected():
+ cls = _make_trigger_class(_events_then_block([]))
+ trigger = cls()
+ manager = SharedStreamManager()
+ key = trigger.shared_stream_key()
+ try:
+ manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+ with pytest.raises(RuntimeError, match="already subscribed"):
+ manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+ finally:
+ await manager.unsubscribe(1, key)
+
+
[email protected]
+async def test_stop_all_clears_every_group():
+ cls = _make_trigger_class(_events_then_block([]))
+ a = cls(queue_url="https://a")
+ b = cls(queue_url="https://b")
+ manager = SharedStreamManager()
+
+ manager.subscribe(trigger_id=1, trigger=a, key=a.shared_stream_key())
+ manager.subscribe(trigger_id=2, trigger=b, key=b.shared_stream_key())
+ assert len(manager._groups) == 2
+
+ await manager.stop_all()
+ assert manager._groups == {}
+
+
[email protected]
+async def test_late_subscriber_after_poll_failure_gets_fresh_group():
+ """The first call's open_shared_stream raises; a subsequent subscribe for
the same key should
+ start a brand new poll rather than attach to the dead group.
+ """
+ invocations: list[int] = []
+
+ async def _open_shared_stream(cls, kwargs):
+ n = len(invocations)
+ invocations.append(n)
+ if n == 0:
+ raise RuntimeError("first invocation fails")
+ yield {"region": "us"}
+ await asyncio.Event().wait()
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ trigger = cls()
+ manager = SharedStreamManager()
+ key = trigger.shared_stream_key()
+
+ stream1 = manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+ with pytest.raises(RuntimeError, match="first invocation fails"):
+ await asyncio.wait_for(
+ _collect(trigger.filter_shared_stream(stream1), n=1),
+ timeout=1.0,
+ )
+ await manager.unsubscribe(1, key)
+
+ stream2 = manager.subscribe(trigger_id=2, trigger=trigger, key=key)
+ try:
+ events = await asyncio.wait_for(
+ _collect(trigger.filter_shared_stream(stream2), n=1),
+ timeout=1.0,
+ )
+ assert [e.payload["region"] for e in events] == ["us"]
+ finally:
+ await manager.unsubscribe(2, key)
+
+ assert invocations == [0, 1], "open_shared_stream should be called twice
(failed, then fresh)"
+
+
[email protected]
+async def
test_late_subscriber_during_poll_failure_window_does_not_attach_to_dead_group():
+ """Reproduce the race the lifecycle rewrite closes: a new subscriber
arriving after _poll has
+ raised but before the original subscriber has finished propagating the
failure must see no
+ existing group and create a fresh one — otherwise it would attach to a
queue nothing will ever
+ put events on.
+ """
+ invocations: list[int] = []
+
+ async def _open_shared_stream(cls, kwargs):
+ n = len(invocations)
+ invocations.append(n)
+ if n == 0:
+ raise RuntimeError("boom")
+ yield {"region": "fresh"}
+ await asyncio.Event().wait()
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ trigger = cls()
+ manager = SharedStreamManager()
+ key = trigger.shared_stream_key()
+
+ stream1 = manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+
+ # Wait for the poll task to finish its lifecycle — including the
synchronous self-eviction in
+ # its finally block — but do NOT consume the _PollFailure from stream1
yet. This simulates the
+ # "broadcast done, subscriber not yet unwound" window described in the bug
report.
+ poll_task = manager._groups[key]._poll_task
+ assert poll_task is not None
+ with suppress(RuntimeError):
+ await poll_task
+
+ assert key not in manager._groups, (
+ "the failing poll must evict its group synchronously in _poll's
finally, so this window "
+ "is closed before any other coroutine can subscribe"
+ )
+
+ stream2 = manager.subscribe(trigger_id=2, trigger=trigger, key=key)
+ try:
+ events = await asyncio.wait_for(
+ _collect(trigger.filter_shared_stream(stream2), n=1),
+ timeout=1.0,
+ )
+ assert events[0].payload == {"region": "fresh"}
+ finally:
+ # Original subscriber still has _PollFailure waiting for it.
+ with pytest.raises(RuntimeError, match="boom"):
+ await asyncio.wait_for(
+ _collect(trigger.filter_shared_stream(stream1), n=1),
+ timeout=1.0,
+ )
+ await manager.unsubscribe(1, key)
+ await manager.unsubscribe(2, key)
+
+ assert invocations == [0, 1]
+
+
[email protected]
+async def test_resubscribe_during_last_unsubscribe_creates_fresh_group():
+ """If the last subscriber leaves and the manager is mid-``await
group.stop()``, a concurrent
+ subscribe for the same key must build a new group instead of attaching to
the dying one.
+ """
+ invocations: list[int] = []
+
+ async def _open_shared_stream(cls, kwargs):
+ n = len(invocations)
+ invocations.append(n)
+ yield {"n": n}
+ await asyncio.Event().wait()
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ trigger = cls()
+ manager = SharedStreamManager()
+ key = trigger.shared_stream_key()
+
+ stream1 = manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+ await asyncio.wait_for(
+ _collect(trigger.filter_shared_stream(stream1), n=1),
+ timeout=1.0,
+ )
+
+ unsub_task = asyncio.create_task(manager.unsubscribe(1, key))
+ # One tick: unsubscribe runs synchronously through the pop-from-_groups
step, then yields at
+ # `await group.stop()`. After this yield returns to us, _groups is already
cleared.
+ await asyncio.sleep(0)
+ assert key not in manager._groups, (
+ "manager.unsubscribe must evict the group from _groups before awaiting
stop(), so a "
+ "racing subscribe sees no group and creates a fresh one"
+ )
+
+ stream2 = manager.subscribe(trigger_id=2, trigger=trigger, key=key)
+ try:
+ events = await asyncio.wait_for(
+ _collect(trigger.filter_shared_stream(stream2), n=1),
+ timeout=1.0,
+ )
+ # Second invocation (index 1) — proves stream2 is bound to a fresh
poll, not the dying one.
+ assert events[0].payload == {"n": 1}
+ finally:
+ await unsub_task
+ await manager.unsubscribe(2, key)
+
+ assert invocations == [0, 1]
+
+
[email protected]
+async def test_open_shared_stream_returning_naturally_propagates_as_failure():
+ """A shared poll that exhausts its iterator instead of running
indefinitely would otherwise
+ leave subscribers blocked on queue.get() forever; the manager surfaces it
as an error.
+ """
+
+ async def _open_shared_stream(cls, kwargs):
+ yield {"region": "us"}
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ trigger = cls()
+ manager = SharedStreamManager()
+ key = trigger.shared_stream_key()
+
+ stream = manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+ with pytest.raises(Exception, match="expected to run for the lifetime of
the group"):
+ await asyncio.wait_for(
+ _collect(trigger.filter_shared_stream(stream), n=2),
+ timeout=1.0,
+ )
+
+ assert key not in manager._groups, "natural exhaustion should evict the
group like a failure"
+ await manager.unsubscribe(1, key)
+
+
[email protected]
+async def test_slow_subscriber_overflow_fails_only_that_subscriber():
+ """A subscriber whose ``filter_shared_stream`` lags behind the upstream
cadence enough to
+ overflow its bounded queue must fail loudly with ``_SubscriberOverflow`` —
silent drops are
+ unacceptable for Asset event-driven semantics. Sibling subscribers in the
same group keep
+ receiving events.
+ """
+
+ async def _open_shared_stream(cls, kwargs):
+ for i in range(5):
+ yield {"i": i}
+ # Yield to the loop so the fast consumer gets a chance to drain;
+ # the slow consumer never runs while sleep(0) ticks pass, so its
+ # queue fills up.
+ await asyncio.sleep(0)
+ await asyncio.Event().wait()
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ slow_trigger = cls()
+ fast_trigger = cls()
+ manager = SharedStreamManager(max_subscriber_queue=2)
+ key = slow_trigger.shared_stream_key()
+
+ slow_stream = manager.subscribe(trigger_id=1, trigger=slow_trigger,
key=key)
+ fast_stream = manager.subscribe(trigger_id=2, trigger=fast_trigger,
key=key)
+
+ async def drain_fast():
+ out = []
+ async for ev in fast_trigger.filter_shared_stream(fast_stream):
+ out.append(ev)
+ if len(out) >= 5:
+ break
+ return out
+
+ # Start fast first so it drains its queue as the producer broadcasts.
+ fast_task = asyncio.create_task(drain_fast())
+
+ # Hand control back so the producer can broadcast all 5 events. The fast
+ # consumer keeps its queue around 1; the slow consumer has no task yet,
+ # so its queue fills past maxsize=2 and the overflow handler swaps the
+ # backlog for a failure sentinel.
+ fast_events = await asyncio.wait_for(fast_task, timeout=2.0)
+
+ # Slow consumer starts after the overflow; first event should be the
failure.
+ with pytest.raises(_SubscriberOverflow, match="exceeded maxsize"):
+ await asyncio.wait_for(
+ _collect(slow_trigger.filter_shared_stream(slow_stream), n=1),
+ timeout=2.0,
+ )
+
+ assert [e.payload["i"] for e in fast_events] == [0, 1, 2, 3, 4], (
+ "fast subscriber must not be affected by the slow subscriber's
overflow"
+ )
+ # The group is still alive — only the slow subscriber was failed; fast is
still subscribed.
+ assert key in manager._groups
+ assert 1 in manager._groups[key]._overflowed
+
+ await manager.unsubscribe(1, key)
+ await manager.unsubscribe(2, key)
+
+
[email protected]
+async def test_concurrent_unsubscribes_tear_down_group_cleanly():
+ """N subscribers leaving at once via concurrent ``unsubscribe`` must end
with the group fully
+ torn down and the poll task cancelled — mirrors a triggerer cancelling
many deferred tasks in
+ the same tick.
+ """
+ cls = _make_trigger_class(_events_then_block([]))
+ n_subscribers = 8
+ triggers = [cls() for _ in range(n_subscribers)]
+ key = triggers[0].shared_stream_key()
+ manager = SharedStreamManager()
+
+ for trigger_id, trigger in enumerate(triggers):
+ manager.subscribe(trigger_id=trigger_id, trigger=trigger, key=key)
+ assert len(manager._groups[key]._subscribers) == n_subscribers
+ poll_task = manager._groups[key]._poll_task
+ assert poll_task is not None
+
+ await asyncio.gather(*(manager.unsubscribe(i, key) for i in
range(n_subscribers)))
+
+ assert manager._groups == {}, "every subscriber gone means the group is
gone"
+ assert poll_task.done(), "the poll task must exit when the last subscriber
leaves"
+ assert poll_task.cancelled()
+
+
[email protected]
+async def
test_stop_all_with_blocked_consumer_does_not_inject_failure_sentinel():
+ """A consumer blocked on ``queue.get()`` when ``stop_all`` runs must not
be woken with a
+ poison sentinel. The poll task's ``CancelledError`` path explicitly skips
the terminate
+ broadcast, leaving the standard idiom — the trigger's consuming task is
cancelled separately
+ — as the only exit. Verifies the asymmetry between cancel-driven and
failure-driven teardown.
+ """
+ cls = _make_trigger_class(_events_then_block([])) # never yields;
consumer always blocks
+ trigger = cls()
+ key = trigger.shared_stream_key()
+ manager = SharedStreamManager()
+
+ stream = manager.subscribe(trigger_id=1, trigger=trigger, key=key)
+
+ async def consume():
+ async for event in trigger.filter_shared_stream(stream):
+ return event
+ return None
+
+ consumer = asyncio.create_task(consume())
+ # Let the consumer reach ``await queue.get()``.
+ await asyncio.sleep(0)
+ assert not consumer.done()
+
+ poll_task = manager._groups[key]._poll_task
+ assert poll_task is not None
+
+ await manager.stop_all()
+
+ assert manager._groups == {}
+ assert poll_task.done()
+ assert poll_task.cancelled()
+ # No sentinel was injected — the consumer is still parked on queue.get().
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.wait_for(asyncio.shield(consumer), timeout=0.05)
+
+ consumer.cancel()
+ with suppress(asyncio.CancelledError):
+ await consumer
+
+
[email protected]
+async def test_sibling_non_key_kwargs_diverge_first_subscriber_wins():
+ """Two siblings with the same ``shared_stream_key`` but divergent non-key
kwargs share the
+ group built from the **first** subscriber's kwargs. The second
subscriber's non-key kwargs are
+ silently ignored — this is the documented contract; the test locks the
behavior so any future
+ change (e.g. adding a runtime warning) is a deliberate decision rather
than a regression.
+ """
+ captured_kwargs: list[dict] = []
+
+ async def _open_shared_stream(cls, kwargs):
+ captured_kwargs.append(kwargs)
+ yield {"region": kwargs.get("region")}
+ await asyncio.Event().wait()
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ first = cls(region="us")
+ second = cls(region="eu") # same queue_url (key), different region
(non-key)
+ key = first.shared_stream_key()
+ assert key == second.shared_stream_key()
+
+ manager = SharedStreamManager()
+ try:
+ stream1 = manager.subscribe(trigger_id=1, trigger=first, key=key)
+ manager.subscribe(trigger_id=2, trigger=second, key=key)
+
+ # First subscriber accepts (region="us"); second's filter rejects
since the raw event
+ # carries the first subscriber's region. Verify by consuming from the
first subscriber.
+ events = await _collect(first.filter_shared_stream(stream1), n=1)
+ assert [e.payload for e in events] == [{"region": "us"}]
+
+ assert len(captured_kwargs) == 1, "open_shared_stream must be called
exactly once per group"
+ assert captured_kwargs[0]["region"] == "us", (
+ "first subscriber's non-key kwargs become the group's kwargs"
+ )
+ finally:
+ await manager.unsubscribe(1, key)
+ await manager.unsubscribe(2, key)
+
+
[email protected]
+async def test_serialize_failure_in_subscribe_leaves_groups_clean():
+ """If ``trigger.serialize()`` raises while a fresh group is being built,
``subscribe`` must
+ propagate the exception without leaving an orphan entry in ``_groups``. A
subsequent subscribe
+ for the same key must build a clean group.
+ """
+ cls = _make_trigger_class(_events_then_block([{"region": "us"}]))
+
+ class _BrokenSerializeTrigger(cls):
+ def serialize(self):
+ raise RuntimeError("serialize boom")
+
+ broken = _BrokenSerializeTrigger()
+ manager = SharedStreamManager()
+ key = broken.shared_stream_key()
+
+ with pytest.raises(RuntimeError, match="serialize boom"):
+ manager.subscribe(trigger_id=1, trigger=broken, key=key)
+
+ assert key not in manager._groups, "failed subscribe must not leave an
orphan group entry"
+
+ clean = cls()
+ stream = manager.subscribe(trigger_id=2, trigger=clean, key=key)
+ try:
+ events = await _collect(clean.filter_shared_stream(stream), n=1)
+ assert events[0].payload == {"region": "us"}
+ assert key in manager._groups
+ finally:
+ await manager.unsubscribe(2, key)
+
+
[email protected]
+async def
test_terminal_failure_reaches_every_subscriber_even_with_full_queues():
+ """When the shared poll raises right after a broadcast that filled every
subscriber's queue,
+ the terminal :class:`_PollFailure` sentinel must still reach all of them.
Without draining
+ each queue before the terminal ``put_nowait``, the first overflowed
subscriber would raise
+ ``QueueFull``, abort the broadcast loop, and silently strand the remaining
subscribers on
+ ``queue.get()`` forever.
+ """
+
+ async def _open_shared_stream(cls, kwargs):
+ yield {"region": "us"}
+ raise RuntimeError("upstream died")
+
+ cls = _make_trigger_class(classmethod(_open_shared_stream))
+ first = cls()
+ second = cls()
+ manager = SharedStreamManager(max_subscriber_queue=1)
+ key = first.shared_stream_key()
+
+ first_stream = manager.subscribe(trigger_id=1, trigger=first, key=key)
+ second_stream = manager.subscribe(trigger_id=2, trigger=second, key=key)
+
+ # Both queues sit at maxsize=1 with the broadcast event unread when the
+ # terminal _PollFailure goes out. The fix must drain each queue so the
+ # sentinel lands; both consumers should observe the same RuntimeError.
+ with pytest.raises(RuntimeError, match="upstream died"):
+ await
asyncio.wait_for(_collect(first.filter_shared_stream(first_stream), n=2),
timeout=2.0)
+ with pytest.raises(RuntimeError, match="upstream died"):
+ await
asyncio.wait_for(_collect(second.filter_shared_stream(second_stream), n=2),
timeout=2.0)
+
+ await manager.unsubscribe(1, key)
+ await manager.unsubscribe(2, key)
+
+
[email protected]
+async def
test_fail_overflowed_subscriber_drains_full_queue_before_putting_sentinel():
+ """``_fail_overflowed_subscriber`` must drain the backlog *before* placing
the
+ failure sentinel, not after.
+
+ White-box invariant: given a queue already at capacity, calling
+ ``_fail_overflowed_subscriber`` must leave exactly one item in the queue —
+ the :class:`_PollFailure` wrapping a :class:`_SubscriberOverflow` —
regardless
+ of how many stale events were sitting there beforehand.
+
+ If the drain loop were moved to *after* the ``put_nowait``, the put would
+ raise :exc:`asyncio.QueueFull` before any draining occurred and the
+ subscriber would never receive its failure sentinel.
+ """
+ import structlog
+
+ cap = 3
+ queue: asyncio.Queue = asyncio.Queue(maxsize=cap)
+ # Pre-fill the queue to capacity with stale events.
+ for i in range(cap):
+ queue.put_nowait({"stale": i})
+
+ assert queue.full(), "pre-condition: queue must be full before the call"
+
+ group = _SharedStreamGroup(
+ key="test-key",
+ trigger_class=_ProgrammableSharedStreamTrigger,
+ kwargs={},
+ on_poll_terminate=lambda g: None,
+ max_subscriber_queue=cap,
+ log=structlog.get_logger("test"),
+ )
+ trigger_id = 42
+ group._subscribers[trigger_id] = queue
+
+ group._fail_overflowed_subscriber(trigger_id, queue)
+
+ # Post-conditions that pin the drain-before-put ordering:
+ assert queue.qsize() == 1, "exactly one item must remain: the failure
sentinel"
+ sentinel = queue.get_nowait()
+ assert isinstance(sentinel, _PollFailure), "sentinel must be a
_PollFailure"
+ assert isinstance(sentinel.exc, _SubscriberOverflow), "the wrapped
exception must be _SubscriberOverflow"
+ assert trigger_id in group._overflowed, "trigger_id must be recorded in
_overflowed"
diff --git a/providers/standard/src/airflow/providers/standard/triggers/file.py
b/providers/standard/src/airflow/providers/standard/triggers/file.py
index 699be775ffa..8c7b8894c26 100644
--- a/providers/standard/src/airflow/providers/standard/triggers/file.py
+++ b/providers/standard/src/airflow/providers/standard/triggers/file.py
@@ -18,8 +18,9 @@ from __future__ import annotations
import asyncio
import datetime
+import logging
import os
-from collections.abc import AsyncIterator
+from collections.abc import AsyncIterator, Hashable
from glob import glob
from typing import Any
@@ -36,6 +37,8 @@ else:
TriggerEvent,
)
+log = logging.getLogger(__name__)
+
class FileTrigger(BaseTrigger):
"""
@@ -132,3 +135,112 @@ class FileDeleteTrigger(BaseEventTrigger):
yield TriggerEvent(True)
return
await asyncio.sleep(self.poke_interval)
+
+
+class DirectoryFileDeleteTrigger(BaseEventTrigger):
+ """
+ Fire once when ``filename`` appears in ``directory``, then delete it.
+
+ Functionally equivalent to ``FileDeleteTrigger`` for a single file, but
+ sibling triggers that point at the same ``directory`` and ``poke_interval``
+ share a single underlying directory scan in the triggerer; each instance
+ only fires for its own ``filename``. This is useful when many assets are
+ driven by per-flag-file events landing in a shared inbox directory.
+
+ :param directory: Directory to scan.
+ :param filename: File name (without directory) whose appearance fires this
+ trigger. The matched file is deleted before the event is yielded.
+ :param poke_interval: Time to wait between scans.
+ """
+
+ def __init__(self, *, directory: str, filename: str, poke_interval: float
= 5.0) -> None:
+ super().__init__()
+ self.directory = directory
+ self.filename = filename
+ self.poke_interval = poke_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize DirectoryFileDeleteTrigger arguments and classpath."""
+ return (
+
"airflow.providers.standard.triggers.file.DirectoryFileDeleteTrigger",
+ {
+ "directory": self.directory,
+ "filename": self.filename,
+ "poke_interval": self.poke_interval,
+ },
+ )
+
+ def shared_stream_key(self) -> Hashable | None:
+ """All triggers on the same directory + cadence share one scan."""
+ # Use realpath so trivial path variants all resolve to the same
canonical
+ # path: trailing slashes (``/tmp/flags`` vs ``/tmp/flags/``), relative
vs
+ # absolute paths (``./flags`` vs ``/tmp/flags``), and symlinks vs their
+ # targets all key to the same group instead of running N independent
scans.
+ return ("directory-scan", os.path.realpath(self.directory),
self.poke_interval)
+
+ @classmethod
+ async def open_shared_stream(cls, kwargs: dict[str, Any]) ->
AsyncIterator[Any]:
+ """
+ Drive one directory-listing loop and broadcast each snapshot.
+
+ Missing directories yield an empty snapshot so subscribers keep
+ polling for the file to appear. Configuration-class failures
+ (``PermissionError``, ``NotADirectoryError``, ``IsADirectoryError``)
+ propagate — these are almost always permanent (wrong mount, wrong
+ mode, path points at a file), so silently retrying just hides the
+ misconfiguration from the operator; surfacing them as a
+ ``_PollFailure`` makes the trigger visibly fail in the UI, where it
+ can be diagnosed and restarted after the operator corrects the
+ config. Other ``OSError`` subclasses (transient I/O blips, NFS
+ hiccups, etc.) are logged at warning and the snapshot is skipped for
+ this cadence, since those may self-heal.
+ """
+ directory = anyio.Path(kwargs["directory"])
+ poke_interval: float = kwargs["poke_interval"]
+ while True:
+ try:
+ names = {p.name async for p in directory.iterdir()}
+ except FileNotFoundError:
+ names = set()
+ except (PermissionError, NotADirectoryError, IsADirectoryError):
+ raise
+ except OSError:
+ log.warning(
+ "Failed to list %s; retrying after %ss",
+ directory,
+ poke_interval,
+ exc_info=True,
+ )
+ await asyncio.sleep(poke_interval)
+ continue
+ yield {"directory": str(directory), "names": names}
+ await asyncio.sleep(poke_interval)
+
+ async def filter_shared_stream(self, shared_stream: AsyncIterator[Any]) ->
AsyncIterator[TriggerEvent]:
+ """Fire once for this instance's own filename and delete the file."""
+ async for snapshot in shared_stream:
+ if self.filename not in snapshot["names"]:
+ continue
+ filepath = anyio.Path(snapshot["directory"]) / self.filename
+ try:
+ await filepath.unlink()
+ except FileNotFoundError:
+ # Lost a race with a sibling, or the file disappeared between
+ # snapshot and unlink. Wait for the next scan.
+ continue
+ self.log.info("File %s has been deleted", filepath)
+ yield TriggerEvent({"filepath": str(filepath)})
+ return
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """
+ Standalone fallback when the shared-stream manager is unavailable.
+
+ Mirrors the shared path so the trigger remains usable in unit tests
+ and on Airflow versions without the manager wired in. It does not
+ deduplicate I/O — that requires the triggerer to drive the shared
+ stream.
+ """
+ kwargs = self.serialize()[1]
+ async for event in
self.filter_shared_stream(type(self).open_shared_stream(kwargs)):
+ yield event
diff --git a/providers/standard/tests/unit/standard/triggers/test_file.py
b/providers/standard/tests/unit/standard/triggers/test_file.py
index 793f0aeb628..309e922bbfe 100644
--- a/providers/standard/tests/unit/standard/triggers/test_file.py
+++ b/providers/standard/tests/unit/standard/triggers/test_file.py
@@ -21,7 +21,11 @@ import asyncio
import anyio
import pytest
-from airflow.providers.standard.triggers.file import FileDeleteTrigger,
FileTrigger
+from airflow.providers.standard.triggers.file import (
+ DirectoryFileDeleteTrigger,
+ FileDeleteTrigger,
+ FileTrigger,
+)
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
@@ -106,3 +110,208 @@ class TestFileDeleteTrigger:
# returns, so once the task is done, the file is guaranteed gone.
await asyncio.wait_for(task, timeout=5.0)
assert await anyio.Path(p).exists() is False
+
+
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0")
+class TestDirectoryFileDeleteTrigger:
+ DIRECTORY = "/data/flags"
+
+ def test_serialization(self):
+ trigger = DirectoryFileDeleteTrigger(
+ directory=self.DIRECTORY, filename="orders_us.flag",
poke_interval=5
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.standard.triggers.file.DirectoryFileDeleteTrigger"
+ assert kwargs == {
+ "directory": self.DIRECTORY,
+ "filename": "orders_us.flag",
+ "poke_interval": 5,
+ }
+
+ def test_shared_stream_key_groups_same_directory_and_cadence(self):
+ a = DirectoryFileDeleteTrigger(directory=self.DIRECTORY,
filename="us.flag", poke_interval=1.0)
+ b = DirectoryFileDeleteTrigger(directory=self.DIRECTORY,
filename="eu.flag", poke_interval=1.0)
+ c = DirectoryFileDeleteTrigger(directory=self.DIRECTORY,
filename="us.flag", poke_interval=2.0)
+ d = DirectoryFileDeleteTrigger(directory="/other", filename="us.flag",
poke_interval=1.0)
+
+ assert a.shared_stream_key() == b.shared_stream_key()
+ assert a.shared_stream_key() != c.shared_stream_key()
+ assert a.shared_stream_key() != d.shared_stream_key()
+
+ @pytest.mark.parametrize(
+ ("first", "second"),
+ [
+ ("/data/flags", "/data/flags/"),
+ ("/data/flags", "/data//flags"),
+ ("/data/flags", "/data/./flags"),
+ ("/data/parent/../flags", "/data/flags"),
+ ],
+ )
+ def test_shared_stream_key_normalises_trivial_path_variants(self, first,
second):
+ a = DirectoryFileDeleteTrigger(directory=first, filename="us.flag",
poke_interval=1.0)
+ b = DirectoryFileDeleteTrigger(directory=second, filename="us.flag",
poke_interval=1.0)
+ assert a.shared_stream_key() == b.shared_stream_key()
+
+ def test_shared_stream_key_realpath_trailing_slash(self, tmp_path):
+ """Trailing slash variant keys to the same group as the plain path."""
+ real_dir = str(tmp_path / "flags")
+ a = DirectoryFileDeleteTrigger(directory=real_dir, filename="f",
poke_interval=1.0)
+ b = DirectoryFileDeleteTrigger(directory=real_dir + "/", filename="f",
poke_interval=1.0)
+ assert a.shared_stream_key() == b.shared_stream_key()
+
+ def test_shared_stream_key_realpath_relative_vs_absolute(self, tmp_path,
monkeypatch):
+ """A relative path resolves to the same key as its absolute
equivalent."""
+ monkeypatch.chdir(tmp_path)
+ a = DirectoryFileDeleteTrigger(directory=".", filename="f",
poke_interval=1.0)
+ b = DirectoryFileDeleteTrigger(directory=str(tmp_path), filename="f",
poke_interval=1.0)
+ assert a.shared_stream_key() == b.shared_stream_key()
+
+ @pytest.mark.skipif(
+ not hasattr(__import__("os"), "symlink"),
+ reason="symlinks not supported on this platform",
+ )
+ def test_shared_stream_key_realpath_symlink_vs_target(self, tmp_path):
+ """A symlink and its target resolve to the same key."""
+
+ real_dir = tmp_path / "real"
+ real_dir.mkdir()
+ link_dir = tmp_path / "link"
+ link_dir.symlink_to(real_dir)
+ a = DirectoryFileDeleteTrigger(directory=str(real_dir), filename="f",
poke_interval=1.0)
+ b = DirectoryFileDeleteTrigger(directory=str(link_dir), filename="f",
poke_interval=1.0)
+ assert a.shared_stream_key() == b.shared_stream_key()
+
+ @pytest.mark.asyncio
+ async def test_filter_shared_stream_fires_only_for_own_filename(self,
tmp_path):
+ directory = tmp_path / "flags"
+ await anyio.Path(directory).mkdir()
+ await (anyio.Path(directory) / "us.flag").touch()
+
+ async def stream():
+ yield {"directory": str(directory), "names": {"us.flag",
"eu.flag"}}
+
+ us = DirectoryFileDeleteTrigger(directory=str(directory),
filename="us.flag", poke_interval=1.0)
+ events = [event async for event in us.filter_shared_stream(stream())]
+
+ assert len(events) == 1
+ assert events[0].payload == {"filepath": str(directory / "us.flag")}
+ assert await (anyio.Path(directory) / "us.flag").exists() is False
+
+ @pytest.mark.asyncio
+ async def test_filter_shared_stream_skips_other_filenames(self, tmp_path):
+ directory = tmp_path / "flags"
+ await anyio.Path(directory).mkdir()
+ await (anyio.Path(directory) / "eu.flag").touch()
+
+ async def stream():
+ yield {"directory": str(directory), "names": {"eu.flag"}}
+
+ us = DirectoryFileDeleteTrigger(directory=str(directory),
filename="us.flag", poke_interval=1.0)
+ events = [event async for event in us.filter_shared_stream(stream())]
+
+ # Did not fire, did not delete the unrelated file.
+ assert events == []
+ assert await (anyio.Path(directory) / "eu.flag").exists() is True
+
+ @pytest.mark.asyncio
+ async def
test_filter_shared_stream_recovers_when_sibling_unlinks_first(self, tmp_path):
+ directory = tmp_path / "flags"
+ await anyio.Path(directory).mkdir()
+
+ async def stream():
+ # Snapshot says the file is there; in reality a sibling already
+ # consumed it, so unlink raises FileNotFoundError. We must keep
+ # iterating, not crash. After the snapshot drops the filename,
+ # we exit the iterator without firing.
+ yield {"directory": str(directory), "names": {"us.flag"}}
+ yield {"directory": str(directory), "names": set()}
+
+ us = DirectoryFileDeleteTrigger(directory=str(directory),
filename="us.flag", poke_interval=1.0)
+ events = [event async for event in us.filter_shared_stream(stream())]
+
+ assert events == []
+
+ @pytest.mark.asyncio
+ async def test_open_shared_stream_handles_missing_directory(self,
tmp_path):
+ missing = tmp_path / "does_not_exist"
+ snapshots = []
+
+ async def consume():
+ it = DirectoryFileDeleteTrigger.open_shared_stream(
+ {"directory": str(missing), "poke_interval": 0.01}
+ ).__aiter__()
+ for _ in range(2):
+ snapshots.append(await it.__anext__())
+
+ await asyncio.wait_for(consume(), timeout=1.0)
+
+ assert all(s["names"] == set() for s in snapshots)
+
+ @pytest.mark.parametrize(
+ "exc_cls",
+ [PermissionError, NotADirectoryError, IsADirectoryError],
+ )
+ @pytest.mark.asyncio
+ async def test_open_shared_stream_raises_on_config_bug_oserror(self,
mocker, tmp_path, exc_cls):
+ """PermissionError, NotADirectoryError, and IsADirectoryError must
propagate rather than spin."""
+
+ async def _iterdir(self):
+ raise exc_cls("config bug")
+ if False:
+ yield # pragma: no cover - sentinel for async generator typing
+
+ mocker.patch.object(anyio.Path, "iterdir", _iterdir)
+
+ directory = tmp_path / "flags"
+ gen = DirectoryFileDeleteTrigger.open_shared_stream(
+ {"directory": str(directory), "poke_interval": 0.01}
+ )
+ with pytest.raises(exc_cls):
+ await gen.__anext__()
+
+ @pytest.mark.asyncio
+ async def test_open_shared_stream_swallows_transient_oserror(self,
tmp_path, mocker):
+ """A generic OSError is logged and retried; the snapshot from the next
call is yielded."""
+ call_count = 0
+
+ async def _iterdir(self):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ raise OSError("transient blip")
+ if False:
+ yield # pragma: no cover - sentinel for async generator typing
+
+ mocker.patch.object(anyio.Path, "iterdir", _iterdir)
+
+ async def _noop_sleep(_duration):
+ pass
+
+ mocker.patch("asyncio.sleep", side_effect=_noop_sleep)
+
+ directory = tmp_path / "flags"
+ gen = DirectoryFileDeleteTrigger.open_shared_stream(
+ {"directory": str(directory), "poke_interval": 0.01}
+ )
+ snapshot = await gen.__anext__()
+
+ assert snapshot == {"directory": str(directory), "names": set()}
+ assert call_count == 2
+
+ @pytest.mark.asyncio
+ async def test_run_standalone_fallback_polls_until_filename_appears(self,
tmp_path):
+ directory = tmp_path / "flags"
+ await anyio.Path(directory).mkdir()
+ target = anyio.Path(directory) / "us.flag"
+
+ trigger = DirectoryFileDeleteTrigger(directory=str(directory),
filename="us.flag", poke_interval=0.05)
+ task = asyncio.create_task(trigger.run().__anext__())
+
+ await asyncio.sleep(0.2)
+ assert task.done() is False
+
+ await target.touch()
+ event = await asyncio.wait_for(task, timeout=1.0)
+
+ assert event.payload == {"filepath": str(target)}
+ assert await target.exists() is False