This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0bb56315e6 Added `overrides` parameter to CloudRunExecuteJobOperator 
(#34874)
0bb56315e6 is described below

commit 0bb56315e664875cd764486bb2090e0a2ef747d8
Author: Chloe Sheasby <chl...@phase2online.com>
AuthorDate: Wed Oct 25 14:21:41 2023 -0500

    Added `overrides` parameter to CloudRunExecuteJobOperator (#34874)
---
 airflow/providers/google/cloud/hooks/cloud_run.py  | 12 +++-
 .../providers/google/cloud/operators/cloud_run.py  |  7 ++-
 .../operators/cloud/cloud_run.rst                  |  9 +++
 .../providers/google/cloud/hooks/test_cloud_run.py | 15 ++++-
 .../google/cloud/operators/test_cloud_run.py       | 68 +++++++++++++++++++++-
 .../google/cloud/cloud_run/example_cloud_run.py    | 32 +++++++++-
 6 files changed, 133 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/cloud_run.py 
b/airflow/providers/google/cloud/hooks/cloud_run.py
index 6cc3b304dc..8741aa1d4d 100644
--- a/airflow/providers/google/cloud/hooks/cloud_run.py
+++ b/airflow/providers/google/cloud/hooks/cloud_run.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import itertools
-from typing import TYPE_CHECKING, Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Sequence
 
 from google.cloud.run_v2 import (
     CreateJobRequest,
@@ -113,9 +113,15 @@ class CloudRunHook(GoogleBaseHook):
 
     @GoogleBaseHook.fallback_to_default_project_id
     def execute_job(
-        self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
+        self,
+        job_name: str,
+        region: str,
+        project_id: str = PROVIDE_PROJECT_ID,
+        overrides: dict[str, Any] | None = None,
     ) -> operation.Operation:
-        run_job_request = 
RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+        run_job_request = RunJobRequest(
+            name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", 
overrides=overrides
+        )
         operation = self.get_conn().run_job(request=run_job_request)
         return operation
 
diff --git a/airflow/providers/google/cloud/operators/cloud_run.py 
b/airflow/providers/google/cloud/operators/cloud_run.py
index ba50ea111d..14d27810da 100644
--- a/airflow/providers/google/cloud/operators/cloud_run.py
+++ b/airflow/providers/google/cloud/operators/cloud_run.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from google.cloud.run_v2 import Job
 
@@ -248,6 +248,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
     :param job_name: Required. The name of the job to update.
     :param job: Required. The job descriptor containing the new configuration 
of the job to update.
         The name field will be replaced by job_name
+    :param overrides: Optional map of override values.
     :param gcp_conn_id: The connection ID used to connect to Google Cloud.
     :param polling_period_seconds: Optional: Control the rate of the poll for 
the result of deferrable run.
         By default, the trigger will poll every 10 seconds.
@@ -270,6 +271,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
         project_id: str,
         region: str,
         job_name: str,
+        overrides: dict[str, Any] | None = None,
         polling_period_seconds: float = 10,
         timeout_seconds: float | None = None,
         gcp_conn_id: str = "google_cloud_default",
@@ -281,6 +283,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
         self.project_id = project_id
         self.region = region
         self.job_name = job_name
+        self.overrides = overrides
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
         self.polling_period_seconds = polling_period_seconds
@@ -293,7 +296,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
             gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain
         )
         self.operation = hook.execute_job(
-            region=self.region, project_id=self.project_id, 
job_name=self.job_name
+            region=self.region, project_id=self.project_id, 
job_name=self.job_name, overrides=self.overrides
         )
 
         if not self.deferrable:
diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst 
b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
index 7c80f86d15..cf90afde68 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
@@ -77,6 +77,15 @@ or you can define the same operator in the deferrable mode:
     :start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode]
     :end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode]
 
+You can also specify overrides that allow you to give a new entrypoint command 
to the job and more:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_cloud_run_execute_job_with_overrides]
+    :end-before: [END howto_operator_cloud_run_execute_job_with_overrides]
 
 
 Update a job
diff --git a/tests/providers/google/cloud/hooks/test_cloud_run.py 
b/tests/providers/google/cloud/hooks/test_cloud_run.py
index c91bc490f3..6a9a4fa898 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_run.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_run.py
@@ -34,7 +34,7 @@ from airflow.providers.google.cloud.hooks.cloud_run import 
CloudRunAsyncHook, Cl
 from tests.providers.google.cloud.utils.base_gcp_mock import 
mock_base_gcp_hook_default_project_id
 
 
-class TestCloudBathHook:
+class TestCloudRunHook:
     def dummy_get_credentials(self):
         pass
 
@@ -111,9 +111,18 @@ class TestCloudBathHook:
         job_name = "job1"
         region = "region1"
         project_id = "projectid"
-        run_job_request = 
RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+        overrides = {
+            "container_overrides": [{"args": ["python", "main.py"]}],
+            "task_count": 1,
+            "timeout": "60s",
+        }
+        run_job_request = RunJobRequest(
+            name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", 
overrides=overrides
+        )
 
-        cloud_run_hook.execute_job(job_name=job_name, region=region, 
project_id=project_id)
+        cloud_run_hook.execute_job(
+            job_name=job_name, region=region, project_id=project_id, 
overrides=overrides
+        )
         
cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request)
 
     @mock.patch(
diff --git a/tests/providers/google/cloud/operators/test_cloud_run.py 
b/tests/providers/google/cloud/operators/test_cloud_run.py
index 0fe7779158..152e625a23 100644
--- a/tests/providers/google/cloud/operators/test_cloud_run.py
+++ b/tests/providers/google/cloud/operators/test_cloud_run.py
@@ -96,7 +96,7 @@ class TestCloudRunExecuteJobOperator:
         operator.execute(context=mock.MagicMock())
 
         hook_mock.return_value.execute_job.assert_called_once_with(
-            job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
+            job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, 
overrides=None
         )
 
     @mock.patch(CLOUD_RUN_HOOK_PATH)
@@ -209,6 +209,72 @@ class TestCloudRunExecuteJobOperator:
         result = operator.execute_complete(mock.MagicMock(), event)
         assert result["name"] == JOB_NAME
 
+    @mock.patch(CLOUD_RUN_HOOK_PATH)
+    def test_execute_overrides(self, hook_mock):
+        hook_mock.return_value.get_job.return_value = JOB
+        hook_mock.return_value.execute_job.return_value = 
self._mock_operation(3, 3, 0)
+
+        overrides = {
+            "container_overrides": [{"args": ["python", "main.py"]}],
+            "task_count": 1,
+            "timeout": "60s",
+        }
+
+        operator = CloudRunExecuteJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, overrides=overrides
+        )
+
+        operator.execute(context=mock.MagicMock())
+
+        hook_mock.return_value.execute_job.assert_called_once_with(
+            job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, 
overrides=overrides
+        )
+
+    @mock.patch(CLOUD_RUN_HOOK_PATH)
+    def test_execute_overrides_with_invalid_task_count(self, hook_mock):
+        overrides = {
+            "container_overrides": [{"args": ["python", "main.py"]}],
+            "task_count": -1,
+            "timeout": "60s",
+        }
+
+        operator = CloudRunExecuteJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, overrides=overrides
+        )
+
+        with pytest.raises(AirflowException):
+            operator.execute(context=mock.MagicMock())
+
+    @mock.patch(CLOUD_RUN_HOOK_PATH)
+    def test_execute_overrides_with_invalid_timeout(self, hook_mock):
+        overrides = {
+            "container_overrides": [{"args": ["python", "main.py"]}],
+            "task_count": 1,
+            "timeout": "60",
+        }
+
+        operator = CloudRunExecuteJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, overrides=overrides
+        )
+
+        with pytest.raises(AirflowException):
+            operator.execute(context=mock.MagicMock())
+
+    @mock.patch(CLOUD_RUN_HOOK_PATH)
+    def test_execute_overrides_with_invalid_container_args(self, hook_mock):
+        overrides = {
+            "container_overrides": [{"name": "job", "args": "python main.py"}],
+            "task_count": 1,
+            "timeout": "60s",
+        }
+
+        operator = CloudRunExecuteJobOperator(
+            task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, overrides=overrides
+        )
+
+        with pytest.raises(AirflowException):
+            operator.execute(context=mock.MagicMock())
+
     def _mock_operation(self, task_count, succeeded_count, failed_count):
         operation = mock.MagicMock()
         operation.result.return_value = self._mock_execution(task_count, 
succeeded_count, failed_count)
diff --git a/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py 
b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
index 330789d82d..08c82d6eb0 100644
--- a/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+++ b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
@@ -44,12 +44,14 @@ region = "us-central1"
 job_name_prefix = "cloudrun-system-test-job"
 job1_name = f"{job_name_prefix}1"
 job2_name = f"{job_name_prefix}2"
+job3_name = f"{job_name_prefix}3"
 
 create1_task_name = "create-job1"
 create2_task_name = "create-job2"
 
 execute1_task_name = "execute-job1"
 execute2_task_name = "execute-job2"
+execute3_task_name = "execute-job3"
 
 update_job1_task_name = "update-job1"
 
@@ -70,6 +72,9 @@ def _assert_executed_jobs_xcom(ti):
     job2_dicts = ti.xcom_pull(task_ids=[execute2_task_name], 
key="return_value")
     assert job2_name in job2_dicts[0]["name"]
 
+    job3_dicts = ti.xcom_pull(task_ids=[execute3_task_name], 
key="return_value")
+    assert job3_name in job3_dicts[0]["name"]
+
 
 def _assert_created_jobs_xcom(ti):
     job1_dicts = ti.xcom_pull(task_ids=[create1_task_name], key="return_value")
@@ -181,6 +186,31 @@ with DAG(
     )
     # [END howto_operator_cloud_run_execute_job_deferrable_mode]
 
+    # [START howto_operator_cloud_run_execute_job_with_overrides]
+    overrides = {
+        "container_overrides": [
+            {
+                "name": "job",
+                "args": ["python", "main.py"],
+                "env": [{"name": "ENV_VAR", "value": "value"}],
+                "clearArgs": False,
+            }
+        ],
+        "task_count": 1,
+        "timeout": "60s",
+    }
+
+    execute3 = CloudRunExecuteJobOperator(
+        task_id=execute3_task_name,
+        project_id=PROJECT_ID,
+        region=region,
+        overrides=overrides,
+        job_name=job3_name,
+        dag=dag,
+        deferrable=False,
+    )
+    # [END howto_operator_cloud_run_execute_job_with_overrides]
+
     assert_executed_jobs = PythonOperator(
         task_id="assert-executed-jobs", 
python_callable=_assert_executed_jobs_xcom, dag=dag
     )
@@ -237,7 +267,7 @@ with DAG(
     (
         (create1, create2)
         >> assert_created_jobs
-        >> (execute1, execute2)
+        >> (execute1, execute2, execute3)
         >> assert_executed_jobs
         >> list_jobs_limit
         >> assert_jobs_limit

Reply via email to