This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new d9d09638a5 BUGFIX: Make sure XComs work correctly in MSGraphAsyncOperator with paged results and dynamic task mapping (#40301) d9d09638a5 is described below commit d9d09638a5a57ec48e2ed791f248a55202f29869 Author: David Blain <i...@dabla.be> AuthorDate: Thu Jun 20 11:09:14 2024 +0200 BUGFIX: Make sure XComs work correctly in MSGraphAsyncOperator with paged results and dynamic task mapping (#40301) --------- Co-authored-by: David Blain <david.bl...@infrabel.be> --- .../providers/microsoft/azure/operators/msgraph.py | 57 +++++++++++++--------- tests/providers/microsoft/conftest.py | 4 +- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index 39ca32d2b6..cd38795473 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -178,8 +178,9 @@ class MSGraphAsyncOperator(BaseOperator): event["response"] = result try: - self.trigger_next_link(response, method_name=self.pull_execute_complete.__name__) + self.trigger_next_link(response=response, method_name=self.execute_complete.__name__) except TaskDeferred as exception: + self.results = self.pull_xcom(context=context) self.append_result( result=result, append_result_as_list_if_absent=True, @@ -188,7 +189,6 @@ class MSGraphAsyncOperator(BaseOperator): raise exception self.append_result(result=result) - self.log.debug("results: %s", self.results) return self.results return None @@ -198,8 +198,6 @@ class MSGraphAsyncOperator(BaseOperator): result: Any, append_result_as_list_if_absent: bool = False, ): - self.log.debug("value: %s", result) - if isinstance(self.results, list): if isinstance(result, list): self.results.extend(result) @@ -214,30 +212,43 @@ class MSGraphAsyncOperator(BaseOperator): else: self.results = result - def push_xcom(self, context: Context, value) -> None: - self.log.debug("do_xcom_push: %s", self.do_xcom_push) - if self.do_xcom_push: - self.log.info("Pushing XCom with key '%s': %s", self.key, value) - self.xcom_push(context=context, key=self.key, value=value) - - def pull_execute_complete(self, context: Context, event: dict[Any, Any] | None = None) -> Any: - self.results = list( - self.xcom_pull( - context=context, + def pull_xcom(self, context: Context) -> list: + map_index = context["ti"].map_index + value = list( + context["ti"].xcom_pull( + key=self.key, task_ids=self.task_id, dag_id=self.dag_id, - key=self.key, + map_indexes=map_index, ) or [] ) - self.log.info( - "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s", - self.task_id, - self.dag_id, - self.key, - self.results, - ) - return self.execute_complete(context, event) + + if map_index: + self.log.info( + "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s' and map_index %s: %s", + self.task_id, + self.dag_id, + self.key, + map_index, + value, + ) + else: + self.log.info( + "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s", + self.task_id, + self.dag_id, + self.key, + value, + ) + + return value + + def push_xcom(self, context: Context, value) -> None: + self.log.debug("do_xcom_push: %s", self.do_xcom_push) + if self.do_xcom_push: + self.log.info("Pushing XCom with key '%s': %s", self.key, value) + self.xcom_push(context=context, key=self.key, value=value) @staticmethod def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]: diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index b2db8c44ba..ecd19d8865 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -143,6 +143,8 @@ def mock_context(task) -> Context: map_indexes: Iterable[int] | int | None = None, default: Any | None = None, ) -> Any: + if map_indexes: + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") def xcom_push( @@ -152,7 +154,7 @@ def mock_context(task) -> Context: execution_date: datetime | None = None, session: Session = NEW_SESSION, ) -> None: - values[f"{self.task_id}_{self.dag_id}_{key}"] = value + values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value values["ti"] = MockedTaskInstance(task=task)