This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 73c88418feb Fix DatabricksWorkflowTaskGroup ignoring
upstream/downstream set before with block (#68924)
73c88418feb is described below
commit 73c88418feb682776636a98eb7e4b82964f7d699
Author: Noritaka Sekiyama <[email protected]>
AuthorDate: Wed Jun 24 23:34:38 2026 +0900
Fix DatabricksWorkflowTaskGroup ignoring upstream/downstream set before
with block (#68924)
When ``>>`` was called before the ``with task_group:`` block, the
dependency was recorded on the task group but never transferred to
the launch task (upstream) or the leaf tasks (downstream), because
those tasks did not exist yet. Transfer ``self.upstream_task_ids``
and ``self.downstream_task_ids`` in ``__exit__()`` after creating
the internal tasks.
closes: #51598
Co-authored-by: shubhgurav0590 <[email protected]>
---
.../databricks/operators/databricks_workflow.py | 9 +++
.../operators/test_databricks_workflow.py | 76 ++++++++++++++++++++++
2 files changed, 85 insertions(+)
diff --git
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
index 779c2fc9f15..317bd444911 100644
---
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
+++
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py
@@ -409,5 +409,14 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
for root_task in roots:
root_task.set_upstream(create_databricks_workflow_task)
+
+ # When ``>>`` is called before the ``with`` block, the dependency
is
+ # recorded on the task group but the launch/leaf tasks don't exist
yet.
+ # Transfer those task-group-level dependencies now that the tasks
are created.
+ for upstream_id in self.upstream_task_ids:
+
create_databricks_workflow_task.set_upstream(self.dag.get_task(upstream_id))
+ for downstream_id in self.downstream_task_ids:
+ for leaf_task in self.get_leaves():
+ leaf_task.set_downstream(self.dag.get_task(downstream_id))
finally:
super().__exit__(_type, _value, _tb)
diff --git
a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
index 84069ee0ff7..0f42298beb7 100644
---
a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
+++
b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py
@@ -571,3 +571,79 @@ class TestWorkflowDependsOnWirePayload:
job_id, job_spec = launch_task._hook.reset_job.call_args.args
assert job_id == 42
self._assert_parent_depends_on(job_spec)
+
+
+class TestTaskGroupExternalDependencyTransfer:
+ """Verify that task-group-level ``>>`` deps transfer to the launch / leaf
tasks.
+
+ When ``>>`` is called *before* the ``with`` block, the launch task doesn't
+ exist yet; ``__exit__`` must transfer ``self.upstream_task_ids`` to the
+ launch task and ``self.downstream_task_ids`` to the leaf tasks.
+ """
+
+ JOB_CLUSTERS = [
+ {
+ "job_cluster_key": "c",
+ "new_cluster": {
+ "spark_version": "15.4.x-scala2.12",
+ "num_workers": 0,
+ "node_type_id": "i3.xlarge",
+ },
+ }
+ ]
+
+ def test_upstream_set_before_with_block(self):
+ with DAG(dag_id="test_up_before", start_date=DEFAULT_DATE,
schedule=None) as dag:
+ start = EmptyOperator(task_id="start")
+ tg = DatabricksWorkflowTaskGroup(
+ group_id="tg", databricks_conn_id="databricks_conn",
job_clusters=self.JOB_CLUSTERS
+ )
+ start >> tg
+ with tg:
+ DatabricksNotebookOperator(
+ task_id="nb", notebook_path="/t", source="WORKSPACE",
job_cluster_key="c"
+ )
+ launch = dag.get_task("tg.launch")
+ assert "start" in launch.upstream_task_ids
+
+ def test_upstream_set_after_with_block(self):
+ with DAG(dag_id="test_up_after", start_date=DEFAULT_DATE,
schedule=None) as dag:
+ start = EmptyOperator(task_id="start")
+ tg = DatabricksWorkflowTaskGroup(
+ group_id="tg", databricks_conn_id="databricks_conn",
job_clusters=self.JOB_CLUSTERS
+ )
+ with tg:
+ DatabricksNotebookOperator(
+ task_id="nb", notebook_path="/t", source="WORKSPACE",
job_cluster_key="c"
+ )
+ start >> tg
+ launch = dag.get_task("tg.launch")
+ assert "start" in launch.upstream_task_ids
+
+ def test_downstream_set_before_with_block(self):
+ with DAG(dag_id="test_down_before", start_date=DEFAULT_DATE,
schedule=None) as dag:
+ tg = DatabricksWorkflowTaskGroup(
+ group_id="tg", databricks_conn_id="databricks_conn",
job_clusters=self.JOB_CLUSTERS
+ )
+ end = EmptyOperator(task_id="end")
+ tg >> end
+ with tg:
+ DatabricksNotebookOperator(
+ task_id="nb", notebook_path="/t", source="WORKSPACE",
job_cluster_key="c"
+ )
+ nb = dag.get_task("tg.nb")
+ assert "end" in nb.downstream_task_ids
+
+ def test_downstream_set_after_with_block(self):
+ with DAG(dag_id="test_down_after", start_date=DEFAULT_DATE,
schedule=None) as dag:
+ tg = DatabricksWorkflowTaskGroup(
+ group_id="tg", databricks_conn_id="databricks_conn",
job_clusters=self.JOB_CLUSTERS
+ )
+ end = EmptyOperator(task_id="end")
+ with tg:
+ DatabricksNotebookOperator(
+ task_id="nb", notebook_path="/t", source="WORKSPACE",
job_cluster_key="c"
+ )
+ tg >> end
+ nb = dag.get_task("tg.nb")
+ assert "end" in nb.downstream_task_ids