This is an automated email from the ASF dual-hosted git repository. phanikumv pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 26964f8a8e feat(providers/dbt): add reuse_existing_run for allowing DbtCloudRunJobOperator to reuse existing run (#37474) 26964f8a8e is described below commit 26964f8a8e740115d40c608b153fa28d6f5979bf Author: Wei Lee <weilee...@gmail.com> AuthorDate: Tue Feb 20 09:37:57 2024 +0800 feat(providers/dbt): add reuse_existing_run for allowing DbtCloudRunJobOperator to reuse existing run (#37474) --- airflow/providers/dbt/cloud/hooks/dbt.py | 15 +++++++ airflow/providers/dbt/cloud/operators/dbt.py | 40 ++++++++++++----- tests/providers/dbt/cloud/hooks/test_dbt.py | 23 ++++++++-- tests/providers/dbt/cloud/operators/test_dbt.py | 57 +++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 13 deletions(-) diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index a375aa27c7..85eba8da04 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -109,6 +109,7 @@ class DbtCloudJobRunStatus(Enum): SUCCESS = 10 ERROR = 20 CANCELLED = 30 + NON_TERMINAL_STATUSES = (QUEUED, STARTING, RUNNING) TERMINAL_STATUSES = (SUCCESS, ERROR, CANCELLED) @classmethod @@ -460,6 +461,20 @@ class DbtCloudHook(HttpHook): paginate=True, ) + @fallback_to_default_account + def get_job_runs(self, account_id: int | None = None, payload: dict[str, Any] | None = None) -> Response: + """ + Retrieve metadata for a specific run of a dbt Cloud job. + + :param account_id: Optional. The ID of a dbt Cloud account. + :param paylod: Optional. Query Parameters + :return: The request response. + """ + return self._run_and_get_response( + endpoint=f"{account_id}/runs/", + payload=payload, + ) + @fallback_to_default_account def get_job_run( self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None diff --git a/airflow/providers/dbt/cloud/operators/dbt.py b/airflow/providers/dbt/cloud/operators/dbt.py index 0b56e88e01..0b31ba3014 100644 --- a/airflow/providers/dbt/cloud/operators/dbt.py +++ b/airflow/providers/dbt/cloud/operators/dbt.py @@ -73,6 +73,8 @@ class DbtCloudRunJobOperator(BaseOperator): Used only if ``wait_for_termination`` is True. Defaults to 60 seconds. :param additional_run_config: Optional. Any additional parameters that should be included in the API request when triggering the job. + :param reuse_existing_run: Flag to determine whether to reuse existing non terminal job run. If set to + true and non terminal job runs found, it use the latest run without triggering a new job run. :param deferrable: Run operator in the deferrable mode :return: The ID of the triggered dbt Cloud job run. """ @@ -102,6 +104,7 @@ class DbtCloudRunJobOperator(BaseOperator): timeout: int = 60 * 60 * 24 * 7, check_interval: int = 60, additional_run_config: dict[str, Any] | None = None, + reuse_existing_run: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: @@ -117,6 +120,7 @@ class DbtCloudRunJobOperator(BaseOperator): self.check_interval = check_interval self.additional_run_config = additional_run_config or {} self.run_id: int | None = None + self.reuse_existing_run = reuse_existing_run self.deferrable = deferrable def execute(self, context: Context): @@ -125,16 +129,32 @@ class DbtCloudRunJobOperator(BaseOperator): f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG." ) - trigger_job_response = self.hook.trigger_job_run( - account_id=self.account_id, - job_id=self.job_id, - cause=self.trigger_reason, - steps_override=self.steps_override, - schema_override=self.schema_override, - additional_run_config=self.additional_run_config, - ) - self.run_id = trigger_job_response.json()["data"]["id"] - job_run_url = trigger_job_response.json()["data"]["href"] + non_terminal_runs = None + if self.reuse_existing_run: + non_terminal_runs = self.hook.get_job_runs( + account_id=self.account_id, + payload={ + "job_definition_id": self.job_id, + "status": DbtCloudJobRunStatus.NON_TERMINAL_STATUSES, + "order_by": "-created_at", + }, + ).json()["data"] + if non_terminal_runs: + self.run_id = non_terminal_runs[0]["id"] + job_run_url = non_terminal_runs[0]["href"] + + if not self.reuse_existing_run or not non_terminal_runs: + trigger_job_response = self.hook.trigger_job_run( + account_id=self.account_id, + job_id=self.job_id, + cause=self.trigger_reason, + steps_override=self.steps_override, + schema_override=self.schema_override, + additional_run_config=self.additional_run_config, + ) + self.run_id = trigger_job_response.json()["data"]["id"] + job_run_url = trigger_job_response.json()["data"]["href"] + # Push the ``job_run_url`` value to XCom regardless of what happens during execution so that the job # run can be monitored via the operator link. context["ti"].xcom_push(key="job_run_url", value=job_run_url) diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py b/tests/providers/dbt/cloud/hooks/test_dbt.py index 20e1313965..39d31a444b 100644 --- a/tests/providers/dbt/cloud/hooks/test_dbt.py +++ b/tests/providers/dbt/cloud/hooks/test_dbt.py @@ -439,6 +439,21 @@ class TestDbtCloudHook: }, ) + @pytest.mark.parametrize( + argnames="conn_id, account_id", + argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], + ids=["default_account", "explicit_account"], + ) + @patch.object(DbtCloudHook, "run") + def test_get_job_runs(self, mock_http_run, conn_id, account_id): + hook = DbtCloudHook(conn_id) + hook.get_job_runs(account_id=account_id) + + assert hook.method == "GET" + + _account_id = account_id or DEFAULT_ACCOUNT_ID + hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/", data=None) + @pytest.mark.parametrize( argnames="conn_id, account_id", argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], @@ -493,9 +508,11 @@ class TestDbtCloudHook: argnames=("job_run_status", "expected_status", "expected_output"), argvalues=wait_for_job_run_status_test_args, ids=[ - f"run_status_{argval[0]}_expected_{argval[1]}" - if isinstance(argval[1], int) - else f"run_status_{argval[0]}_expected_AnyTerminalStatus" + ( + f"run_status_{argval[0]}_expected_{argval[1]}" + if isinstance(argval[1], int) + else f"run_status_{argval[0]}_expected_AnyTerminalStatus" + ) for argval in wait_for_job_run_status_test_args ], ) diff --git a/tests/providers/dbt/cloud/operators/test_dbt.py b/tests/providers/dbt/cloud/operators/test_dbt.py index b4c1aa89e7..90465602dc 100644 --- a/tests/providers/dbt/cloud/operators/test_dbt.py +++ b/tests/providers/dbt/cloud/operators/test_dbt.py @@ -307,6 +307,63 @@ class TestDbtCloudRunJobOperator: mock_get_job_run.assert_not_called() + @patch.object(DbtCloudHook, "get_job_runs") + @patch.object(DbtCloudHook, "trigger_job_run") + @pytest.mark.parametrize( + "conn_id, account_id", + [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], + ids=["default_account", "explicit_account"], + ) + def test_execute_no_wait_for_termination_and_reuse_existing_run( + self, mock_run_job, mock_get_jobs_run, conn_id, account_id + ): + mock_get_jobs_run.return_value.json.return_value = { + "data": [ + { + "id": 10000, + "status": 1, + "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format( + account_id=DEFAULT_ACCOUNT_ID, project_id=PROJECT_ID, run_id=RUN_ID + ), + }, + { + "id": 10001, + "status": 2, + "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format( + account_id=DEFAULT_ACCOUNT_ID, project_id=PROJECT_ID, run_id=RUN_ID + ), + }, + ] + } + + operator = DbtCloudRunJobOperator( + task_id=TASK_ID, + dbt_cloud_conn_id=conn_id, + account_id=account_id, + trigger_reason=None, + dag=self.dag, + wait_for_termination=False, + reuse_existing_run=True, + **self.config, + ) + + assert operator.dbt_cloud_conn_id == conn_id + assert operator.job_id == self.config["job_id"] + assert operator.account_id == account_id + assert operator.check_interval == self.config["check_interval"] + assert operator.timeout == self.config["timeout"] + assert not operator.wait_for_termination + assert operator.steps_override == self.config["steps_override"] + assert operator.schema_override == self.config["schema_override"] + assert operator.additional_run_config == self.config["additional_run_config"] + + with patch.object(DbtCloudHook, "get_job_run") as mock_get_job_run: + operator.execute(context=self.mock_context) + + mock_run_job.assert_not_called() + + mock_get_job_run.assert_not_called() + @patch.object(DbtCloudHook, "trigger_job_run") @pytest.mark.parametrize( "conn_id, account_id",