This is an automated email from the ASF dual-hosted git repository.
jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4ee07a9a1cf Fix: Use get instead of hasattr for task_result in
BulkStateFetcher (#52839)
4ee07a9a1cf is described below
commit 4ee07a9a1cf0b57528d5c411d13fdc8d155aebdb
Author: Wei-Yu Chen <[email protected]>
AuthorDate: Sat Sep 20 00:32:29 2025 -0400
Fix: Use get instead of hasattr for task_result in BulkStateFetcher (#52839)
* Fix: Use get instead of hasattr for task_result in BulkStateFetcher
* add type annotation for task_results_by_task_id
* add type annotation for params in methods of BulkStateFetcher
* retain type annotation only in param level
* add mock value for sync_parallelism in test
---
.../celery/executors/celery_executor_utils.py | 26 +++++++++++++---------
.../integration/celery/test_celery_executor.py | 6 ++---
2 files changed, 19 insertions(+), 13 deletions(-)
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
index 76202dd139f..8ccbc9b56dd 100644
---
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
+++
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
@@ -30,7 +30,7 @@ import subprocess
import sys
import traceback
import warnings
-from collections.abc import Mapping, MutableMapping, Sequence
+from collections.abc import Collection, Mapping, MutableMapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from typing import TYPE_CHECKING, Any
@@ -323,14 +323,14 @@ class BulkStateFetcher(LoggingMixin):
Otherwise, multiprocessing.Pool will be used. Each task status will be
downloaded individually.
"""
- def __init__(self, sync_parallelism=None):
+ def __init__(self, sync_parallelism: int):
super().__init__()
self._sync_parallelism = sync_parallelism
- def _tasks_list_to_task_ids(self, async_tasks) -> set[str]:
+ def _tasks_list_to_task_ids(self, async_tasks: Collection[AsyncResult]) ->
set[str]:
return {a.task_id for a in async_tasks}
- def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
+ def get_many(self, async_results: Collection[AsyncResult]) -> Mapping[str,
EventBufferValueType]:
"""Get status for many Celery tasks using the best method available."""
if isinstance(app.backend, BaseKeyValueStoreBackend):
result = self._get_many_from_kv_backend(async_results)
@@ -341,7 +341,9 @@ class BulkStateFetcher(LoggingMixin):
self.log.debug("Fetched %d state(s) for %d task(s)", len(result),
len(async_results))
return result
- def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str,
EventBufferValueType]:
+ def _get_many_from_kv_backend(
+ self, async_tasks: Collection[AsyncResult]
+ ) -> Mapping[str, EventBufferValueType]:
task_ids = self._tasks_list_to_task_ids(async_tasks)
keys = [app.backend.get_key_for_task(k) for k in task_ids]
values = app.backend.mget(keys)
@@ -351,13 +353,15 @@ class BulkStateFetcher(LoggingMixin):
return self._prepare_state_and_info_by_task_dict(task_ids,
task_results_by_task_id)
@retry
- def _query_task_cls_from_db_backend(self, task_ids, **kwargs):
+ def _query_task_cls_from_db_backend(self, task_ids: set[str], **kwargs):
session = app.backend.ResultSession()
task_cls = getattr(app.backend, "task_cls", TaskDb)
with session_cleanup(session):
return
session.scalars(select(task_cls).where(task_cls.task_id.in_(task_ids))).all()
- def _get_many_from_db_backend(self, async_tasks) -> Mapping[str,
EventBufferValueType]:
+ def _get_many_from_db_backend(
+ self, async_tasks: Collection[AsyncResult]
+ ) -> Mapping[str, EventBufferValueType]:
task_ids = self._tasks_list_to_task_ids(async_tasks)
tasks = self._query_task_cls_from_db_backend(task_ids)
task_results = [app.backend.meta_from_decoded(task.to_dict()) for task
in tasks]
@@ -367,21 +371,23 @@ class BulkStateFetcher(LoggingMixin):
@staticmethod
def _prepare_state_and_info_by_task_dict(
- task_ids, task_results_by_task_id
+ task_ids: set[str], task_results_by_task_id: dict[str, dict[str, Any]]
) -> Mapping[str, EventBufferValueType]:
state_info: MutableMapping[str, EventBufferValueType] = {}
for task_id in task_ids:
task_result = task_results_by_task_id.get(task_id)
if task_result:
state = task_result["status"]
- info = None if not hasattr(task_result, "info") else
task_result["info"]
+ info = task_result.get("info")
else:
state = celery_states.PENDING
info = None
state_info[task_id] = state, info
return state_info
- def _get_many_using_multiprocessing(self, async_results) -> Mapping[str,
EventBufferValueType]:
+ def _get_many_using_multiprocessing(
+ self, async_results: Collection[AsyncResult]
+ ) -> Mapping[str, EventBufferValueType]:
num_process = min(len(async_results), self._sync_parallelism)
with ProcessPoolExecutor(max_workers=num_process) as sync_pool:
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 874fbce3cbc..d16a3435810 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -335,7 +335,7 @@ class TestBulkStateFetcher:
"airflow.providers.celery.executors.celery_executor_utils.Celery.backend",
mock_backend
):
caplog.clear()
- fetcher = celery_executor_utils.BulkStateFetcher()
+ fetcher = celery_executor_utils.BulkStateFetcher(1)
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
@@ -367,7 +367,7 @@ class TestBulkStateFetcher:
mock.MagicMock(**{"to_dict.return_value": {"status":
"SUCCESS", "task_id": "123"}})
]
- fetcher = celery_executor_utils.BulkStateFetcher()
+ fetcher = celery_executor_utils.BulkStateFetcher(1)
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
@@ -401,7 +401,7 @@ class TestBulkStateFetcher:
mock_retry_db_result.return_value,
]
- fetcher = celery_executor_utils.BulkStateFetcher()
+ fetcher = celery_executor_utils.BulkStateFetcher(1)
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),