This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new faf847c59d0 Add support for async callables in ``PythonOperator`` 
(#60268)
faf847c59d0 is described below

commit faf847c59d0cfdcc58851a45e0a80aa72bed93d0
Author: David Blain <[email protected]>
AuthorDate: Thu Jan 15 19:06:59 2026 +0100

    Add support for async callables in ``PythonOperator`` (#60268)
    
    
    This PR is related to the discussion I started on the 
[devlist](https://lists.apache.org/thread/ztnfsqolow4v1zsv4pkpnxc1fk0hbf2p) and 
which allows you to natively execute async code on PythonOperators.
    
    There is also an AIP for this: 
https://cwiki.apache.org/confluence/display/AIRFLOW/%5BWIP%5D+AIP-98%3A+Rethinking+deferrable+operators%2C+async+hooks+and+performance+in+Airflow+3
    
    Below an example which show you how it can be used with async hooks:
    
    ```
    @task(show_return_value_in_logs=False)
    async def load_xml_files(files):
        import asyncio
        from io import BytesIO
        from more_itertools import chunked
        from os import cpu_count
        from tenacity import retry, stop_after_attempt, wait_fixed
    
        from airflow.providers.sftp.hooks.sftp import SFTPClientPool
    
        print("number of files:", len(files))
    
        async with SFTPClientPool(sftp_conn_id=sftp_conn, 
pool_size=cpu_count()) as pool:
            @retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
            async def download_file(file):
                async with pool.get_sftp_client() as sftp:
                    print("downloading:", file)
                    buffer = BytesIO()
                    async with sftp.open(file, encoding=xml_encoding) as 
remote_file:
                        data = await remote_file.read()
                        buffer.write(data.encode(xml_encoding))
                        buffer.seek(0)
                    return buffer
    
            for batch in chunked(files, cpu_count() * 2):
                tasks = [asyncio.create_task(download_file(f)) for f in batch]
    
                # Wait for this batch to finish before starting the next
                for task in asyncio.as_completed(tasks):
                    result = await task
                     # Do something with result or accumulate it and return it 
as an XCom
    ```
    
    This PR will fix additional remarks made by @kaxil on the original 
[PR](https://github.com/apache/airflow/pull/59087) which has been reverted.
---
 airflow-core/newsfragments/60268.improvement.rst   |   1 +
 .../src/airflow/jobs/triggerer_job_runner.py       |   4 +-
 .../core_api/routes/public/test_task_instances.py  |  58 ++++-----
 .../providers/common/compat/standard/operators.py  |  56 ++++++++
 .../providers/common/compat/version_compat.py      |   2 +
 providers/standard/docs/operators/python.rst       |  31 +++++
 providers/standard/pyproject.toml                  |   2 +-
 .../example_dags/example_python_decorator.py       |  17 +++
 .../example_dags/example_python_operator.py        |  18 +++
 .../airflow/providers/standard/operators/python.py |  65 +++++++++-
 .../tests/unit/standard/decorators/test_python.py  |  36 +++++-
 .../tests/unit/standard/operators/test_python.py   |  25 +++-
 task-sdk/docs/api.rst                              |   4 +-
 task-sdk/docs/index.rst                            |   1 +
 task-sdk/src/airflow/sdk/__init__.py               |  10 +-
 task-sdk/src/airflow/sdk/__init__.pyi              |   2 +
 task-sdk/src/airflow/sdk/bases/decorator.py        |  49 ++++++-
 task-sdk/src/airflow/sdk/bases/operator.py         |  55 +++++++-
 .../sdk/definitions/_internal/abstractoperator.py  |   4 +
 .../airflow/sdk/execution_time/callback_runner.py  |  67 +++++++++-
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  72 ++++++++---
 task-sdk/tests/task_sdk/bases/test_decorator.py    | 144 ++++++++++++++++++++-
 22 files changed, 653 insertions(+), 70 deletions(-)

diff --git a/airflow-core/newsfragments/60268.improvement.rst 
b/airflow-core/newsfragments/60268.improvement.rst
new file mode 100644
index 00000000000..8c7e92b8f0d
--- /dev/null
+++ b/airflow-core/newsfragments/60268.improvement.rst
@@ -0,0 +1 @@
+The ``PythonOperator`` parameter ``python_callable`` now also supports async 
callables in Airflow 3.2, allowing users to run async def functions without 
manually managing an event loop.
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 532b0faf1de..2508e4a281a 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -789,8 +789,6 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, 
ToTriggerSupervisor]):
         factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
     )
 
-    _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
-
     def _read_frame(self):
         from asgiref.sync import async_to_sync
 
@@ -825,7 +823,7 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, 
ToTriggerSupervisor]):
         frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
         bytes = frame.as_bytes()
 
-        async with self._lock:
+        async with self._async_lock:
             self._async_writer.write(bytes)
 
             return await self._aget_response(frame.id)
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 2b53501bee5..91622ce11e2 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -214,7 +214,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -372,7 +372,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -436,7 +436,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -492,7 +492,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -612,7 +612,7 @@ class TestGetMappedTaskInstance(TestTaskInstanceEndpoint):
                 "pid": 100,
                 "pool": "default_pool",
                 "pool_slots": 1,
-                "priority_weight": 9,
+                "priority_weight": 14,
                 "queue": "default_queue",
                 "queued_when": None,
                 "scheduled_when": None,
@@ -1404,7 +1404,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 False,
                 "/dags/~/dagRuns/~/taskInstances",
                 {"dag_id_pattern": "example_python_operator"},
-                9,  # Based on test failure - example_python_operator creates 
9 task instances
+                14,  # Based on test failure - example_python_operator creates 
14 task instances
                 3,
                 id="test dag_id_pattern exact match",
             ),
@@ -1413,7 +1413,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 False,
                 "/dags/~/dagRuns/~/taskInstances",
                 {"dag_id_pattern": "example_%"},
-                17,  # Based on test failure - both DAGs together create 17 
task instances
+                22,  # Based on test failure - both DAGs together create 22 
task instances
                 3,
                 id="test dag_id_pattern wildcard prefix",
             ),
