This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6594a2f0f45 Add OpenLineage support to AthenaSQLHook (#66844)
6594a2f0f45 is described below

commit 6594a2f0f45aea1bb233c3810a7ebb9861057553
Author: Rahul Madan <[email protected]>
AuthorDate: Mon May 18 14:12:10 2026 +0530

    Add OpenLineage support to AthenaSQLHook (#66844)
    
    * Add OpenLineage support to AthenaSQLHook
    
    Signed-off-by: Rahul Madan <[email protected]>
    
    * Added tests for athena sql hook
    
    Signed-off-by: Rahul Madan <[email protected]>
    
    * Address review: hook-constructor region wins + support aws_domain extra
    
    Signed-off-by: Rahul Madan <[email protected]>
    
    ---------
    
    Signed-off-by: Rahul Madan <[email protected]>
---
 .../providers/amazon/aws/hooks/athena_sql.py       | 31 ++++++++
 .../tests/unit/amazon/aws/hooks/test_athena_sql.py | 93 ++++++++++++++++++++++
 2 files changed, 124 insertions(+)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
index 94348612700..a9791aec246 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
@@ -177,6 +177,37 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
             aws_domain=self.conn.extra_dejson.get("aws_domain", 
"amazonaws.com"),
         )
 
+    def get_openlineage_database_info(self, connection):
+        """Return Amazon Athena specific information for OpenLineage."""
+        from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+        region_name = self.region_name or 
connection.extra_dejson.get("region_name")
+        aws_domain = connection.extra_dejson.get("aws_domain", "amazonaws.com")
+        authority = f"athena.{region_name}.{aws_domain}" if region_name else 
f"athena.{aws_domain}"
+
+        return DatabaseInfo(
+            scheme="awsathena",
+            authority=authority,
+            information_schema_columns=[
+                "table_schema",
+                "table_name",
+                "column_name",
+                "ordinal_position",
+                "data_type",
+                "table_catalog",
+            ],
+            database=connection.extra_dejson.get("catalog", "AwsDataCatalog"),
+            is_information_schema_cross_db=True,
+        )
+
+    def get_openlineage_database_dialect(self, _) -> str:
+        """Return Athena dialect. Athena uses Trino SQL engine."""
+        return "trino"
+
+    def get_openlineage_default_schema(self) -> str | None:
+        """Return Athena default schema."""
+        return self.conn.schema or "default"
+
     def get_uri(self) -> str:
         """Overridden to use the Athena dialect as driver name."""
         from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
index fc6fe82737c..7b5cec5f7bc 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
@@ -181,3 +181,96 @@ class TestAthenaSQLHookConn:
         assert hook._verify is False
         assert hook._region_name == "us-west-2"
         assert hook._config is not None
+
+
+class TestAthenaSQLHookOpenLineage:
+    """Static tests for the OpenLineage methods on AthenaSQLHook."""
+
+    EXPECTED_INFORMATION_SCHEMA_COLUMNS = [
+        "table_schema",
+        "table_name",
+        "column_name",
+        "ordinal_position",
+        "data_type",
+        "table_catalog",
+    ]
+
+    @staticmethod
+    def _make_hook(connection: Connection, hook_region: str | None = None) -> 
AthenaSQLHook:
+        hook = AthenaSQLHook(region_name=hook_region) if hook_region else 
AthenaSQLHook()
+        hook.get_connection = mock.Mock(return_value=connection)  # type: 
ignore[method-assign]
+        return hook
+
+    @pytest.mark.parametrize(
+        ("extras", "hook_region", "expected_authority"),
+        [
+            # region from connection extras when hook-constructor region not 
set
+            ({"region_name": "us-east-1"}, None, 
"athena.us-east-1.amazonaws.com"),
+            # hook-constructor region (explicit user override) wins over 
extras region
+            ({"region_name": "eu-west-1"}, "us-east-2", 
"athena.us-east-2.amazonaws.com"),
+            # hook-constructor region used when extras have none
+            ({}, "ap-south-1", "athena.ap-south-1.amazonaws.com"),
+            # graceful fallback when neither is set
+            ({}, None, "athena.amazonaws.com"),
+            # aws_domain extra changes the domain (AWS GovCloud / China / ISO 
partitions)
+            (
+                {"region_name": "cn-north-1", "aws_domain": 
"amazonaws.com.cn"},
+                None,
+                "athena.cn-north-1.amazonaws.com.cn",
+            ),
+            # aws_domain still applied when region falls back
+            ({"aws_domain": "amazonaws.com.cn"}, None, 
"athena.amazonaws.com.cn"),
+        ],
+    )
+    def test_get_openlineage_database_info_region_extraction(self, extras, 
hook_region, expected_authority):
+        conn = Connection(conn_type="athena", schema="default", extra=extras)
+        hook = self._make_hook(conn, hook_region)
+        info = hook.get_openlineage_database_info(conn)
+        assert info.authority == expected_authority
+
+    def test_get_openlineage_database_info_returns_expected_fields(self):
+        """Snapshot of the DatabaseInfo shape so accidental changes are 
caught."""
+        conn = Connection(
+            conn_type="athena",
+            schema="default",
+            extra={"region_name": "us-east-1"},
+        )
+        hook = self._make_hook(conn)
+        info = hook.get_openlineage_database_info(conn)
+        assert info.scheme == "awsathena"
+        assert info.authority == "athena.us-east-1.amazonaws.com"
+        assert info.database == "AwsDataCatalog"
+        assert info.is_information_schema_cross_db is True
+        assert info.information_schema_columns == 
self.EXPECTED_INFORMATION_SCHEMA_COLUMNS
+
+    def test_get_openlineage_database_info_custom_catalog(self):
+        conn = Connection(
+            conn_type="athena",
+            schema="default",
+            extra={"region_name": "us-east-1", "catalog": "MyCatalog"},
+        )
+        hook = self._make_hook(conn)
+        info = hook.get_openlineage_database_info(conn)
+        assert info.database == "MyCatalog"
+
+    def test_get_openlineage_database_dialect_returns_trino(self):
+        conn = Connection(conn_type="athena", extra={"region_name": 
"us-east-1"})
+        hook = self._make_hook(conn)
+        assert hook.get_openlineage_database_dialect(conn) == "trino"
+
+    @pytest.mark.parametrize(
+        ("connection_schema", "expected_schema"),
+        [
+            ("mydb", "mydb"),
+            (None, "default"),
+            ("", "default"),
+        ],
+    )
+    def test_get_openlineage_default_schema(self, connection_schema, 
expected_schema):
+        conn = Connection(
+            conn_type="athena",
+            schema=connection_schema,
+            extra={"region_name": "us-east-1"},
+        )
+        hook = self._make_hook(conn)
+        assert hook.get_openlineage_default_schema() == expected_schema

Reply via email to