Copilot commented on code in PR #59798:
URL: https://github.com/apache/airflow/pull/59798#discussion_r2646670893


##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +302,77 @@ def execute(self, context: Context) -> None:
         )
         # Add task to job
         self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
-        # Wait for tasks to complete
-        fail_tasks = 
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, 
timeout=self.timeout)
-        # Clean up
-        if self.should_delete_job:
-            # delete job first
-            self.clean_up(job_id=self.batch_job_id)
-        if self.should_delete_pool:
-            self.clean_up(self.batch_pool_id)
-        # raise exception if any task fail
-        if fail_tasks:
-            raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
+        if self.deferrable:
+            # Verify pool and nodes are in terminal state before deferral
+            pool = self.hook.connection.pool.get(self.batch_pool_id)
+            nodes = 
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+            if pool.resize_errors:
+                raise AirflowException(f"Pool resize errors: 
{pool.resize_errors}")
+            self.log.debug("Deferral pre-check: %d nodes present in pool %s", 
len(nodes), self.batch_pool_id)
+
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=AzureBatchJobTrigger(
+                    job_id=self.batch_job_id,
+                    azure_batch_conn_id=self.azure_batch_conn_id,
+                    timeout=self.timeout,
+                ),
+                method_name="execute_complete",
+            )
+            return
+
+        # Wait for tasks to complete (synchronous path) with guaranteed 
cleanup on failure
+        sync_failed = False
+        try:
+            fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+                job_id=self.batch_job_id, timeout=self.timeout
+            )
+            if fail_tasks:
+                sync_failed = True
+                raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
+        finally:
+            if sync_failed:
+                # Ensure cleanup runs before exception propagates (historical 
behavior)
+                if self.should_delete_job:
+                    self.clean_up(job_id=self.batch_job_id)
+                if self.should_delete_pool:
+                    self.clean_up(self.batch_pool_id)
+                self._cleanup_done = True
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:

Review Comment:
   The `execute_complete` method lacks a docstring. Add documentation 
explaining the trigger callback, expected event structure (`status`, 
`fail_task_ids`, `message`), return value, and possible exceptions.
   ```suggestion
       def execute_complete(self, context: Context, event: dict[str, Any] | 
None = None) -> str:
           """
           Callback executed when the deferrable Azure Batch job trigger fires.
   
           This method is invoked after :meth:`defer` when running in 
deferrable mode.
           It processes the event emitted by 
:class:`~airflow.providers.microsoft.azure.triggers.batch.AzureBatchJobTrigger`
           and converts it into the final task outcome.
   
           The event is expected to be a mapping with (at least) the following 
keys:
   
           * ``status``: A string indicating the job outcome. Recognized values 
are:
             * ``"success"``: all tasks completed successfully.
             * ``"failure"``: one or more tasks failed.
             * ``"timeout"``: the job did not complete within the configured 
timeout.
             * ``"error"``: an internal error occurred while waiting for task 
completion.
           * ``fail_task_ids``: Optional iterable of task IDs that failed. 
Defaults to an empty list
             when no failures are reported.
           * ``message``: Optional human-readable description providing 
additional context for
             ``"timeout"`` or ``"error"`` statuses.
   
           :param context: Airflow task context provided by the scheduler on 
resumption.
           :param event: Event dictionary sent from 
:class:`AzureBatchJobTrigger` describing the
               terminal state of the monitored Azure Batch job.
           :returns: The Azure Batch job ID associated with this operator, on 
successful completion.
   
           :raises AirflowException: If no event is received, if the status is 
``"error"`` or
               ``"failure"``, if any failed task IDs are reported, or if an 
unexpected status is
               encountered.
           :raises TimeoutError: If the event indicates that the operation 
timed out.
           """
   ```



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +302,77 @@ def execute(self, context: Context) -> None:
         )
         # Add task to job
         self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
-        # Wait for tasks to complete
-        fail_tasks = 
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, 
timeout=self.timeout)
-        # Clean up
-        if self.should_delete_job:
-            # delete job first
-            self.clean_up(job_id=self.batch_job_id)
-        if self.should_delete_pool:
-            self.clean_up(self.batch_pool_id)
-        # raise exception if any task fail
-        if fail_tasks:
-            raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
+        if self.deferrable:
+            # Verify pool and nodes are in terminal state before deferral
+            pool = self.hook.connection.pool.get(self.batch_pool_id)
+            nodes = 
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+            if pool.resize_errors:
+                raise AirflowException(f"Pool resize errors: 
{pool.resize_errors}")
+            self.log.debug("Deferral pre-check: %d nodes present in pool %s", 
len(nodes), self.batch_pool_id)
+
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=AzureBatchJobTrigger(
+                    job_id=self.batch_job_id,
+                    azure_batch_conn_id=self.azure_batch_conn_id,
+                    timeout=self.timeout,
+                ),
+                method_name="execute_complete",
+            )
+            return
+
+        # Wait for tasks to complete (synchronous path) with guaranteed 
cleanup on failure
+        sync_failed = False
+        try:
+            fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+                job_id=self.batch_job_id, timeout=self.timeout
+            )
+            if fail_tasks:
+                sync_failed = True
+                raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
+        finally:
+            if sync_failed:
+                # Ensure cleanup runs before exception propagates (historical 
behavior)
+                if self.should_delete_job:
+                    self.clean_up(job_id=self.batch_job_id)
+                if self.should_delete_pool:
+                    self.clean_up(self.batch_pool_id)
+                self._cleanup_done = True
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if not event:
+            raise AirflowException("No event received in trigger callback")
+
+        status = event.get("status")
+        fail_task_ids = event.get("fail_task_ids", [])
+
+        if status == "timeout":
+            raise TimeoutError(event.get("message", "Timed out waiting for 
tasks to complete"))

Review Comment:
   Using built-in `TimeoutError` is inconsistent with Airflow patterns. 
Consider using `AirflowException` with a timeout-specific message to maintain 
consistency with other exception handling in this operator.
   ```suggestion
               raise AirflowException(event.get("message", "Timed out waiting 
for tasks to complete"))
   ```



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py:
##########
@@ -296,24 +302,77 @@ def execute(self, context: Context) -> None:
         )
         # Add task to job
         self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)