@@ -1927,8 +1927,8 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
         [
             pytest.param(
                 {"dag_ids": ["example_python_operator", "example_skip_dag"]},
-                17,
-                17,
+                22,
+                22,
                 id="with dag filter",
             ),
         ],
@@ -2037,7 +2037,7 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
         assert len(response_batch2.json()["task_instances"]) > 0
 
         # Match
-        ti_count = 9
+        ti_count = 10
         assert response_batch1.json()["total_entries"] == 
response_batch2.json()["total_entries"] == ti_count
         assert (num_entries_batch1 + num_entries_batch2) == ti_count
         assert response_batch1 != response_batch2
@@ -2076,7 +2076,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -2122,7 +2122,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -2199,7 +2199,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
                 "pid": 100,
                 "pool": "default_pool",
                 "pool_slots": 1,
-                "priority_weight": 9,
+                "priority_weight": 14,
                 "queue": "default_queue",
                 "queued_when": None,
                 "scheduled_when": None,
@@ -2271,7 +2271,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -2318,7 +2318,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
             "pid": 100,
             "pool": "default_pool",
             "pool_slots": 1,
-            "priority_weight": 9,
+            "priority_weight": 14,
             "queue": "default_queue",
             "queued_when": None,
             "scheduled_when": None,
@@ -3162,7 +3162,7 @@ class 
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
                 "pid": 100,
                 "pool": "default_pool",
                 "pool_slots": 1,
-                "priority_weight": 9,
+                "priority_weight": 14,
                 "queue": "default_queue",
                 "queued_when": None,
                 "scheduled_when": None,
@@ -3534,7 +3534,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
                     "pid": 100,
                     "pool": "default_pool",
                     "pool_slots": 1,
-                    "priority_weight": 9,
+                    "priority_weight": 14,
                     "queue": "default_queue",
                     "queued_when": None,
                     "scheduled_when": None,
@@ -3571,7 +3571,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
                     "pid": 100,
                     "pool": "default_pool",
                     "pool_slots": 1,
-                    "priority_weight": 9,
+                    "priority_weight": 14,
                     "queue": "default_queue",
                     "queued_when": None,
                     "scheduled_when": None,
@@ -3642,7 +3642,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
                     "pid": 100,
                     "pool": "default_pool",
                     "pool_slots": 1,
-                    "priority_weight": 9,
+                    "priority_weight": 14,
                     "queue": "default_queue",
                     "queued_when": None,
                     "scheduled_when": None,
@@ -3725,7 +3725,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
                         "pid": 100,
                         "pool": "default_pool",
                         "pool_slots": 1,
-                        "priority_weight": 9,
+                        "priority_weight": 14,
                         "queue": "default_queue",
                         "queued_when": None,
                         "scheduled_when": None,
@@ -3762,7 +3762,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
                         "pid": 100,
                         "pool": "default_pool",
                         "pool_slots": 1,
-                        "priority_weight": 9,
+                        "priority_weight": 14,
                         "queue": "default_queue",
                         "queued_when": None,
                         "scheduled_when": None,
@@ -4002,7 +4002,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                     "pid": 100,
                     "pool": "default_pool",
                     "pool_slots": 1,
-                    "priority_weight": 9,
+                    "priority_weight": 14,
                     "queue": "default_queue",
                     "queued_when": None,
                     "scheduled_when": None,
@@ -4276,7 +4276,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                             "pid": 100,
                             "pool": "default_pool",
                             "pool_slots": 1,
-                            "priority_weight": 9,
+                            "priority_weight": 14,
                             "queue": "default_queue",
                             "queued_when": None,
                             "scheduled_when": None,
@@ -4410,7 +4410,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                     "pid": 100,
                     "pool": "default_pool",
                     "pool_slots": 1,
-                    "priority_weight": 9,
+                    "priority_weight": 14,
                     "queue": "default_queue",
                     "queued_when": None,
                     "scheduled_when": None,
@@ -4471,7 +4471,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                     "pid": 100,
                     "pool": "default_pool",
                     "pool_slots": 1,
-                    "priority_weight": 9,
+                    "priority_weight": 14,
                     "queue": "default_queue",
                     "queued_when": None,
                     "scheduled_when": None,
@@ -4550,7 +4550,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                         "pid": 100,
                         "pool": "default_pool",
                         "pool_slots": 1,
-                        "priority_weight": 9,
+                        "priority_weight": 14,
                         "queue": "default_queue",
                         "queued_when": None,
                         "scheduled_when": None,
@@ -4631,7 +4631,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                 "pid": 100,
                 "pool": "default_pool",
                 "pool_slots": 1,
-                "priority_weight": 9,
+                "priority_weight": 14,
                 "queue": "default_queue",
                 "queued_when": None,
                 "scheduled_when": None,
@@ -4749,7 +4749,7 @@ class 
TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
                     "pid": 100,
                     "pool": "default_pool",
                     "pool_slots": 1,
-                    "priority_weight": 9,
+                    "priority_weight": 14,
                     "queue": "default_queue",
                     "queued_when": None,
                     "scheduled_when": None,
@@ -5035,7 +5035,7 @@ class 
TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
                             "pid": 100,
                             "pool": "default_pool",
                             "pool_slots": 1,
-                            "priority_weight": 9,
+                            "priority_weight": 14,
                             "queue": "default_queue",
                             "queued_when": None,
                             "scheduled_when": None,
diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
 
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
index 6b77db3e4a9..866d94ac792 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
@@ -17,18 +17,74 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 from airflow.providers.common.compat._compat_utils import create_module_getattr
+from airflow.providers.common.compat.version_compat import (
+    AIRFLOW_V_3_0_PLUS,
+    AIRFLOW_V_3_1_PLUS,
+    AIRFLOW_V_3_2_PLUS,
+)
 
 _IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
     # Re-export from sdk (which handles Airflow 2.x/3.x fallbacks)
     "BaseOperator": "airflow.providers.common.compat.sdk",
+    "BaseAsyncOperator": "airflow.providers.common.compat.sdk",
     "get_current_context": "airflow.providers.common.compat.sdk",
