This is an automated email from the ASF dual-hosted git repository.

pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ecf7fa07d Add DatabricksWorkflowTaskGroup (#39771)
2ecf7fa07d is described below

commit 2ecf7fa07d6d681c73ae4831801f9d98db298d89
Author: Pankaj Koti <pankajkoti...@gmail.com>
AuthorDate: Thu May 30 14:49:37 2024 +0530

    Add DatabricksWorkflowTaskGroup (#39771)
    
    This pull request introduces the 
[DatabricksWorkflowTaskGroup](https://github.com/astronomer/astro-provider-databricks/blob/main/src/astro_databricks/operators/workflow.py#L226)
    to the Airflow Databricks provider from the 
[astro-provider-databricks](https://github.com/astronomer/astro-provider-databricks/tree/main)
    repository.
    It marks another pull request aimed at contributing
    operators and features from that repository into the Airflow
    Databricks provider, the previous PR being 
https://github.com/apache/airflow/pull/39178.
    
    The task group launches a [Databricks 
Workflow](https://docs.databricks.com/en/workflows/index.html)
    and runs the notebook jobs from within it, resulting in a
    [75% cost reduction](https://www.databricks.com/product/pricing) ($0.40/DBU 
for all-purpose compute,
    $0.07/DBU for Jobs compute) when compared to executing
    ``DatabricksNotebookOperator`` outside of ``DatabricksWorkflowTaskGroup``.
    
    ---------
    Co-authored-by: Daniel Imberman <daniel.imber...@gmail.com>
    Co-authored-by: Tatiana Al-Chueyr <tatiana.alchu...@gmail.com>
    Co-authored-by: Wei Lee <weilee...@gmail.com>
---
 airflow/providers/databricks/hooks/databricks.py   |  18 ++
 .../providers/databricks/operators/databricks.py   | 141 ++++++++--
 .../databricks/operators/databricks_workflow.py    | 312 +++++++++++++++++++++
 airflow/providers/databricks/provider.yaml         |  10 +
 ...icks_workflow_task_group_airflow_graph_view.png | Bin 0 -> 188072 bytes
 .../img/workflow_run_databricks_graph_view.png     | Bin 0 -> 405053 bytes
 .../operators/workflow.rst                         |  71 +++++
 generated/provider_dependencies.json               |   1 +
 .../databricks/operators/test_databricks.py        | 119 ++++++++
 .../operators/test_databricks_workflow.py          | 233 +++++++++++++++
 .../databricks/example_databricks_workflow.py      | 118 ++++++++
 11 files changed, 998 insertions(+), 25 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks.py 
b/airflow/providers/databricks/hooks/databricks.py
index 710074d239..ee5349add0 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -29,6 +29,7 @@ or the ``api/2.1/jobs/runs/submit``
 from __future__ import annotations
 
 import json
+from enum import Enum
 from typing import Any
 
 from requests import exceptions as requests_exceptions
@@ -63,6 +64,23 @@ WORKSPACE_GET_STATUS_ENDPOINT = ("GET", 
"api/2.0/workspace/get-status")
 SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
 
 
+class RunLifeCycleState(Enum):
+    """Enum for the run life cycle state concept of Databricks runs.
+
+    See more information at: 
https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state
+    """
+
+    BLOCKED = "BLOCKED"
+    INTERNAL_ERROR = "INTERNAL_ERROR"
+    PENDING = "PENDING"
+    QUEUED = "QUEUED"
+    RUNNING = "RUNNING"
+    SKIPPED = "SKIPPED"
+    TERMINATED = "TERMINATED"
+    TERMINATING = "TERMINATING"
+    WAITING_FOR_RETRY = "WAITING_FOR_RETRY"
+
+
 class RunState:
     """Utility class for the run state concept of Databricks runs."""
 
diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index ff8de10132..d6118f247f 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -29,13 +29,18 @@ from deprecated import deprecated
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator, BaseOperatorLink, XCom
-from airflow.providers.databricks.hooks.databricks import DatabricksHook, 
RunState
+from airflow.providers.databricks.hooks.databricks import DatabricksHook, 
RunLifeCycleState, RunState
+from airflow.providers.databricks.operators.databricks_workflow import (
+    DatabricksWorkflowTaskGroup,
+    WorkflowRunMetadata,
+)
 from airflow.providers.databricks.triggers.databricks import 
DatabricksExecutionTrigger
 from airflow.providers.databricks.utils.databricks import 
normalise_json_content, validate_trigger_event
 
 if TYPE_CHECKING:
     from airflow.models.taskinstancekey import TaskInstanceKey
     from airflow.utils.context import Context
+    from airflow.utils.task_group import TaskGroup
 
 DEFER_METHOD_NAME = "execute_complete"
 XCOM_RUN_ID_KEY = "run_id"
@@ -926,7 +931,10 @@ class DatabricksNotebookOperator(BaseOperator):
     :param deferrable: Run operator in the deferrable mode.
     """
 
-    template_fields = ("notebook_params",)
+    template_fields = (
+        "notebook_params",
+        "workflow_run_metadata",
+    )
     CALLER = "DatabricksNotebookOperator"
 
     def __init__(
@@ -944,6 +952,7 @@ class DatabricksNotebookOperator(BaseOperator):
         databricks_retry_args: dict[Any, Any] | None = None,
         wait_for_termination: bool = True,
         databricks_conn_id: str = "databricks_default",
+        workflow_run_metadata: dict | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs: Any,
     ):
@@ -962,6 +971,10 @@ class DatabricksNotebookOperator(BaseOperator):
         self.databricks_conn_id = databricks_conn_id
         self.databricks_run_id: int | None = None
         self.deferrable = deferrable
+
+        # This is used to store the metadata of the Databricks job run when 
the job is launched from within DatabricksWorkflowTaskGroup.
+        self.workflow_run_metadata: dict | None = workflow_run_metadata
+
         super().__init__(**kwargs)
 
     @cached_property
@@ -1016,6 +1029,79 @@ class DatabricksNotebookOperator(BaseOperator):
         """Get the databricks task ID using dag_id and task_id. Removes 
illegal characters."""
         return f"{self.dag_id}__{task_id.replace('.', '__')}"
 
+    @property
+    def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | 
None:
+        """
+        Traverse up parent TaskGroups until the `is_databricks` flag 
associated with the root DatabricksWorkflowTaskGroup is found.
+
+        If found, returns the task group. Otherwise, return None.
+        """
+        parent_tg: TaskGroup | DatabricksWorkflowTaskGroup | None = 
self.task_group
+
+        while parent_tg:
+            if getattr(parent_tg, "is_databricks", False):
+                return parent_tg  # type: ignore[return-value]
+
+            if getattr(parent_tg, "task_group", None):
+                parent_tg = parent_tg.task_group
+            else:
+                return None
+
+        return None
+
+    def _extend_workflow_notebook_packages(
+        self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup
+    ) -> None:
+        """Extend the task group packages into the notebook's packages, 
without adding any duplicates."""
+        for task_group_package in 
databricks_workflow_task_group.notebook_packages:
+            exists = any(
+                task_group_package == existing_package for existing_package in 
self.notebook_packages
+            )
+            if not exists:
+                self.notebook_packages.append(task_group_package)
+
+    def _convert_to_databricks_workflow_task(
+        self, relevant_upstreams: list[BaseOperator], context: Context | None 
= None
+    ) -> dict[str, object]:
+        """Convert the operator to a Databricks workflow task that can be a 
task in a workflow."""
+        databricks_workflow_task_group = self._databricks_workflow_task_group
+        if not databricks_workflow_task_group:
+            raise AirflowException(
+                "Calling `_convert_to_databricks_workflow_task` without a 
parent TaskGroup."
+            )
+
+        if hasattr(databricks_workflow_task_group, "notebook_packages"):
+            
self._extend_workflow_notebook_packages(databricks_workflow_task_group)
+
+        if hasattr(databricks_workflow_task_group, "notebook_params"):
+            self.notebook_params = {
+                **self.notebook_params,
+                **databricks_workflow_task_group.notebook_params,
+            }
+
+        base_task_json = self._get_task_base_json()
+        result = {
+            "task_key": self._get_databricks_task_id(self.task_id),
+            "depends_on": [
+                {"task_key": self._get_databricks_task_id(task_id)}
+                for task_id in self.upstream_task_ids
+                if task_id in relevant_upstreams
+            ],
+            **base_task_json,
+        }
+
+        if self.existing_cluster_id and self.job_cluster_key:
+            raise ValueError(
+                "Both existing_cluster_id and job_cluster_key are set. Only 
one can be set per task."
+            )
+
+        if self.existing_cluster_id:
+            result["existing_cluster_id"] = self.existing_cluster_id
+        elif self.job_cluster_key:
+            result["job_cluster_key"] = self.job_cluster_key
+
+        return result
+
     def _get_run_json(self) -> dict[str, Any]:
         """Get run json to be used for task submissions."""
         run_json = {
@@ -1039,6 +1125,17 @@ class DatabricksNotebookOperator(BaseOperator):
         self.log.info("Check the job run in Databricks: %s", url)
         return self.databricks_run_id
 
+    def _handle_terminal_run_state(self, run_state: RunState) -> None:
+        if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value:
+            raise AirflowException(
+                f"Databricks job failed with state 
{run_state.life_cycle_state}. Message: {run_state.state_message}"
+            )
+        if not run_state.is_successful:
+            raise AirflowException(
+                f"Task failed. Final state {run_state.result_state}. Reason: 
{run_state.state_message}"
+            )
+        self.log.info("Task succeeded. Final state %s.", 
run_state.result_state)
+
     def monitor_databricks_job(self) -> None:
         if self.databricks_run_id is None:
             raise ValueError("Databricks job not yet launched. Please run 
launch_notebook_job first.")
@@ -1063,34 +1160,28 @@ class DatabricksNotebookOperator(BaseOperator):
             run = self._hook.get_run(self.databricks_run_id)
             run_state = RunState(**run["state"])
             self.log.info(
-                "task %s %s", self._get_databricks_task_id(self.task_id), 
run_state.life_cycle_state
-            )
-            self.log.info("Current state of the job: %s", 
run_state.life_cycle_state)
-        if run_state.life_cycle_state != "TERMINATED":
-            raise AirflowException(
-                f"Databricks job failed with state 
{run_state.life_cycle_state}. "
-                f"Message: {run_state.state_message}"
+                "Current state of the databricks task %s is %s",
+                self._get_databricks_task_id(self.task_id),
+                run_state.life_cycle_state,
             )
-        if not run_state.is_successful:
-            raise AirflowException(
-                f"Task failed. Final state {run_state.result_state}. Reason: 
{run_state.state_message}"
-            )
-        self.log.info("Task succeeded. Final state %s.", 
run_state.result_state)
+        self._handle_terminal_run_state(run_state)
 
     def execute(self, context: Context) -> None:
-        self.launch_notebook_job()
+        if self._databricks_workflow_task_group:
+            # If we are in a DatabricksWorkflowTaskGroup, we should have an 
upstream task launched.
+            if not self.workflow_run_metadata:
+                launch_task_id = next(task for task in self.upstream_task_ids 
if task.endswith(".launch"))
+                self.workflow_run_metadata = 
context["ti"].xcom_pull(task_ids=launch_task_id)
+            workflow_run_metadata = WorkflowRunMetadata(  # type: 
ignore[arg-type]
+                **self.workflow_run_metadata
+            )
+            self.databricks_run_id = workflow_run_metadata.run_id
+            self.databricks_conn_id = workflow_run_metadata.conn_id
+        else:
+            self.launch_notebook_job()
         if self.wait_for_termination:
             self.monitor_databricks_job()
 
     def execute_complete(self, context: dict | None, event: dict) -> None:
         run_state = RunState.from_json(event["run_state"])
-        if run_state.life_cycle_state != "TERMINATED":
-            raise AirflowException(
-                f"Databricks job failed with state 
{run_state.life_cycle_state}. "
-                f"Message: {run_state.state_message}"
-            )
-        if not run_state.is_successful:
-            raise AirflowException(
-                f"Task failed. Final state {run_state.result_state}. Reason: 
{run_state.state_message}"
-            )
-        self.log.info("Task succeeded. Final state %s.", 
run_state.result_state)
+        self._handle_terminal_run_state(run_state)
diff --git a/airflow/providers/databricks/operators/databricks_workflow.py 
b/airflow/providers/databricks/operators/databricks_workflow.py
new file mode 100644
index 0000000000..8203145314
--- /dev/null
+++ b/airflow/providers/databricks/operators/databricks_workflow.py
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import time
+from dataclasses import dataclass
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from mergedeep import merge
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.databricks.hooks.databricks import DatabricksHook, 
RunLifeCycleState
+from airflow.utils.task_group import TaskGroup
+
+if TYPE_CHECKING:
+    from types import TracebackType
+
+    from airflow.models.taskmixin import DAGNode
+    from airflow.utils.context import Context
+
+
+@dataclass
+class WorkflowRunMetadata:
+    """
+    Metadata for a Databricks workflow run.
+
+    :param run_id: The ID of the Databricks workflow run.
+    :param job_id: The ID of the Databricks workflow job.
+    :param conn_id: The connection ID used to connect to Databricks.
+    """
+
+    conn_id: str
+    job_id: str
+    run_id: int
+
+
+def _flatten_node(
+    node: TaskGroup | BaseOperator | DAGNode, tasks: list[BaseOperator] | None 
= None
+) -> list[BaseOperator]:
+    """Flatten a node (either a TaskGroup or Operator) to a list of nodes."""
+    if tasks is None:
+        tasks = []
+    if isinstance(node, BaseOperator):
+        return [node]
+
+    if isinstance(node, TaskGroup):
+        new_tasks = []
+        for _, child in node.children.items():
+            new_tasks += _flatten_node(child, tasks)
+
+        return tasks + new_tasks
+
+    return tasks
+
+
+class _CreateDatabricksWorkflowOperator(BaseOperator):
+    """
+    Creates a Databricks workflow from a DatabricksWorkflowTaskGroup specified 
in a DAG.
+
+    :param task_id: The task_id of the operator
+    :param databricks_conn_id: The connection ID to use when connecting to 
Databricks.
+    :param existing_clusters: A list of existing clusters to use for the 
workflow.
+    :param extra_job_params: A dictionary of extra properties which will 
override the default Databricks
+        Workflow Job definitions.
+    :param job_clusters: A list of job clusters to use for the workflow.
+    :param max_concurrent_runs: The maximum number of concurrent runs for the 
workflow.
+    :param notebook_params: A dictionary of notebook parameters to pass to the 
workflow. These parameters
+        will be passed to all notebooks in the workflow.
+    :param tasks_to_convert: A list of tasks to convert to a Databricks 
workflow. This list can also be
+        populated after instantiation using the `add_task` method.
+    """
+
+    template_fields = ("notebook_params",)
+    caller = "_CreateDatabricksWorkflowOperator"
+
+    def __init__(
+        self,
+        task_id: str,
+        databricks_conn_id: str,
+        existing_clusters: list[str] | None = None,
+        extra_job_params: dict[str, Any] | None = None,
+        job_clusters: list[dict[str, object]] | None = None,
+        max_concurrent_runs: int = 1,
+        notebook_params: dict | None = None,
+        tasks_to_convert: list[BaseOperator] | None = None,
+        **kwargs,
+    ):
+        self.databricks_conn_id = databricks_conn_id
+        self.existing_clusters = existing_clusters or []
+        self.extra_job_params = extra_job_params or {}
+        self.job_clusters = job_clusters or []
+        self.max_concurrent_runs = max_concurrent_runs
+        self.notebook_params = notebook_params or {}
+        self.tasks_to_convert = tasks_to_convert or []
+        self.relevant_upstreams = [task_id]
+        super().__init__(task_id=task_id, **kwargs)
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            caller=caller,
+        )
+
+    @cached_property
+    def _hook(self) -> DatabricksHook:
+        return self._get_hook(caller=self.caller)
+
+    def add_task(self, task: BaseOperator) -> None:
+        """Add a task to the list of tasks to convert to a Databricks 
workflow."""
+        self.tasks_to_convert.append(task)
+
+    @property
+    def job_name(self) -> str:
+        if not self.task_group:
+            raise AirflowException("Task group must be set before accessing 
job_name")
+        return f"{self.dag_id}.{self.task_group.group_id}"
+
+    def create_workflow_json(self, context: Context | None = None) -> 
dict[str, object]:
+        """Create a workflow json to be used in the Databricks API."""
+        task_json = [
+            task._convert_to_databricks_workflow_task(  # type: 
ignore[attr-defined]
+                relevant_upstreams=self.relevant_upstreams, context=context
+            )
+            for task in self.tasks_to_convert
+        ]
+
+        default_json = {
+            "name": self.job_name,
+            "email_notifications": {"no_alert_for_skipped_runs": False},
+            "timeout_seconds": 0,
+            "tasks": task_json,
+            "format": "MULTI_TASK",
+            "job_clusters": self.job_clusters,
+            "max_concurrent_runs": self.max_concurrent_runs,
+        }
+        return merge(default_json, self.extra_job_params)
+
+    def _create_or_reset_job(self, context: Context) -> int:
+        job_spec = self.create_workflow_json(context=context)
+        existing_jobs = self._hook.list_jobs(job_name=self.job_name)
+        job_id = existing_jobs[0]["job_id"] if existing_jobs else None
+        if job_id:
+            self.log.info(
+                "Updating existing Databricks workflow job %s with spec %s",
+                self.job_name,
+                json.dumps(job_spec, indent=2),
+            )
+            self._hook.reset_job(job_id, job_spec)
+        else:
+            self.log.info(
+                "Creating new Databricks workflow job %s with spec %s",
+                self.job_name,
+                json.dumps(job_spec, indent=2),
+            )
+            job_id = self._hook.create_job(job_spec)
+        return job_id
+
+    def _wait_for_job_to_start(self, run_id: int) -> None:
+        run_url = self._hook.get_run_page_url(run_id)
+        self.log.info("Check the progress of the Databricks job at %s", 
run_url)
+        life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state
+        if life_cycle_state not in (
+            RunLifeCycleState.PENDING.value,
+            RunLifeCycleState.RUNNING.value,
+            RunLifeCycleState.BLOCKED.value,
+        ):
+            raise AirflowException(f"Could not start the workflow job. State: 
{life_cycle_state}")
+        while life_cycle_state in (RunLifeCycleState.PENDING.value, 
RunLifeCycleState.BLOCKED.value):
+            self.log.info("Waiting for the Databricks job to start running")
+            time.sleep(5)
+            life_cycle_state = 
self._hook.get_run_state(run_id).life_cycle_state
+        self.log.info("Databricks job started. State: %s", life_cycle_state)
+
+    def execute(self, context: Context) -> Any:
+        if not isinstance(self.task_group, DatabricksWorkflowTaskGroup):
+            raise AirflowException("Task group must be a 
DatabricksWorkflowTaskGroup")
+
+        job_id = self._create_or_reset_job(context)
+
+        run_id = self._hook.run_now(
+            {
+                "job_id": job_id,
+                "jar_params": self.task_group.jar_params,
+                "notebook_params": self.notebook_params,
+                "python_params": self.task_group.python_params,
+                "spark_submit_params": self.task_group.spark_submit_params,
+            }
+        )
+
+        self._wait_for_job_to_start(run_id)
+
+        return {
+            "conn_id": self.databricks_conn_id,
+            "job_id": job_id,
+            "run_id": run_id,
+        }
+
+
+class DatabricksWorkflowTaskGroup(TaskGroup):
+    """
+    A task group that takes a list of tasks and creates a databricks workflow.
+
+    The DatabricksWorkflowTaskGroup takes a list of tasks and creates a 
databricks workflow
+    based on the metadata produced by those tasks. For a task to be eligible 
for this
+    TaskGroup, it must contain the ``_convert_to_databricks_workflow_task`` 
method. If any tasks
+    do not contain this method then the Taskgroup will raise an error at parse 
time.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksWorkflowTaskGroup`
+
+    :param databricks_conn_id: The name of the databricks connection to use.
+    :param existing_clusters: A list of existing clusters to use for this 
workflow.
+    :param extra_job_params: A dictionary containing properties which will 
override the default
+        Databricks Workflow Job definitions.
+    :param jar_params: A list of jar parameters to pass to the workflow. These 
parameters will be passed to all jar
+        tasks in the workflow.
+    :param job_clusters: A list of job clusters to use for this workflow.
+    :param max_concurrent_runs: The maximum number of concurrent runs for this 
workflow.
+    :param notebook_packages: A list of dictionary of Python packages to be 
installed. Packages defined
+        at the workflow task group level are installed for each of the 
notebook tasks under it. And
+        packages defined at the notebook task level are installed specific for 
the notebook task.
+    :param notebook_params: A dictionary of notebook parameters to pass to the 
workflow. These parameters
+        will be passed to all notebook tasks in the workflow.
+    :param python_params: A list of python parameters to pass to the workflow. 
These parameters will be passed to
+        all python tasks in the workflow.
+    :param spark_submit_params: A list of spark submit parameters to pass to 
the workflow. These parameters
+        will be passed to all spark submit tasks.
+    """
+
+    is_databricks = True
+
+    def __init__(
+        self,
+        databricks_conn_id: str,
+        existing_clusters: list[str] | None = None,
+        extra_job_params: dict[str, Any] | None = None,
+        jar_params: list[str] | None = None,
+        job_clusters: list[dict] | None = None,
+        max_concurrent_runs: int = 1,
+        notebook_packages: list[dict[str, Any]] | None = None,
+        notebook_params: dict | None = None,
+        python_params: list | None = None,
+        spark_submit_params: list | None = None,
+        **kwargs,
+    ):
+        self.databricks_conn_id = databricks_conn_id
+        self.existing_clusters = existing_clusters or []
+        self.extra_job_params = extra_job_params or {}
+        self.jar_params = jar_params or []
+        self.job_clusters = job_clusters or []
+        self.max_concurrent_runs = max_concurrent_runs
+        self.notebook_packages = notebook_packages or []
+        self.notebook_params = notebook_params or {}
+        self.python_params = python_params or []
+        self.spark_submit_params = spark_submit_params or []
+        super().__init__(**kwargs)
+
+    def __exit__(
+        self, _type: type[BaseException] | None, _value: BaseException | None, 
_tb: TracebackType | None
+    ) -> None:
+        """Exit the context manager and add tasks to a single 
``_CreateDatabricksWorkflowOperator``."""
+        roots = list(self.get_roots())
+        tasks = _flatten_node(self)
+
+        create_databricks_workflow_task = _CreateDatabricksWorkflowOperator(
+            dag=self.dag,
+            task_group=self,
+            task_id="launch",
+            databricks_conn_id=self.databricks_conn_id,
+            existing_clusters=self.existing_clusters,
+            extra_job_params=self.extra_job_params,
+            job_clusters=self.job_clusters,
+            max_concurrent_runs=self.max_concurrent_runs,
+            notebook_params=self.notebook_params,
+        )
+
+        for task in tasks:
+            if not (
+                hasattr(task, "_convert_to_databricks_workflow_task")
+                and callable(task._convert_to_databricks_workflow_task)
+            ):
+                raise AirflowException(
+                    f"Task {task.task_id} does not support conversion to 
databricks workflow task."
+                )
+
+            task.workflow_run_metadata = create_databricks_workflow_task.output
+            
create_databricks_workflow_task.relevant_upstreams.append(task.task_id)
+            create_databricks_workflow_task.add_task(task)
+
+        for root_task in roots:
+            root_task.set_upstream(create_databricks_workflow_task)
+
+        super().__exit__(_type, _value, _tb)
diff --git a/airflow/providers/databricks/provider.yaml 
b/airflow/providers/databricks/provider.yaml
index a6e2174640..80506dc16c 100644
--- a/airflow/providers/databricks/provider.yaml
+++ b/airflow/providers/databricks/provider.yaml
@@ -72,6 +72,7 @@ dependencies:
   # The 2.9.1 (to be released soon) already contains the fix
   - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0
   - aiohttp>=3.9.2, <4
+  - mergedeep>=1.3.4
 
 additional-extras:
   # pip install apache-airflow-providers-databricks[sdk]
@@ -108,6 +109,12 @@ integrations:
       - /docs/apache-airflow-providers-databricks/operators/repos_delete.rst
     logo: /integration-logos/databricks/Databricks.png
     tags: [service]
+  - integration-name: Databricks Workflow
+    external-doc-url: https://docs.databricks.com/en/workflows/index.html
+    how-to-guide:
+      - /docs/apache-airflow-providers-databricks/operators/workflow.rst
+    logo: /integration-logos/databricks/Databricks.png
+    tags: [service]
 
 operators:
   - integration-name: Databricks
@@ -119,6 +126,9 @@ operators:
   - integration-name: Databricks Repos
     python-modules:
       - airflow.providers.databricks.operators.databricks_repos
+  - integration-name: Databricks Workflow
+    python-modules:
+      - airflow.providers.databricks.operators.databricks_workflow
 
 hooks:
   - integration-name: Databricks
diff --git 
a/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png
 
b/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png
new file mode 100644
index 0000000000..3a3cb669e0
Binary files /dev/null and 
b/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png
 differ
diff --git 
a/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png
 
b/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png
new file mode 100644
index 0000000000..cb189b8105
Binary files /dev/null and 
b/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png
 differ
diff --git a/docs/apache-airflow-providers-databricks/operators/workflow.rst 
b/docs/apache-airflow-providers-databricks/operators/workflow.rst
new file mode 100644
index 0000000000..f58514dd51
--- /dev/null
+++ b/docs/apache-airflow-providers-databricks/operators/workflow.rst
@@ -0,0 +1,71 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/operator:DatabricksWorkflowTaskGroup:
+
+
+DatabricksWorkflowTaskGroup
+===========================
+
+Use the 
:class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup`
 to launch and monitor
+Databricks notebook job runs as Airflow tasks. The task group launches a 
`Databricks Workflow <https://docs.databricks.com/en/workflows/index.html/>`_ 
and runs the notebook jobs from within it, resulting in a `75% cost reduction 
<https://www.databricks.com/product/pricing>`_ ($0.40/DBU for all-purpose 
compute, $0.07/DBU for Jobs compute) when compared to executing 
``DatabricksNotebookOperator`` outside of ``DatabricksWorkflowTaskGroup``.
+
+
+There are a few advantages to defining your Databricks Workflows in Airflow:
+
+=======================================  
=============================================  =================================
+Authoring interface                      via Databricks (Web-based with 
Databricks UI)  via Airflow(Code with Airflow DAG)
+=======================================  
=============================================  =================================
+Workflow compute pricing                 ✅                                     
        ✅
+Notebook code in source control          ✅                                     
        ✅
+Workflow structure in source control                                           
         ✅
+Retry from beginning                     ✅                                     
        ✅
+Retry single task                                                              
         ✅
+Task groups within Workflows                                                   
         ✅
+Trigger workflows from other DAGs                                              
         ✅
+Workflow-level parameters                                                      
         ✅
+=======================================  
=============================================  =================================
+
+Examples
+--------
+
+Example of what a DAG looks like with a DatabricksWorkflowTaskGroup
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. exampleinclude:: 
/../../tests/system/providers/databricks/example_databricks_workflow.py
+    :language: python
+    :start-after: [START howto_databricks_workflow_notebook]
+    :end-before: [END howto_databricks_workflow_notebook]
+
+With this example, Airflow will produce a job named 
``<dag_name>.test_workflow_<USER>_<GROUP_ID>`` that will
+run task ``notebook_1`` and then ``notebook_2``. The job will be created in 
the databricks workspace
+if it does not already exist. If the job already exists, it will be updated to 
match
+the workflow defined in the DAG.
+
+The following image displays the resulting Databricks Workflow in the Airflow 
UI (based on the above example provided)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. image:: ../img/databricks_workflow_task_group_airflow_graph_view.png
+
+The corresponding Databricks Workflow  in the Databricks UI for the run 
triggered from the Airflow DAG is depicted below
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. image:: ../img/workflow_run_databricks_graph_view.png
+
+
+To minimize update conflicts, we recommend that you keep parameters in the 
``notebook_params`` of the
+``DatabricksWorkflowTaskGroup`` and not in the ``DatabricksNotebookOperator`` 
whenever possible.
+This is because, tasks in the ``DatabricksWorkflowTaskGroup`` are passed in on 
the job trigger time and
+do not modify the job definition.
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 7f29984d32..01c1e33785 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -411,6 +411,7 @@
       "apache-airflow-providers-common-sql>=1.10.0",
       "apache-airflow>=2.7.0",
       "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
+      "mergedeep>=1.3.4",
       "requests>=2.27.0,<3"
     ],
     "devel-deps": [
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index d6e7eb3892..2774385ea5 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -2014,3 +2014,122 @@ class TestDatabricksNotebookOperator:
             "Set it instead to `None` if you desire the task to run 
indefinitely."
         )
         assert str(exc_info.value) == exception_message
+
+    def test_extend_workflow_notebook_packages(self):
+        """Test that the operator can extend the notebook packages of a 
Databricks workflow task group."""
+        databricks_workflow_task_group = MagicMock()
+        databricks_workflow_task_group.notebook_packages = [
+            {"pypi": {"package": "numpy"}},
+            {"pypi": {"package": "pandas"}},
+        ]
+
+        operator = DatabricksNotebookOperator(
+            notebook_path="/path/to/notebook",
+            source="WORKSPACE",
+            task_id="test_task",
+            notebook_packages=[
+                {"pypi": {"package": "numpy"}},
+                {"pypi": {"package": "scipy"}},
+            ],
+        )
+
+        
operator._extend_workflow_notebook_packages(databricks_workflow_task_group)
+
+        assert operator.notebook_packages == [
+            {"pypi": {"package": "numpy"}},
+            {"pypi": {"package": "scipy"}},
+            {"pypi": {"package": "pandas"}},
+        ]
+
+    def test_convert_to_databricks_workflow_task(self):
+        """Test that the operator can convert itself to a Databricks workflow 
task."""
+        dag = DAG(dag_id="example_dag", start_date=datetime.now())
+        operator = DatabricksNotebookOperator(
+            notebook_path="/path/to/notebook",
+            source="WORKSPACE",
+            task_id="test_task",
+            notebook_packages=[
+                {"pypi": {"package": "numpy"}},
+                {"pypi": {"package": "scipy"}},
+            ],
+            dag=dag,
+        )
+
+        databricks_workflow_task_group = MagicMock()
+        databricks_workflow_task_group.notebook_packages = [
+            {"pypi": {"package": "numpy"}},
+        ]
+        databricks_workflow_task_group.notebook_params = {"param1": "value1"}
+
+        operator.notebook_packages = [{"pypi": {"package": "pandas"}}]
+        operator.notebook_params = {"param2": "value2"}
+        operator.task_group = databricks_workflow_task_group
+        operator.task_id = "test_task"
+        operator.upstream_task_ids = ["upstream_task"]
+        relevant_upstreams = [MagicMock(task_id="upstream_task")]
+
+        task_json = 
operator._convert_to_databricks_workflow_task(relevant_upstreams)
+
+        expected_json = {
+            "task_key": "example_dag__test_task",
+            "depends_on": [],
+            "timeout_seconds": 0,
+            "email_notifications": {},
+            "notebook_task": {
+                "notebook_path": "/path/to/notebook",
+                "source": "WORKSPACE",
+                "base_parameters": {
+                    "param2": "value2",
+                    "param1": "value1",
+                },
+            },
+            "libraries": [
+                {"pypi": {"package": "pandas"}},
+                {"pypi": {"package": "numpy"}},
+            ],
+        }
+
+        assert task_json == expected_json
+
+    def test_convert_to_databricks_workflow_task_no_task_group(self):
+        """Test that an error is raised if the operator is not in a 
TaskGroup."""
+        operator = DatabricksNotebookOperator(
+            notebook_path="/path/to/notebook",
+            source="WORKSPACE",
+            task_id="test_task",
+            notebook_packages=[
+                {"pypi": {"package": "numpy"}},
+                {"pypi": {"package": "scipy"}},
+            ],
+        )
+        operator.task_group = None
+        relevant_upstreams = [MagicMock(task_id="upstream_task")]
+
+        with pytest.raises(
+            AirflowException,
+            match="Calling `_convert_to_databricks_workflow_task` without a 
parent TaskGroup.",
+        ):
+            operator._convert_to_databricks_workflow_task(relevant_upstreams)
+
+    def test_convert_to_databricks_workflow_task_cluster_conflict(self):
+        """Test that an error is raised if both `existing_cluster_id` and 
`job_cluster_key` are set."""
+        operator = DatabricksNotebookOperator(
+            notebook_path="/path/to/notebook",
+            source="WORKSPACE",
+            task_id="test_task",
+            notebook_packages=[
+                {"pypi": {"package": "numpy"}},
+                {"pypi": {"package": "scipy"}},
+            ],
+        )
+        databricks_workflow_task_group = MagicMock()
+        operator.existing_cluster_id = "existing-cluster-id"
+        operator.job_cluster_key = "job-cluster-key"
+        operator.task_group = databricks_workflow_task_group
+        relevant_upstreams = [MagicMock(task_id="upstream_task")]
+
+        with pytest.raises(
+            ValueError,
+            match="Both existing_cluster_id and job_cluster_key are set. Only 
one can be set per task.",
+        ):
+            operator._convert_to_databricks_workflow_task(relevant_upstreams)
diff --git a/tests/providers/databricks/operators/test_databricks_workflow.py 
b/tests/providers/databricks/operators/test_databricks_workflow.py
new file mode 100644
index 0000000000..99f1a9d148
--- /dev/null
+++ b/tests/providers/databricks/operators/test_databricks_workflow.py
@@ -0,0 +1,233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow import DAG
+from airflow.exceptions import AirflowException
+from airflow.models.baseoperator import BaseOperator
+from airflow.operators.empty import EmptyOperator
+from airflow.providers.databricks.hooks.databricks import RunLifeCycleState
+from airflow.providers.databricks.operators.databricks_workflow import (
+    DatabricksWorkflowTaskGroup,
+    _CreateDatabricksWorkflowOperator,
+    _flatten_node,
+)
+from airflow.utils import timezone
+
+pytestmark = pytest.mark.db_test
+
+DEFAULT_DATE = timezone.datetime(2021, 1, 1)
+
+
+@pytest.fixture
+def mock_databricks_hook():
+    """Provide a mock DatabricksHook."""
+    with 
patch("airflow.providers.databricks.operators.databricks_workflow.DatabricksHook")
 as mock_hook:
+        yield mock_hook
+
+
+@pytest.fixture
+def context():
+    """Provide a mock context object."""
+    return MagicMock()
+
+
+@pytest.fixture
+def mock_task_group():
+    """Provide a mock DatabricksWorkflowTaskGroup with necessary attributes."""
+    mock_group = MagicMock(spec=DatabricksWorkflowTaskGroup)
+    mock_group.group_id = "test_group"
+    return mock_group
+
+
+def test_flatten_node():
+    """Test that _flatten_node returns a flat list of operators."""
+    task_group = MagicMock(spec=DatabricksWorkflowTaskGroup)
+    base_operator = MagicMock(spec=BaseOperator)
+    task_group.children = {"task1": base_operator, "task2": base_operator}
+
+    result = _flatten_node(task_group)
+    assert result == [base_operator, base_operator]
+
+
+def test_create_workflow_json(mock_databricks_hook, context, mock_task_group):
+    """Test that _CreateDatabricksWorkflowOperator.create_workflow_json 
returns the expected JSON."""
+    operator = _CreateDatabricksWorkflowOperator(
+        task_id="test_task",
+        databricks_conn_id="databricks_default",
+    )
+    operator.task_group = mock_task_group
+
+    task = MagicMock(spec=BaseOperator)
+    task._convert_to_databricks_workflow_task = MagicMock(return_value={})
+    operator.add_task(task)
+
+    workflow_json = operator.create_workflow_json(context=context)
+
+    assert ".test_group" in workflow_json["name"]
+    assert "tasks" in workflow_json
+    assert workflow_json["format"] == "MULTI_TASK"
+    assert workflow_json["email_notifications"] == 
{"no_alert_for_skipped_runs": False}
+    assert workflow_json["job_clusters"] == []
+    assert workflow_json["max_concurrent_runs"] == 1
+    assert workflow_json["timeout_seconds"] == 0
+
+
+def test_create_or_reset_job_existing(mock_databricks_hook, context, 
mock_task_group):
+    """Test that _CreateDatabricksWorkflowOperator._create_or_reset_job resets 
the job if it already exists."""
+    operator = _CreateDatabricksWorkflowOperator(task_id="test_task", 
databricks_conn_id="databricks_default")
+    operator.task_group = mock_task_group
+    operator._hook.list_jobs.return_value = [{"job_id": 123}]
+    operator._hook.create_job.return_value = 123
+
+    job_id = operator._create_or_reset_job(context)
+    assert job_id == 123
+    operator._hook.reset_job.assert_called_once()
+
+
+def test_create_or_reset_job_new(mock_databricks_hook, context, 
mock_task_group):
+    """Test that _CreateDatabricksWorkflowOperator._create_or_reset_job 
creates a new job if it does not exist."""
+    operator = _CreateDatabricksWorkflowOperator(task_id="test_task", 
databricks_conn_id="databricks_default")
+    operator.task_group = mock_task_group
+    operator._hook.list_jobs.return_value = []
+    operator._hook.create_job.return_value = 456
+
+    job_id = operator._create_or_reset_job(context)
+    assert job_id == 456
+    operator._hook.create_job.assert_called_once()
+
+
+def test_wait_for_job_to_start(mock_databricks_hook):
+    """Test that _CreateDatabricksWorkflowOperator._wait_for_job_to_start 
waits for the job to start."""
+    operator = _CreateDatabricksWorkflowOperator(task_id="test_task", 
databricks_conn_id="databricks_default")
+    mock_hook_instance = mock_databricks_hook.return_value
+    mock_hook_instance.get_run_state.side_effect = [
+        MagicMock(life_cycle_state=RunLifeCycleState.PENDING.value),
+        MagicMock(life_cycle_state=RunLifeCycleState.RUNNING.value),
+    ]
+
+    operator._wait_for_job_to_start(123)
+    mock_hook_instance.get_run_state.assert_called()
+
+
+def test_execute(mock_databricks_hook, context, mock_task_group):
+    """Test that _CreateDatabricksWorkflowOperator.execute runs the task 
group."""
+    operator = _CreateDatabricksWorkflowOperator(task_id="test_task", 
databricks_conn_id="databricks_default")
+    operator.task_group = mock_task_group
+    mock_task_group.jar_params = {}
+    mock_task_group.python_params = {}
+    mock_task_group.spark_submit_params = {}
+
+    mock_hook_instance = mock_databricks_hook.return_value
+    mock_hook_instance.run_now.return_value = 789
+    mock_hook_instance.list_jobs.return_value = [{"job_id": 123}]
+    mock_hook_instance.get_run_state.return_value = MagicMock(
+        life_cycle_state=RunLifeCycleState.RUNNING.value
+    )
+
+    task = MagicMock(spec=BaseOperator)
+    task._convert_to_databricks_workflow_task = MagicMock(return_value={})
+    operator.add_task(task)
+
+    result = operator.execute(context)
+
+    assert result == {
+        "conn_id": "databricks_default",
+        "job_id": 123,
+        "run_id": 789,
+    }
+    mock_hook_instance.run_now.assert_called_once()
+
+
+def test_execute_invalid_task_group(context):
+    """Test that _CreateDatabricksWorkflowOperator.execute raises an exception 
if the task group is invalid."""
+    operator = _CreateDatabricksWorkflowOperator(task_id="test_task", 
databricks_conn_id="databricks_default")
+    operator.task_group = MagicMock()  # Not a DatabricksWorkflowTaskGroup
+
+    with pytest.raises(AirflowException, match="Task group must be a 
DatabricksWorkflowTaskGroup"):
+        operator.execute(context)
+
+
+@pytest.fixture
+def mock_databricks_workflow_operator():
+    with patch(
+        
"airflow.providers.databricks.operators.databricks_workflow._CreateDatabricksWorkflowOperator"
+    ) as mock_operator:
+        yield mock_operator
+
+
+def test_task_group_initialization():
+    """Test that DatabricksWorkflowTaskGroup initializes correctly."""
+    with DAG(dag_id="example_databricks_workflow_dag", 
start_date=DEFAULT_DATE) as example_dag:
+        with DatabricksWorkflowTaskGroup(
+            group_id="test_databricks_workflow", 
databricks_conn_id="databricks_conn"
+        ) as task_group:
+            task_1 = EmptyOperator(task_id="task1")
+            task_1._convert_to_databricks_workflow_task = 
MagicMock(return_value={})
+        assert task_group.group_id == "test_databricks_workflow"
+        assert task_group.databricks_conn_id == "databricks_conn"
+        assert task_group.dag == example_dag
+
+
+def test_task_group_exit_creates_operator(mock_databricks_workflow_operator):
+    """Test that DatabricksWorkflowTaskGroup creates a 
_CreateDatabricksWorkflowOperator on exit."""
+    with DAG(dag_id="example_databricks_workflow_dag", 
start_date=DEFAULT_DATE) as example_dag:
+        with DatabricksWorkflowTaskGroup(
+            group_id="test_databricks_workflow",
+            databricks_conn_id="databricks_conn",
+        ) as task_group:
+            task1 = MagicMock(task_id="task1")
+            task1._convert_to_databricks_workflow_task = 
MagicMock(return_value={})
+            task2 = MagicMock(task_id="task2")
+            task2._convert_to_databricks_workflow_task = 
MagicMock(return_value={})
+
+            task_group.add(task1)
+            task_group.add(task2)
+
+            task1.set_downstream(task2)
+
+    mock_databricks_workflow_operator.assert_called_once_with(
+        dag=example_dag,
+        task_group=task_group,
+        task_id="launch",
+        databricks_conn_id="databricks_conn",
+        existing_clusters=[],
+        extra_job_params={},
+        job_clusters=[],
+        max_concurrent_runs=1,
+        notebook_params={},
+    )
+
+
+def 
test_task_group_root_tasks_set_upstream_to_operator(mock_databricks_workflow_operator):
+    """Test that tasks added to a DatabricksWorkflowTaskGroup are set upstream 
to the operator."""
+    with DAG(dag_id="example_databricks_workflow_dag", 
start_date=DEFAULT_DATE):
+        with DatabricksWorkflowTaskGroup(
+            group_id="test_databricks_workflow1",
+            databricks_conn_id="databricks_conn",
+        ) as task_group:
+            task1 = MagicMock(task_id="task1")
+            task1._convert_to_databricks_workflow_task = 
MagicMock(return_value={})
+            task_group.add(task1)
+
+    create_operator_instance = mock_databricks_workflow_operator.return_value
+    task1.set_upstream.assert_called_once_with(create_operator_instance)
diff --git a/tests/system/providers/databricks/example_databricks_workflow.py 
b/tests/system/providers/databricks/example_databricks_workflow.py
new file mode 100644
index 0000000000..6b05f34684
--- /dev/null
+++ b/tests/system/providers/databricks/example_databricks_workflow.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG for using the DatabricksWorkflowTaskGroup and 
DatabricksNotebookOperator."""
+
+from __future__ import annotations
+
+import os
+from datetime import timedelta
+
+from airflow.models.dag import DAG
+from airflow.providers.databricks.operators.databricks import 
DatabricksNotebookOperator
+from airflow.providers.databricks.operators.databricks_workflow import 
DatabricksWorkflowTaskGroup
+from airflow.utils.timezone import datetime
+
+EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
+
+DATABRICKS_CONN_ID = os.getenv("DATABRICKS_CONN_ID", "databricks_conn")
+DATABRICKS_NOTIFICATION_EMAIL = os.getenv("DATABRICKS_NOTIFICATION_EMAIL", 
"your_em...@serviceprovider.com")
+
+GROUP_ID = os.getenv("DATABRICKS_GROUP_ID", "1234").replace(".", "_")
+USER = os.environ.get("USER")
+
+job_cluster_spec = [
+    {
+        "job_cluster_key": "Shared_job_cluster",
+        "new_cluster": {
+            "cluster_name": "",
+            "spark_version": "11.3.x-scala2.12",
+            "aws_attributes": {
+                "first_on_demand": 1,
+                "availability": "SPOT_WITH_FALLBACK",
+                "zone_id": "us-east-2b",
+                "spot_bid_price_percent": 100,
+                "ebs_volume_count": 0,
+            },
+            "node_type_id": "i3.xlarge",
+            "spark_env_vars": {"PYSPARK_PYTHON": 
"/databricks/python3/bin/python3"},
+            "enable_elastic_disk": False,
+            "data_security_mode": "LEGACY_SINGLE_USER_STANDARD",
+            "runtime_engine": "STANDARD",
+            "num_workers": 8,
+        },
+    }
+]
+dag = DAG(
+    dag_id="example_databricks_workflow",
+    start_date=datetime(2022, 1, 1),
+    schedule_interval=None,
+    catchup=False,
+    tags=["example", "databricks"],
+)
+with dag:
+    # [START howto_databricks_workflow_notebook]
+    task_group = DatabricksWorkflowTaskGroup(
+        group_id=f"test_workflow_{USER}_{GROUP_ID}",
+        databricks_conn_id=DATABRICKS_CONN_ID,
+        job_clusters=job_cluster_spec,
+        notebook_params={"ts": "{{ ts }}"},
+        notebook_packages=[
+            {
+                "pypi": {
+                    "package": "simplejson==3.18.0",  # Pin specification 
version of a package like this.
+                    "repo": "https://pypi.org/simple";,  # You can specify your 
required Pypi index here.
+                }
+            },
+        ],
+        extra_job_params={
+            "email_notifications": {
+                "on_start": [DATABRICKS_NOTIFICATION_EMAIL],
+            },
+        },
+    )
+    with task_group:
+        notebook_1 = DatabricksNotebookOperator(
+            task_id="workflow_notebook_1",
+            databricks_conn_id=DATABRICKS_CONN_ID,
+            notebook_path="/Shared/Notebook_1",
+            notebook_packages=[{"pypi": {"package": "Faker"}}],
+            source="WORKSPACE",
+            job_cluster_key="Shared_job_cluster",
+            execution_timeout=timedelta(seconds=600),
+        )
+        notebook_2 = DatabricksNotebookOperator(
+            task_id="workflow_notebook_2",
+            databricks_conn_id=DATABRICKS_CONN_ID,
+            notebook_path="/Shared/Notebook_2",
+            source="WORKSPACE",
+            job_cluster_key="Shared_job_cluster",
+            notebook_params={"foo": "bar", "ds": "{{ ds }}"},
+        )
+        notebook_1 >> notebook_2
+    # [END howto_databricks_workflow_notebook]
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to