kaxil commented on code in PR #60804:
URL: https://github.com/apache/airflow/pull/60804#discussion_r3083386785
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -39,50 +44,122 @@
class DBDagBag:
"""
- Internal class for retrieving and caching dags in the scheduler.
+ Internal class for retrieving dags from the database.
+
+ Optionally supports LRU+TTL caching when cache_size is provided.
+ The scheduler uses this without caching, while the API server can
+ enable caching via configuration.
:meta private:
"""
- def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to
dag
- self.load_op_links = load_op_links
+ def __init__(
+ self,
+ load_op_links: bool = True,
+ cache_size: int | None = None,
+ cache_ttl: int | None = None,
+ ) -> None:
+ """
+ Initialize DBDagBag.
- def _read_dag(self, serialized_dag_model: SerializedDagModel) ->
SerializedDAG | None:
- serialized_dag_model.load_op_links = self.load_op_links
- if dag := serialized_dag_model.dag:
- self._dags[serialized_dag_model.dag_version_id] =
serialized_dag_model
+ :param load_op_links: Should the extra operator link be loaded when
de-serializing the DAG?
+ :param cache_size: Size of LRU cache. If None or 0, uses unbounded
dict (no eviction).
+ :param cache_ttl: Time-to-live for cache entries in seconds. If None
or 0, no TTL (LRU only).
+ """
+ self.load_op_links = load_op_links
+ self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+ self._dag_models: dict[UUID | str, SerializedDagModel] = {}
+ self._use_cache = False
+
+ # Initialize bounded cache if cache_size is provided and > 0
+ if cache_size and cache_size > 0:
+ if cache_ttl and cache_ttl > 0:
+ self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+ else:
+ self._dags = LRUCache(maxsize=cache_size)
+ self._use_cache = True
+
+ # Lock required for bounded caches: cachetools caches are NOT
thread-safe
+ # (LRU reordering and TTL cleanup mutate internal linked lists).
+ # nullcontext for unbounded dict avoids lock overhead in the scheduler
path.
+ self._lock: RLock | nullcontext = RLock() if self._use_cache else
nullcontext()
+
+ def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
+ """Read and optionally cache a SerializedDAG from a
SerializedDagModel."""
+ serdag.load_op_links = self.load_op_links
+ dag = serdag.dag
+ if not dag:
+ return None
+ with self._lock:
+ self._dags[serdag.dag_version_id] = dag
+ cache_size = len(self._dags)
+ if self._use_cache:
+ Stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1)
return dag
- def get_serialized_dag_model(self, version_id: UUID, session: Session) ->
SerializedDagModel | None:
+ def _get_dag(self, version_id: UUID | str, session: Session) ->
SerializedDAG | None:
+ # Check cache first
+ with self._lock:
+ dag = self._dags.get(version_id)
+
+ if dag:
+ if self._use_cache:
+ Stats.incr("api_server.dag_bag.cache_hit")
+ return dag
+
+ dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
+ if not dag_version:
+ return None
+ if not (serdag := dag_version.serialized_dag):
+ return None
+
+ # Double-checked locking: another thread may have cached it while we
queried DB.
+ # Only emit the miss metric after confirming no other thread cached
it, to avoid
+ # counting a single lookup as both a miss and a hit.
+ if self._use_cache:
+ with self._lock:
+ if dag := self._dags.get(version_id):
+ Stats.incr("api_server.dag_bag.cache_hit")
+ return dag
+ Stats.incr("api_server.dag_bag.cache_miss")
+ return self._read_dag(serdag)
+
+ def get_dag(self, version_id: UUID | str, session: Session) ->
SerializedDAG | None:
+ """Get a dag by its version id, using cache if enabled."""
+ return self._get_dag(version_id=version_id, session=session)
+
+ def get_serialized_dag_model(self, version_id: UUID | str, session:
Session) -> SerializedDagModel | None:
"""
Return the SerializedDagModel for a given dag version id.
- This will first consult the in-memory cache keyed by the dag version
id. If the
- model is not cached, the database is queried for a corresponding
:class:`DagVersion`
- and its associated :class:`SerializedDagModel`.
+ Uses a separate plain dict cache (not the LRU/TTL cache, which stores
+ deserialized SerializedDAG objects). The triggerer needs the full model
+ for ``serialized_dag_model.data``.
+ """
+ if serdag := self._dag_models.get(version_id):
Review Comment:
Good catch. Removed `_dag_models` entirely. `get_serialized_dag_model()` now
always queries the DB. The only production caller is the triggerer, which
creates a fresh `DBDagBag()` per batch, so the within-batch deduplication was
marginal. An unbounded dict in a PR fixing unbounded memory growth was a
contradiction.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]