This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 73c88418feb Fix DatabricksWorkflowTaskGroup ignoring 
upstream/downstream set before with block (#68924)
73c88418feb is described below

commit 73c88418feb682776636a98eb7e4b82964f7d699
Author: Noritaka Sekiyama <[email protected]>
AuthorDate: Wed Jun 24 23:34:38 2026 +0900

    Fix DatabricksWorkflowTaskGroup ignoring upstream/downstream set before 
with block (#68924)
    
    When ``>>`` was called before the ``with task_group:`` block, the
    dependency was recorded on the task group but never transferred to
    the launch task (upstream) or the leaf tasks (downstream), because
    those tasks did not exist yet. Transfer ``self.upstream_task_ids``
    and ``self.downstream_task_ids`` in ``__exit__()`` after creating
    the internal tasks.
    
    closes: #51598
    
    Co-authored-by: shubhgurav0590 <[email protected]>
---
 .../databricks/operators/databricks_workflow.py    |  9 +++
 .../operators/test_databricks_workflow.py          | 76 ++++++++++++++++++++++
 2 files changed, 85 insertions(+)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
index 779c2fc9f15..317bd444911 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
@@ -409,5 +409,14 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
 
             for root_task in roots:
                 root_task.set_upstream(create_databricks_workflow_task)
+
+            # When ``>>`` is called before the ``with`` block, the dependency 
is
+            # recorded on the task group but the launch/leaf tasks don't exist 
yet.
+            # Transfer those task-group-level dependencies now that the tasks 
are created.
+            for upstream_id in self.upstream_task_ids:
+                
create_databricks_workflow_task.set_upstream(self.dag.get_task(upstream_id))
+            for downstream_id in self.downstream_task_ids:
+                for leaf_task in self.get_leaves():
+                    leaf_task.set_downstream(self.dag.get_task(downstream_id))
         finally:
             super().__exit__(_type, _value, _tb)
diff --git 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
index 84069ee0ff7..0f42298beb7 100644
--- 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
+++ 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
@@ -571,3 +571,79 @@ class TestWorkflowDependsOnWirePayload:
         job_id, job_spec = launch_task._hook.reset_job.call_args.args
         assert job_id == 42
         self._assert_parent_depends_on(job_spec)
+
+
+class TestTaskGroupExternalDependencyTransfer:
+    """Verify that task-group-level ``>>`` deps transfer to the launch / leaf 
tasks.
+
+    When ``>>`` is called *before* the ``with`` block, the launch task doesn't
+    exist yet; ``__exit__`` must transfer ``self.upstream_task_ids`` to the
+    launch task and ``self.downstream_task_ids`` to the leaf tasks.
+    """
+
+    JOB_CLUSTERS = [
+        {
+            "job_cluster_key": "c",
+            "new_cluster": {
+                "spark_version": "15.4.x-scala2.12",
+                "num_workers": 0,
+                "node_type_id": "i3.xlarge",
+            },
+        }
+    ]
+
+    def test_upstream_set_before_with_block(self):
+        with DAG(dag_id="test_up_before", start_date=DEFAULT_DATE, 
schedule=None) as dag:
+            start = EmptyOperator(task_id="start")
+            tg = DatabricksWorkflowTaskGroup(
+                group_id="tg", databricks_conn_id="databricks_conn", 
job_clusters=self.JOB_CLUSTERS
+            )
+            start >> tg
+            with tg:
+                DatabricksNotebookOperator(
+                    task_id="nb", notebook_path="/t", source="WORKSPACE", 
job_cluster_key="c"
+                )
+        launch = dag.get_task("tg.launch")
+        assert "start" in launch.upstream_task_ids
+
+    def test_upstream_set_after_with_block(self):
+        with DAG(dag_id="test_up_after", start_date=DEFAULT_DATE, 
schedule=None) as dag:
+            start = EmptyOperator(task_id="start")
+            tg = DatabricksWorkflowTaskGroup(
+                group_id="tg", databricks_conn_id="databricks_conn", 
job_clusters=self.JOB_CLUSTERS
+            )
+            with tg:
+                DatabricksNotebookOperator(
+                    task_id="nb", notebook_path="/t", source="WORKSPACE", 
job_cluster_key="c"
+                )
+            start >> tg
+        launch = dag.get_task("tg.launch")
+        assert "start" in launch.upstream_task_ids
+
+    def test_downstream_set_before_with_block(self):
+        with DAG(dag_id="test_down_before", start_date=DEFAULT_DATE, 
schedule=None) as dag:
+            tg = DatabricksWorkflowTaskGroup(
+                group_id="tg", databricks_conn_id="databricks_conn", 
job_clusters=self.JOB_CLUSTERS
+            )
+            end = EmptyOperator(task_id="end")
+            tg >> end
+            with tg:
+                DatabricksNotebookOperator(
+                    task_id="nb", notebook_path="/t", source="WORKSPACE", 
job_cluster_key="c"
+                )
+        nb = dag.get_task("tg.nb")
+        assert "end" in nb.downstream_task_ids
+
+    def test_downstream_set_after_with_block(self):
+        with DAG(dag_id="test_down_after", start_date=DEFAULT_DATE, 
schedule=None) as dag:
+            tg = DatabricksWorkflowTaskGroup(
+                group_id="tg", databricks_conn_id="databricks_conn", 
job_clusters=self.JOB_CLUSTERS
+            )
+            end = EmptyOperator(task_id="end")
+            with tg:
+                DatabricksNotebookOperator(
+                    task_id="nb", notebook_path="/t", source="WORKSPACE", 
job_cluster_key="c"
+                )
+            tg >> end
+        nb = dag.get_task("tg.nb")
+        assert "end" in nb.downstream_task_ids

Reply via email to