This is an automated email from the ASF dual-hosted git repository.

pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0631ba76e0b Add support for serverless job in Databricks operators 
(#45188)
0631ba76e0b is described below

commit 0631ba76e0bdc7a52e873f3ec85787cbaf0e0dec
Author: Hari Selvarajan <[email protected]>
AuthorDate: Tue Jan 28 12:06:46 2025 +0000

    Add support for serverless job in Databricks operators (#45188)
    
    closes: #45138
---
 .../operators/jobs_create.rst                      |   1 +
 .../operators/submit_run.rst                       |   1 +
 .../operators/task.rst                             |   7 ++
 .../providers/databricks/operators/databricks.py   |  23 ++++-
 .../databricks/operators/databricks_workflow.py    |  10 ++
 .../tests/databricks/operators/test_databricks.py  | 104 ++++++++++++++++++++-
 .../operators/test_databricks_workflow.py          |  12 +++
 .../tests/system/databricks/example_databricks.py  |  23 +++++
 8 files changed, 177 insertions(+), 4 deletions(-)

diff --git a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst 
b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst
index 621423f83f3..5a79c8244a9 100644
--- a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst
+++ b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst
@@ -56,6 +56,7 @@ Currently the named parameters that 
``DatabricksCreateJobsOperator`` supports ar
   - ``max_concurrent_runs``
   - ``git_source``
   - ``access_control_list``
+  - ``environments``
 
 
 Examples
diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst 
b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
index 10548583cfa..7a8d13f646c 100644
--- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
+++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
@@ -80,6 +80,7 @@ Currently the named parameters that 
``DatabricksSubmitRunOperator`` supports are
     - ``libraries``
     - ``run_name``
     - ``timeout_seconds``
+    - ``environments``
 
 .. code-block:: python
 
diff --git a/docs/apache-airflow-providers-databricks/operators/task.rst 
b/docs/apache-airflow-providers-databricks/operators/task.rst
index 331481d915c..47ceafe58ad 100644
--- a/docs/apache-airflow-providers-databricks/operators/task.rst
+++ b/docs/apache-airflow-providers-databricks/operators/task.rst
@@ -44,3 +44,10 @@ Running a SQL query in Databricks using 
DatabricksTaskOperator
     :language: python
     :start-after: [START howto_operator_databricks_task_sql]
     :end-before: [END howto_operator_databricks_task_sql]
+
+Running a python file in Databricks in using DatabricksTaskOperator
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. exampleinclude:: 
/../../providers/tests/system/databricks/example_databricks.py
+    :language: python
+    :start-after: [START howto_operator_databricks_task_python]
+    :end-before: [END howto_operator_databricks_task_python]
diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py 
b/providers/src/airflow/providers/databricks/operators/databricks.py
index b8fde94c594..3c121c49a9e 100644
--- a/providers/src/airflow/providers/databricks/operators/databricks.py
+++ b/providers/src/airflow/providers/databricks/operators/databricks.py
@@ -293,6 +293,8 @@ class DatabricksCreateJobsOperator(BaseOperator):
     :param databricks_retry_delay: Number of seconds to wait between retries 
(it
             might be a floating point number).
     :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
+    :param environments: An optional list of task execution environment 
specifications
+        that can be referenced by serverless tasks of this job.
 
     """
 
@@ -324,6 +326,7 @@ class DatabricksCreateJobsOperator(BaseOperator):
         databricks_retry_limit: int = 3,
         databricks_retry_delay: int = 1,
         databricks_retry_args: dict[Any, Any] | None = None,
+        environments: list[dict] | None = None,
         **kwargs,
     ) -> None:
         """Create a new ``DatabricksCreateJobsOperator``."""
@@ -360,6 +363,8 @@ class DatabricksCreateJobsOperator(BaseOperator):
             self.json["git_source"] = git_source
         if access_control_list is not None:
             self.json["access_control_list"] = access_control_list
+        if environments is not None:
+            self.json["environments"] = environments
         if self.json:
             self.json = normalise_json_content(self.json)
 
@@ -503,6 +508,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
     :param git_source: Optional specification of a remote git repository from 
which
         supported task types are retrieved.
     :param deferrable: Run operator in the deferrable mode.
+    :param environments: An optional list of task execution environment 
specifications
+        that can be referenced by serverless tasks of this job.
 
         .. seealso::
             
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit
@@ -543,6 +550,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
         wait_for_termination: bool = True,
         git_source: dict[str, str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        environments: list[dict] | None = None,
         **kwargs,
     ) -> None:
         """Create a new ``DatabricksSubmitRunOperator``."""
@@ -587,6 +595,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
             self.json["access_control_list"] = access_control_list
         if git_source is not None:
             self.json["git_source"] = git_source
+        if environments is not None:
+            self.json["environments"] = environments
 
         if "dbt_task" in self.json and "git_source" not in self.json:
             raise AirflowException("git_source is required for dbt_task")
@@ -983,6 +993,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
     :param wait_for_termination: if we should wait for termination of the job 
run. ``True`` by default.
     :param workflow_run_metadata: Metadata for the workflow run. This is used 
when the operator is used within
         a workflow. It is expected to be a dictionary containing the run_id 
and conn_id for the workflow.
+    :param environments: An optional list of task execution environment 
specifications
+        that can be referenced by serverless tasks of this job.
     """
 
     def __init__(
@@ -1000,6 +1012,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
         polling_period_seconds: int = 5,
         wait_for_termination: bool = True,
         workflow_run_metadata: dict[str, Any] | None = None,
+        environments: list[dict] | None = None,
         **kwargs: Any,
     ):
         self.caller = caller
@@ -1015,7 +1028,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
         self.polling_period_seconds = polling_period_seconds
         self.wait_for_termination = wait_for_termination
         self.workflow_run_metadata = workflow_run_metadata
-
+        self.environments = environments
         self.databricks_run_id: int | None = None
 
         super().__init__(**kwargs)
@@ -1095,8 +1108,10 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
             run_json["new_cluster"] = self.new_cluster
         elif self.existing_cluster_id:
             run_json["existing_cluster_id"] = self.existing_cluster_id
+        elif self.environments:
+            run_json["environments"] = self.environments
         else:
-            raise ValueError("Must specify either existing_cluster_id or 
new_cluster.")
+            raise ValueError("Must specify either existing_cluster_id, 
new_cluster or environments.")
         return run_json
 
     def _launch_job(self, context: Context | None = None) -> int:
@@ -1400,6 +1415,8 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator):
     :param new_cluster: Specs for a new cluster on which this task will be run.
     :param polling_period_seconds: Controls the rate which we poll for the 
result of this notebook job run.
     :param wait_for_termination: if we should wait for termination of the job 
run. ``True`` by default.
+    :param environments: An optional list of task execution environment 
specifications
+        that can be referenced by serverless tasks of this job
     """
 
     CALLER = "DatabricksTaskOperator"
@@ -1419,6 +1436,7 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator):
         polling_period_seconds: int = 5,
         wait_for_termination: bool = True,
         workflow_run_metadata: dict | None = None,
+        environments: list[dict] | None = None,
         **kwargs,
     ):
         self.task_config = task_config
@@ -1436,6 +1454,7 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator):
             polling_period_seconds=polling_period_seconds,
             wait_for_termination=wait_for_termination,
             workflow_run_metadata=workflow_run_metadata,
+            environments=environments,
             **kwargs,
         )
 
diff --git 
a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py 
b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py
index 6df8e2d025c..d185d6d30cd 100644
--- 
a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py
+++ 
b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py
@@ -90,6 +90,8 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
         will be passed to all notebooks in the workflow.
     :param tasks_to_convert: A list of tasks to convert to a Databricks 
workflow. This list can also be
         populated after instantiation using the `add_task` method.
+    :param environments: An optional list of task execution environment 
specifications
+        that can be referenced by serverless tasks of this job.
     """
 
     operator_extra_links = (WorkflowJobRunLink(), 
WorkflowJobRepairAllFailedLink())
@@ -106,6 +108,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
         max_concurrent_runs: int = 1,
         notebook_params: dict | None = None,
         tasks_to_convert: list[BaseOperator] | None = None,
+        environments: list[dict] | None = None,
         **kwargs,
     ):
         self.databricks_conn_id = databricks_conn_id
@@ -117,6 +120,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
         self.tasks_to_convert = tasks_to_convert or []
         self.relevant_upstreams = [task_id]
         self.workflow_run_metadata: WorkflowRunMetadata | None = None
+        self.environments = environments
         super().__init__(task_id=task_id, **kwargs)
 
     def _get_hook(self, caller: str) -> DatabricksHook:
@@ -156,6 +160,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
             "format": "MULTI_TASK",
             "job_clusters": self.job_clusters,
             "max_concurrent_runs": self.max_concurrent_runs,
+            "environments": self.environments,
         }
         return merge(default_json, self.extra_job_params)
 
@@ -274,6 +279,8 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
         all python tasks in the workflow.
     :param spark_submit_params: A list of spark submit parameters to pass to 
the workflow. These parameters
         will be passed to all spark submit tasks.
+    :param environments: An optional list of task execution environment 
specifications
+        that can be referenced by serverless tasks of this job.
     """
 
     is_databricks = True
@@ -290,6 +297,7 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
         notebook_params: dict | None = None,
         python_params: list | None = None,
         spark_submit_params: list | None = None,
+        environments: list[dict] | None = None,
         **kwargs,
     ):
         self.databricks_conn_id = databricks_conn_id
@@ -302,6 +310,7 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
         self.notebook_params = notebook_params or {}
         self.python_params = python_params or []
         self.spark_submit_params = spark_submit_params or []
+        self.environments = environments or []
         super().__init__(**kwargs)
 
     def __exit__(
@@ -321,6 +330,7 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
             job_clusters=self.job_clusters,
             max_concurrent_runs=self.max_concurrent_runs,
             notebook_params=self.notebook_params,
+            environments=self.environments,
         )
 
         for task in tasks:
diff --git a/providers/tests/databricks/operators/test_databricks.py 
b/providers/tests/databricks/operators/test_databricks.py
index 51e7a765998..ac5e556e48d 100644
--- a/providers/tests/databricks/operators/test_databricks.py
+++ b/providers/tests/databricks/operators/test_databricks.py
@@ -265,6 +265,15 @@ ACCESS_CONTROL_LIST = [
         "permission_level": "CAN_MANAGE",
     }
 ]
+ENVIRONMENTS = [
+    {
+        "environment_key": "default_environment",
+        "spec": {
+            "client": "1",
+            "dependencies": ["library1"],
+        },
+    }
+]
 
 
 def mock_dict(d: dict):
@@ -307,6 +316,7 @@ class TestDatabricksCreateJobsOperator:
             max_concurrent_runs=MAX_CONCURRENT_RUNS,
             git_source=GIT_SOURCE,
             access_control_list=ACCESS_CONTROL_LIST,
+            environments=ENVIRONMENTS,
         )
         expected = utils.normalise_json_content(
             {
@@ -321,6 +331,7 @@ class TestDatabricksCreateJobsOperator:
                 "max_concurrent_runs": MAX_CONCURRENT_RUNS,
                 "git_source": GIT_SOURCE,
                 "access_control_list": ACCESS_CONTROL_LIST,
+                "environments": ENVIRONMENTS,
             }
         )
 
@@ -342,6 +353,7 @@ class TestDatabricksCreateJobsOperator:
             "max_concurrent_runs": MAX_CONCURRENT_RUNS,
             "git_source": GIT_SOURCE,
             "access_control_list": ACCESS_CONTROL_LIST,
+            "environments": ENVIRONMENTS,
         }
         op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
 
@@ -358,6 +370,7 @@ class TestDatabricksCreateJobsOperator:
                 "max_concurrent_runs": MAX_CONCURRENT_RUNS,
                 "git_source": GIT_SOURCE,
                 "access_control_list": ACCESS_CONTROL_LIST,
+                "environments": ENVIRONMENTS,
             }
         )
 
@@ -380,6 +393,7 @@ class TestDatabricksCreateJobsOperator:
         override_max_concurrent_runs = 0
         override_git_source = {}
         override_access_control_list = []
+        override_environments = []
         json = {
             "name": JOB_NAME,
             "tags": TAGS,
@@ -392,6 +406,7 @@ class TestDatabricksCreateJobsOperator:
             "max_concurrent_runs": MAX_CONCURRENT_RUNS,
             "git_source": GIT_SOURCE,
             "access_control_list": ACCESS_CONTROL_LIST,
+            "environments": ENVIRONMENTS,
         }
 
         op = DatabricksCreateJobsOperator(
@@ -408,6 +423,7 @@ class TestDatabricksCreateJobsOperator:
             max_concurrent_runs=override_max_concurrent_runs,
             git_source=override_git_source,
             access_control_list=override_access_control_list,
+            environments=override_environments,
         )
 
         expected = utils.normalise_json_content(
@@ -423,6 +439,7 @@ class TestDatabricksCreateJobsOperator:
                 "max_concurrent_runs": override_max_concurrent_runs,
                 "git_source": override_git_source,
                 "access_control_list": override_access_control_list,
+                "environments": override_environments,
             }
         )
 
@@ -466,6 +483,7 @@ class TestDatabricksCreateJobsOperator:
             "max_concurrent_runs": MAX_CONCURRENT_RUNS,
             "git_source": GIT_SOURCE,
             "access_control_list": ACCESS_CONTROL_LIST,
+            "environments": ENVIRONMENTS,
         }
         op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
         db_mock = db_mock_class.return_value
@@ -490,6 +508,7 @@ class TestDatabricksCreateJobsOperator:
                 "max_concurrent_runs": MAX_CONCURRENT_RUNS,
                 "git_source": GIT_SOURCE,
                 "access_control_list": ACCESS_CONTROL_LIST,
+                "environments": ENVIRONMENTS,
             }
         )
         db_mock_class.assert_called_once_with(
@@ -522,6 +541,7 @@ class TestDatabricksCreateJobsOperator:
             "max_concurrent_runs": MAX_CONCURRENT_RUNS,
             "git_source": GIT_SOURCE,
             "access_control_list": ACCESS_CONTROL_LIST,
+            "environments": ENVIRONMENTS,
         }
         op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
         db_mock = db_mock_class.return_value
@@ -544,6 +564,7 @@ class TestDatabricksCreateJobsOperator:
                 "max_concurrent_runs": MAX_CONCURRENT_RUNS,
                 "git_source": GIT_SOURCE,
                 "access_control_list": ACCESS_CONTROL_LIST,
+                "environments": ENVIRONMENTS,
             }
         )
         db_mock_class.assert_called_once_with(
@@ -654,6 +675,72 @@ class TestDatabricksSubmitRunOperator:
 
         assert expected == utils.normalise_json_content(op.json)
 
+    def test_init_with_serverless_spark_python_task_named_parameters(self):
+        """
+        Test the initializer with the named parameters.
+        """
+        python_tasks = [
+            {
+                "task_key": "pythong_task_1",
+                "new_cluster": {
+                    "spark_version": "7.3.x-scala2.12",
+                    "node_type_id": "i3.xlarge",
+                    "spark_conf": {
+                        "spark.speculation": True,
+                    },
+                    "aws_attributes": {
+                        "availability": "SPOT",
+                        "zone_id": "us-west-2a",
+                    },
+                    "autoscale": {
+                        "min_workers": 2,
+                        "max_workers": 16,
+                    },
+                },
+                "spark_python_task": {"python_file": 
"/Users/[email protected]/example_file.py"},
+                "timeout_seconds": 86400,
+                "max_retries": 3,
+                "min_retry_interval_millis": 2000,
+                "retry_on_timeout": False,
+                "environment_key": "default_environment",
+            },
+        ]
+        json = {
+            "name": JOB_NAME,
+            "tags": TAGS,
+            "tasks": python_tasks,
+            "job_clusters": JOB_CLUSTERS,
+            "email_notifications": EMAIL_NOTIFICATIONS,
+            "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+            "timeout_seconds": TIMEOUT_SECONDS,
+            "schedule": SCHEDULE,
+            "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+            "git_source": GIT_SOURCE,
+        }
+        op = DatabricksSubmitRunOperator(
+            task_id=TASK_ID,
+            json=json,
+            environments=ENVIRONMENTS,
+        )
+        expected = utils.normalise_json_content(
+            {
+                "name": JOB_NAME,
+                "tags": TAGS,
+                "tasks": python_tasks,
+                "job_clusters": JOB_CLUSTERS,
+                "email_notifications": EMAIL_NOTIFICATIONS,
+                "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+                "timeout_seconds": TIMEOUT_SECONDS,
+                "schedule": SCHEDULE,
+                "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+                "git_source": GIT_SOURCE,
+                "environments": ENVIRONMENTS,
+                "run_name": TASK_ID,
+            }
+        )
+
+        assert expected == utils.normalise_json_content(op.json)
+
     def test_init_with_pipeline_name_task_named_parameters(self):
         """
         Test the initializer with the named parameters.
@@ -2121,7 +2208,7 @@ class TestDatabricksNotebookOperator:
         exception_message = "Both new_cluster and existing_cluster_id are set. 
Only one should be set."
         assert str(exc_info.value) == exception_message
 
-    def test_both_new_and_existing_cluster_unset(self):
+    def test_both_new_and_existing_cluster_unset(self, caplog):
         operator = DatabricksNotebookOperator(
             task_id="test_task",
             notebook_path="test_path",
@@ -2130,7 +2217,7 @@ class TestDatabricksNotebookOperator:
         )
         with pytest.raises(ValueError) as exc_info:
             operator._get_run_json()
-        exception_message = "Must specify either existing_cluster_id or 
new_cluster."
+        exception_message = "Must specify either existing_cluster_id, 
new_cluster or environments."
         assert str(exc_info.value) == exception_message
 
     def test_job_runs_forever_by_default(self):
@@ -2343,3 +2430,16 @@ class TestDatabricksTaskOperator:
         expected_task_key = "test_task_key"
 
         assert expected_task_key == operator.databricks_task_key
+
+    def test_get_task_base_json_serverless(self):
+        task_config = SPARK_PYTHON_TASK
+        operator = DatabricksTaskOperator(
+            task_id="test_task",
+            databricks_conn_id="test_conn_id",
+            task_config=task_config,
+            environments=ENVIRONMENTS,
+        )
+        task_base_json = operator._get_task_base_json()
+
+        assert operator.task_config == task_config
+        assert task_base_json == task_config
diff --git a/providers/tests/databricks/operators/test_databricks_workflow.py 
b/providers/tests/databricks/operators/test_databricks_workflow.py
index fbc429ed1d9..acfd9bf9a7d 100644
--- a/providers/tests/databricks/operators/test_databricks_workflow.py
+++ b/providers/tests/databricks/operators/test_databricks_workflow.py
@@ -77,9 +77,19 @@ def test_flatten_node():
 
 def test_create_workflow_json(mock_databricks_hook, context, mock_task_group):
     """Test that _CreateDatabricksWorkflowOperator.create_workflow_json 
returns the expected JSON."""
+    environments = [
+        {
+            "environment_key": "default_environment",
+            "spec": {
+                "client": "1",
+                "dependencies": ["library1"],
+            },
+        }
+    ]
     operator = _CreateDatabricksWorkflowOperator(
         task_id="test_task",
         databricks_conn_id="databricks_default",
+        environments=environments,
     )
     operator.task_group = mock_task_group
 
@@ -96,6 +106,7 @@ def test_create_workflow_json(mock_databricks_hook, context, 
mock_task_group):
     assert workflow_json["job_clusters"] == []
     assert workflow_json["max_concurrent_runs"] == 1
     assert workflow_json["timeout_seconds"] == 0
+    assert workflow_json["environments"] == environments
 
 
 def test_create_or_reset_job_existing(mock_databricks_hook, context, 
mock_task_group):
@@ -216,6 +227,7 @@ def 
test_task_group_exit_creates_operator(mock_databricks_workflow_operator):
         task_group=task_group,
         task_id="launch",
         databricks_conn_id="databricks_conn",
+        environments=[],
         existing_clusters=[],
         extra_job_params={},
         job_clusters=[],
diff --git a/providers/tests/system/databricks/example_databricks.py 
b/providers/tests/system/databricks/example_databricks.py
index 999cebb6742..360645cc30d 100644
--- a/providers/tests/system/databricks/example_databricks.py
+++ b/providers/tests/system/databricks/example_databricks.py
@@ -238,6 +238,29 @@ with DAG(
     )
     # [END howto_operator_databricks_task_sql]
 
+    # [START howto_operator_databricks_task_python]
+    environments = [
+        {
+            "environment_key": "default_environment",
+            "spec": {
+                "client": "1",
+                "dependencies": ["library1"],
+            },
+        }
+    ]
+    task_operator_python_query = DatabricksTaskOperator(
+        task_id="python_task",
+        databricks_conn_id="databricks_conn",
+        task_config={
+            "spark_python_task": {
+                "python_file": "/Users/[email protected]/example_file.py",
+            },
+            "environment_key": "default_environment",
+        },
+        environments=environments,
+    )
+    # [END howto_operator_databricks_task_python]
+
     from tests_common.test_utils.watcher import watcher
 
     # This test needs watcher in order to properly mark success/failure

Reply via email to