This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d9d09638a5 BUGFIX: Make sure XComs work correctly in 
MSGraphAsyncOperator with paged results and dynamic task mapping (#40301)
d9d09638a5 is described below

commit d9d09638a5a57ec48e2ed791f248a55202f29869
Author: David Blain <i...@dabla.be>
AuthorDate: Thu Jun 20 11:09:14 2024 +0200

    BUGFIX: Make sure XComs work correctly in MSGraphAsyncOperator with paged 
results and dynamic task mapping (#40301)
    
    
    
    ---------
    
    Co-authored-by: David Blain <david.bl...@infrabel.be>
---
 .../providers/microsoft/azure/operators/msgraph.py | 57 +++++++++++++---------
 tests/providers/microsoft/conftest.py              |  4 +-
 2 files changed, 37 insertions(+), 24 deletions(-)

diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py 
b/airflow/providers/microsoft/azure/operators/msgraph.py
index 39ca32d2b6..cd38795473 100644
--- a/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -178,8 +178,9 @@ class MSGraphAsyncOperator(BaseOperator):
                 event["response"] = result
 
                 try:
-                    self.trigger_next_link(response, 
method_name=self.pull_execute_complete.__name__)
+                    self.trigger_next_link(response=response, 
method_name=self.execute_complete.__name__)
                 except TaskDeferred as exception:
+                    self.results = self.pull_xcom(context=context)
                     self.append_result(
                         result=result,
                         append_result_as_list_if_absent=True,
@@ -188,7 +189,6 @@ class MSGraphAsyncOperator(BaseOperator):
                     raise exception
 
                 self.append_result(result=result)
-                self.log.debug("results: %s", self.results)
 
                 return self.results
         return None
@@ -198,8 +198,6 @@ class MSGraphAsyncOperator(BaseOperator):
         result: Any,
         append_result_as_list_if_absent: bool = False,
     ):
-        self.log.debug("value: %s", result)
-
         if isinstance(self.results, list):
             if isinstance(result, list):
                 self.results.extend(result)
@@ -214,30 +212,43 @@ class MSGraphAsyncOperator(BaseOperator):
             else:
                 self.results = result
 
-    def push_xcom(self, context: Context, value) -> None:
-        self.log.debug("do_xcom_push: %s", self.do_xcom_push)
-        if self.do_xcom_push:
-            self.log.info("Pushing XCom with key '%s': %s", self.key, value)
-            self.xcom_push(context=context, key=self.key, value=value)
-
-    def pull_execute_complete(self, context: Context, event: dict[Any, Any] | 
None = None) -> Any:
-        self.results = list(
-            self.xcom_pull(
-                context=context,
+    def pull_xcom(self, context: Context) -> list:
+        map_index = context["ti"].map_index
+        value = list(
+            context["ti"].xcom_pull(
+                key=self.key,
                 task_ids=self.task_id,
                 dag_id=self.dag_id,
-                key=self.key,
+                map_indexes=map_index,
             )
             or []
         )
-        self.log.info(
-            "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s",
-            self.task_id,
-            self.dag_id,
-            self.key,
-            self.results,
-        )
-        return self.execute_complete(context, event)
+
+        if map_index:
+            self.log.info(
+                "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s' 
and map_index %s: %s",
+                self.task_id,
+                self.dag_id,
+                self.key,
+                map_index,
+                value,
+            )
+        else:
+            self.log.info(
+                "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': 
%s",
+                self.task_id,
+                self.dag_id,
+                self.key,
+                value,
+            )
+
+        return value
+
+    def push_xcom(self, context: Context, value) -> None:
+        self.log.debug("do_xcom_push: %s", self.do_xcom_push)
+        if self.do_xcom_push:
+            self.log.info("Pushing XCom with key '%s': %s", self.key, value)
+            self.xcom_push(context=context, key=self.key, value=value)
 
     @staticmethod
     def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, 
dict[str, Any] | None]:
diff --git a/tests/providers/microsoft/conftest.py 
b/tests/providers/microsoft/conftest.py
index b2db8c44ba..ecd19d8865 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -143,6 +143,8 @@ def mock_context(task) -> Context:
             map_indexes: Iterable[int] | int | None = None,
             default: Any | None = None,
         ) -> Any:
+            if map_indexes:
+                return values.get(f"{task_ids or self.task_id}_{dag_id or 
self.dag_id}_{key}_{map_indexes}")
             return values.get(f"{task_ids or self.task_id}_{dag_id or 
self.dag_id}_{key}")
 
         def xcom_push(
@@ -152,7 +154,7 @@ def mock_context(task) -> Context:
             execution_date: datetime | None = None,
             session: Session = NEW_SESSION,
         ) -> None:
-            values[f"{self.task_id}_{self.dag_id}_{key}"] = value
+            values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = 
value
 
     values["ti"] = MockedTaskInstance(task=task)
 

Reply via email to