This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 65fad4affc Change default schema behaviour in SQLParser. (#32347)
65fad4affc is described below

commit 65fad4affc24b33c4499ad0fbcdfff535fbae3bf
Author: JDarDagran <kuba0...@gmail.com>
AuthorDate: Tue Jul 4 17:48:41 2023 +0200

    Change default schema behaviour in SQLParser. (#32347)
    
    Signed-off-by: Jakub Dardzinski <kuba0...@gmail.com>
---
 airflow/providers/openlineage/sqlparser.py         |  3 +-
 airflow/providers/openlineage/utils/sql.py         |  9 ++---
 tests/providers/openlineage/utils/test_sql.py      |  4 +++
 .../providers/openlineage/utils/test_sqlparser.py  | 39 ++++++++++++++--------
 4 files changed, 36 insertions(+), 19 deletions(-)

diff --git a/airflow/providers/openlineage/sqlparser.py 
b/airflow/providers/openlineage/sqlparser.py
index 657428549e..ed3e92e58b 100644
--- a/airflow/providers/openlineage/sqlparser.py
+++ b/airflow/providers/openlineage/sqlparser.py
@@ -103,7 +103,7 @@ class SQLParser:
 
     def parse(self, sql: list[str] | str) -> SqlMeta | None:
         """Parse a single or a list of SQL statements."""
-        return parse(sql=sql, dialect=self.dialect, 
default_schema=self.default_schema)
+        return parse(sql=sql, dialect=self.dialect)
 
     def parse_table_schemas(
         self,
@@ -126,6 +126,7 @@ class SQLParser:
         return get_table_schemas(
             hook,
             namespace,
+            self.default_schema,
             database or database_info.database,
             self.create_information_schema_query(tables=inputs, 
**database_kwargs) if inputs else None,
             self.create_information_schema_query(tables=outputs, 
**database_kwargs) if outputs else None,
diff --git a/airflow/providers/openlineage/utils/sql.py 
b/airflow/providers/openlineage/utils/sql.py
index 317e46e442..fe43a25bae 100644
--- a/airflow/providers/openlineage/utils/sql.py
+++ b/airflow/providers/openlineage/utils/sql.py
@@ -63,12 +63,12 @@ class TableSchema:
     database: str | None
     fields: list[SchemaField]
 
-    def to_dataset(self, namespace: str, database: str | None = None) -> 
Dataset:
+    def to_dataset(self, namespace: str, database: str | None = None, schema: 
str | None = None) -> Dataset:
         # Prefix the table name with database and schema name using
         # the format: {database_name}.{table_schema}.{table_name}.
         name = ".".join(
             part
-            for part in [self.database if self.database else database, 
self.schema, self.table]
+            for part in [self.database or database, self.schema or schema, 
self.table]
             if part is not None
         )
         return Dataset(
@@ -81,6 +81,7 @@ class TableSchema:
 def get_table_schemas(
     hook: BaseHook,
     namespace: str,
+    schema: str | None,
     database: str | None,
     in_query: str | None,
     out_query: str | None,
@@ -97,12 +98,12 @@ def get_table_schemas(
     with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor:
         if in_query:
             cursor.execute(in_query)
-            in_datasets = [x.to_dataset(namespace, database) for x in 
parse_query_result(cursor)]
+            in_datasets = [x.to_dataset(namespace, database, schema) for x in 
parse_query_result(cursor)]
         else:
             in_datasets = []
         if out_query:
             cursor.execute(out_query)
-            out_datasets = [x.to_dataset(namespace, database) for x in 
parse_query_result(cursor)]
+            out_datasets = [x.to_dataset(namespace, database, schema) for x in 
parse_query_result(cursor)]
         else:
             out_datasets = []
     return in_datasets, out_datasets
diff --git a/tests/providers/openlineage/utils/test_sql.py 
b/tests/providers/openlineage/utils/test_sql.py
index be929a1ad6..a82ab36bda 100644
--- a/tests/providers/openlineage/utils/test_sql.py
+++ b/tests/providers/openlineage/utils/test_sql.py
@@ -77,6 +77,7 @@ def test_get_table_schemas():
         hook=hook,
         namespace="bigquery",
         database=DB_NAME,
+        schema=DB_SCHEMA_NAME,
         in_query="fake_sql",
         out_query="another_fake_sql",
     )
@@ -139,6 +140,7 @@ def test_get_table_schemas_with_mixed_databases():
         hook=hook,
         namespace="bigquery",
         database=DB_NAME,
+        schema=DB_SCHEMA_NAME,
         in_query="fake_sql",
         out_query="another_fake_sql",
     )
@@ -179,6 +181,7 @@ def test_get_table_schemas_with_mixed_schemas():
         hook=hook,
         namespace="bigquery",
         database=DB_NAME,
+        schema=DB_SCHEMA_NAME,
         in_query="fake_sql",
         out_query="another_fake_sql",
     )
@@ -237,6 +240,7 @@ def test_get_table_schemas_with_other_database():
         hook=hook,
         namespace="bigquery",
         database=DB_NAME,
+        schema=DB_SCHEMA_NAME,
         in_query="fake_sql",
         out_query="another_fake_sql",
     )
diff --git a/tests/providers/openlineage/utils/test_sqlparser.py 
b/tests/providers/openlineage/utils/test_sqlparser.py
index 6f11a7ad94..31611c7c11 100644
--- a/tests/providers/openlineage/utils/test_sqlparser.py
+++ b/tests/providers/openlineage/utils/test_sqlparser.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 from unittest import mock
 from unittest.mock import MagicMock
 
+import pytest
 from openlineage.client.facet import SchemaDatasetFacet, SchemaField, 
SqlJobFacet
 from openlineage.client.run import Dataset
 from openlineage.common.sql import DbTableMeta
@@ -155,19 +156,20 @@ class TestSQLParser:
             database_info=db_info,
         )
 
+    @pytest.mark.parametrize("parser_returns_schema", [True, False])
     @mock.patch("airflow.providers.openlineage.sqlparser.SQLParser.parse")
-    def test_generate_openlineage_metadata_from_sql(self, mock_parse):
-        parser = SQLParser()
+    def test_generate_openlineage_metadata_from_sql(self, mock_parse, 
parser_returns_schema):
+        parser = SQLParser(default_schema="ANOTHER_SCHEMA")
         db_info = DatabaseInfo(scheme="myscheme", authority="host:port")
 
         hook = MagicMock()
 
-        rows = lambda name: [
-            (DB_SCHEMA_NAME, name, "ID", 1, "int4"),
-            (DB_SCHEMA_NAME, name, "AMOUNT_OFF", 2, "int4"),
-            (DB_SCHEMA_NAME, name, "CUSTOMER_EMAIL", 3, "varchar"),
-            (DB_SCHEMA_NAME, name, "STARTS_ON", 4, "timestamp"),
-            (DB_SCHEMA_NAME, name, "ENDS_ON", 5, "timestamp"),
+        rows = lambda schema, table: [
+            (schema, table, "ID", 1, "int4"),
+            (schema, table, "AMOUNT_OFF", 2, "int4"),
+            (schema, table, "CUSTOMER_EMAIL", 3, "varchar"),
+            (schema, table, "STARTS_ON", 4, "timestamp"),
+            (schema, table, "ENDS_ON", 5, "timestamp"),
         ]
 
         sql = """CREATE TABLE table_out (
@@ -182,13 +184,17 @@ class TestSQLParser:
         """
 
         hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [
-            rows("TABLE_IN"),
-            rows("TABLE_OUT"),
+            rows(DB_SCHEMA_NAME if parser_returns_schema else None, 
"TABLE_IN"),
+            rows(DB_SCHEMA_NAME if parser_returns_schema else None, 
"TABLE_OUT"),
         ]
 
         mock_sql_meta = MagicMock()
-        mock_sql_meta.in_tables = [DbTableMeta("PUBLIC.TABLE_IN")]
-        mock_sql_meta.out_tables = [DbTableMeta("PUBLIC.TABLE_OUT")]
+        if parser_returns_schema:
+            mock_sql_meta.in_tables = [DbTableMeta("PUBLIC.TABLE_IN")]
+            mock_sql_meta.out_tables = [DbTableMeta("PUBLIC.TABLE_OUT")]
+        else:
+            mock_sql_meta.in_tables = [DbTableMeta("TABLE_IN")]
+            mock_sql_meta.out_tables = [DbTableMeta("TABLE_OUT")]
         mock_sql_meta.errors = []
 
         mock_parse.return_value = mock_sql_meta
@@ -201,15 +207,20 @@ class TestSQLParser:
             ENDS_ON timestamp
 
 )"""
+        expected_schema = "PUBLIC" if parser_returns_schema else 
"ANOTHER_SCHEMA"
         expected = OperatorLineage(
             inputs=[
                 Dataset(
-                    namespace="myscheme://host:port", name="PUBLIC.TABLE_IN", 
facets={"schema": SCHEMA_FACET}
+                    namespace="myscheme://host:port",
+                    name=f"{expected_schema}.TABLE_IN",
+                    facets={"schema": SCHEMA_FACET},
                 )
             ],
             outputs=[
                 Dataset(
-                    namespace="myscheme://host:port", name="PUBLIC.TABLE_OUT", 
facets={"schema": SCHEMA_FACET}
+                    namespace="myscheme://host:port",
+                    name=f"{expected_schema}.TABLE_OUT",
+                    facets={"schema": SCHEMA_FACET},
                 )
             ],
             job_facets={"sql": SqlJobFacet(query=formatted_sql)},

Reply via email to