+    "is_async_callable": "airflow.providers.common.compat.sdk",
     # Standard provider items with direct fallbacks
     "PythonOperator": ("airflow.providers.standard.operators.python", 
"airflow.operators.python"),
     "ShortCircuitOperator": ("airflow.providers.standard.operators.python", 
"airflow.operators.python"),
     "_SERIALIZERS": ("airflow.providers.standard.operators.python", 
"airflow.operators.python"),
 }
 
+if TYPE_CHECKING:
+    from airflow.sdk.bases.decorator import is_async_callable
+    from airflow.sdk.bases.operator import BaseAsyncOperator
+elif AIRFLOW_V_3_2_PLUS:
+    from airflow.sdk.bases.decorator import is_async_callable
+    from airflow.sdk.bases.operator import BaseAsyncOperator
+else:
+    if AIRFLOW_V_3_0_PLUS:
+        from airflow.sdk import BaseOperator
+    else:
+        from airflow.models import BaseOperator
+
+    def is_async_callable(func) -> bool:
+        """Detect if a callable is an async function."""
+        import inspect
+        from functools import partial
+
+        while isinstance(func, partial):
+            func = func.func
+        return inspect.iscoroutinefunction(func)
+
+    class BaseAsyncOperator(BaseOperator):
+        """Stub for Airflow < 3.2 that raises a clear error."""
+
+        @property
+        def is_async(self) -> bool:
+            return True
+
+        if not AIRFLOW_V_3_1_PLUS:
+
+            @property
+            def xcom_push(self) -> bool:
+                return self.do_xcom_push
+
+            @xcom_push.setter
+            def xcom_push(self, value: bool):
+                self.do_xcom_push = value
+
+        async def aexecute(self, context):
+            raise NotImplementedError()
+
+        def execute(self, context):
+            raise RuntimeError(
+                "Async operators require Airflow 3.2+. Upgrade Airflow or use 
a synchronous callable."
+            )
+
+
 __getattr__ = create_module_getattr(import_map=_IMPORT_MAP)
 
 __all__ = sorted(_IMPORT_MAP.keys())
diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py 
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
index 4142937bd2a..e3fd1e55f14 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
@@ -34,6 +34,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
 
 AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0)
 AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
+AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0)
 
 # BaseOperator removed from version_compat to avoid circular imports
 # Import it directly in files that need it instead
@@ -41,4 +42,5 @@ AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() 
>= (3, 1, 0)
 __all__ = [
     "AIRFLOW_V_3_0_PLUS",
     "AIRFLOW_V_3_1_PLUS",
+    "AIRFLOW_V_3_2_PLUS",
 ]
diff --git a/providers/standard/docs/operators/python.rst 
b/providers/standard/docs/operators/python.rst
index 2e5e63ea437..a87762c11cd 100644
--- a/providers/standard/docs/operators/python.rst
+++ b/providers/standard/docs/operators/python.rst
@@ -72,6 +72,37 @@ Pass extra arguments to the ``@task`` decorated function as 
you would with a nor
             :start-after: [START howto_operator_python_kwargs]
             :end-before: [END howto_operator_python_kwargs]
 
+Async Python functions
+^^^^^^^^^^^^^^^^^^^^^^
+
+.. versionadded:: 3.2
+
+Async Python callables are now also supported out of the box. This means we 
don't need to cope with the event loop
+and allows us to easily invoke async Python code and async Airflow hooks which 
are not always available through
+deferred operators.
+
+As opposed to deferred operators which are executed on the triggerer, async 
operators are executed on the workers.
+
+.. tab-set::
+
+    .. tab-item:: @task
+        :sync: taskflow
+
+        .. exampleinclude:: 
/../src/airflow/providers/standard/example_dags/example_python_decorator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START howto_async_operator_python_kwargs]
+            :end-before: [END howto_async_operator_python_kwargs]
+
+    .. tab-item:: PythonOperator
+        :sync: operator
+
+        .. exampleinclude:: 
/../src/airflow/providers/standard/example_dags/example_python_operator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START howto_async_operator_python_kwargs]
+            :end-before: [END howto_async_operator_python_kwargs]
+
 Templating
 ^^^^^^^^^^
 
diff --git a/providers/standard/pyproject.toml 
b/providers/standard/pyproject.toml
index 05635ff6a0d..5ad02e5efcf 100644
--- a/providers/standard/pyproject.toml
+++ b/providers/standard/pyproject.toml
@@ -59,7 +59,7 @@ requires-python = ">=3.10"
 # After you modify the dependencies, and rebuild your Breeze CI image with 
``breeze ci-image build``
 dependencies = [
     "apache-airflow>=2.11.0",
-    "apache-airflow-providers-common-compat>=1.12.0",
+    "apache-airflow-providers-common-compat>=1.12.0", # use next version
 ]
 
 # The optional dependencies should be modified in place in the generated file
diff --git 
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
 
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
index ac9938d92ea..086d67f831c 100644
--- 
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
+++ 
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
@@ -22,6 +22,7 @@ virtual environment.
 
 from __future__ import annotations
 
+import asyncio
 import logging
 import sys
 import time
@@ -75,6 +76,22 @@ def example_python_decorator():
         run_this >> log_the_sql >> sleeping_task
     # [END howto_operator_python_kwargs]
 
