This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 55049c50d5 Add deferrable mode to `DbtCloudRunJobOperator` (#29014)
55049c50d5 is described below

commit 55049c50d52323e242c2387f285f0591ea38cde7
Author: Phani Kumar <94376113+phanik...@users.noreply.github.com>
AuthorDate: Mon Jan 23 17:36:24 2023 +0530

    Add deferrable mode to `DbtCloudRunJobOperator` (#29014)
    
    This PR donates the `DbtCloudRunJobOperatorAsync` from 
[astronomer-providers](https://github.com/astronomer/astronomer-providers) repo
---
 airflow/providers/dbt/cloud/hooks/dbt.py           | 110 ++++++++++++++++++-
 airflow/providers/dbt/cloud/operators/dbt.py       |  67 +++++++++---
 airflow/providers/dbt/cloud/provider.yaml          |   2 +
 airflow/providers/dbt/cloud/triggers/__init__.py   |  16 +++
 airflow/providers/dbt/cloud/triggers/dbt.py        | 119 +++++++++++++++++++++
 .../operators.rst                                  |  12 +++
 generated/provider_dependencies.json               |   4 +-
 7 files changed, 314 insertions(+), 16 deletions(-)

diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py 
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 4b6ac2151a..3ddeeb222b 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -22,8 +22,11 @@ import warnings
 from enum import Enum
 from functools import wraps
 from inspect import signature
-from typing import Any, Callable, Sequence, Set
+from typing import Any, Callable, Sequence, Set, TypeVar, cast
 
+import aiohttp
+from aiohttp import ClientResponseError
+from asgiref.sync import sync_to_async
 from requests import PreparedRequest, Session
 from requests.auth import AuthBase
 from requests.models import Response
@@ -125,6 +128,34 @@ class DbtCloudJobRunException(AirflowException):
     """An exception that indicates a job run failed to complete."""
 
 
+T = TypeVar("T", bound=Any)
+
+
+def provide_account_id(func: T) -> T:
+    """
+    Decorator which provides a fallback value for ``account_id``. If the 
``account_id`` is None or not passed
+    to the decorated function, the value will be taken from the configured dbt 
Cloud Airflow Connection.
+    """
+    function_signature = signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = function_signature.bind(*args, **kwargs)
+
+        if bound_args.arguments.get("account_id") is None:
+            self = args[0]
+            if self.dbt_cloud_conn_id:
+                connection = await 
sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
+                default_account_id = connection.login
+                if not default_account_id:
+                    raise AirflowException("Could not determine the dbt Cloud 
account.")
+                bound_args.arguments["account_id"] = int(default_account_id)
+
+        return await func(*bound_args.args, **bound_args.kwargs)
+
+    return cast(T, wrapper)
+
+
 class DbtCloudHook(HttpHook):
     """
     Interact with dbt Cloud using the V2 API.
@@ -150,6 +181,83 @@ class DbtCloudHook(HttpHook):
         super().__init__(auth_type=TokenAuth)
         self.dbt_cloud_conn_id = dbt_cloud_conn_id
 
+    @staticmethod
+    def get_request_url_params(
+        tenant: str, endpoint: str, include_related: list[str] | None = None
+    ) -> tuple[str, dict[str, Any]]:
+        """
+        Form URL from base url and endpoint url
+
+        :param tenant: The tenant name which is need to be replaced in base 
url.
+        :param endpoint: Endpoint url to be requested.
+        :param include_related: Optional. List of related fields to pull with 
the run.
+            Valid values are "trigger", "job", "repository", and "environment".
+        """
+        data: dict[str, Any] = {}
+        base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/";
+        if include_related:
+            data = {"include_related": include_related}
+        if base_url and not base_url.endswith("/") and endpoint and not 
endpoint.startswith("/"):
+            url = base_url + "/" + endpoint
+        else:
+            url = (base_url or "") + (endpoint or "")
+        return url, data
+
+    async def get_headers_tenants_from_connection(self) -> tuple[dict[str, 
Any], str]:
+        """Get Headers, tenants from the connection details"""
+        headers: dict[str, Any] = {}
+        connection: Connection = await 
sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
+        tenant: str = connection.schema if connection.schema else "cloud"
+        package_name, provider_version = _get_provider_info()
+        headers["User-Agent"] = f"{package_name}-v{provider_version}"
+        headers["Content-Type"] = "application/json"
+        headers["Authorization"] = f"Token {connection.password}"
+        return headers, tenant
+
+    @provide_account_id
+    async def get_job_details(
+        self, run_id: int, account_id: int | None = None, include_related: 
list[str] | None = None
+    ) -> Any:
+        """
+        Uses Http async call to retrieve metadata for a specific run of a dbt 
Cloud job.
+
+        :param run_id: The ID of a dbt Cloud job run.
+        :param account_id: Optional. The ID of a dbt Cloud account.
+        :param include_related: Optional. List of related fields to pull with 
the run.
+            Valid values are "trigger", "job", "repository", and "environment".
+        """
+        endpoint = f"{account_id}/runs/{run_id}/"
+        headers, tenant = await self.get_headers_tenants_from_connection()
+        url, params = self.get_request_url_params(tenant, endpoint, 
include_related)
+        async with aiohttp.ClientSession(headers=headers) as session:
+            async with session.get(url, params=params) as response:
+                try:
+                    response.raise_for_status()
+                    return await response.json()
+                except ClientResponseError as e:
+                    raise AirflowException(str(e.status) + ":" + e.message)
+
+    async def get_job_status(
+        self, run_id: int, account_id: int | None = None, include_related: 
list[str] | None = None
+    ) -> int:
+        """
+        Retrieves the status for a specific run of a dbt Cloud job.
+
+        :param run_id: The ID of a dbt Cloud job run.
+        :param account_id: Optional. The ID of a dbt Cloud account.
+        :param include_related: Optional. List of related fields to pull with 
the run.
+            Valid values are "trigger", "job", "repository", and "environment".
+        """
+        try:
+            self.log.info("Getting the status of job run %s.", str(run_id))
+            response = await self.get_job_details(
+                run_id, account_id=account_id, include_related=include_related
+            )
+            job_run_status: int = response["data"]["status"]
+            return job_run_status
+        except Exception as e:
+            raise e
+
     @cached_property
     def connection(self) -> Connection:
         _connection = self.get_connection(self.dbt_cloud_conn_id)
diff --git a/airflow/providers/dbt/cloud/operators/dbt.py 
b/airflow/providers/dbt/cloud/operators/dbt.py
index 472b2ffa7f..f65ce077d3 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -17,10 +17,14 @@
 from __future__ import annotations
 
 import json
+import time
+import warnings
 from typing import TYPE_CHECKING, Any
 
+from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, BaseOperatorLink, XCom
 from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, 
DbtCloudJobRunException, DbtCloudJobRunStatus
+from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -63,6 +67,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         Used only if ``wait_for_termination`` is True. Defaults to 60 seconds.
     :param additional_run_config: Optional. Any additional parameters that 
should be included in the API
         request when triggering the job.
+    :param deferrable: Run operator in the deferrable mode
     :return: The ID of the triggered dbt Cloud job run.
     """
 
@@ -91,6 +96,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         timeout: int = 60 * 60 * 24 * 7,
         check_interval: int = 60,
         additional_run_config: dict[str, Any] | None = None,
+        deferrable: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -106,8 +112,9 @@ class DbtCloudRunJobOperator(BaseOperator):
         self.additional_run_config = additional_run_config or {}
         self.hook: DbtCloudHook
         self.run_id: int
+        self.deferrable = deferrable
 
-    def execute(self, context: Context) -> int:
+    def execute(self, context: Context):
         if self.trigger_reason is None:
             self.trigger_reason = (
                 f"Triggered via Apache Airflow by task {self.task_id!r} in the 
{self.dag.dag_id} DAG."
@@ -129,20 +136,52 @@ class DbtCloudRunJobOperator(BaseOperator):
         context["ti"].xcom_push(key="job_run_url", value=job_run_url)
 
         if self.wait_for_termination:
-            self.log.info("Waiting for job run %s to terminate.", 
str(self.run_id))
-
-            if self.hook.wait_for_job_run_status(
-                run_id=self.run_id,
-                account_id=self.account_id,
-                expected_statuses=DbtCloudJobRunStatus.SUCCESS.value,
-                check_interval=self.check_interval,
-                timeout=self.timeout,
-            ):
-                self.log.info("Job run %s has completed successfully.", 
str(self.run_id))
+            if self.deferrable is False:
+                self.log.info("Waiting for job run %s to terminate.", 
str(self.run_id))
+
+                if self.hook.wait_for_job_run_status(
+                    run_id=self.run_id,
+                    account_id=self.account_id,
+                    expected_statuses=DbtCloudJobRunStatus.SUCCESS.value,
+                    check_interval=self.check_interval,
+                    timeout=self.timeout,
+                ):
+                    self.log.info("Job run %s has completed successfully.", 
str(self.run_id))
+                else:
+                    raise DbtCloudJobRunException(f"Job run {self.run_id} has 
failed or has been cancelled.")
+
+                return self.run_id
             else:
-                raise DbtCloudJobRunException(f"Job run {self.run_id} has 
failed or has been cancelled.")
-
-        return self.run_id
+                end_time = time.time() + self.timeout
+                self.defer(
+                    timeout=self.execution_timeout,
+                    trigger=DbtCloudRunJobTrigger(
+                        conn_id=self.dbt_cloud_conn_id,
+                        run_id=self.run_id,
+                        end_time=end_time,
+                        account_id=self.account_id,
+                        poll_interval=self.check_interval,
+                    ),
+                    method_name="execute_complete",
+                )
+        else:
+            if self.deferrable is True:
+                warnings.warn(
+                    "Argument `wait_for_termination` is False and `deferrable` 
is True , hence "
+                    "`deferrable` parameter doesn't have any effect",
+                )
+            return self.run_id
+
+    def execute_complete(self, context: "Context", event: dict[str, Any]) -> 
int:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if event["status"] == "error":
+            raise AirflowException(event["message"])
+        self.log.info(event["message"])
+        return int(event["run_id"])
 
     def on_kill(self) -> None:
         if self.run_id:
diff --git a/airflow/providers/dbt/cloud/provider.yaml 
b/airflow/providers/dbt/cloud/provider.yaml
index ad2817eb8e..4315f9c272 100644
--- a/airflow/providers/dbt/cloud/provider.yaml
+++ b/airflow/providers/dbt/cloud/provider.yaml
@@ -34,6 +34,8 @@ versions:
 dependencies:
   - apache-airflow>=2.3.0
   - apache-airflow-providers-http
+  - asgiref
+  - aiohttp
 
 integrations:
   - integration-name: dbt Cloud
diff --git a/airflow/providers/dbt/cloud/triggers/__init__.py 
b/airflow/providers/dbt/cloud/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/dbt/cloud/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/dbt/cloud/triggers/dbt.py 
b/airflow/providers/dbt/cloud/triggers/dbt.py
new file mode 100644
index 0000000000..9bad789a52
--- /dev/null
+++ b/airflow/providers/dbt/cloud/triggers/dbt.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import time
+from typing import Any, AsyncIterator
+
+from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, 
DbtCloudJobRunStatus
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class DbtCloudRunJobTrigger(BaseTrigger):
+    """
+    DbtCloudRunJobTrigger is triggered with run id and account id, makes async 
Http call to dbt and
+    get the status for the submitted job with run id in polling interval of 
time.
+
+    :param conn_id: The connection identifier for connecting to Dbt.
+    :param run_id: The ID of a dbt Cloud job.
+    :param end_time: Time in seconds to wait for a job run to reach a terminal 
status. Defaults to 7 days.
+    :param account_id: The ID of a dbt Cloud account.
+    :param poll_interval:  polling period in seconds to check for the status.
+    """
+
+    def __init__(
+        self,
+        conn_id: str,
+        run_id: int,
+        end_time: float,
+        poll_interval: float,
+        account_id: int | None,
+    ):
+        super().__init__()
+        self.run_id = run_id
+        self.account_id = account_id
+        self.conn_id = conn_id
+        self.end_time = end_time
+        self.poll_interval = poll_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes DbtCloudRunJobTrigger arguments and classpath."""
+        return (
+            "airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger",
+            {
+                "run_id": self.run_id,
+                "account_id": self.account_id,
+                "conn_id": self.conn_id,
+                "end_time": self.end_time,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator["TriggerEvent"]:
+        """Make async connection to Dbt, polls for the pipeline run status"""
+        hook = DbtCloudHook(self.conn_id)
+        try:
+            while await self.is_still_running(hook):
+                if self.end_time < time.time():
+                    yield TriggerEvent(
+                        {
+                            "status": "error",
+                            "message": f"Job run {self.run_id} has not reached 
a terminal status after "
+                            f"{self.end_time} seconds.",
+                            "run_id": self.run_id,
+                        }
+                    )
+                await asyncio.sleep(self.poll_interval)
+            job_run_status = await hook.get_job_status(self.run_id, 
self.account_id)
+            if job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
+                yield TriggerEvent(
+                    {
+                        "status": "success",
+                        "message": f"Job run {self.run_id} has completed 
successfully.",
+                        "run_id": self.run_id,
+                    }
+                )
+            elif job_run_status == DbtCloudJobRunStatus.CANCELLED.value:
+                yield TriggerEvent(
+                    {
+                        "status": "cancelled",
+                        "message": f"Job run {self.run_id} has been 
cancelled.",
+                        "run_id": self.run_id,
+                    }
+                )
+            else:
+                yield TriggerEvent(
+                    {
+                        "status": "error",
+                        "message": f"Job run {self.run_id} has failed.",
+                        "run_id": self.run_id,
+                    }
+                )
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e), 
"run_id": self.run_id})
+
+    async def is_still_running(self, hook: DbtCloudHook) -> bool:
+        """
+        Async function to check whether the job is submitted via async API is 
in
+        running state and returns True if it is still running else
+        return False
+        """
+        job_run_status = await hook.get_job_status(self.run_id, 
self.account_id)
+        if not DbtCloudJobRunStatus.is_terminal(job_run_status):
+            return True
+        return False
diff --git a/docs/apache-airflow-providers-dbt-cloud/operators.rst 
b/docs/apache-airflow-providers-dbt-cloud/operators.rst
index de5b0b8060..1f7b27b280 100644
--- a/docs/apache-airflow-providers-dbt-cloud/operators.rst
+++ b/docs/apache-airflow-providers-dbt-cloud/operators.rst
@@ -40,6 +40,18 @@ execution time. This functionality is controlled by the 
``wait_for_termination``
 :class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`). 
Setting ``wait_for_termination`` to
 False is a good approach for long-running dbt Cloud jobs.
 
+The ``deferrable`` parameter along with ``wait_for_termination`` will control 
the functionality
+whether to poll the job status on the worker or defer using the Triggerer.
+When ``wait_for_termination`` is True and ``deferrable`` is False,we submit 
the job and ``poll``
+for its status on the worker. This will keep the worker slot occupied till the 
job execution is done.
+When ``wait_for_termination`` is True and ``deferrable`` is True,
+we submit the job and ``defer`` using Triggerer. This will release the worker 
slot leading to savings in
+resource utilization while the job is running.
+
+When ``wait_for_termination`` is False and ``deferrable`` is False, we just 
submit the job and can only
+track the job status with the 
:class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`.
+
+
 While ``schema_override`` and ``steps_override`` are explicit, optional 
parameters for the
 ``DbtCloudRunJobOperator``, custom run configurations can also be passed to 
the operator using the
 ``additional_run_config`` dictionary. This parameter can be used to initialize 
additional runtime
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 90f878796b..8f42cded79 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -241,8 +241,10 @@
   },
   "dbt.cloud": {
     "deps": [
+      "aiohttp",
       "apache-airflow-providers-http",
-      "apache-airflow>=2.3.0"
+      "apache-airflow>=2.3.0",
+      "asgiref"
     ],
     "cross-providers-deps": [
       "http"

Reply via email to