This is an automated email from the ASF dual-hosted git repository.
pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0631ba76e0b Add support for serverless job in Databricks operators
(#45188)
0631ba76e0b is described below
commit 0631ba76e0bdc7a52e873f3ec85787cbaf0e0dec
Author: Hari Selvarajan <[email protected]>
AuthorDate: Tue Jan 28 12:06:46 2025 +0000
Add support for serverless job in Databricks operators (#45188)
closes: #45138
---
.../operators/jobs_create.rst | 1 +
.../operators/submit_run.rst | 1 +
.../operators/task.rst | 7 ++
.../providers/databricks/operators/databricks.py | 23 ++++-
.../databricks/operators/databricks_workflow.py | 10 ++
.../tests/databricks/operators/test_databricks.py | 104 ++++++++++++++++++++-
.../operators/test_databricks_workflow.py | 12 +++
.../tests/system/databricks/example_databricks.py | 23 +++++
8 files changed, 177 insertions(+), 4 deletions(-)
diff --git a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst
b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst
index 621423f83f3..5a79c8244a9 100644
--- a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst
+++ b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst
@@ -56,6 +56,7 @@ Currently the named parameters that
``DatabricksCreateJobsOperator`` supports ar
- ``max_concurrent_runs``
- ``git_source``
- ``access_control_list``
+ - ``environments``
Examples
diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
index 10548583cfa..7a8d13f646c 100644
--- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
+++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
@@ -80,6 +80,7 @@ Currently the named parameters that
``DatabricksSubmitRunOperator`` supports are
- ``libraries``
- ``run_name``
- ``timeout_seconds``
+ - ``environments``
.. code-block:: python
diff --git a/docs/apache-airflow-providers-databricks/operators/task.rst
b/docs/apache-airflow-providers-databricks/operators/task.rst
index 331481d915c..47ceafe58ad 100644
--- a/docs/apache-airflow-providers-databricks/operators/task.rst
+++ b/docs/apache-airflow-providers-databricks/operators/task.rst
@@ -44,3 +44,10 @@ Running a SQL query in Databricks using
DatabricksTaskOperator
:language: python
:start-after: [START howto_operator_databricks_task_sql]
:end-before: [END howto_operator_databricks_task_sql]
+
+Running a python file in Databricks in using DatabricksTaskOperator
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. exampleinclude::
/../../providers/tests/system/databricks/example_databricks.py
+ :language: python
+ :start-after: [START howto_operator_databricks_task_python]
+ :end-before: [END howto_operator_databricks_task_python]
diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py
b/providers/src/airflow/providers/databricks/operators/databricks.py
index b8fde94c594..3c121c49a9e 100644
--- a/providers/src/airflow/providers/databricks/operators/databricks.py
+++ b/providers/src/airflow/providers/databricks/operators/databricks.py
@@ -293,6 +293,8 @@ class DatabricksCreateJobsOperator(BaseOperator):
:param databricks_retry_delay: Number of seconds to wait between retries
(it
might be a floating point number).
:param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
+ :param environments: An optional list of task execution environment
specifications
+ that can be referenced by serverless tasks of this job.
"""
@@ -324,6 +326,7 @@ class DatabricksCreateJobsOperator(BaseOperator):
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[Any, Any] | None = None,
+ environments: list[dict] | None = None,
**kwargs,
) -> None:
"""Create a new ``DatabricksCreateJobsOperator``."""
@@ -360,6 +363,8 @@ class DatabricksCreateJobsOperator(BaseOperator):
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
+ if environments is not None:
+ self.json["environments"] = environments
if self.json:
self.json = normalise_json_content(self.json)
@@ -503,6 +508,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param git_source: Optional specification of a remote git repository from
which
supported task types are retrieved.
:param deferrable: Run operator in the deferrable mode.
+ :param environments: An optional list of task execution environment
specifications
+ that can be referenced by serverless tasks of this job.
.. seealso::
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit
@@ -543,6 +550,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
wait_for_termination: bool = True,
git_source: dict[str, str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ environments: list[dict] | None = None,
**kwargs,
) -> None:
"""Create a new ``DatabricksSubmitRunOperator``."""
@@ -587,6 +595,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
self.json["access_control_list"] = access_control_list
if git_source is not None:
self.json["git_source"] = git_source
+ if environments is not None:
+ self.json["environments"] = environments
if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
@@ -983,6 +993,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
:param wait_for_termination: if we should wait for termination of the job
run. ``True`` by default.
:param workflow_run_metadata: Metadata for the workflow run. This is used
when the operator is used within
a workflow. It is expected to be a dictionary containing the run_id
and conn_id for the workflow.
+ :param environments: An optional list of task execution environment
specifications
+ that can be referenced by serverless tasks of this job.
"""
def __init__(
@@ -1000,6 +1012,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
polling_period_seconds: int = 5,
wait_for_termination: bool = True,
workflow_run_metadata: dict[str, Any] | None = None,
+ environments: list[dict] | None = None,
**kwargs: Any,
):
self.caller = caller
@@ -1015,7 +1028,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
self.polling_period_seconds = polling_period_seconds
self.wait_for_termination = wait_for_termination
self.workflow_run_metadata = workflow_run_metadata
-
+ self.environments = environments
self.databricks_run_id: int | None = None
super().__init__(**kwargs)
@@ -1095,8 +1108,10 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
run_json["new_cluster"] = self.new_cluster
elif self.existing_cluster_id:
run_json["existing_cluster_id"] = self.existing_cluster_id
+ elif self.environments:
+ run_json["environments"] = self.environments
else:
- raise ValueError("Must specify either existing_cluster_id or
new_cluster.")
+ raise ValueError("Must specify either existing_cluster_id,
new_cluster or environments.")
return run_json
def _launch_job(self, context: Context | None = None) -> int:
@@ -1400,6 +1415,8 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator):
:param new_cluster: Specs for a new cluster on which this task will be run.
:param polling_period_seconds: Controls the rate which we poll for the
result of this notebook job run.
:param wait_for_termination: if we should wait for termination of the job
run. ``True`` by default.
+ :param environments: An optional list of task execution environment
specifications
+ that can be referenced by serverless tasks of this job
"""
CALLER = "DatabricksTaskOperator"
@@ -1419,6 +1436,7 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator):
polling_period_seconds: int = 5,
wait_for_termination: bool = True,
workflow_run_metadata: dict | None = None,
+ environments: list[dict] | None = None,
**kwargs,
):
self.task_config = task_config
@@ -1436,6 +1454,7 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator):
polling_period_seconds=polling_period_seconds,
wait_for_termination=wait_for_termination,
workflow_run_metadata=workflow_run_metadata,
+ environments=environments,
**kwargs,
)
diff --git
a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py
b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py
index 6df8e2d025c..d185d6d30cd 100644
---
a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py
+++
b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py
@@ -90,6 +90,8 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
will be passed to all notebooks in the workflow.
:param tasks_to_convert: A list of tasks to convert to a Databricks
workflow. This list can also be
populated after instantiation using the `add_task` method.
+ :param environments: An optional list of task execution environment
specifications
+ that can be referenced by serverless tasks of this job.
"""
operator_extra_links = (WorkflowJobRunLink(),
WorkflowJobRepairAllFailedLink())
@@ -106,6 +108,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
max_concurrent_runs: int = 1,
notebook_params: dict | None = None,
tasks_to_convert: list[BaseOperator] | None = None,
+ environments: list[dict] | None = None,
**kwargs,
):
self.databricks_conn_id = databricks_conn_id
@@ -117,6 +120,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
self.tasks_to_convert = tasks_to_convert or []
self.relevant_upstreams = [task_id]
self.workflow_run_metadata: WorkflowRunMetadata | None = None
+ self.environments = environments
super().__init__(task_id=task_id, **kwargs)
def _get_hook(self, caller: str) -> DatabricksHook:
@@ -156,6 +160,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
"format": "MULTI_TASK",
"job_clusters": self.job_clusters,
"max_concurrent_runs": self.max_concurrent_runs,
+ "environments": self.environments,
}
return merge(default_json, self.extra_job_params)
@@ -274,6 +279,8 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
all python tasks in the workflow.
:param spark_submit_params: A list of spark submit parameters to pass to
the workflow. These parameters
will be passed to all spark submit tasks.
+ :param environments: An optional list of task execution environment
specifications
+ that can be referenced by serverless tasks of this job.
"""
is_databricks = True
@@ -290,6 +297,7 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
notebook_params: dict | None = None,
python_params: list | None = None,
spark_submit_params: list | None = None,
+ environments: list[dict] | None = None,
**kwargs,
):
self.databricks_conn_id = databricks_conn_id
@@ -302,6 +310,7 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
self.notebook_params = notebook_params or {}
self.python_params = python_params or []
self.spark_submit_params = spark_submit_params or []
+ self.environments = environments or []
super().__init__(**kwargs)
def __exit__(
@@ -321,6 +330,7 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
job_clusters=self.job_clusters,
max_concurrent_runs=self.max_concurrent_runs,
notebook_params=self.notebook_params,
+ environments=self.environments,
)
for task in tasks:
diff --git a/providers/tests/databricks/operators/test_databricks.py
b/providers/tests/databricks/operators/test_databricks.py
index 51e7a765998..ac5e556e48d 100644
--- a/providers/tests/databricks/operators/test_databricks.py
+++ b/providers/tests/databricks/operators/test_databricks.py
@@ -265,6 +265,15 @@ ACCESS_CONTROL_LIST = [
"permission_level": "CAN_MANAGE",
}
]
+ENVIRONMENTS = [
+ {
+ "environment_key": "default_environment",
+ "spec": {
+ "client": "1",
+ "dependencies": ["library1"],
+ },
+ }
+]
def mock_dict(d: dict):
@@ -307,6 +316,7 @@ class TestDatabricksCreateJobsOperator:
max_concurrent_runs=MAX_CONCURRENT_RUNS,
git_source=GIT_SOURCE,
access_control_list=ACCESS_CONTROL_LIST,
+ environments=ENVIRONMENTS,
)
expected = utils.normalise_json_content(
{
@@ -321,6 +331,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
)
@@ -342,6 +353,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
@@ -358,6 +370,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
)
@@ -380,6 +393,7 @@ class TestDatabricksCreateJobsOperator:
override_max_concurrent_runs = 0
override_git_source = {}
override_access_control_list = []
+ override_environments = []
json = {
"name": JOB_NAME,
"tags": TAGS,
@@ -392,6 +406,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
op = DatabricksCreateJobsOperator(
@@ -408,6 +423,7 @@ class TestDatabricksCreateJobsOperator:
max_concurrent_runs=override_max_concurrent_runs,
git_source=override_git_source,
access_control_list=override_access_control_list,
+ environments=override_environments,
)
expected = utils.normalise_json_content(
@@ -423,6 +439,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": override_max_concurrent_runs,
"git_source": override_git_source,
"access_control_list": override_access_control_list,
+ "environments": override_environments,
}
)
@@ -466,6 +483,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
db_mock = db_mock_class.return_value
@@ -490,6 +508,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
)
db_mock_class.assert_called_once_with(
@@ -522,6 +541,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
db_mock = db_mock_class.return_value
@@ -544,6 +564,7 @@ class TestDatabricksCreateJobsOperator:
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
+ "environments": ENVIRONMENTS,
}
)
db_mock_class.assert_called_once_with(
@@ -654,6 +675,72 @@ class TestDatabricksSubmitRunOperator:
assert expected == utils.normalise_json_content(op.json)
+ def test_init_with_serverless_spark_python_task_named_parameters(self):
+ """
+ Test the initializer with the named parameters.
+ """
+ python_tasks = [
+ {
+ "task_key": "pythong_task_1",
+ "new_cluster": {
+ "spark_version": "7.3.x-scala2.12",
+ "node_type_id": "i3.xlarge",
+ "spark_conf": {
+ "spark.speculation": True,
+ },
+ "aws_attributes": {
+ "availability": "SPOT",
+ "zone_id": "us-west-2a",
+ },
+ "autoscale": {
+ "min_workers": 2,
+ "max_workers": 16,
+ },
+ },
+ "spark_python_task": {"python_file":
"/Users/[email protected]/example_file.py"},
+ "timeout_seconds": 86400,
+ "max_retries": 3,
+ "min_retry_interval_millis": 2000,
+ "retry_on_timeout": False,
+ "environment_key": "default_environment",
+ },
+ ]
+ json = {
+ "name": JOB_NAME,
+ "tags": TAGS,
+ "tasks": python_tasks,
+ "job_clusters": JOB_CLUSTERS,
+ "email_notifications": EMAIL_NOTIFICATIONS,
+ "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+ "timeout_seconds": TIMEOUT_SECONDS,
+ "schedule": SCHEDULE,
+ "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+ "git_source": GIT_SOURCE,
+ }
+ op = DatabricksSubmitRunOperator(
+ task_id=TASK_ID,
+ json=json,
+ environments=ENVIRONMENTS,
+ )
+ expected = utils.normalise_json_content(
+ {
+ "name": JOB_NAME,
+ "tags": TAGS,
+ "tasks": python_tasks,
+ "job_clusters": JOB_CLUSTERS,
+ "email_notifications": EMAIL_NOTIFICATIONS,
+ "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+ "timeout_seconds": TIMEOUT_SECONDS,
+ "schedule": SCHEDULE,
+ "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+ "git_source": GIT_SOURCE,
+ "environments": ENVIRONMENTS,
+ "run_name": TASK_ID,
+ }
+ )
+
+ assert expected == utils.normalise_json_content(op.json)
+
def test_init_with_pipeline_name_task_named_parameters(self):
"""
Test the initializer with the named parameters.
@@ -2121,7 +2208,7 @@ class TestDatabricksNotebookOperator:
exception_message = "Both new_cluster and existing_cluster_id are set.
Only one should be set."
assert str(exc_info.value) == exception_message
- def test_both_new_and_existing_cluster_unset(self):
+ def test_both_new_and_existing_cluster_unset(self, caplog):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
@@ -2130,7 +2217,7 @@ class TestDatabricksNotebookOperator:
)
with pytest.raises(ValueError) as exc_info:
operator._get_run_json()
- exception_message = "Must specify either existing_cluster_id or
new_cluster."
+ exception_message = "Must specify either existing_cluster_id,
new_cluster or environments."
assert str(exc_info.value) == exception_message
def test_job_runs_forever_by_default(self):
@@ -2343,3 +2430,16 @@ class TestDatabricksTaskOperator:
expected_task_key = "test_task_key"
assert expected_task_key == operator.databricks_task_key
+
+ def test_get_task_base_json_serverless(self):
+ task_config = SPARK_PYTHON_TASK
+ operator = DatabricksTaskOperator(
+ task_id="test_task",
+ databricks_conn_id="test_conn_id",
+ task_config=task_config,
+ environments=ENVIRONMENTS,
+ )
+ task_base_json = operator._get_task_base_json()
+
+ assert operator.task_config == task_config
+ assert task_base_json == task_config
diff --git a/providers/tests/databricks/operators/test_databricks_workflow.py
b/providers/tests/databricks/operators/test_databricks_workflow.py
index fbc429ed1d9..acfd9bf9a7d 100644
--- a/providers/tests/databricks/operators/test_databricks_workflow.py
+++ b/providers/tests/databricks/operators/test_databricks_workflow.py
@@ -77,9 +77,19 @@ def test_flatten_node():
def test_create_workflow_json(mock_databricks_hook, context, mock_task_group):
"""Test that _CreateDatabricksWorkflowOperator.create_workflow_json
returns the expected JSON."""
+ environments = [
+ {
+ "environment_key": "default_environment",
+ "spec": {
+ "client": "1",
+ "dependencies": ["library1"],
+ },
+ }
+ ]
operator = _CreateDatabricksWorkflowOperator(
task_id="test_task",
databricks_conn_id="databricks_default",
+ environments=environments,
)
operator.task_group = mock_task_group
@@ -96,6 +106,7 @@ def test_create_workflow_json(mock_databricks_hook, context,
mock_task_group):
assert workflow_json["job_clusters"] == []
assert workflow_json["max_concurrent_runs"] == 1
assert workflow_json["timeout_seconds"] == 0
+ assert workflow_json["environments"] == environments
def test_create_or_reset_job_existing(mock_databricks_hook, context,
mock_task_group):
@@ -216,6 +227,7 @@ def
test_task_group_exit_creates_operator(mock_databricks_workflow_operator):
task_group=task_group,
task_id="launch",
databricks_conn_id="databricks_conn",
+ environments=[],
existing_clusters=[],
extra_job_params={},
job_clusters=[],
diff --git a/providers/tests/system/databricks/example_databricks.py
b/providers/tests/system/databricks/example_databricks.py
index 999cebb6742..360645cc30d 100644
--- a/providers/tests/system/databricks/example_databricks.py
+++ b/providers/tests/system/databricks/example_databricks.py
@@ -238,6 +238,29 @@ with DAG(
)
# [END howto_operator_databricks_task_sql]
+ # [START howto_operator_databricks_task_python]
+ environments = [
+ {
+ "environment_key": "default_environment",
+ "spec": {
+ "client": "1",
+ "dependencies": ["library1"],
+ },
+ }
+ ]
+ task_operator_python_query = DatabricksTaskOperator(
+ task_id="python_task",
+ databricks_conn_id="databricks_conn",
+ task_config={
+ "spark_python_task": {
+ "python_file": "/Users/[email protected]/example_file.py",
+ },
+ "environment_key": "default_environment",
+ },
+ environments=environments,
+ )
+ # [END howto_operator_databricks_task_python]
+
from tests_common.test_utils.watcher import watcher
# This test needs watcher in order to properly mark success/failure