This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a59d2ecdc39 chore: use OL macros instead of building OL ids from
scratch (#59197)
a59d2ecdc39 is described below
commit a59d2ecdc39f5ae3446e55c037d1ce86a6df2fa7
Author: Kacper Muda <[email protected]>
AuthorDate: Mon Dec 8 17:13:49 2025 +0100
chore: use OL macros instead of building OL ids from scratch (#59197)
---
.../providers/databricks/utils/openlineage.py | 88 ++++++----------------
.../unit/databricks/hooks/test_databricks_sql.py | 2 +-
.../unit/databricks/utils/test_openlineage.py | 60 ++++-----------
.../providers/dbt/cloud/utils/openlineage.py | 79 +++++++++++--------
.../tests/unit/dbt/cloud/utils/test_openlineage.py | 41 ++++++++--
.../providers/snowflake/utils/openlineage.py | 88 ++++++----------------
.../tests/unit/snowflake/hooks/test_snowflake.py | 2 +-
.../tests/unit/snowflake/utils/test_openlineage.py | 59 ++++-----------
8 files changed, 166 insertions(+), 253 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
index 971e59f2915..56f4400df61 100644
--- a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
+++ b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Any
import requests
from airflow.providers.common.compat.openlineage.check import
require_openlineage_version
-from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone
if TYPE_CHECKING:
@@ -37,60 +36,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
-def _get_logical_date(task_instance):
- # todo: remove when min airflow version >= 3.0
- if AIRFLOW_V_3_0_PLUS:
- dagrun = task_instance.get_template_context()["dag_run"]
- return dagrun.logical_date or dagrun.run_after
-
- if hasattr(task_instance, "logical_date"):
- date = task_instance.logical_date
- else:
- date = task_instance.execution_date
-
- return date
-
-
-def _get_dag_run_clear_number(task_instance):
- # todo: remove when min airflow version >= 3.0
- if AIRFLOW_V_3_0_PLUS:
- dagrun = task_instance.get_template_context()["dag_run"]
- return dagrun.clear_number
- return task_instance.dag_run.clear_number
-
-
-# todo: move this run_id logic into OpenLineage's listener to avoid differences
-def _get_ol_run_id(task_instance) -> str:
- """
- Get OpenLineage run_id from TaskInstance.
-
- It's crucial that the task_instance's run_id creation logic matches
OpenLineage's listener implementation.
- Only then can we ensure that the generated run_id aligns with the Airflow
task,
- enabling a proper connection between events.
- """
- from airflow.providers.openlineage.plugins.adapter import
OpenLineageAdapter
-
- # Generate same OL run id as is generated for current task instance
- return OpenLineageAdapter.build_task_instance_run_id(
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- logical_date=_get_logical_date(task_instance),
- try_number=task_instance.try_number,
- map_index=task_instance.map_index,
- )
-
-
-# todo: move this run_id logic into OpenLineage's listener to avoid differences
-def _get_ol_dag_run_id(task_instance) -> str:
- from airflow.providers.openlineage.plugins.adapter import
OpenLineageAdapter
-
- return OpenLineageAdapter.build_dag_run_id(
- dag_id=task_instance.dag_id,
- logical_date=_get_logical_date(task_instance),
- clear_number=_get_dag_run_clear_number(task_instance),
- )
-
-
def _get_parent_run_facet(task_instance):
"""
Retrieve the ParentRunFacet associated with a specific Airflow task
instance.
@@ -101,22 +46,39 @@ def _get_parent_run_facet(task_instance):
"""
from openlineage.client.facet_v2 import parent_run
- from airflow.providers.openlineage.conf import namespace
+ from airflow.providers.openlineage.plugins.macros import (
+ lineage_job_name,
+ lineage_job_namespace,
+ lineage_root_job_name,
+ lineage_root_run_id,
+ lineage_run_id,
+ )
+
+ parent_run_id = lineage_run_id(task_instance)
+ parent_job_name = lineage_job_name(task_instance)
+ parent_job_namespace = lineage_job_namespace()
+
+ root_parent_run_id = lineage_root_run_id(task_instance)
+ rot_parent_job_name = lineage_root_job_name(task_instance)
+
+ try: # Added in OL provider 2.9.0, try to use it if possible
+ from airflow.providers.openlineage.plugins.macros import
lineage_root_job_namespace
- parent_run_id = _get_ol_run_id(task_instance)
- root_parent_run_id = _get_ol_dag_run_id(task_instance)
+ root_parent_job_namespace = lineage_root_job_namespace(task_instance)
+ except ImportError:
+ root_parent_job_namespace = lineage_job_namespace()
return parent_run.ParentRunFacet(
run=parent_run.Run(runId=parent_run_id),
job=parent_run.Job(
- namespace=namespace(),
- name=f"{task_instance.dag_id}.{task_instance.task_id}",
+ namespace=parent_job_namespace,
+ name=parent_job_name,
),
root=parent_run.Root(
run=parent_run.RootRun(runId=root_parent_run_id),
job=parent_run.RootJob(
- name=task_instance.dag_id,
- namespace=namespace(),
+ name=rot_parent_job_name,
+ namespace=root_parent_job_namespace,
),
),
)
@@ -209,7 +171,7 @@ def _create_ol_event_pair(
return start, end
-@require_openlineage_version(provider_min_version="2.3.0")
+@require_openlineage_version(provider_min_version="2.5.0")
def emit_openlineage_events_for_databricks_queries(
task_instance,
hook: DatabricksSqlHook | DatabricksHook | None = None,
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index 064a054d680..49d874973f8 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -570,7 +570,7 @@ def
test_get_openlineage_database_specific_lineage_with_old_openlineage_provider
hook.get_openlineage_database_info = lambda x:
mock.MagicMock(authority="auth", scheme="scheme")
expected_err = (
- "OpenLineage provider version `1.99.0` is lower than required `2.3.0`,
"
+ "OpenLineage provider version `1.99.0` is lower than required `2.5.0`,
"
"skipping function `emit_openlineage_events_for_databricks_queries`
execution"
)
with pytest.raises(AirflowOptionalProviderFeatureException,
match=expected_err):
diff --git
a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
index 20305a9af4e..2d127040d32 100644
--- a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
+++ b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
@@ -34,7 +34,6 @@ from airflow.providers.databricks.hooks.databricks import
DatabricksHook
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.providers.databricks.utils.openlineage import (
_create_ol_event_pair,
- _get_ol_run_id,
_get_parent_run_facet,
_get_queries_details_from_databricks,
_process_data_from_api,
@@ -46,40 +45,9 @@ from airflow.utils import timezone
from airflow.utils.state import TaskInstanceState
-def test_get_ol_run_id_ti_success():
- logical_date = timezone.datetime(2025, 1, 1)
- mock_ti = mock.MagicMock(
- dag_id="dag_id",
- task_id="task_id",
- map_index=1,
- try_number=1,
- logical_date=logical_date,
- state=TaskInstanceState.SUCCESS,
- )
- mock_ti.get_template_context.return_value = {"dag_run":
mock.MagicMock(logical_date=logical_date)}
-
- result = _get_ol_run_id(mock_ti)
- assert result == "01941f29-7c00-7087-8906-40e512c257bd"
-
-
-def test_get_ol_run_id_ti_failed():
- logical_date = timezone.datetime(2025, 1, 1)
- mock_ti = mock.MagicMock(
- dag_id="dag_id",
- task_id="task_id",
- map_index=1,
- try_number=1,
- logical_date=logical_date,
- state=TaskInstanceState.FAILED,
- )
- mock_ti.get_template_context.return_value = {"dag_run":
mock.MagicMock(logical_date=logical_date)}
-
- result = _get_ol_run_id(mock_ti)
- assert result == "01941f29-7c00-7087-8906-40e512c257bd"
-
-
def test_get_parent_run_facet():
logical_date = timezone.datetime(2025, 1, 1)
+ dr = mock.MagicMock(logical_date=logical_date, clear_number=0)
mock_ti = mock.MagicMock(
dag_id="dag_id",
task_id="task_id",
@@ -87,14 +55,18 @@ def test_get_parent_run_facet():
try_number=1,
logical_date=logical_date,
state=TaskInstanceState.SUCCESS,
+ dag_run=dr,
)
- mock_ti.get_template_context.return_value = {"dag_run":
mock.MagicMock(logical_date=logical_date)}
+ mock_ti.get_template_context.return_value = {"dag_run": dr}
result = _get_parent_run_facet(mock_ti)
assert result.run.runId == "01941f29-7c00-7087-8906-40e512c257bd"
assert result.job.namespace == namespace()
assert result.job.name == "dag_id.task_id"
+ assert result.root.run.runId == "01941f29-7c00-743e-b109-28b18d0a19c5"
+ assert result.root.job.namespace == namespace()
+ assert result.root.job.name == "dag_id"
def test_run_api_call_success():
@@ -283,7 +255,7 @@ def test_create_ol_event_pair_success(mock_generate_uuid,
is_successful):
assert start_event.job == end_event.job
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries(mock_generate_uuid,
mock_version, time_machine):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
@@ -520,7 +492,7 @@ def
test_emit_openlineage_events_for_databricks_queries(mock_generate_uuid, mock
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries_without_metadata(
mock_generate_uuid, mock_version, time_machine
@@ -638,7 +610,7 @@ def
test_emit_openlineage_events_for_databricks_queries_without_metadata(
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids(
mock_generate_uuid, mock_version, time_machine
@@ -760,7 +732,7 @@ def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i
@mock.patch(
"airflow.providers.openlineage.sqlparser.SQLParser.create_namespace",
return_value="databricks_ns"
)
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace(
mock_generate_uuid, mock_version, mock_parser, time_machine
@@ -878,7 +850,7 @@ def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace_raw_ns(
mock_generate_uuid, mock_version, time_machine
@@ -997,7 +969,7 @@ def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def
test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_query_ids(
mock_generate_uuid, mock_version, time_machine
@@ -1117,7 +1089,7 @@ def
test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_q
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
def
test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_hook(mock_version):
query_ids = []
original_query_ids = copy.deepcopy(query_ids)
@@ -1142,7 +1114,7 @@ def
test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_ho
fake_adapter.emit.assert_not_called() # No events should be emitted
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
def
test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_and_hook(mock_version):
query_ids = ["1", "2"]
original_query_ids = copy.deepcopy(query_ids)
@@ -1168,7 +1140,7 @@ def
test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_
fake_adapter.emit.assert_not_called() # No events should be emitted
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
def
test_emit_openlineage_events_for_databricks_queries_missing_hook_and_query_for_extra_metadata_true(
mock_version,
):
@@ -1213,7 +1185,7 @@ def
test_emit_openlineage_events_with_old_openlineage_provider(mock_version):
return_value=fake_listener,
):
expected_err = (
- "OpenLineage provider version `1.99.0` is lower than required
`2.3.0`, "
+ "OpenLineage provider version `1.99.0` is lower than required
`2.5.0`, "
"skipping function
`emit_openlineage_events_for_databricks_queries` execution"
)
diff --git
a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py
index 241e72d9874..188ca1255f6 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py
@@ -56,7 +56,48 @@ def _get_dag_run_clear_number(task_instance):
return task_instance.dag_run.clear_number
-@require_openlineage_version(provider_min_version="2.3.0")
+def _get_parent_run_metadata(task_instance):
+ """
+ Retrieve the ParentRunMetadata associated with a specific Airflow task
instance.
+
+ This metadata helps link OpenLineage events of child jobs to the original
Airflow task execution.
+ Establishing this connection enables better lineage tracking and
observability.
+ """
+ from openlineage.common.provider.dbt import ParentRunMetadata
+
+ from airflow.providers.openlineage.plugins.macros import (
+ lineage_job_name,
+ lineage_job_namespace,
+ lineage_root_job_name,
+ lineage_root_run_id,
+ lineage_run_id,
+ )
+
+ parent_run_id = lineage_run_id(task_instance)
+ parent_job_name = lineage_job_name(task_instance)
+ parent_job_namespace = lineage_job_namespace()
+
+ root_parent_run_id = lineage_root_run_id(task_instance)
+ rot_parent_job_name = lineage_root_job_name(task_instance)
+
+ try: # Added in OL provider 2.9.0, try to use it if possible
+ from airflow.providers.openlineage.plugins.macros import
lineage_root_job_namespace
+
+ root_parent_job_namespace = lineage_root_job_namespace(task_instance)
+ except ImportError:
+ root_parent_job_namespace = lineage_job_namespace()
+
+ return ParentRunMetadata(
+ run_id=parent_run_id,
+ job_name=parent_job_name,
+ job_namespace=parent_job_namespace,
+ root_parent_run_id=root_parent_run_id,
+ root_parent_job_name=rot_parent_job_name,
+ root_parent_job_namespace=root_parent_job_namespace,
+ )
+
+
+@require_openlineage_version(provider_min_version="2.5.0")
def generate_openlineage_events_from_dbt_cloud_run(
operator: DbtCloudRunJobOperator | DbtCloudJobRunSensor, task_instance:
TaskInstance
) -> OperatorLineage:
@@ -74,14 +115,10 @@ def generate_openlineage_events_from_dbt_cloud_run(
:return: An empty OperatorLineage object indicating the completion of
events generation.
"""
- from openlineage.common.provider.dbt import DbtCloudArtifactProcessor,
ParentRunMetadata
+ from openlineage.common.provider.dbt import DbtCloudArtifactProcessor
- from airflow.providers.openlineage.conf import namespace
from airflow.providers.openlineage.extractors import OperatorLineage
- from airflow.providers.openlineage.plugins.adapter import (
- _PRODUCER,
- OpenLineageAdapter,
- )
+ from airflow.providers.openlineage.plugins.adapter import _PRODUCER
from airflow.providers.openlineage.plugins.listener import
get_openlineage_listener
# if no account_id set this will fallback
@@ -140,29 +177,7 @@ def generate_openlineage_events_from_dbt_cloud_run(
)
log.debug("Preparing OpenLineage parent job information to be included in
DBT events.")
- # generate same run id of current task instance
- parent_run_id = OpenLineageAdapter.build_task_instance_run_id(
- dag_id=task_instance.dag_id,
- task_id=operator.task_id,
- logical_date=_get_logical_date(task_instance),
- try_number=task_instance.try_number,
- map_index=task_instance.map_index,
- )
-
- root_parent_run_id = OpenLineageAdapter.build_dag_run_id(
- dag_id=task_instance.dag_id,
- logical_date=_get_logical_date(task_instance),
- clear_number=_get_dag_run_clear_number(task_instance),
- )
-
- parent_job = ParentRunMetadata(
- run_id=parent_run_id,
- job_name=f"{task_instance.dag_id}.{task_instance.task_id}",
- job_namespace=namespace(),
- root_parent_run_id=root_parent_run_id,
- root_parent_job_name=task_instance.dag_id,
- root_parent_job_namespace=namespace(),
- )
+ parent_metadata = _get_parent_run_metadata(task_instance)
adapter = get_openlineage_listener().adapter
# process each step in loop, sending generated events in the same order as
steps
@@ -178,7 +193,7 @@ def generate_openlineage_events_from_dbt_cloud_run(
processor = DbtCloudArtifactProcessor(
producer=_PRODUCER,
- job_namespace=namespace(),
+ job_namespace=parent_metadata.job_namespace,
skip_errors=False,
logger=operator.log,
manifest=manifest,
@@ -187,7 +202,7 @@ def generate_openlineage_events_from_dbt_cloud_run(
catalog=catalog,
)
- processor.dbt_run_metadata = parent_job
+ processor.dbt_run_metadata = parent_metadata
events = processor.parse().events()
log.debug("Found %s OpenLineage events for artifact no. %s.",
len(events), counter)
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py
b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py
index e5d9431f8b3..9455a669f89 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py
@@ -27,8 +27,14 @@ from packaging.version import parse
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator
-from airflow.providers.dbt.cloud.utils.openlineage import
generate_openlineage_events_from_dbt_cloud_run
+from airflow.providers.dbt.cloud.utils.openlineage import (
+ _get_parent_run_metadata,
+ generate_openlineage_events_from_dbt_cloud_run,
+)
+from airflow.providers.openlineage.conf import namespace
from airflow.providers.openlineage.extractors import OperatorLineage
+from airflow.utils import timezone
+from airflow.utils.state import TaskInstanceState
TASK_ID = "dbt_test"
DAG_ID = "dbt_dag"
@@ -94,12 +100,13 @@ def get_dbt_artifact(*args, **kwargs):
[
("1.99.0", True),
("2.0.0", True),
- ("2.3.0", False),
+ ("2.3.0", True),
+ ("2.5.0", False),
("2.99.0", False),
],
)
def test_previous_version_openlineage_provider(value, is_error):
- """When using OpenLineage, the dbt-cloud provider now depends on
openlineage provider >= 2.3"""
+ """When using OpenLineage, the dbt-cloud provider now depends on
openlineage provider >= 2.4"""
def _mock_version(package):
if package == "apache-airflow-providers-openlineage":
@@ -110,7 +117,7 @@ def test_previous_version_openlineage_provider(value,
is_error):
mock_task_instance = MagicMock()
expected_err = (
- f"OpenLineage provider version `{value}` is lower than required
`2.3.0`, "
+ f"OpenLineage provider version `{value}` is lower than required
`2.5.0`, "
"skipping function `generate_openlineage_events_from_dbt_cloud_run`
execution"
)
@@ -126,8 +133,32 @@ def test_previous_version_openlineage_provider(value,
is_error):
generate_openlineage_events_from_dbt_cloud_run(mock_operator,
mock_task_instance)
+def test_get_parent_run_metadata():
+ logical_date = timezone.datetime(2025, 1, 1)
+ dr = MagicMock(logical_date=logical_date, clear_number=0)
+ mock_ti = MagicMock(
+ dag_id="dag_id",
+ task_id="task_id",
+ map_index=1,
+ try_number=1,
+ logical_date=logical_date,
+ state=TaskInstanceState.SUCCESS,
+ dag_run=dr,
+ )
+ mock_ti.get_template_context.return_value = {"dag_run": dr}
+
+ result = _get_parent_run_metadata(mock_ti)
+
+ assert result.run_id == "01941f29-7c00-7087-8906-40e512c257bd"
+ assert result.job_namespace == namespace()
+ assert result.job_name == "dag_id.task_id"
+ assert result.root_parent_run_id == "01941f29-7c00-743e-b109-28b18d0a19c5"
+ assert result.root_parent_job_namespace == namespace()
+ assert result.root_parent_job_name == "dag_id"
+
+
class TestGenerateOpenLineageEventsFromDbtCloudRun:
- @patch("importlib.metadata.version", return_value="2.3.0")
+ @patch("importlib.metadata.version", return_value="3.0.0")
@patch("airflow.providers.openlineage.plugins.listener.get_openlineage_listener")
@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.build_task_instance_run_id")
@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.build_dag_run_id")
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py
b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py
index d06a6c463e4..483f176cf59 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py
@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urlparse, urlunparse
from airflow.providers.common.compat.openlineage.check import
require_openlineage_version
-from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone
if TYPE_CHECKING:
@@ -109,60 +108,6 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
return urlunparse((parts.scheme, hostname, parts.path, parts.params,
parts.query, parts.fragment))
-def _get_logical_date(task_instance):
- # todo: remove when min airflow version >= 3.0
- if AIRFLOW_V_3_0_PLUS:
- dagrun = task_instance.get_template_context()["dag_run"]
- return dagrun.logical_date or dagrun.run_after
-
- if hasattr(task_instance, "logical_date"):
- date = task_instance.logical_date
- else:
- date = task_instance.execution_date
-
- return date
-
-
-def _get_dag_run_clear_number(task_instance):
- # todo: remove when min airflow version >= 3.0
- if AIRFLOW_V_3_0_PLUS:
- dagrun = task_instance.get_template_context()["dag_run"]
- return dagrun.clear_number
- return task_instance.dag_run.clear_number
-
-
-# todo: move this run_id logic into OpenLineage's listener to avoid differences
-def _get_ol_run_id(task_instance) -> str:
- """
- Get OpenLineage run_id from TaskInstance.
-
- It's crucial that the task_instance's run_id creation logic matches
OpenLineage's listener implementation.
- Only then can we ensure that the generated run_id aligns with the Airflow
task,
- enabling a proper connection between events.
- """
- from airflow.providers.openlineage.plugins.adapter import
OpenLineageAdapter
-
- # Generate same OL run id as is generated for current task instance
- return OpenLineageAdapter.build_task_instance_run_id(
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- logical_date=_get_logical_date(task_instance),
- try_number=task_instance.try_number,
- map_index=task_instance.map_index,
- )
-
-
-# todo: move this run_id logic into OpenLineage's listener to avoid differences
-def _get_ol_dag_run_id(task_instance) -> str:
- from airflow.providers.openlineage.plugins.adapter import
OpenLineageAdapter
-
- return OpenLineageAdapter.build_dag_run_id(
- dag_id=task_instance.dag_id,
- logical_date=_get_logical_date(task_instance),
- clear_number=_get_dag_run_clear_number(task_instance),
- )
-
-
def _get_parent_run_facet(task_instance):
"""
Retrieve the ParentRunFacet associated with a specific Airflow task
instance.
@@ -173,22 +118,39 @@ def _get_parent_run_facet(task_instance):
"""
from openlineage.client.facet_v2 import parent_run
- from airflow.providers.openlineage.conf import namespace
+ from airflow.providers.openlineage.plugins.macros import (
+ lineage_job_name,
+ lineage_job_namespace,
+ lineage_root_job_name,
+ lineage_root_run_id,
+ lineage_run_id,
+ )
+
+ parent_run_id = lineage_run_id(task_instance)
+ parent_job_name = lineage_job_name(task_instance)
+ parent_job_namespace = lineage_job_namespace()
+
+ root_parent_run_id = lineage_root_run_id(task_instance)
+ rot_parent_job_name = lineage_root_job_name(task_instance)
+
+ try: # Added in OL provider 2.9.0, try to use it if possible
+ from airflow.providers.openlineage.plugins.macros import
lineage_root_job_namespace
- parent_run_id = _get_ol_run_id(task_instance)
- root_parent_run_id = _get_ol_dag_run_id(task_instance)
+ root_parent_job_namespace = lineage_root_job_namespace(task_instance)
+ except ImportError:
+ root_parent_job_namespace = lineage_job_namespace()
return parent_run.ParentRunFacet(
run=parent_run.Run(runId=parent_run_id),
job=parent_run.Job(
- namespace=namespace(),
- name=f"{task_instance.dag_id}.{task_instance.task_id}",
+ namespace=parent_job_namespace,
+ name=parent_job_name,
),
root=parent_run.Root(
run=parent_run.RootRun(runId=root_parent_run_id),
job=parent_run.RootJob(
- name=task_instance.dag_id,
- namespace=namespace(),
+ name=rot_parent_job_name,
+ namespace=root_parent_job_namespace,
),
),
)
@@ -299,7 +261,7 @@ def _create_snowflake_event_pair(
return start, end
-@require_openlineage_version(provider_min_version="2.3.0")
+@require_openlineage_version(provider_min_version="2.5.0")
def emit_openlineage_events_for_snowflake_queries(
task_instance,
hook: SnowflakeHook | SnowflakeSqlApiHook | None = None,
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index 4485c87c7a2..d92ca12752d 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -1046,7 +1046,7 @@ class TestPytestSnowflakeHook:
hook.get_openlineage_database_info = lambda x:
mock.MagicMock(authority="auth", scheme="scheme")
expected_err = (
- "OpenLineage provider version `1.99.0` is lower than required
`2.3.0`, "
+ "OpenLineage provider version `1.99.0` is lower than required
`2.5.0`, "
"skipping function `emit_openlineage_events_for_snowflake_queries`
execution"
)
with pytest.raises(AirflowOptionalProviderFeatureException,
match=expected_err):
diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py
b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py
index 7f70ef51c50..0dcee3adc95 100644
--- a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py
+++ b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py
@@ -35,7 +35,6 @@ from airflow.providers.snowflake.hooks.snowflake import
SnowflakeHook
from airflow.providers.snowflake.hooks.snowflake_sql_api import
SnowflakeSqlApiHook
from airflow.providers.snowflake.utils.openlineage import (
_create_snowflake_event_pair,
- _get_ol_run_id,
_get_parent_run_facet,
_get_queries_details_from_snowflake,
_process_data_from_api,
@@ -117,40 +116,9 @@ def test_fix_account_name(name, expected):
)
-def test_get_ol_run_id_ti_success():
- logical_date = timezone.datetime(2025, 1, 1)
- mock_ti = mock.MagicMock(
- dag_id="dag_id",
- task_id="task_id",
- map_index=1,
- try_number=1,
- logical_date=logical_date,
- state=TaskInstanceState.SUCCESS,
- )
- mock_ti.get_template_context.return_value = {"dag_run":
mock.MagicMock(logical_date=logical_date)}
-
- result = _get_ol_run_id(mock_ti)
- assert result == "01941f29-7c00-7087-8906-40e512c257bd"
-
-
-def test_get_ol_run_id_ti_failed():
- logical_date = timezone.datetime(2025, 1, 1)
- mock_ti = mock.MagicMock(
- dag_id="dag_id",
- task_id="task_id",
- map_index=1,
- try_number=1,
- logical_date=logical_date,
- state=TaskInstanceState.FAILED,
- )
- mock_ti.get_template_context.return_value = {"dag_run":
mock.MagicMock(logical_date=logical_date)}
-
- result = _get_ol_run_id(mock_ti)
- assert result == "01941f29-7c00-7087-8906-40e512c257bd"
-
-
def test_get_parent_run_facet():
logical_date = timezone.datetime(2025, 1, 1)
+ dr = mock.MagicMock(logical_date=logical_date, clear_number=0)
mock_ti = mock.MagicMock(
dag_id="dag_id",
task_id="task_id",
@@ -158,14 +126,18 @@ def test_get_parent_run_facet():
try_number=1,
logical_date=logical_date,
state=TaskInstanceState.SUCCESS,
+ dag_run=dr,
)
- mock_ti.get_template_context.return_value = {"dag_run":
mock.MagicMock(logical_date=logical_date)}
+ mock_ti.get_template_context.return_value = {"dag_run": dr}
result = _get_parent_run_facet(mock_ti)
assert result.run.runId == "01941f29-7c00-7087-8906-40e512c257bd"
assert result.job.namespace == namespace()
assert result.job.name == "dag_id.task_id"
+ assert result.root.run.runId == "01941f29-7c00-743e-b109-28b18d0a19c5"
+ assert result.root.job.namespace == namespace()
+ assert result.root.job.name == "dag_id"
def test_process_data_from_api():
@@ -578,7 +550,7 @@ def
test_create_snowflake_event_pair_success(mock_generate_uuid, is_successful):
assert start_event.job == end_event.job
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_snowflake_queries_with_extra_metadata(
mock_generate_uuid, mock_version, time_machine
@@ -818,7 +790,7 @@ def
test_emit_openlineage_events_for_snowflake_queries_with_extra_metadata(
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_snowflake_queries_without_extra_metadata(
mock_generate_uuid, mock_version, time_machine
@@ -936,7 +908,7 @@ def
test_emit_openlineage_events_for_snowflake_queries_without_extra_metadata(
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_snowflake_queries_without_query_ids(
mock_generate_uuid, mock_version, time_machine
@@ -1056,7 +1028,7 @@ def
test_emit_openlineage_events_for_snowflake_queries_without_query_ids(
@mock.patch("airflow.providers.openlineage.sqlparser.SQLParser.create_namespace",
return_value="snowflake_ns")
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def
test_emit_openlineage_events_for_snowflake_queries_without_query_ids_and_namespace(
mock_generate_uuid, mock_version, mock_parser, time_machine
@@ -1175,7 +1147,7 @@ def
test_emit_openlineage_events_for_snowflake_queries_without_query_ids_and_nam
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def
test_emit_openlineage_events_for_snowflake_queries_with_query_ids_and_hook_query_ids(
mock_generate_uuid, mock_version, time_machine
@@ -1294,7 +1266,7 @@ def
test_emit_openlineage_events_for_snowflake_queries_with_query_ids_and_hook_q
assert fake_adapter.emit.call_args_list == expected_calls
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
def
test_emit_openlineage_events_for_snowflake_queries_missing_query_ids_and_hook(mock_version):
fake_adapter = mock.MagicMock()
fake_adapter.emit = mock.MagicMock()
@@ -1313,7 +1285,7 @@ def
test_emit_openlineage_events_for_snowflake_queries_missing_query_ids_and_hoo
fake_adapter.emit.assert_not_called() # No events should be emitted
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
def
test_emit_openlineage_events_for_snowflake_queries_missing_query_namespace_and_hook(mock_version):
query_ids = ["1", "2"]
original_query_ids = copy.deepcopy(query_ids)
@@ -1338,7 +1310,7 @@ def
test_emit_openlineage_events_for_snowflake_queries_missing_query_namespace_a
fake_adapter.emit.assert_not_called() # No events should be emitted
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("importlib.metadata.version", return_value="3.0.0")
def
test_emit_openlineage_events_for_snowflake_queries_missing_hook_and_query_for_extra_metadata_true(
mock_version,
):
@@ -1368,7 +1340,6 @@ def
test_emit_openlineage_events_for_snowflake_queries_missing_hook_and_query_fo
fake_adapter.emit.assert_not_called() # No events should be emitted
-# emit_openlineage_events_for_snowflake_queries requires OL provider 2.3.0
@mock.patch("importlib.metadata.version", return_value="1.99.0")
def test_emit_openlineage_events_with_old_openlineage_provider(mock_version):
query_ids = ["q1", "q2"]
@@ -1384,7 +1355,7 @@ def
test_emit_openlineage_events_with_old_openlineage_provider(mock_version):
return_value=fake_listener,
):
expected_err = (
- "OpenLineage provider version `1.99.0` is lower than required
`2.3.0`, "
+ "OpenLineage provider version `1.99.0` is lower than required
`2.5.0`, "
"skipping function `emit_openlineage_events_for_snowflake_queries`
execution"
)