This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new faf847c59d0 Add support for async callables in ``PythonOperator``
(#60268)
faf847c59d0 is described below
commit faf847c59d0cfdcc58851a45e0a80aa72bed93d0
Author: David Blain <[email protected]>
AuthorDate: Thu Jan 15 19:06:59 2026 +0100
Add support for async callables in ``PythonOperator`` (#60268)
This PR is related to the discussion I started on the
[devlist](https://lists.apache.org/thread/ztnfsqolow4v1zsv4pkpnxc1fk0hbf2p) and
which allows you to natively execute async code on PythonOperators.
There is also an AIP for this:
https://cwiki.apache.org/confluence/display/AIRFLOW/%5BWIP%5D+AIP-98%3A+Rethinking+deferrable+operators%2C+async+hooks+and+performance+in+Airflow+3
Below an example which show you how it can be used with async hooks:
```
@task(show_return_value_in_logs=False)
async def load_xml_files(files):
import asyncio
from io import BytesIO
from more_itertools import chunked
from os import cpu_count
from tenacity import retry, stop_after_attempt, wait_fixed
from airflow.providers.sftp.hooks.sftp import SFTPClientPool
print("number of files:", len(files))
async with SFTPClientPool(sftp_conn_id=sftp_conn,
pool_size=cpu_count()) as pool:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
async def download_file(file):
async with pool.get_sftp_client() as sftp:
print("downloading:", file)
buffer = BytesIO()
async with sftp.open(file, encoding=xml_encoding) as
remote_file:
data = await remote_file.read()
buffer.write(data.encode(xml_encoding))
buffer.seek(0)
return buffer
for batch in chunked(files, cpu_count() * 2):
tasks = [asyncio.create_task(download_file(f)) for f in batch]
# Wait for this batch to finish before starting the next
for task in asyncio.as_completed(tasks):
result = await task
# Do something with result or accumulate it and return it
as an XCom
```
This PR will fix additional remarks made by @kaxil on the original
[PR](https://github.com/apache/airflow/pull/59087) which has been reverted.
---
airflow-core/newsfragments/60268.improvement.rst | 1 +
.../src/airflow/jobs/triggerer_job_runner.py | 4 +-
.../core_api/routes/public/test_task_instances.py | 58 ++++-----
.../providers/common/compat/standard/operators.py | 56 ++++++++
.../providers/common/compat/version_compat.py | 2 +
providers/standard/docs/operators/python.rst | 31 +++++
providers/standard/pyproject.toml | 2 +-
.../example_dags/example_python_decorator.py | 17 +++
.../example_dags/example_python_operator.py | 18 +++
.../airflow/providers/standard/operators/python.py | 65 +++++++++-
.../tests/unit/standard/decorators/test_python.py | 36 +++++-
.../tests/unit/standard/operators/test_python.py | 25 +++-
task-sdk/docs/api.rst | 4 +-
task-sdk/docs/index.rst | 1 +
task-sdk/src/airflow/sdk/__init__.py | 10 +-
task-sdk/src/airflow/sdk/__init__.pyi | 2 +
task-sdk/src/airflow/sdk/bases/decorator.py | 49 ++++++-
task-sdk/src/airflow/sdk/bases/operator.py | 55 +++++++-
.../sdk/definitions/_internal/abstractoperator.py | 4 +
.../airflow/sdk/execution_time/callback_runner.py | 67 +++++++++-
task-sdk/src/airflow/sdk/execution_time/comms.py | 72 ++++++++---
task-sdk/tests/task_sdk/bases/test_decorator.py | 144 ++++++++++++++++++++-
22 files changed, 653 insertions(+), 70 deletions(-)
diff --git a/airflow-core/newsfragments/60268.improvement.rst
b/airflow-core/newsfragments/60268.improvement.rst
new file mode 100644
index 00000000000..8c7e92b8f0d
--- /dev/null
+++ b/airflow-core/newsfragments/60268.improvement.rst
@@ -0,0 +1 @@
+The ``PythonOperator`` parameter ``python_callable`` now also supports async
callables in Airflow 3.2, allowing users to run async def functions without
manually managing an event loop.
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 532b0faf1de..2508e4a281a 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -789,8 +789,6 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner,
ToTriggerSupervisor]):
factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
)
- _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
-
def _read_frame(self):
from asgiref.sync import async_to_sync
@@ -825,7 +823,7 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner,
ToTriggerSupervisor]):
frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
bytes = frame.as_bytes()
- async with self._lock:
+ async with self._async_lock:
self._async_writer.write(bytes)
return await self._aget_response(frame.id)
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 2b53501bee5..91622ce11e2 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -214,7 +214,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -372,7 +372,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -436,7 +436,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -492,7 +492,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -612,7 +612,7 @@ class TestGetMappedTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -1404,7 +1404,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
False,
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_python_operator"},
- 9, # Based on test failure - example_python_operator creates
9 task instances
+ 14, # Based on test failure - example_python_operator creates
14 task instances
3,
id="test dag_id_pattern exact match",
),
@@ -1413,7 +1413,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
False,
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_%"},
- 17, # Based on test failure - both DAGs together create 17
task instances
+ 22, # Based on test failure - both DAGs together create 22
task instances
3,
id="test dag_id_pattern wildcard prefix",
),
@@ -1927,8 +1927,8 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
[
pytest.param(
{"dag_ids": ["example_python_operator", "example_skip_dag"]},
- 17,
- 17,
+ 22,
+ 22,
id="with dag filter",
),
],
@@ -2037,7 +2037,7 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
assert len(response_batch2.json()["task_instances"]) > 0
# Match
- ti_count = 9
+ ti_count = 10
assert response_batch1.json()["total_entries"] ==
response_batch2.json()["total_entries"] == ti_count
assert (num_entries_batch1 + num_entries_batch2) == ti_count
assert response_batch1 != response_batch2
@@ -2076,7 +2076,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2122,7 +2122,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2199,7 +2199,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2271,7 +2271,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2318,7 +2318,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3162,7 +3162,7 @@ class
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3534,7 +3534,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3571,7 +3571,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3642,7 +3642,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3725,7 +3725,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3762,7 +3762,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4002,7 +4002,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4276,7 +4276,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4410,7 +4410,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4471,7 +4471,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4550,7 +4550,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4631,7 +4631,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4749,7 +4749,7 @@ class
TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -5035,7 +5035,7 @@ class
TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 9,
+ "priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
diff --git
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
index 6b77db3e4a9..866d94ac792 100644
---
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
+++
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
@@ -17,18 +17,74 @@
from __future__ import annotations
+from typing import TYPE_CHECKING
+
from airflow.providers.common.compat._compat_utils import create_module_getattr
+from airflow.providers.common.compat.version_compat import (
+ AIRFLOW_V_3_0_PLUS,
+ AIRFLOW_V_3_1_PLUS,
+ AIRFLOW_V_3_2_PLUS,
+)
_IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
# Re-export from sdk (which handles Airflow 2.x/3.x fallbacks)
"BaseOperator": "airflow.providers.common.compat.sdk",
+ "BaseAsyncOperator": "airflow.providers.common.compat.sdk",
"get_current_context": "airflow.providers.common.compat.sdk",
+ "is_async_callable": "airflow.providers.common.compat.sdk",
# Standard provider items with direct fallbacks
"PythonOperator": ("airflow.providers.standard.operators.python",
"airflow.operators.python"),
"ShortCircuitOperator": ("airflow.providers.standard.operators.python",
"airflow.operators.python"),
"_SERIALIZERS": ("airflow.providers.standard.operators.python",
"airflow.operators.python"),
}
+if TYPE_CHECKING:
+ from airflow.sdk.bases.decorator import is_async_callable
+ from airflow.sdk.bases.operator import BaseAsyncOperator
+elif AIRFLOW_V_3_2_PLUS:
+ from airflow.sdk.bases.decorator import is_async_callable
+ from airflow.sdk.bases.operator import BaseAsyncOperator
+else:
+ if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import BaseOperator
+ else:
+ from airflow.models import BaseOperator
+
+ def is_async_callable(func) -> bool:
+ """Detect if a callable is an async function."""
+ import inspect
+ from functools import partial
+
+ while isinstance(func, partial):
+ func = func.func
+ return inspect.iscoroutinefunction(func)
+
+ class BaseAsyncOperator(BaseOperator):
+ """Stub for Airflow < 3.2 that raises a clear error."""
+
+ @property
+ def is_async(self) -> bool:
+ return True
+
+ if not AIRFLOW_V_3_1_PLUS:
+
+ @property
+ def xcom_push(self) -> bool:
+ return self.do_xcom_push
+
+ @xcom_push.setter
+ def xcom_push(self, value: bool):
+ self.do_xcom_push = value
+
+ async def aexecute(self, context):
+ raise NotImplementedError()
+
+ def execute(self, context):
+ raise RuntimeError(
+ "Async operators require Airflow 3.2+. Upgrade Airflow or use
a synchronous callable."
+ )
+
+
__getattr__ = create_module_getattr(import_map=_IMPORT_MAP)
__all__ = sorted(_IMPORT_MAP.keys())
diff --git
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
index 4142937bd2a..e3fd1e55f14 100644
---
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
+++
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
@@ -34,6 +34,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
+AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0)
# BaseOperator removed from version_compat to avoid circular imports
# Import it directly in files that need it instead
@@ -41,4 +42,5 @@ AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple()
>= (3, 1, 0)
__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
+ "AIRFLOW_V_3_2_PLUS",
]
diff --git a/providers/standard/docs/operators/python.rst
b/providers/standard/docs/operators/python.rst
index 2e5e63ea437..a87762c11cd 100644
--- a/providers/standard/docs/operators/python.rst
+++ b/providers/standard/docs/operators/python.rst
@@ -72,6 +72,37 @@ Pass extra arguments to the ``@task`` decorated function as
you would with a nor
:start-after: [START howto_operator_python_kwargs]
:end-before: [END howto_operator_python_kwargs]
+Async Python functions
+^^^^^^^^^^^^^^^^^^^^^^
+
+.. versionadded:: 3.2
+
+Async Python callables are now also supported out of the box. This means we
don't need to cope with the event loop
+and allows us to easily invoke async Python code and async Airflow hooks which
are not always available through
+deferred operators.
+
+As opposed to deferred operators which are executed on the triggerer, async
operators are executed on the workers.
+
+.. tab-set::
+
+ .. tab-item:: @task
+ :sync: taskflow
+
+ .. exampleinclude::
/../src/airflow/providers/standard/example_dags/example_python_decorator.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_async_operator_python_kwargs]
+ :end-before: [END howto_async_operator_python_kwargs]
+
+ .. tab-item:: PythonOperator
+ :sync: operator
+
+ .. exampleinclude::
/../src/airflow/providers/standard/example_dags/example_python_operator.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_async_operator_python_kwargs]
+ :end-before: [END howto_async_operator_python_kwargs]
+
Templating
^^^^^^^^^^
diff --git a/providers/standard/pyproject.toml
b/providers/standard/pyproject.toml
index 05635ff6a0d..5ad02e5efcf 100644
--- a/providers/standard/pyproject.toml
+++ b/providers/standard/pyproject.toml
@@ -59,7 +59,7 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with
``breeze ci-image build``
dependencies = [
"apache-airflow>=2.11.0",
- "apache-airflow-providers-common-compat>=1.12.0",
+ "apache-airflow-providers-common-compat>=1.12.0", # use next version
]
# The optional dependencies should be modified in place in the generated file
diff --git
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
index ac9938d92ea..086d67f831c 100644
---
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
+++
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
@@ -22,6 +22,7 @@ virtual environment.
from __future__ import annotations
+import asyncio
import logging
import sys
import time
@@ -75,6 +76,22 @@ def example_python_decorator():
run_this >> log_the_sql >> sleeping_task
# [END howto_operator_python_kwargs]
+ # [START howto_async_operator_python_kwargs]
+ # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
+ # Asynchronous callables are natively supported since Airflow 3.2+
+ @task
+ async def my_async_sleeping_function(random_base):
+ """This is a function that will run within the DAG execution"""
+ await asyncio.sleep(random_base)
+
+ for i in range(5):
+ async_sleeping_task =
my_async_sleeping_function.override(task_id=f"async_sleep_for_{i}")(
+ random_base=i / 10
+ )
+
+ run_this >> log_the_sql >> async_sleeping_task
+ # [END howto_async_operator_python_kwargs]
+
# [START howto_operator_python_venv]
@task.virtualenv(
task_id="virtualenv_python", requirements=["colorama==0.4.0"],
system_site_packages=False
diff --git
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
index 18aa8f207e3..a9378938873 100644
---
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
+++
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
@@ -22,6 +22,7 @@ within a virtual environment.
from __future__ import annotations
+import asyncio
import logging
import sys
import time
@@ -88,6 +89,23 @@ with DAG(
run_this >> log_the_sql >> sleeping_task
# [END howto_operator_python_kwargs]
+ # [START howto_async_operator_python_kwargs]
+ # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
+ # Asynchronous callables are natively supported since Airflow 3.2+
+ async def my_async_sleeping_function(random_base):
+ """This is a function that will run within the DAG execution"""
+ await asyncio.sleep(random_base)
+
+ for i in range(5):
+ async_sleeping_task = PythonOperator(
+ task_id=f"async_sleep_for_{i}",
+ python_callable=my_async_sleeping_function,
+ op_kwargs={"random_base": i / 10},
+ )
+
+ run_this >> log_the_sql >> async_sleeping_task
+ # [END howto_async_operator_python_kwargs]
+
# [START howto_operator_python_venv]
def callable_virtualenv():
"""
diff --git
a/providers/standard/src/airflow/providers/standard/operators/python.py
b/providers/standard/src/airflow/providers/standard/operators/python.py
index ac8862f2923..7e6bbd166a6 100644
--- a/providers/standard/src/airflow/providers/standard/operators/python.py
+++ b/providers/standard/src/airflow/providers/standard/operators/python.py
@@ -48,13 +48,17 @@ from airflow.exceptions import (
)
from airflow.models.variable import Variable
from airflow.providers.common.compat.sdk import AirflowException,
AirflowSkipException, context_merge
+from airflow.providers.common.compat.standard.operators import (
+ BaseAsyncOperator,
+ is_async_callable,
+)
from airflow.providers.standard.hooks.package_index import PackageIndexHook
from airflow.providers.standard.utils.python_virtualenv import (
_execute_in_subprocess,
prepare_virtualenv,
write_python_script,
)
-from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS,
BaseOperator
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_2_PLUS
from airflow.utils import hashlib_wrapper
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
@@ -75,7 +79,10 @@ if TYPE_CHECKING:
from pendulum.datetime import DateTime
from airflow.providers.common.compat.sdk import Context
- from airflow.sdk.execution_time.callback_runner import
ExecutionCallableRunner
+ from airflow.sdk.execution_time.callback_runner import (
+ AsyncExecutionCallableRunner,
+ ExecutionCallableRunner,
+ )
from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol
_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
@@ -115,9 +122,9 @@ class _PythonVersionInfo(NamedTuple):
return cls(*_parse_version_info(result.strip()))
-class PythonOperator(BaseOperator):
+class PythonOperator(BaseAsyncOperator):
"""
- Executes a Python callable.
+ Base class for all Python operators.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -192,7 +199,14 @@ class PythonOperator(BaseOperator):
self.template_ext = templates_exts
self.show_return_value_in_logs = show_return_value_in_logs
- def execute(self, context: Context) -> Any:
+ @property
+ def is_async(self) -> bool:
+ return is_async_callable(self.python_callable)
+
+ def execute(self, context) -> Any:
+ if self.is_async:
+ return BaseAsyncOperator.execute(self, context)
+
context_merge(context, self.op_kwargs,
templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
@@ -236,6 +250,47 @@ class PythonOperator(BaseOperator):
runner = create_execution_runner(self.python_callable, asset_events,
logger=self.log)
return runner.run(*self.op_args, **self.op_kwargs)
+ if AIRFLOW_V_3_2_PLUS:
+
+ async def aexecute(self, context):
+ context_merge(context, self.op_kwargs,
templates_dict=self.templates_dict)
+ self.op_kwargs = self.determine_kwargs(context)
+
+ # This needs to be lazy because subclasses may implement
execute_callable
+ # by running a separate process that can't use the eager result.
+ def __prepare_execution() -> (
+ tuple[AsyncExecutionCallableRunner,
OutletEventAccessorsProtocol] | None
+ ):
+ from airflow.sdk.execution_time.callback_runner import
create_async_executable_runner
+ from airflow.sdk.execution_time.context import
context_get_outlet_events
+
+ return (
+ cast("AsyncExecutionCallableRunner",
create_async_executable_runner),
+ context_get_outlet_events(context),
+ )
+
+ self.__prepare_execution = __prepare_execution
+
+ return_value = await self.aexecute_callable()
+ if self.show_return_value_in_logs:
+ self.log.info("Done. Returned value was: %s", return_value)
+ else:
+ self.log.info("Done. Returned value not shown")
+
+ return return_value
+
+ async def aexecute_callable(self) -> Any:
+ """
+ Call the python callable with the given arguments.
+
+ :return: the return value of the call.
+ """
+ if (execution_preparation := self.__prepare_execution()) is None:
+ return await self.python_callable(*self.op_args,
**self.op_kwargs)
+ create_execution_runner, asset_events = execution_preparation
+ runner = create_execution_runner(self.python_callable,
asset_events, logger=self.log)
+ return await runner.run(*self.op_args, **self.op_kwargs)
+
class BranchPythonOperator(BaseBranchOperator, PythonOperator):
"""
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py
b/providers/standard/tests/unit/standard/decorators/test_python.py
index 32b2fc2615d..d2aa38ec544 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -37,8 +37,16 @@ from tests_common.test_utils.version_compat import (
from unit.standard.operators.test_python import BasePythonTest
if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg, setup, task
as task_decorator, teardown
- from airflow.sdk.bases.decorator import DecoratedMappedOperator
+ from airflow.sdk import (
+ DAG,
+ BaseOperator,
+ TaskGroup,
+ XComArg,
+ setup,
+ task as task_decorator,
+ teardown,
+ )
+ from airflow.sdk.bases.decorator import DecoratedMappedOperator,
_TaskDecorator
from airflow.sdk.definitions._internal.expandinput import
DictOfListsExpandInput
else:
from airflow.decorators import ( # type: ignore[attr-defined,no-redef]
@@ -46,7 +54,7 @@ else:
task as task_decorator,
teardown,
)
- from airflow.decorators.base import DecoratedMappedOperator # type:
ignore[no-redef]
+ from airflow.decorators.base import DecoratedMappedOperator,
_TaskDecorator # type: ignore[no-redef]
from airflow.models.baseoperator import BaseOperator # type:
ignore[no-redef]
from airflow.models.dag import DAG # type: ignore[assignment,no-redef]
from airflow.models.expandinput import DictOfListsExpandInput # type:
ignore[attr-defined,no-redef]
@@ -658,9 +666,9 @@ class TestAirflowTaskDecorator(BasePythonTest):
hello.override(pool="my_pool", priority_weight=i)()
weights = []
- for task in self.dag_non_serialized.tasks:
- assert task.pool == "my_pool"
- weights.append(task.priority_weight)
+ for _task in self.dag_non_serialized.tasks:
+ assert _task.pool == "my_pool"
+ weights.append(_task.priority_weight)
assert weights == [0, 1, 2]
def test_python_callable_args_work_as_well_as_baseoperator_args(self,
dag_maker):
@@ -1142,3 +1150,19 @@ def
test_teardown_trigger_rule_override_behavior(dag_maker, session):
my_teardown()
assert work_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS
assert setup_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS
+
+
+async def async_fn():
+ return 42
+
+
+def test_python_task():
+ from airflow.providers.standard.decorators.python import
_PythonDecoratedOperator, python_task
+
+ decorator = python_task(async_fn)
+
+ assert isinstance(decorator, _TaskDecorator)
+ assert decorator.function == async_fn
+ assert decorator.operator_class == _PythonDecoratedOperator
+ assert not decorator.multiple_outputs
+ assert decorator.kwargs == {"task_id": "async_fn"}
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py
b/providers/standard/tests/unit/standard/operators/test_python.py
index a59c33b29dc..3818f42cb3c 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+import asyncio
import copy
import logging
import os
@@ -43,7 +44,7 @@ from slugify import slugify
from airflow.exceptions import AirflowProviderDeprecationWarning,
DeserializingResultError
from airflow.models.connection import Connection
from airflow.models.taskinstance import TaskInstance, clear_task_instances
-from airflow.providers.common.compat.sdk import AirflowException
+from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import (
BranchExternalPythonOperator,
@@ -69,15 +70,14 @@ from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_1,
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS,
+ AIRFLOW_V_3_2_PLUS,
NOTSET,
)
if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import BaseOperator
from airflow.sdk.execution_time.context import set_current_context
from airflow.serialization.serialized_objects import LazyDeserializedDAG
else:
- from airflow.models.baseoperator import BaseOperator # type:
ignore[no-redef]
from airflow.models.taskinstance import set_current_context # type:
ignore[attr-defined,no-redef]
if TYPE_CHECKING:
@@ -2465,6 +2465,25 @@ class TestShortCircuitWithTeardown:
assert set(actual_skipped) == {op3}
+class TestPythonAsyncOperator(TestPythonOperator):
+ def test_run_async_task(self, caplog):
+ caplog.set_level(logging.INFO, logger=LOGGER_NAME)
+
+ async def say_hello(name: str) -> str:
+ await asyncio.sleep(1)
+ return f"Hello {name}!"
+
+ if AIRFLOW_V_3_2_PLUS:
+ self.run_as_task(say_hello, op_kwargs={"name": "world"},
show_return_value_in_logs=True)
+ assert "Done. Returned value was: Hello world!" in caplog.messages
+ else:
+ with pytest.raises(
+ RuntimeError,
+ match=r"Async operators require Airflow 3\.2\+\. Upgrade
Airflow or use a synchronous callable\.",
+ ):
+ self.run_as_task(say_hello, op_kwargs={"name": "world"},
show_return_value_in_logs=True)
+
+
@pytest.mark.parametrize(
("text_input", "expected_tuple"),
[
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 05476cbe562..66fe1637486 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -62,6 +62,8 @@ Task Decorators:
Bases
-----
+.. autoapiclass:: airflow.sdk.BaseAsyncOperator
+
.. autoapiclass:: airflow.sdk.BaseOperator
.. autoapiclass:: airflow.sdk.BaseSensorOperator
@@ -183,7 +185,7 @@ Everything else
.. autoapimodule:: airflow.sdk
:members:
:special-members: __version__
- :exclude-members: BaseOperator, DAG, dag, asset, Asset, AssetAlias,
AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg, get_current_context,
get_parsing_context
+ :exclude-members: BaseAsyncOperator, BaseOperator, DAG, dag, asset, Asset,
AssetAlias, AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg,
get_current_context, get_parsing_context
:undoc-members:
:imported-members:
:no-index:
diff --git a/task-sdk/docs/index.rst b/task-sdk/docs/index.rst
index 819f637676b..f3258ea8243 100644
--- a/task-sdk/docs/index.rst
+++ b/task-sdk/docs/index.rst
@@ -78,6 +78,7 @@ Why use ``airflow.sdk``?
**Classes**
- :class:`airflow.sdk.Asset`
+- :class:`airflow.sdk.BaseAsyncOperator`
- :class:`airflow.sdk.BaseHook`
- :class:`airflow.sdk.BaseNotifier`
- :class:`airflow.sdk.BaseOperator`
diff --git a/task-sdk/src/airflow/sdk/__init__.py
b/task-sdk/src/airflow/sdk/__init__.py
index 4dbda282d08..034a6379430 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -26,6 +26,7 @@ __all__ = [
"AssetAny",
"AssetOrTimeSchedule",
"AssetWatcher",
+ "BaseAsyncOperator",
"BaseHook",
"BaseNotifier",
"BaseOperator",
@@ -76,7 +77,13 @@ if TYPE_CHECKING:
from airflow.sdk.api.datamodels._generated import DagRunState,
TaskInstanceState, TriggerRule, WeightRule
from airflow.sdk.bases.hook import BaseHook
from airflow.sdk.bases.notifier import BaseNotifier
- from airflow.sdk.bases.operator import BaseOperator, chain, chain_linear,
cross_downstream
+ from airflow.sdk.bases.operator import (
+ BaseAsyncOperator,
+ BaseOperator,
+ chain,
+ chain_linear,
+ cross_downstream,
+ )
from airflow.sdk.bases.operatorlink import BaseOperatorLink
from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue
from airflow.sdk.configuration import AirflowSDKConfigParser
@@ -117,6 +124,7 @@ __lazy_imports: dict[str, str] = {
"AssetAny": ".definitions.asset",
"AssetOrTimeSchedule": ".definitions.timetables.assets",
"AssetWatcher": ".definitions.asset",
+ "BaseAsyncOperator": ".bases.operator",
"BaseHook": ".bases.hook",
"BaseNotifier": ".bases.notifier",
"BaseOperator": ".bases.operator",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi
b/task-sdk/src/airflow/sdk/__init__.pyi
index eede7ff806a..b035f49226c 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -24,6 +24,7 @@ from airflow.sdk.api.datamodels._generated import (
from airflow.sdk.bases.hook import BaseHook as BaseHook
from airflow.sdk.bases.notifier import BaseNotifier as BaseNotifier
from airflow.sdk.bases.operator import (
+ BaseAsyncOperator as BaseAsyncOperator,
BaseOperator as BaseOperator,
chain as chain,
chain_linear as chain_linear,
@@ -83,6 +84,7 @@ __all__ = [
"AssetAny",
"AssetOrTimeSchedule",
"AssetWatcher",
+ "BaseAsyncOperator",
"BaseHook",
"BaseNotifier",
"BaseOperator",
diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py
b/task-sdk/src/airflow/sdk/bases/decorator.py
index bde898c1696..131750ae8c9 100644
--- a/task-sdk/src/airflow/sdk/bases/decorator.py
+++ b/task-sdk/src/airflow/sdk/bases/decorator.py
@@ -22,7 +22,8 @@ import re
import textwrap
import warnings
from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
-from functools import cached_property, update_wrapper
+from contextlib import suppress
+from functools import cached_property, partial, update_wrapper
from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, Protocol,
TypeVar, cast, overload
import attr
@@ -149,6 +150,48 @@ def get_unique_task_id(
return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
+def unwrap_partial(fn: Callable) -> Callable:
+ while isinstance(fn, partial):
+ fn = fn.func
+ return fn
+
+
+def unwrap_callable(func):
+ from airflow.sdk.definitions.mappedoperator import OperatorPartial
+
+ if isinstance(func, (_TaskDecorator, OperatorPartial)):
+ func = getattr(func, "function", getattr(func, "_func", func))
+
+ func = unwrap_partial(func)
+
+ with suppress(Exception):
+ func = inspect.unwrap(func)
+
+ return func
+
+
+def is_async_callable(func):
+ """Detect if a callable (possibly wrapped) is an async function."""
+ func = unwrap_callable(func)
+
+ if not callable(func):
+ return False
+
+ # Direct async function
+ if inspect.iscoroutinefunction(func):
+ return True
+
+ # Callable object with async __call__
+ if not inspect.isfunction(func):
+ call = type(func).__call__ # Bandit-safe
+ with suppress(Exception):
+ call = inspect.unwrap(call)
+ if inspect.iscoroutinefunction(call):
+ return True
+
+ return False
+
+
class DecoratedOperator(BaseOperator):
"""
Wraps a Python callable and captures args/kwargs when called for execution.
@@ -243,6 +286,10 @@ class DecoratedOperator(BaseOperator):
self.op_kwargs = op_kwargs
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
+ @property
+ def is_async(self) -> bool:
+ return is_async_callable(self.python_callable)
+
def execute(self, context: Context):
# todo make this more generic (move to prepare_lineage) so it deals
with non taskflow operators
# as well
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index 0c97df00ef0..5f07a188362 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -18,13 +18,15 @@
from __future__ import annotations
import abc
+import asyncio
import collections.abc
import contextlib
import copy
import inspect
import sys
import warnings
-from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
+from asyncio import AbstractEventLoop
+from collections.abc import Callable, Collection, Generator, Iterable,
Mapping, Sequence
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@@ -101,6 +103,7 @@ if TYPE_CHECKING:
TaskPostExecuteHook = Callable[[Context, Any], None]
__all__ = [
+ "BaseAsyncOperator",
"BaseOperator",
"chain",
"chain_linear",
@@ -196,6 +199,27 @@ def coerce_resources(resources: dict[str, Any] | None) ->
Resources | None:
return Resources(**resources)
[email protected]
+def event_loop() -> Generator[AbstractEventLoop]:
+ new_event_loop = False
+ loop = None
+ try:
+ try:
+ loop = asyncio.get_event_loop()
+ if loop.is_closed():
+ raise RuntimeError
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ new_event_loop = True
+ yield loop
+ finally:
+ if new_event_loop and loop is not None:
+ with contextlib.suppress(AttributeError):
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
class _PartialDescriptor:
"""A descriptor that guards against ``.partial`` being called on Task
objects."""
@@ -1670,6 +1694,35 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return bool(self.on_skipped_callback)
+class BaseAsyncOperator(BaseOperator):
+ """
+ Base class for async-capable operators.
+
+ As opposed to deferred operators which are executed on the triggerer,
async operators are executed
+ on the worker.
+ """
+
+ @property
+ def is_async(self) -> bool:
+ return True
+
+ async def aexecute(self, context):
+ """Async version of execute(). Subclasses should implement this."""
+ raise NotImplementedError()
+
+ def execute(self, context):
+ """Run `aexecute()` inside an event loop."""
+ with event_loop() as loop:
+ if self.execution_timeout:
+ return loop.run_until_complete(
+ asyncio.wait_for(
+ self.aexecute(context),
+ timeout=self.execution_timeout.total_seconds(),
+ )
+ )
+ return loop.run_until_complete(self.aexecute(context))
+
+
def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
r"""
Given a number of tasks, builds a dependency chain.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
index 6c99a72b220..e7e5ebe8b9a 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
@@ -137,6 +137,10 @@ class AbstractOperator(Templater, DAGNode):
)
)
+ @property
+ def is_async(self) -> bool:
+ return False
+
@property
def task_type(self) -> str:
raise NotImplementedError()
diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
index 316c3d38e99..322e4bc9780 100644
--- a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
@@ -20,8 +20,8 @@ from __future__ import annotations
import inspect
import logging
-from collections.abc import Callable
-from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
+from collections.abc import AsyncIterator, Awaitable, Callable
+from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast
from typing_extensions import ParamSpec
@@ -39,6 +39,11 @@ class _ExecutionCallableRunner(Generic[P, R]):
def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type:
ignore[empty-body]
+class _AsyncExecutionCallableRunner(Generic[P, R]):
+ @staticmethod
+ async def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type:
ignore[empty-body]
+
+
class ExecutionCallableRunner(Protocol):
def __call__(
self,
@@ -49,6 +54,16 @@ class ExecutionCallableRunner(Protocol):
) -> _ExecutionCallableRunner[P, R]: ...
+class AsyncExecutionCallableRunner(Protocol):
+ def __call__(
+ self,
+ func: Callable[P, R],
+ outlet_events: OutletEventAccessorsProtocol,
+ *,
+ logger: logging.Logger | Logger,
+ ) -> _AsyncExecutionCallableRunner[P, R]: ...
+
+
def create_executable_runner(
func: Callable[P, R],
outlet_events: OutletEventAccessorsProtocol,
@@ -109,3 +124,51 @@ def create_executable_runner(
return result # noqa: F821 # Ruff is not smart enough to know
this is always set in _run().
return cast("_ExecutionCallableRunner[P, R]", _ExecutionCallableRunnerImpl)
+
+
+def create_async_executable_runner(
+ func: Callable[P, Awaitable[R] | AsyncIterator],
+ outlet_events: OutletEventAccessorsProtocol,
+ *,
+ logger: logging.Logger | Logger,
+) -> _AsyncExecutionCallableRunner[P, R]:
+ """
+ Run an async execution callable against a task context and given arguments.
+
+ If the callable is a simple function, this simply calls it with the
supplied
+ arguments (including the context). If the callable is a generator function,
+ the generator is exhausted here, with the yielded values getting fed back
+ into the task context automatically for execution.
+
+ This convoluted implementation of inner class with closure is so *all*
+ arguments passed to ``run()`` can be forwarded to the wrapped function.
This
+ is particularly important for the argument "self", which some use cases
+ need to receive. This is not possible if this is implemented as a normal
+ class, where "self" needs to point to the runner object, not the object
+ bounded to the inner callable.
+
+ :meta private:
+ """
+
+ class _AsyncExecutionCallableRunnerImpl(_AsyncExecutionCallableRunner):
+ @staticmethod
+ async def run(*args: P.args, **kwargs: P.kwargs) -> R:
+ from airflow.sdk.definitions.asset.metadata import Metadata
+
+ if not inspect.isasyncgenfunction(func):
+ result = cast("Awaitable[R]", func(*args, **kwargs))
+ return await result
+
+ results: list[Any] = []
+
+ async for result in func(*args, **kwargs):
+ if isinstance(result, Metadata):
+ outlet_events[result.asset].extra.update(result.extra)
+ if result.alias:
+ outlet_events[result.alias].add(result.asset,
extra=result.extra)
+
+ results.append(result)
+
+ return cast("R", results)
+
+ return cast("_AsyncExecutionCallableRunner[P, R]",
_AsyncExecutionCallableRunnerImpl)
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 52a96d0b665..15755e640d9 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -48,7 +48,9 @@ Execution API server is because:
from __future__ import annotations
+import asyncio
import itertools
+import threading
from collections.abc import Iterator
from datetime import datetime
from functools import cached_property
@@ -185,31 +187,69 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda:
TypeAdapter(ToTask), repr=False)
+ # Threading lock for sync operations
+ _thread_lock: threading.Lock = attrs.field(factory=threading.Lock,
repr=False)
+ # Async lock for async operations
+ _async_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
+
def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
"""Send a request to the parent and block until the response is
received."""
frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
frame_bytes = frame.as_bytes()
- self.socket.sendall(frame_bytes)
- if isinstance(msg, ResendLoggingFD):
- if recv_fds is None:
- return None
- # We need special handling here! The server can't send us the fd
number, as the number on the
- # supervisor will be different to in this process, so we have to
mutate the message ourselves here.
- frame, fds = self._read_frame(maxfds=1)
- resp = self._from_frame(frame)
- if TYPE_CHECKING:
- assert isinstance(resp, SentFDs)
- resp.fds = fds
- # Since we know this is an expliclt SendFDs, and since this class
is generic SendFDs might not
- # always be in the return type union
- return resp # type: ignore[return-value]
+ # We must make sure sockets aren't intermixed between sync and async
calls,
+ # thus we need a dual locking mechanism to ensure that.
+ with self._thread_lock:
+ self.socket.sendall(frame_bytes)
+ if isinstance(msg, ResendLoggingFD):
+ if recv_fds is None:
+ return None
+ # We need special handling here! The server can't send us the
fd number, as the number on the
+ # supervisor will be different to in this process, so we have
to mutate the message ourselves here.
+ frame, fds = self._read_frame(maxfds=1)
+ resp = self._from_frame(frame)
+ if TYPE_CHECKING:
+ assert isinstance(resp, SentFDs)
+ resp.fds = fds
+ # Since we know this is an expliclt SendFDs, and since this
class is generic SendFDs might not
+ # always be in the return type union
+ return resp # type: ignore[return-value]
return self._get_response()
async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
- """Send a request to the parent without blocking."""
- raise NotImplementedError
+ """
+ Send a request to the parent without blocking.
+
+ Uses async lock for coroutine safety and thread lock for socket safety.
+ """
+ frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
+ frame_bytes = frame.as_bytes()
+
+ async with self._async_lock:
+ # Acquire the threading lock without blocking the event loop
+ loop = asyncio.get_running_loop()
+ await loop.run_in_executor(None, self._thread_lock.acquire)
+ try:
+ # Async write to socket
+ await loop.sock_sendall(self.socket, frame_bytes)
+
+ if isinstance(msg, ResendLoggingFD):
+ if recv_fds is None:
+ return None
+ # Blocking read in a thread
+ frame, fds = await asyncio.to_thread(self._read_frame,
maxfds=1)
+ resp = self._from_frame(frame)
+ if TYPE_CHECKING:
+ assert isinstance(resp, SentFDs)
+ resp.fds = fds
+ return resp # type: ignore[return-value]
+
+ # Normal blocking read in a thread
+ frame = await asyncio.to_thread(self._read_frame)
+ return self._from_frame(frame)
+ finally:
+ self._thread_lock.release()
@overload
def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...
diff --git a/task-sdk/tests/task_sdk/bases/test_decorator.py
b/task-sdk/tests/task_sdk/bases/test_decorator.py
index a0202680530..913e4c6e5f0 100644
--- a/task-sdk/tests/task_sdk/bases/test_decorator.py
+++ b/task-sdk/tests/task_sdk/bases/test_decorator.py
@@ -17,11 +17,15 @@
from __future__ import annotations
import ast
+import functools
import importlib.util
import textwrap
from pathlib import Path
-from airflow.sdk.bases.decorator import DecoratedOperator
+import pytest
+
+from airflow.sdk import task
+from airflow.sdk.bases.decorator import DecoratedOperator, is_async_callable
RAW_CODE = """
from airflow.sdk import task
@@ -63,3 +67,141 @@ class TestBaseDecorator:
# Returned source must be valid Python
ast.parse(cleaned)
assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
+
+
+def simple_decorator(fn):
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+
+def decorator_without_wraps(fn):
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+
+async def async_fn():
+ return 42
+
+
+def sync_fn():
+ return 42
+
+
+@simple_decorator
+async def wrapped_async_fn():
+ return 42
+
+
+@simple_decorator
+def wrapped_sync_fn():
+ return 42
+
+
+@decorator_without_wraps
+async def wrapped_async_fn_no_wraps():
+ return 42
+
+
+@simple_decorator
+@simple_decorator
+async def multi_wrapped_async_fn():
+ return 42
+
+
+async def async_with_args(x, y):
+ return x + y
+
+
+def sync_with_args(x, y):
+ return x + y
+
+
+class AsyncCallable:
+ async def __call__(self):
+ return 42
+
+
+class SyncCallable:
+ def __call__(self):
+ return 42
+
+
+class WrappedAsyncCallable:
+ @simple_decorator
+ async def __call__(self):
+ return 42
+
+
+class TestAsyncCallable:
+ def test_plain_async_function(self):
+ assert is_async_callable(async_fn)
+
+ def test_plain_sync_function(self):
+ assert not is_async_callable(sync_fn)
+
+ def test_wrapped_async_function_with_wraps(self):
+ assert is_async_callable(wrapped_async_fn)
+
+ def test_wrapped_sync_function_with_wraps(self):
+ assert not is_async_callable(wrapped_sync_fn)
+
+ def test_wrapped_async_function_without_wraps(self):
+ """
+ Without functools.wraps, inspect.unwrap cannot recover the coroutine.
+ This documents expected behavior.
+ """
+ assert not is_async_callable(wrapped_async_fn_no_wraps)
+
+ def test_multi_wrapped_async_function(self):
+ assert is_async_callable(multi_wrapped_async_fn)
+
+ def test_partial_async_function(self):
+ fn = functools.partial(async_with_args, 1)
+ assert is_async_callable(fn)
+
+ def test_partial_sync_function(self):
+ fn = functools.partial(sync_with_args, 1)
+ assert not is_async_callable(fn)
+
+ def test_nested_partial_async_function(self):
+ fn = functools.partial(
+ functools.partial(async_with_args, 1),
+ 2,
+ )
+ assert is_async_callable(fn)
+
+ def test_async_callable_class(self):
+ assert is_async_callable(AsyncCallable())
+
+ def test_sync_callable_class(self):
+ assert not is_async_callable(SyncCallable())
+
+ def test_wrapped_async_callable_class(self):
+ assert is_async_callable(WrappedAsyncCallable())
+
+ def test_partial_callable_class(self):
+ fn = functools.partial(AsyncCallable())
+ assert is_async_callable(fn)
+
+ @pytest.mark.parametrize("value", [None, 42, "string", object()])
+ def test_non_callable(self, value):
+ assert not is_async_callable(value)
+
+ def test_task_decorator_async_function(self):
+ @task
+ async def async_task_fn():
+ return 42
+
+ assert is_async_callable(async_task_fn)
+
+ def test_task_decorator_sync_function(self):
+ @task
+ def sync_task_fn():
+ return 42
+
+ assert not is_async_callable(sync_task_fn)