This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 0bb56315e6 Added `overrides` parameter to CloudRunExecuteJobOperator (#34874) 0bb56315e6 is described below commit 0bb56315e664875cd764486bb2090e0a2ef747d8 Author: Chloe Sheasby <chl...@phase2online.com> AuthorDate: Wed Oct 25 14:21:41 2023 -0500 Added `overrides` parameter to CloudRunExecuteJobOperator (#34874) --- airflow/providers/google/cloud/hooks/cloud_run.py | 12 +++- .../providers/google/cloud/operators/cloud_run.py | 7 ++- .../operators/cloud/cloud_run.rst | 9 +++ .../providers/google/cloud/hooks/test_cloud_run.py | 15 ++++- .../google/cloud/operators/test_cloud_run.py | 68 +++++++++++++++++++++- .../google/cloud/cloud_run/example_cloud_run.py | 32 +++++++++- 6 files changed, 133 insertions(+), 10 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/cloud_run.py b/airflow/providers/google/cloud/hooks/cloud_run.py index 6cc3b304dc..8741aa1d4d 100644 --- a/airflow/providers/google/cloud/hooks/cloud_run.py +++ b/airflow/providers/google/cloud/hooks/cloud_run.py @@ -18,7 +18,7 @@ from __future__ import annotations import itertools -from typing import TYPE_CHECKING, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Iterable, Sequence from google.cloud.run_v2 import ( CreateJobRequest, @@ -113,9 +113,15 @@ class CloudRunHook(GoogleBaseHook): @GoogleBaseHook.fallback_to_default_project_id def execute_job( - self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID + self, + job_name: str, + region: str, + project_id: str = PROVIDE_PROJECT_ID, + overrides: dict[str, Any] | None = None, ) -> operation.Operation: - run_job_request = RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}") + run_job_request = RunJobRequest( + name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", overrides=overrides + ) operation = self.get_conn().run_job(request=run_job_request) return operation diff --git a/airflow/providers/google/cloud/operators/cloud_run.py b/airflow/providers/google/cloud/operators/cloud_run.py index ba50ea111d..14d27810da 100644 --- a/airflow/providers/google/cloud/operators/cloud_run.py +++ b/airflow/providers/google/cloud/operators/cloud_run.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from google.cloud.run_v2 import Job @@ -248,6 +248,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator): :param job_name: Required. The name of the job to update. :param job: Required. The job descriptor containing the new configuration of the job to update. The name field will be replaced by job_name + :param overrides: Optional map of override values. :param gcp_conn_id: The connection ID used to connect to Google Cloud. :param polling_period_seconds: Optional: Control the rate of the poll for the result of deferrable run. By default, the trigger will poll every 10 seconds. @@ -270,6 +271,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator): project_id: str, region: str, job_name: str, + overrides: dict[str, Any] | None = None, polling_period_seconds: float = 10, timeout_seconds: float | None = None, gcp_conn_id: str = "google_cloud_default", @@ -281,6 +283,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator): self.project_id = project_id self.region = region self.job_name = job_name + self.overrides = overrides self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.polling_period_seconds = polling_period_seconds @@ -293,7 +296,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) self.operation = hook.execute_job( - region=self.region, project_id=self.project_id, job_name=self.job_name + region=self.region, project_id=self.project_id, job_name=self.job_name, overrides=self.overrides ) if not self.deferrable: diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst index 7c80f86d15..cf90afde68 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst @@ -77,6 +77,15 @@ or you can define the same operator in the deferrable mode: :start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode] :end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode] +You can also specify overrides that allow you to give a new entrypoint command to the job and more: + +:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator` + +.. exampleinclude:: /../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_cloud_run_execute_job_with_overrides] + :end-before: [END howto_operator_cloud_run_execute_job_with_overrides] Update a job diff --git a/tests/providers/google/cloud/hooks/test_cloud_run.py b/tests/providers/google/cloud/hooks/test_cloud_run.py index c91bc490f3..6a9a4fa898 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_run.py +++ b/tests/providers/google/cloud/hooks/test_cloud_run.py @@ -34,7 +34,7 @@ from airflow.providers.google.cloud.hooks.cloud_run import CloudRunAsyncHook, Cl from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id -class TestCloudBathHook: +class TestCloudRunHook: def dummy_get_credentials(self): pass @@ -111,9 +111,18 @@ class TestCloudBathHook: job_name = "job1" region = "region1" project_id = "projectid" - run_job_request = RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}") + overrides = { + "container_overrides": [{"args": ["python", "main.py"]}], + "task_count": 1, + "timeout": "60s", + } + run_job_request = RunJobRequest( + name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", overrides=overrides + ) - cloud_run_hook.execute_job(job_name=job_name, region=region, project_id=project_id) + cloud_run_hook.execute_job( + job_name=job_name, region=region, project_id=project_id, overrides=overrides + ) cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request) @mock.patch( diff --git a/tests/providers/google/cloud/operators/test_cloud_run.py b/tests/providers/google/cloud/operators/test_cloud_run.py index 0fe7779158..152e625a23 100644 --- a/tests/providers/google/cloud/operators/test_cloud_run.py +++ b/tests/providers/google/cloud/operators/test_cloud_run.py @@ -96,7 +96,7 @@ class TestCloudRunExecuteJobOperator: operator.execute(context=mock.MagicMock()) hook_mock.return_value.execute_job.assert_called_once_with( - job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID + job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=None ) @mock.patch(CLOUD_RUN_HOOK_PATH) @@ -209,6 +209,72 @@ class TestCloudRunExecuteJobOperator: result = operator.execute_complete(mock.MagicMock(), event) assert result["name"] == JOB_NAME + @mock.patch(CLOUD_RUN_HOOK_PATH) + def test_execute_overrides(self, hook_mock): + hook_mock.return_value.get_job.return_value = JOB + hook_mock.return_value.execute_job.return_value = self._mock_operation(3, 3, 0) + + overrides = { + "container_overrides": [{"args": ["python", "main.py"]}], + "task_count": 1, + "timeout": "60s", + } + + operator = CloudRunExecuteJobOperator( + task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides + ) + + operator.execute(context=mock.MagicMock()) + + hook_mock.return_value.execute_job.assert_called_once_with( + job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=overrides + ) + + @mock.patch(CLOUD_RUN_HOOK_PATH) + def test_execute_overrides_with_invalid_task_count(self, hook_mock): + overrides = { + "container_overrides": [{"args": ["python", "main.py"]}], + "task_count": -1, + "timeout": "60s", + } + + operator = CloudRunExecuteJobOperator( + task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides + ) + + with pytest.raises(AirflowException): + operator.execute(context=mock.MagicMock()) + + @mock.patch(CLOUD_RUN_HOOK_PATH) + def test_execute_overrides_with_invalid_timeout(self, hook_mock): + overrides = { + "container_overrides": [{"args": ["python", "main.py"]}], + "task_count": 1, + "timeout": "60", + } + + operator = CloudRunExecuteJobOperator( + task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides + ) + + with pytest.raises(AirflowException): + operator.execute(context=mock.MagicMock()) + + @mock.patch(CLOUD_RUN_HOOK_PATH) + def test_execute_overrides_with_invalid_container_args(self, hook_mock): + overrides = { + "container_overrides": [{"name": "job", "args": "python main.py"}], + "task_count": 1, + "timeout": "60s", + } + + operator = CloudRunExecuteJobOperator( + task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides + ) + + with pytest.raises(AirflowException): + operator.execute(context=mock.MagicMock()) + def _mock_operation(self, task_count, succeeded_count, failed_count): operation = mock.MagicMock() operation.result.return_value = self._mock_execution(task_count, succeeded_count, failed_count) diff --git a/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py index 330789d82d..08c82d6eb0 100644 --- a/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py +++ b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py @@ -44,12 +44,14 @@ region = "us-central1" job_name_prefix = "cloudrun-system-test-job" job1_name = f"{job_name_prefix}1" job2_name = f"{job_name_prefix}2" +job3_name = f"{job_name_prefix}3" create1_task_name = "create-job1" create2_task_name = "create-job2" execute1_task_name = "execute-job1" execute2_task_name = "execute-job2" +execute3_task_name = "execute-job3" update_job1_task_name = "update-job1" @@ -70,6 +72,9 @@ def _assert_executed_jobs_xcom(ti): job2_dicts = ti.xcom_pull(task_ids=[execute2_task_name], key="return_value") assert job2_name in job2_dicts[0]["name"] + job3_dicts = ti.xcom_pull(task_ids=[execute3_task_name], key="return_value") + assert job3_name in job3_dicts[0]["name"] + def _assert_created_jobs_xcom(ti): job1_dicts = ti.xcom_pull(task_ids=[create1_task_name], key="return_value") @@ -181,6 +186,31 @@ with DAG( ) # [END howto_operator_cloud_run_execute_job_deferrable_mode] + # [START howto_operator_cloud_run_execute_job_with_overrides] + overrides = { + "container_overrides": [ + { + "name": "job", + "args": ["python", "main.py"], + "env": [{"name": "ENV_VAR", "value": "value"}], + "clearArgs": False, + } + ], + "task_count": 1, + "timeout": "60s", + } + + execute3 = CloudRunExecuteJobOperator( + task_id=execute3_task_name, + project_id=PROJECT_ID, + region=region, + overrides=overrides, + job_name=job3_name, + dag=dag, + deferrable=False, + ) + # [END howto_operator_cloud_run_execute_job_with_overrides] + assert_executed_jobs = PythonOperator( task_id="assert-executed-jobs", python_callable=_assert_executed_jobs_xcom, dag=dag ) @@ -237,7 +267,7 @@ with DAG( ( (create1, create2) >> assert_created_jobs - >> (execute1, execute2) + >> (execute1, execute2, execute3) >> assert_executed_jobs >> list_jobs_limit >> assert_jobs_limit