-        # Wait for tasks to complete
-        fail_tasks = 
self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, 
timeout=self.timeout)
-        # Clean up
-        if self.should_delete_job:
-            # delete job first
-            self.clean_up(job_id=self.batch_job_id)
-        if self.should_delete_pool:
-            self.clean_up(self.batch_pool_id)
-        # raise exception if any task fail
-        if fail_tasks:
-            raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
+        if self.deferrable:
+            # Verify pool and nodes are in terminal state before deferral
+            pool = self.hook.connection.pool.get(self.batch_pool_id)
+            nodes = 
list(self.hook.connection.compute_node.list(self.batch_pool_id))
+            if pool.resize_errors:
+                raise AirflowException(f"Pool resize errors: 
{pool.resize_errors}")
+            self.log.debug("Deferral pre-check: %d nodes present in pool %s", 
len(nodes), self.batch_pool_id)
+
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=AzureBatchJobTrigger(
+                    job_id=self.batch_job_id,
+                    azure_batch_conn_id=self.azure_batch_conn_id,
+                    timeout=self.timeout,
+                ),
+                method_name="execute_complete",
+            )
+            return
+
+        # Wait for tasks to complete (synchronous path) with guaranteed 
cleanup on failure
+        sync_failed = False
+        try:
+            fail_tasks = self.hook.wait_for_job_tasks_to_complete(
+                job_id=self.batch_job_id, timeout=self.timeout
+            )
+            if fail_tasks:
+                sync_failed = True
+                raise AirflowException(f"Job fail. The failed task are: 
{fail_tasks}")
+        finally:
+            if sync_failed:
+                # Ensure cleanup runs before exception propagates (historical 
behavior)
+                if self.should_delete_job:
+                    self.clean_up(job_id=self.batch_job_id)
+                if self.should_delete_pool:
+                    self.clean_up(self.batch_pool_id)
+                self._cleanup_done = True
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if not event:
+            raise AirflowException("No event received in trigger callback")
+
+        status = event.get("status")
+        fail_task_ids = event.get("fail_task_ids", [])
+
+        if status == "timeout":
+            raise TimeoutError(event.get("message", "Timed out waiting for 
tasks to complete"))
+        if status == "error":
+            raise AirflowException(event.get("message", "Unknown error while 
waiting for tasks"))
+        if status == "failure" or fail_task_ids:
+            raise AirflowException(f"Job failed. Failed tasks: 
{fail_task_ids}")
+        if status != "success":
+            raise AirflowException(f"Unexpected event status: {event}")

Review Comment:
   The error message could be more actionable. Consider clarifying what 
statuses are expected and provide the actual status received, e.g., 
`f\"Unexpected event status '{status}'. Expected 'success', 'failure', 
'timeout', or 'error'.\"`
   ```suggestion
               raise AirflowException(
                   f"Unexpected event status '{status}'. Expected 'success', 
'failure', 'timeout', or 'error'. "
                   f"Full event: {event}"
               )
   ```



##########
providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py:
##########
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from datetime import timedelta
+from typing import Any
+
+from azure.batch import models as batch_models
+
+from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils import timezone
+
+
+class AzureBatchJobTrigger(BaseTrigger):
+    """
+    Poll Azure Batch for task completion for a given job.
+
+    :param job_id: The Azure Batch job identifier to poll.
+    :param azure_batch_conn_id: Connection id for Azure Batch.
+    :param timeout: Maximum wait time in minutes.
+    :param poll_interval: Seconds to sleep between polls.
+    """
+
+    def __init__(
+        self,
+        job_id: str,
+        azure_batch_conn_id: str = "azure_batch_default",
+        timeout: int = 25,
+        poll_interval: int = 15,
+    ) -> None:
+        super().__init__()
+        self.job_id = job_id
+        self.azure_batch_conn_id = azure_batch_conn_id
+        self.timeout = timeout
+        self.poll_interval = poll_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize trigger configuration."""
+        return (
+            
"airflow.providers.microsoft.azure.triggers.batch.AzureBatchJobTrigger",
+            {
+                "job_id": self.job_id,
+                "azure_batch_conn_id": self.azure_batch_conn_id,
+                "timeout": self.timeout,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        hook = AzureBatchHook(self.azure_batch_conn_id)
+        timeout_time = timezone.utcnow() + timedelta(minutes=self.timeout)

Review Comment:
   Using `timezone.utcnow()` is deprecated. Use `timezone.utcnow()` 
consistently if it's the project standard, but consider that 
`datetime.now(timezone.utc)` is the recommended approach in modern Python for 
timezone-aware UTC timestamps.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to