This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 63c9b432c31 Make `apply_function` optional in `AwaitMessageTrigger` 
(#55437)
63c9b432c31 is described below

commit 63c9b432c3197e247103e6defed4340178305425
Author: Vincent <[email protected]>
AuthorDate: Wed Sep 10 10:25:06 2025 -0400

    Make `apply_function` optional in `AwaitMessageTrigger` (#55437)
---
 .../apache/kafka/triggers/await_message.py         | 28 +++++++++++++---------
 .../apache/kafka/triggers/test_await_message.py    | 11 +++++++--
 2 files changed, 26 insertions(+), 13 deletions(-)

diff --git 
a/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
 
b/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
index d4e140c151f..80d2a418e6f 100644
--- 
a/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
+++ 
b/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
@@ -37,10 +37,8 @@ class AwaitMessageTrigger(BaseTrigger):
     - poll the Kafka topics for a message, if no message returned, sleep
     - process the message with provided callable and commit the message offset:
 
-        - if callable returns any data, raise a TriggerEvent with the return 
data
-
-        - else continue to next message
-
+        - if callable is provided and returns any data, raise a TriggerEvent 
with the return data
+        - else raise a TriggerEvent with the original message
 
     :param kafka_config_id: The connection object to use, defaults to 
"kafka_default"
     :param topics: The topic (or topic regex) that should be searched for 
messages
@@ -59,7 +57,7 @@ class AwaitMessageTrigger(BaseTrigger):
     def __init__(
         self,
         topics: Sequence[str],
-        apply_function: str,
+        apply_function: str | None = None,
         kafka_config_id: str = "kafka_default",
         apply_function_args: Sequence[Any] | None = None,
         apply_function_kwargs: dict[Any, Any] | None = None,
@@ -97,9 +95,13 @@ class AwaitMessageTrigger(BaseTrigger):
         async_poll = sync_to_async(consumer.poll)
         async_commit = sync_to_async(consumer.commit)
 
-        processing_call = import_string(self.apply_function)
-        processing_call = partial(processing_call, *self.apply_function_args, 
**self.apply_function_kwargs)
-        async_message_process = sync_to_async(processing_call)
+        async_message_process = None
+        if self.apply_function:
+            processing_call = import_string(self.apply_function)
+            processing_call = partial(
+                processing_call, *self.apply_function_args, 
**self.apply_function_kwargs
+            )
+            async_message_process = sync_to_async(processing_call)
         while True:
             message = await async_poll(self.poll_timeout)
 
@@ -108,10 +110,14 @@ class AwaitMessageTrigger(BaseTrigger):
             elif message.error():
                 raise AirflowException(f"Error: {message.error()}")
             else:
-                rv = await async_message_process(message)
-                if rv:
+                event = (
+                    await async_message_process(message)
+                    if async_message_process
+                    else message.value().decode("utf-8")
+                )
+                if event:
                     await async_commit(message=message, asynchronous=False)
-                    yield TriggerEvent(rv)
+                    yield TriggerEvent(event)
                     break
                 else:
                     await async_commit(message=message, asynchronous=False)
diff --git 
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py 
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
index 982089bb7c0..f08c17dfb5f 100644
--- 
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
+++ 
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
@@ -99,13 +99,20 @@ class TestTrigger:
             poll_interval=5,
         )
 
+    @pytest.mark.parametrize(
+        "apply_function",
+        [
+            
"unit.apache.kafka.triggers.test_await_message.apply_function_true",
+            None,
+        ],
+    )
     @pytest.mark.asyncio
-    async def test_trigger_run_good(self, mocker):
+    async def test_trigger_run_good(self, mocker, apply_function):
         mocker.patch.object(KafkaConsumerHook, "get_consumer", 
return_value=MockedConsumer)
 
         trigger = AwaitMessageTrigger(
             kafka_config_id="kafka_d",
-            
apply_function="unit.apache.kafka.triggers.test_await_message.apply_function_true",
+            apply_function=apply_function,
             topics=["noop"],
             poll_timeout=0.0001,
             poll_interval=5,

Reply via email to