This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 26964f8a8e feat(providers/dbt): add reuse_existing_run for allowing 
DbtCloudRunJobOperator to reuse existing run (#37474)
26964f8a8e is described below

commit 26964f8a8e740115d40c608b153fa28d6f5979bf
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Tue Feb 20 09:37:57 2024 +0800

    feat(providers/dbt): add reuse_existing_run for allowing 
DbtCloudRunJobOperator to reuse existing run (#37474)
---
 airflow/providers/dbt/cloud/hooks/dbt.py        | 15 +++++++
 airflow/providers/dbt/cloud/operators/dbt.py    | 40 ++++++++++++-----
 tests/providers/dbt/cloud/hooks/test_dbt.py     | 23 ++++++++--
 tests/providers/dbt/cloud/operators/test_dbt.py | 57 +++++++++++++++++++++++++
 4 files changed, 122 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py 
b/airflow/providers/dbt/cloud/hooks/dbt.py
index a375aa27c7..85eba8da04 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -109,6 +109,7 @@ class DbtCloudJobRunStatus(Enum):
     SUCCESS = 10
     ERROR = 20
     CANCELLED = 30
+    NON_TERMINAL_STATUSES = (QUEUED, STARTING, RUNNING)
     TERMINAL_STATUSES = (SUCCESS, ERROR, CANCELLED)
 
     @classmethod
@@ -460,6 +461,20 @@ class DbtCloudHook(HttpHook):
             paginate=True,
         )
 
+    @fallback_to_default_account
+    def get_job_runs(self, account_id: int | None = None, payload: dict[str, 
Any] | None = None) -> Response:
+        """
+        Retrieve metadata for a specific run of a dbt Cloud job.
+
+        :param account_id: Optional. The ID of a dbt Cloud account.
+        :param paylod: Optional. Query Parameters
+        :return: The request response.
+        """
+        return self._run_and_get_response(
+            endpoint=f"{account_id}/runs/",
+            payload=payload,
+        )
+
     @fallback_to_default_account
     def get_job_run(
         self, run_id: int, account_id: int | None = None, include_related: 
list[str] | None = None
diff --git a/airflow/providers/dbt/cloud/operators/dbt.py 
b/airflow/providers/dbt/cloud/operators/dbt.py
index 0b56e88e01..0b31ba3014 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -73,6 +73,8 @@ class DbtCloudRunJobOperator(BaseOperator):
         Used only if ``wait_for_termination`` is True. Defaults to 60 seconds.
     :param additional_run_config: Optional. Any additional parameters that 
should be included in the API
         request when triggering the job.
+    :param reuse_existing_run: Flag to determine whether to reuse existing non 
terminal job run. If set to
+        true and non terminal job runs found, it use the latest run without 
triggering a new job run.
     :param deferrable: Run operator in the deferrable mode
     :return: The ID of the triggered dbt Cloud job run.
     """
@@ -102,6 +104,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         timeout: int = 60 * 60 * 24 * 7,
         check_interval: int = 60,
         additional_run_config: dict[str, Any] | None = None,
+        reuse_existing_run: bool = False,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ) -> None:
@@ -117,6 +120,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         self.check_interval = check_interval
         self.additional_run_config = additional_run_config or {}
         self.run_id: int | None = None
+        self.reuse_existing_run = reuse_existing_run
         self.deferrable = deferrable
 
     def execute(self, context: Context):
@@ -125,16 +129,32 @@ class DbtCloudRunJobOperator(BaseOperator):
                 f"Triggered via Apache Airflow by task {self.task_id!r} in the 
{self.dag.dag_id} DAG."
             )
 
-        trigger_job_response = self.hook.trigger_job_run(
-            account_id=self.account_id,
-            job_id=self.job_id,
-            cause=self.trigger_reason,
-            steps_override=self.steps_override,
-            schema_override=self.schema_override,
-            additional_run_config=self.additional_run_config,
-        )
-        self.run_id = trigger_job_response.json()["data"]["id"]
-        job_run_url = trigger_job_response.json()["data"]["href"]
+        non_terminal_runs = None
+        if self.reuse_existing_run:
+            non_terminal_runs = self.hook.get_job_runs(
+                account_id=self.account_id,
+                payload={
+                    "job_definition_id": self.job_id,
+                    "status": DbtCloudJobRunStatus.NON_TERMINAL_STATUSES,
+                    "order_by": "-created_at",
+                },
+            ).json()["data"]
+            if non_terminal_runs:
+                self.run_id = non_terminal_runs[0]["id"]
+                job_run_url = non_terminal_runs[0]["href"]
+
+        if not self.reuse_existing_run or not non_terminal_runs:
+            trigger_job_response = self.hook.trigger_job_run(
+                account_id=self.account_id,
+                job_id=self.job_id,
+                cause=self.trigger_reason,
+                steps_override=self.steps_override,
+                schema_override=self.schema_override,
+                additional_run_config=self.additional_run_config,
+            )
+            self.run_id = trigger_job_response.json()["data"]["id"]
+            job_run_url = trigger_job_response.json()["data"]["href"]
+
         # Push the ``job_run_url`` value to XCom regardless of what happens 
during execution so that the job
         # run can be monitored via the operator link.
         context["ti"].xcom_push(key="job_run_url", value=job_run_url)
diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py 
b/tests/providers/dbt/cloud/hooks/test_dbt.py
index 20e1313965..39d31a444b 100644
--- a/tests/providers/dbt/cloud/hooks/test_dbt.py
+++ b/tests/providers/dbt/cloud/hooks/test_dbt.py
@@ -439,6 +439,21 @@ class TestDbtCloudHook:
             },
         )
 
+    @pytest.mark.parametrize(
+        argnames="conn_id, account_id",
+        argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+        ids=["default_account", "explicit_account"],
+    )
+    @patch.object(DbtCloudHook, "run")
+    def test_get_job_runs(self, mock_http_run, conn_id, account_id):
+        hook = DbtCloudHook(conn_id)
+        hook.get_job_runs(account_id=account_id)
+
+        assert hook.method == "GET"
+
+        _account_id = account_id or DEFAULT_ACCOUNT_ID
+        hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/", 
data=None)
+
     @pytest.mark.parametrize(
         argnames="conn_id, account_id",
         argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
@@ -493,9 +508,11 @@ class TestDbtCloudHook:
         argnames=("job_run_status", "expected_status", "expected_output"),
         argvalues=wait_for_job_run_status_test_args,
         ids=[
-            f"run_status_{argval[0]}_expected_{argval[1]}"
-            if isinstance(argval[1], int)
-            else f"run_status_{argval[0]}_expected_AnyTerminalStatus"
+            (
+                f"run_status_{argval[0]}_expected_{argval[1]}"
+                if isinstance(argval[1], int)
+                else f"run_status_{argval[0]}_expected_AnyTerminalStatus"
+            )
             for argval in wait_for_job_run_status_test_args
         ],
     )
diff --git a/tests/providers/dbt/cloud/operators/test_dbt.py 
b/tests/providers/dbt/cloud/operators/test_dbt.py
index b4c1aa89e7..90465602dc 100644
--- a/tests/providers/dbt/cloud/operators/test_dbt.py
+++ b/tests/providers/dbt/cloud/operators/test_dbt.py
@@ -307,6 +307,63 @@ class TestDbtCloudRunJobOperator:
 
             mock_get_job_run.assert_not_called()
 
+    @patch.object(DbtCloudHook, "get_job_runs")
+    @patch.object(DbtCloudHook, "trigger_job_run")
+    @pytest.mark.parametrize(
+        "conn_id, account_id",
+        [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+        ids=["default_account", "explicit_account"],
+    )
+    def test_execute_no_wait_for_termination_and_reuse_existing_run(
+        self, mock_run_job, mock_get_jobs_run, conn_id, account_id
+    ):
+        mock_get_jobs_run.return_value.json.return_value = {
+            "data": [
+                {
+                    "id": 10000,
+                    "status": 1,
+                    "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format(
+                        account_id=DEFAULT_ACCOUNT_ID, project_id=PROJECT_ID, 
run_id=RUN_ID
+                    ),
+                },
+                {
+                    "id": 10001,
+                    "status": 2,
+                    "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format(
+                        account_id=DEFAULT_ACCOUNT_ID, project_id=PROJECT_ID, 
run_id=RUN_ID
+                    ),
+                },
+            ]
+        }
+
+        operator = DbtCloudRunJobOperator(
+            task_id=TASK_ID,
+            dbt_cloud_conn_id=conn_id,
+            account_id=account_id,
+            trigger_reason=None,
+            dag=self.dag,
+            wait_for_termination=False,
+            reuse_existing_run=True,
+            **self.config,
+        )
+
+        assert operator.dbt_cloud_conn_id == conn_id
+        assert operator.job_id == self.config["job_id"]
+        assert operator.account_id == account_id
+        assert operator.check_interval == self.config["check_interval"]
+        assert operator.timeout == self.config["timeout"]
+        assert not operator.wait_for_termination
+        assert operator.steps_override == self.config["steps_override"]
+        assert operator.schema_override == self.config["schema_override"]
+        assert operator.additional_run_config == 
self.config["additional_run_config"]
+
+        with patch.object(DbtCloudHook, "get_job_run") as mock_get_job_run:
+            operator.execute(context=self.mock_context)
+
+            mock_run_job.assert_not_called()
+
+            mock_get_job_run.assert_not_called()
+
     @patch.object(DbtCloudHook, "trigger_job_run")
     @pytest.mark.parametrize(
         "conn_id, account_id",

Reply via email to