+    # [START howto_async_operator_python_kwargs]
+    # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
+    # Asynchronous callables are natively supported since Airflow 3.2+
+    @task
+    async def my_async_sleeping_function(random_base):
+        """This is a function that will run within the DAG execution"""
+        await asyncio.sleep(random_base)
+
+    for i in range(5):
+        async_sleeping_task = 
my_async_sleeping_function.override(task_id=f"async_sleep_for_{i}")(
+            random_base=i / 10
+        )
+
+        run_this >> log_the_sql >> async_sleeping_task
+    # [END howto_async_operator_python_kwargs]
+
     # [START howto_operator_python_venv]
     @task.virtualenv(
         task_id="virtualenv_python", requirements=["colorama==0.4.0"], 
system_site_packages=False
diff --git 
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
 
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
index 18aa8f207e3..a9378938873 100644
--- 
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
+++ 
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
@@ -22,6 +22,7 @@ within a virtual environment.
 
 from __future__ import annotations
 
+import asyncio
 import logging
 import sys
 import time
@@ -88,6 +89,23 @@ with DAG(
         run_this >> log_the_sql >> sleeping_task
     # [END howto_operator_python_kwargs]
 
+    # [START howto_async_operator_python_kwargs]
+    # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
+    # Asynchronous callables are natively supported since Airflow 3.2+
+    async def my_async_sleeping_function(random_base):
+        """This is a function that will run within the DAG execution"""
+        await asyncio.sleep(random_base)
+
+    for i in range(5):
+        async_sleeping_task = PythonOperator(
+            task_id=f"async_sleep_for_{i}",
+            python_callable=my_async_sleeping_function,
+            op_kwargs={"random_base": i / 10},
+        )
+
+        run_this >> log_the_sql >> async_sleeping_task
+    # [END howto_async_operator_python_kwargs]
+
     # [START howto_operator_python_venv]
     def callable_virtualenv():
         """
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/python.py 
b/providers/standard/src/airflow/providers/standard/operators/python.py
index ac8862f2923..7e6bbd166a6 100644
--- a/providers/standard/src/airflow/providers/standard/operators/python.py
+++ b/providers/standard/src/airflow/providers/standard/operators/python.py
@@ -48,13 +48,17 @@ from airflow.exceptions import (
 )
 from airflow.models.variable import Variable
 from airflow.providers.common.compat.sdk import AirflowException, 
AirflowSkipException, context_merge
+from airflow.providers.common.compat.standard.operators import (
+    BaseAsyncOperator,
+    is_async_callable,
+)
 from airflow.providers.standard.hooks.package_index import PackageIndexHook
 from airflow.providers.standard.utils.python_virtualenv import (
     _execute_in_subprocess,
     prepare_virtualenv,
     write_python_script,
 )
-from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, 
BaseOperator
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_2_PLUS
 from airflow.utils import hashlib_wrapper
 from airflow.utils.file import get_unique_dag_module_name
 from airflow.utils.operator_helpers import KeywordParameters
@@ -75,7 +79,10 @@ if TYPE_CHECKING:
     from pendulum.datetime import DateTime
 
     from airflow.providers.common.compat.sdk import Context
-    from airflow.sdk.execution_time.callback_runner import 
ExecutionCallableRunner
+    from airflow.sdk.execution_time.callback_runner import (
+        AsyncExecutionCallableRunner,
+        ExecutionCallableRunner,
+    )
     from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol
 
     _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
@@ -115,9 +122,9 @@ class _PythonVersionInfo(NamedTuple):
         return cls(*_parse_version_info(result.strip()))
 
 
-class PythonOperator(BaseOperator):
+class PythonOperator(BaseAsyncOperator):
     """
-    Executes a Python callable.
+    Base class for all Python operators.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -192,7 +199,14 @@ class PythonOperator(BaseOperator):
             self.template_ext = templates_exts
         self.show_return_value_in_logs = show_return_value_in_logs
 
-    def execute(self, context: Context) -> Any:
+    @property
+    def is_async(self) -> bool:
+        return is_async_callable(self.python_callable)
+
+    def execute(self, context) -> Any:
+        if self.is_async:
+            return BaseAsyncOperator.execute(self, context)
+
         context_merge(context, self.op_kwargs, 
templates_dict=self.templates_dict)
         self.op_kwargs = self.determine_kwargs(context)
 
@@ -236,6 +250,47 @@ class PythonOperator(BaseOperator):
         runner = create_execution_runner(self.python_callable, asset_events, 
logger=self.log)
         return runner.run(*self.op_args, **self.op_kwargs)
 
+    if AIRFLOW_V_3_2_PLUS:
+
+        async def aexecute(self, context):
+            context_merge(context, self.op_kwargs, 
templates_dict=self.templates_dict)
+            self.op_kwargs = self.determine_kwargs(context)
+
+            # This needs to be lazy because subclasses may implement 
execute_callable
+            # by running a separate process that can't use the eager result.
+            def __prepare_execution() -> (
+                tuple[AsyncExecutionCallableRunner, 
OutletEventAccessorsProtocol] | None
+            ):
+                from airflow.sdk.execution_time.callback_runner import 
create_async_executable_runner
+                from airflow.sdk.execution_time.context import 
context_get_outlet_events
+
+                return (
+                    cast("AsyncExecutionCallableRunner", 
create_async_executable_runner),
+                    context_get_outlet_events(context),
+                )
+
+            self.__prepare_execution = __prepare_execution
+
+            return_value = await self.aexecute_callable()
+            if self.show_return_value_in_logs:
+                self.log.info("Done. Returned value was: %s", return_value)
+            else:
+                self.log.info("Done. Returned value not shown")
+
+            return return_value
+
+        async def aexecute_callable(self) -> Any:
+            """
+            Call the python callable with the given arguments.
+
+            :return: the return value of the call.
+            """
+            if (execution_preparation := self.__prepare_execution()) is None:
+                return await self.python_callable(*self.op_args, 
**self.op_kwargs)
+            create_execution_runner, asset_events = execution_preparation
+            runner = create_execution_runner(self.python_callable, 
asset_events, logger=self.log)
+            return await runner.run(*self.op_args, **self.op_kwargs)
+
 
 class BranchPythonOperator(BaseBranchOperator, PythonOperator):
     """
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py 
b/providers/standard/tests/unit/standard/decorators/test_python.py
index 32b2fc2615d..d2aa38ec544 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -37,8 +37,16 @@ from tests_common.test_utils.version_compat import (
 from unit.standard.operators.test_python import BasePythonTest
 
 if AIRFLOW_V_3_0_PLUS:
-    from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg, setup, task 
as task_decorator, teardown
-    from airflow.sdk.bases.decorator import DecoratedMappedOperator
+    from airflow.sdk import (
+        DAG,
+        BaseOperator,
+        TaskGroup,
+        XComArg,
+        setup,
+        task as task_decorator,
+        teardown,
+    )
+    from airflow.sdk.bases.decorator import DecoratedMappedOperator, 
_TaskDecorator
     from airflow.sdk.definitions._internal.expandinput import 
DictOfListsExpandInput
 else:
     from airflow.decorators import (  # type: ignore[attr-defined,no-redef]
@@ -46,7 +54,7 @@ else:
         task as task_decorator,
         teardown,
     )
-    from airflow.decorators.base import DecoratedMappedOperator  # type: 
ignore[no-redef]
+    from airflow.decorators.base import DecoratedMappedOperator, 
_TaskDecorator  # type: ignore[no-redef]
     from airflow.models.baseoperator import BaseOperator  # type: 
ignore[no-redef]
     from airflow.models.dag import DAG  # type: ignore[assignment,no-redef]
     from airflow.models.expandinput import DictOfListsExpandInput  # type: 
ignore[attr-defined,no-redef]
@@ -658,9 +666,9 @@ class TestAirflowTaskDecorator(BasePythonTest):
                 hello.override(pool="my_pool", priority_weight=i)()
 
         weights = []
-        for task in self.dag_non_serialized.tasks:
-            assert task.pool == "my_pool"
-            weights.append(task.priority_weight)
+        for _task in self.dag_non_serialized.tasks:
+            assert _task.pool == "my_pool"
+            weights.append(_task.priority_weight)
         assert weights == [0, 1, 2]
 
     def test_python_callable_args_work_as_well_as_baseoperator_args(self, 
dag_maker):
@@ -1142,3 +1150,19 @@ def 
test_teardown_trigger_rule_override_behavior(dag_maker, session):
         my_teardown()
     assert work_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS
     assert setup_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS
+
+
+async def async_fn():
+    return 42
+
+
+def test_python_task():
+    from airflow.providers.standard.decorators.python import 
_PythonDecoratedOperator, python_task
+
+    decorator = python_task(async_fn)
+
+    assert isinstance(decorator, _TaskDecorator)
+    assert decorator.function == async_fn
+    assert decorator.operator_class == _PythonDecoratedOperator
+    assert not decorator.multiple_outputs
+    assert decorator.kwargs == {"task_id": "async_fn"}
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py 
b/providers/standard/tests/unit/standard/operators/test_python.py
index a59c33b29dc..3818f42cb3c 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import asyncio
 import copy
 import logging
 import os
@@ -43,7 +44,7 @@ from slugify import slugify
 from airflow.exceptions import AirflowProviderDeprecationWarning, 
DeserializingResultError
 from airflow.models.connection import Connection
 from airflow.models.taskinstance import TaskInstance, clear_task_instances
-from airflow.providers.common.compat.sdk import AirflowException
+from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.providers.standard.operators.python import (
     BranchExternalPythonOperator,
@@ -69,15 +70,14 @@ from tests_common.test_utils.version_compat import (
     AIRFLOW_V_3_0_1,
     AIRFLOW_V_3_0_PLUS,
     AIRFLOW_V_3_1_PLUS,
+    AIRFLOW_V_3_2_PLUS,
     NOTSET,
 )
 
 if AIRFLOW_V_3_0_PLUS:
-    from airflow.sdk import BaseOperator
     from airflow.sdk.execution_time.context import set_current_context
     from airflow.serialization.serialized_objects import LazyDeserializedDAG
 else:
-    from airflow.models.baseoperator import BaseOperator  # type: 
ignore[no-redef]
     from airflow.models.taskinstance import set_current_context  # type: 
ignore[attr-defined,no-redef]
 
 if TYPE_CHECKING:
@@ -2465,6 +2465,25 @@ class TestShortCircuitWithTeardown:
         assert set(actual_skipped) == {op3}
 
 
+class TestPythonAsyncOperator(TestPythonOperator):
+    def test_run_async_task(self, caplog):
+        caplog.set_level(logging.INFO, logger=LOGGER_NAME)
+
+        async def say_hello(name: str) -> str:
+            await asyncio.sleep(1)
+            return f"Hello {name}!"
+
+        if AIRFLOW_V_3_2_PLUS:
+            self.run_as_task(say_hello, op_kwargs={"name": "world"}, 
show_return_value_in_logs=True)
+            assert "Done. Returned value was: Hello world!" in caplog.messages
+        else:
+            with pytest.raises(
+                RuntimeError,
+                match=r"Async operators require Airflow 3\.2\+\. Upgrade 
Airflow or use a synchronous callable\.",
+            ):
+                self.run_as_task(say_hello, op_kwargs={"name": "world"}, 
show_return_value_in_logs=True)
+
+
 @pytest.mark.parametrize(
     ("text_input", "expected_tuple"),
     [
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 05476cbe562..66fe1637486 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -62,6 +62,8 @@ Task Decorators:
 
 Bases
 -----
+.. autoapiclass:: airflow.sdk.BaseAsyncOperator
+
 .. autoapiclass:: airflow.sdk.BaseOperator
 
 .. autoapiclass:: airflow.sdk.BaseSensorOperator
@@ -183,7 +185,7 @@ Everything else
 .. autoapimodule:: airflow.sdk
   :members:
   :special-members: __version__
-  :exclude-members: BaseOperator, DAG, dag, asset, Asset, AssetAlias, 
AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg, get_current_context, 
get_parsing_context
+  :exclude-members: BaseAsyncOperator, BaseOperator, DAG, dag, asset, Asset, 
AssetAlias, AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg, 
get_current_context, get_parsing_context
   :undoc-members:
   :imported-members:
   :no-index:
diff --git a/task-sdk/docs/index.rst b/task-sdk/docs/index.rst
index 819f637676b..f3258ea8243 100644
--- a/task-sdk/docs/index.rst
+++ b/task-sdk/docs/index.rst
@@ -78,6 +78,7 @@ Why use ``airflow.sdk``?
 **Classes**
 
 - :class:`airflow.sdk.Asset`
+- :class:`airflow.sdk.BaseAsyncOperator`
 - :class:`airflow.sdk.BaseHook`
 - :class:`airflow.sdk.BaseNotifier`
 - :class:`airflow.sdk.BaseOperator`
diff --git a/task-sdk/src/airflow/sdk/__init__.py 
b/task-sdk/src/airflow/sdk/__init__.py
index 4dbda282d08..034a6379430 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -26,6 +26,7 @@ __all__ = [
     "AssetAny",
     "AssetOrTimeSchedule",
     "AssetWatcher",
+    "BaseAsyncOperator",
     "BaseHook",
     "BaseNotifier",
     "BaseOperator",
@@ -76,7 +77,13 @@ if TYPE_CHECKING:
     from airflow.sdk.api.datamodels._generated import DagRunState, 
TaskInstanceState, TriggerRule, WeightRule
     from airflow.sdk.bases.hook import BaseHook
     from airflow.sdk.bases.notifier import BaseNotifier
-    from airflow.sdk.bases.operator import BaseOperator, chain, chain_linear, 
cross_downstream
+    from airflow.sdk.bases.operator import (
+        BaseAsyncOperator,
+        BaseOperator,
+        chain,
+        chain_linear,
+        cross_downstream,
+    )
     from airflow.sdk.bases.operatorlink import BaseOperatorLink
     from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue
     from airflow.sdk.configuration import AirflowSDKConfigParser
@@ -117,6 +124,7 @@ __lazy_imports: dict[str, str] = {
     "AssetAny": ".definitions.asset",
     "AssetOrTimeSchedule": ".definitions.timetables.assets",
     "AssetWatcher": ".definitions.asset",
+    "BaseAsyncOperator": ".bases.operator",
     "BaseHook": ".bases.hook",
     "BaseNotifier": ".bases.notifier",
     "BaseOperator": ".bases.operator",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi 
b/task-sdk/src/airflow/sdk/__init__.pyi
index eede7ff806a..b035f49226c 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -24,6 +24,7 @@ from airflow.sdk.api.datamodels._generated import (
 from airflow.sdk.bases.hook import BaseHook as BaseHook
 from airflow.sdk.bases.notifier import BaseNotifier as BaseNotifier
 from airflow.sdk.bases.operator import (
+    BaseAsyncOperator as BaseAsyncOperator,
     BaseOperator as BaseOperator,
     chain as chain,
     chain_linear as chain_linear,
@@ -83,6 +84,7 @@ __all__ = [
     "AssetAny",
     "AssetOrTimeSchedule",
     "AssetWatcher",
+    "BaseAsyncOperator",
     "BaseHook",
     "BaseNotifier",
     "BaseOperator",
diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py 
b/task-sdk/src/airflow/sdk/bases/decorator.py
index bde898c1696..131750ae8c9 100644
--- a/task-sdk/src/airflow/sdk/bases/decorator.py
+++ b/task-sdk/src/airflow/sdk/bases/decorator.py
@@ -22,7 +22,8 @@ import re
 import textwrap
 import warnings
 from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
-from functools import cached_property, update_wrapper
+from contextlib import suppress
+from functools import cached_property, partial, update_wrapper
 from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, Protocol, 
TypeVar, cast, overload
 
 import attr
@@ -149,6 +150,48 @@ def get_unique_task_id(
     return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
 
 
+def unwrap_partial(fn: Callable) -> Callable:
+    while isinstance(fn, partial):
+        fn = fn.func
+    return fn
+
+
+def unwrap_callable(func):
+    from airflow.sdk.definitions.mappedoperator import OperatorPartial
+
+    if isinstance(func, (_TaskDecorator, OperatorPartial)):
+        func = getattr(func, "function", getattr(func, "_func", func))
+
+    func = unwrap_partial(func)
+
+    with suppress(Exception):
+        func = inspect.unwrap(func)
+
+    return func
+
+
+def is_async_callable(func):
+    """Detect if a callable (possibly wrapped) is an async function."""
+    func = unwrap_callable(func)
+
+    if not callable(func):
+        return False
+
+    # Direct async function
+    if inspect.iscoroutinefunction(func):
+        return True
+
+    # Callable object with async __call__
+    if not inspect.isfunction(func):
+        call = type(func).__call__  # Bandit-safe
+        with suppress(Exception):
+            call = inspect.unwrap(call)
+        if inspect.iscoroutinefunction(call):
+            return True
+
+    return False
+
+
 class DecoratedOperator(BaseOperator):
     """
     Wraps a Python callable and captures args/kwargs when called for execution.
@@ -243,6 +286,10 @@ class DecoratedOperator(BaseOperator):
         self.op_kwargs = op_kwargs
         super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
 
+    @property
+    def is_async(self) -> bool:
+        return is_async_callable(self.python_callable)
+
     def execute(self, context: Context):
         # todo make this more generic (move to prepare_lineage) so it deals 
with non taskflow operators
         #  as well
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index 0c97df00ef0..5f07a188362 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -18,13 +18,15 @@
 from __future__ import annotations
 
 import abc
+import asyncio
 import collections.abc
 import contextlib
 import copy
 import inspect
 import sys
 import warnings
-from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
+from asyncio import AbstractEventLoop
+from collections.abc import Callable, Collection, Generator, Iterable, 
Mapping, Sequence
 from contextvars import ContextVar
 from dataclasses import dataclass, field
 from datetime import datetime, timedelta
@@ -101,6 +103,7 @@ if TYPE_CHECKING:
     TaskPostExecuteHook = Callable[[Context, Any], None]
 
 __all__ = [
+    "BaseAsyncOperator",
     "BaseOperator",
     "chain",
     "chain_linear",
@@ -196,6 +199,27 @@ def coerce_resources(resources: dict[str, Any] | None) -> 
Resources | None:
     return Resources(**resources)
 
 
[email protected]
+def event_loop() -> Generator[AbstractEventLoop]:
+    new_event_loop = False
+    loop = None
+    try:
+        try:
+            loop = asyncio.get_event_loop()
+            if loop.is_closed():
+                raise RuntimeError
+        except RuntimeError:
+            loop = asyncio.new_event_loop()
+            asyncio.set_event_loop(loop)
+            new_event_loop = True
+        yield loop
+    finally:
+        if new_event_loop and loop is not None:
+            with contextlib.suppress(AttributeError):
+                loop.close()
+                asyncio.set_event_loop(None)
+
+
 class _PartialDescriptor:
     """A descriptor that guards against ``.partial`` being called on Task 
objects."""
 
@@ -1670,6 +1694,35 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         return bool(self.on_skipped_callback)
 
 
+class BaseAsyncOperator(BaseOperator):
+    """
+    Base class for async-capable operators.
+
+    As opposed to deferred operators which are executed on the triggerer, 
async operators are executed
+    on the worker.
+    """
+
+    @property
+    def is_async(self) -> bool:
+        return True
+
+    async def aexecute(self, context):
+        """Async version of execute(). Subclasses should implement this."""
+        raise NotImplementedError()
+
+    def execute(self, context):
+        """Run `aexecute()` inside an event loop."""
+        with event_loop() as loop:
+            if self.execution_timeout:
+                return loop.run_until_complete(
+                    asyncio.wait_for(
+                        self.aexecute(context),
+                        timeout=self.execution_timeout.total_seconds(),
+                    )
+                )
+            return loop.run_until_complete(self.aexecute(context))
+
+
 def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
     r"""
     Given a number of tasks, builds a dependency chain.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py 
b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
index 6c99a72b220..e7e5ebe8b9a 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
@@ -137,6 +137,10 @@ class AbstractOperator(Templater, DAGNode):
         )
     )
 
+    @property
+    def is_async(self) -> bool:
+        return False
+
     @property
     def task_type(self) -> str:
         raise NotImplementedError()
diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
index 316c3d38e99..322e4bc9780 100644
--- a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
@@ -20,8 +20,8 @@ from __future__ import annotations
 
 import inspect
 import logging
-from collections.abc import Callable
-from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
+from collections.abc import AsyncIterator, Awaitable, Callable
+from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast
 
 from typing_extensions import ParamSpec
 
@@ -39,6 +39,11 @@ class _ExecutionCallableRunner(Generic[P, R]):
     def run(*args: P.args, **kwargs: P.kwargs) -> R: ...  # type: 
ignore[empty-body]
 
 
+class _AsyncExecutionCallableRunner(Generic[P, R]):
+    @staticmethod
+    async def run(*args: P.args, **kwargs: P.kwargs) -> R: ...  # type: 
ignore[empty-body]
+
+
 class ExecutionCallableRunner(Protocol):
     def __call__(
         self,
@@ -49,6 +54,16 @@ class ExecutionCallableRunner(Protocol):
     ) -> _ExecutionCallableRunner[P, R]: ...
 
 
+class AsyncExecutionCallableRunner(Protocol):
+    def __call__(
+        self,
+        func: Callable[P, R],
+        outlet_events: OutletEventAccessorsProtocol,
+        *,
+        logger: logging.Logger | Logger,
+    ) -> _AsyncExecutionCallableRunner[P, R]: ...
+
+
 def create_executable_runner(
     func: Callable[P, R],
     outlet_events: OutletEventAccessorsProtocol,
@@ -109,3 +124,51 @@ def create_executable_runner(
             return result  # noqa: F821  # Ruff is not smart enough to know 
this is always set in _run().
 
     return cast("_ExecutionCallableRunner[P, R]", _ExecutionCallableRunnerImpl)
+
+
+def create_async_executable_runner(
+    func: Callable[P, Awaitable[R] | AsyncIterator],
+    outlet_events: OutletEventAccessorsProtocol,
+    *,
+    logger: logging.Logger | Logger,
+) -> _AsyncExecutionCallableRunner[P, R]:
+    """
+    Run an async execution callable against a task context and given arguments.
+
+    If the callable is a simple function, this simply calls it with the 
supplied
+    arguments (including the context). If the callable is a generator function,
+    the generator is exhausted here, with the yielded values getting fed back
+    into the task context automatically for execution.
+
+    This convoluted implementation of inner class with closure is so *all*
+    arguments passed to ``run()`` can be forwarded to the wrapped function. 
This
+    is particularly important for the argument "self", which some use cases
+    need to receive. This is not possible if this is implemented as a normal
+    class, where "self" needs to point to the runner object, not the object
+    bounded to the inner callable.
+
+    :meta private:
+    """
+
+    class _AsyncExecutionCallableRunnerImpl(_AsyncExecutionCallableRunner):
+        @staticmethod
+        async def run(*args: P.args, **kwargs: P.kwargs) -> R:
+            from airflow.sdk.definitions.asset.metadata import Metadata
+
+            if not inspect.isasyncgenfunction(func):
+                result = cast("Awaitable[R]", func(*args, **kwargs))
+                return await result
+
+            results: list[Any] = []
+
+            async for result in func(*args, **kwargs):
+                if isinstance(result, Metadata):
+                    outlet_events[result.asset].extra.update(result.extra)
+                    if result.alias:
+                        outlet_events[result.alias].add(result.asset, 
extra=result.extra)
+
+                results.append(result)
+
+            return cast("R", results)
+
+    return cast("_AsyncExecutionCallableRunner[P, R]", 
_AsyncExecutionCallableRunnerImpl)
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 52a96d0b665..15755e640d9 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -48,7 +48,9 @@ Execution API server is because:
 
 from __future__ import annotations
 
+import asyncio
 import itertools
+import threading
 from collections.abc import Iterator
 from datetime import datetime
 from functools import cached_property
@@ -185,31 +187,69 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
 
     err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: 
TypeAdapter(ToTask), repr=False)
 
+    # Threading lock for sync operations
+    _thread_lock: threading.Lock = attrs.field(factory=threading.Lock, 
repr=False)
+    # Async lock for async operations
+    _async_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
+
     def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
         """Send a request to the parent and block until the response is 
received."""
         frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
         frame_bytes = frame.as_bytes()
 
-        self.socket.sendall(frame_bytes)
-        if isinstance(msg, ResendLoggingFD):
-            if recv_fds is None:
-                return None
-            # We need special handling here! The server can't send us the fd 
number, as the number on the
-            # supervisor will be different to in this process, so we have to 
mutate the message ourselves here.
-            frame, fds = self._read_frame(maxfds=1)
-            resp = self._from_frame(frame)
-            if TYPE_CHECKING:
-                assert isinstance(resp, SentFDs)
-            resp.fds = fds
-            # Since we know this is an expliclt SendFDs, and since this class 
is generic SendFDs might not
-            # always be in the return type union
-            return resp  # type: ignore[return-value]
+        # We must make sure sockets aren't intermixed between sync and async 
calls,
+        # thus we need a dual locking mechanism to ensure that.
+        with self._thread_lock:
+            self.socket.sendall(frame_bytes)
+            if isinstance(msg, ResendLoggingFD):
+                if recv_fds is None:
+                    return None
+                # We need special handling here! The server can't send us the 
fd number, as the number on the
+                # supervisor will be different to in this process, so we have 
to mutate the message ourselves here.
+                frame, fds = self._read_frame(maxfds=1)
+                resp = self._from_frame(frame)
+                if TYPE_CHECKING:
+                    assert isinstance(resp, SentFDs)
+                resp.fds = fds
+                # Since we know this is an expliclt SendFDs, and since this 
class is generic SendFDs might not
+                # always be in the return type union
+                return resp  # type: ignore[return-value]
 
         return self._get_response()
 
     async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
-        """Send a request to the parent without blocking."""
-        raise NotImplementedError
+        """
+        Send a request to the parent without blocking.
+
+        Uses async lock for coroutine safety and thread lock for socket safety.
+        """
+        frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
+        frame_bytes = frame.as_bytes()
+
+        async with self._async_lock:
+            # Acquire the threading lock without blocking the event loop
+            loop = asyncio.get_running_loop()
+            await loop.run_in_executor(None, self._thread_lock.acquire)
+            try:
+                # Async write to socket
+                await loop.sock_sendall(self.socket, frame_bytes)
+
+                if isinstance(msg, ResendLoggingFD):
+                    if recv_fds is None:
+                        return None
+                    # Blocking read in a thread
+                    frame, fds = await asyncio.to_thread(self._read_frame, 
maxfds=1)
+                    resp = self._from_frame(frame)
+                    if TYPE_CHECKING:
+                        assert isinstance(resp, SentFDs)
+                    resp.fds = fds
+                    return resp  # type: ignore[return-value]
+
+                # Normal blocking read in a thread
+                frame = await asyncio.to_thread(self._read_frame)
+                return self._from_frame(frame)
+            finally:
+                self._thread_lock.release()
 
     @overload
     def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...
diff --git a/task-sdk/tests/task_sdk/bases/test_decorator.py 
b/task-sdk/tests/task_sdk/bases/test_decorator.py
index a0202680530..913e4c6e5f0 100644
--- a/task-sdk/tests/task_sdk/bases/test_decorator.py
+++ b/task-sdk/tests/task_sdk/bases/test_decorator.py
@@ -17,11 +17,15 @@
 from __future__ import annotations
 
 import ast
+import functools
 import importlib.util
 import textwrap
 from pathlib import Path
 
-from airflow.sdk.bases.decorator import DecoratedOperator
+import pytest
+
+from airflow.sdk import task
+from airflow.sdk.bases.decorator import DecoratedOperator, is_async_callable
 
 RAW_CODE = """
 from airflow.sdk import task
@@ -63,3 +67,141 @@ class TestBaseDecorator:
         # Returned source must be valid Python
         ast.parse(cleaned)
         assert cleaned.lstrip().splitlines()[0].startswith("def a_task")
+
+
+def simple_decorator(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        return fn(*args, **kwargs)
+
+    return wrapper
+
+
+def decorator_without_wraps(fn):
+    def wrapper(*args, **kwargs):
+        return fn(*args, **kwargs)
+
+    return wrapper
+
+
+async def async_fn():
+    return 42
+
+
+def sync_fn():
+    return 42
+
+
+@simple_decorator
+async def wrapped_async_fn():
+    return 42
+
+
+@simple_decorator
+def wrapped_sync_fn():
+    return 42
+
+
+@decorator_without_wraps
+async def wrapped_async_fn_no_wraps():
+    return 42
+
+
+@simple_decorator
+@simple_decorator
+async def multi_wrapped_async_fn():
+    return 42
+
+
+async def async_with_args(x, y):
+    return x + y
+
+
+def sync_with_args(x, y):
+    return x + y
+
+
+class AsyncCallable:
+    async def __call__(self):
+        return 42
+
+
+class SyncCallable:
+    def __call__(self):
+        return 42
+
+
+class WrappedAsyncCallable:
+    @simple_decorator
+    async def __call__(self):
+        return 42
+
+
+class TestAsyncCallable:
+    def test_plain_async_function(self):
+        assert is_async_callable(async_fn)
+
+    def test_plain_sync_function(self):
+        assert not is_async_callable(sync_fn)
+
+    def test_wrapped_async_function_with_wraps(self):
+        assert is_async_callable(wrapped_async_fn)
+
+    def test_wrapped_sync_function_with_wraps(self):
+        assert not is_async_callable(wrapped_sync_fn)
+
+    def test_wrapped_async_function_without_wraps(self):
+        """
+        Without functools.wraps, inspect.unwrap cannot recover the coroutine.
+        This documents expected behavior.
+        """
+        assert not is_async_callable(wrapped_async_fn_no_wraps)
+
+    def test_multi_wrapped_async_function(self):
+        assert is_async_callable(multi_wrapped_async_fn)
+
+    def test_partial_async_function(self):
+        fn = functools.partial(async_with_args, 1)
+        assert is_async_callable(fn)
+
+    def test_partial_sync_function(self):
+        fn = functools.partial(sync_with_args, 1)
+        assert not is_async_callable(fn)
+
+    def test_nested_partial_async_function(self):
+        fn = functools.partial(
+            functools.partial(async_with_args, 1),
+            2,
+        )
+        assert is_async_callable(fn)
+
+    def test_async_callable_class(self):
+        assert is_async_callable(AsyncCallable())
+
+    def test_sync_callable_class(self):
+        assert not is_async_callable(SyncCallable())
+
+    def test_wrapped_async_callable_class(self):
+        assert is_async_callable(WrappedAsyncCallable())
+
+    def test_partial_callable_class(self):
+        fn = functools.partial(AsyncCallable())
+        assert is_async_callable(fn)
+
+    @pytest.mark.parametrize("value", [None, 42, "string", object()])
+    def test_non_callable(self, value):
+        assert not is_async_callable(value)
+
+    def test_task_decorator_async_function(self):
+        @task
+        async def async_task_fn():
+            return 42
+
+        assert is_async_callable(async_task_fn)
+
+    def test_task_decorator_sync_function(self):
+        @task
+        def sync_task_fn():
+            return 42
+
+        assert not is_async_callable(sync_task_fn)

Reply via email to