This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b7d0bf9800 fix OpenLineage extraction for AthenaOperator (#40545)
b7d0bf9800 is described below

commit b7d0bf9800974e2029a777e20417e3498e665503
Author: Kacper Muda <mudakac...@gmail.com>
AuthorDate: Thu Jul 4 11:15:26 2024 +0200

    fix OpenLineage extraction for AthenaOperator (#40545)
    
    Signed-off-by: Kacper Muda <mudakac...@gmail.com>
---
 airflow/providers/amazon/aws/operators/athena.py   | 26 ++++++++++++++++------
 .../providers/amazon/aws/operators/test_athena.py  | 23 ++++++++++++++++++-
 2 files changed, 41 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/athena.py 
b/airflow/providers/amazon/aws/operators/athena.py
index 5d30b93143..0178d60a12 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -175,9 +175,6 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
                 f"query_execution_id is {self.query_execution_id}."
             )
 
-        # Save output location from API response for later use in OpenLineage.
-        self.output_location = 
self.hook.get_output_location(self.query_execution_id)
-
         return self.query_execution_id
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
@@ -185,6 +182,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
 
         if event["status"] != "success":
             raise AirflowException(f"Error while waiting for operation on 
cluster to complete: {event}")
+
+        # Save query_execution_id to be later used by listeners
+        self.query_execution_id = event["value"]
         return event["value"]
 
     def on_kill(self) -> None:
@@ -208,14 +208,21 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
                     )
                     self.hook.poll_query_status(self.query_execution_id, 
sleep_time=self.sleep_time)
 
-    def get_openlineage_facets_on_start(self) -> OperatorLineage:
+    def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
         """
         Retrieve OpenLineage data by parsing SQL queries and enriching them 
with Athena API.
 
         In addition to CTAS query, query and calculation results are stored in 
S3 location.
-        For that reason additional output is attached with this location.
+        For that reason additional output is attached with this location. 
Instead of using the complete
+        path where the results are saved (user's prefix + some UUID), we are 
creating a dataset with the
+        user-provided path only. This should make it easier to match this 
dataset across different processes.
         """
-        from openlineage.client.facet import ExtractionError, 
ExtractionErrorRunFacet, SqlJobFacet
+        from openlineage.client.facet import (
+            ExternalQueryRunFacet,
+            ExtractionError,
+            ExtractionErrorRunFacet,
+            SqlJobFacet,
+        )
         from openlineage.client.run import Dataset
 
         from airflow.providers.openlineage.extractors.base import 
OperatorLineage
@@ -265,6 +272,11 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
             )
         )
 
+        if self.query_execution_id:
+            run_facets["externalQuery"] = ExternalQueryRunFacet(
+                externalQueryId=self.query_execution_id, source="awsathena"
+            )
+
         if self.output_location:
             parsed = urlparse(self.output_location)
             
outputs.append(Dataset(namespace=f"{parsed.scheme}://{parsed.netloc}", 
name=parsed.path or "/"))
@@ -301,7 +313,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
                 )
             }
             fields = [
-                SchemaField(name=column["Name"], type=column["Type"], 
description=column["Comment"])
+                SchemaField(name=column["Name"], type=column["Type"], 
description=column.get("Comment"))
                 for column in table_metadata["TableMetadata"]["Columns"]
             ]
             if fields:
diff --git a/tests/providers/amazon/aws/operators/test_athena.py 
b/tests/providers/amazon/aws/operators/test_athena.py
index 66fb6b297f..5d5a6b88c3 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -21,6 +21,7 @@ from unittest import mock
 
 import pytest
 from openlineage.client.facet import (
+    ExternalQueryRunFacet,
     SchemaDatasetFacet,
     SchemaField,
     SqlJobFacet,
@@ -264,6 +265,24 @@ class TestAthenaOperator:
             query_execution_id=ATHENA_QUERY_ID,
         )
 
+    def 
test_execute_complete_reassigns_query_execution_id_after_deferring(self):
+        """Assert that we use query_execution_id from event after deferral."""
+
+        operator = AthenaOperator(
+            task_id="test_athena_operator",
+            query="SELECT * FROM TEST_TABLE",
+            database="TEST_DATABASE",
+            deferrable=True,
+        )
+        assert operator.query_execution_id is None
+
+        query_execution_id = "123456"
+        operator.execute_complete(
+            context=None,
+            event={"status": "success", "value": query_execution_id},
+        )
+        assert operator.query_execution_id == query_execution_id
+
     @mock.patch.object(AthenaHook, "region_name", 
new_callable=mock.PropertyMock)
     @mock.patch.object(AthenaHook, "get_conn")
     def test_operator_openlineage_data(self, mock_conn, mock_region_name):
@@ -285,6 +304,7 @@ class TestAthenaOperator:
             max_polling_attempts=3,
             dag=self.dag,
         )
+        op.query_execution_id = "12345"  # Mocking what will be available 
after execution
 
         expected_lineage = OperatorLineage(
             inputs=[
@@ -365,5 +385,6 @@ class TestAthenaOperator:
                     query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM 
DISCOUNTS",
                 )
             },
+            run_facets={"externalQuery": 
ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")},
         )
-        assert op.get_openlineage_facets_on_start() == expected_lineage
+        assert op.get_openlineage_facets_on_complete(None) == expected_lineage

Reply via email to