This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7202ee84b3 Add possibility to override the conn type for Druid (#42793)
7202ee84b3 is described below

commit 7202ee84b3204bfcada7effe34912cfd31712e9f
Author: Rasnar <[email protected]>
AuthorDate: Fri Oct 11 08:37:14 2024 +0200

    Add possibility to override the conn type for Druid (#42793)
    
    * Add possibility to override the conn type for Druid
    
    Minor fix, which allows to use the schema which are specified in
    theschema rather than `http` as default. In the same time it doesn't
    changethe logic as any conn_type can be selected. Intuitevely it's
    expectedthat anything specified in `schema` field will actually take
    precedencein the building the desired url.
    
    * Add druid endpoint connection from another PR
    
    * Fix missing scheme in test
    
    * Set schema to None where it's unused
    
    Even though we don't need it directly set, by default the mock will set
    it to an internal object, thus we need to override it to None.
    
    ---------
    
    Co-authored-by: Oleg Auckenthaler <[email protected]>
---
 .../airflow/providers/apache/druid/hooks/druid.py  |  5 ++-
 providers/tests/apache/druid/hooks/test_druid.py   | 44 +++++++++++++++++++---
 2 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/providers/src/airflow/providers/apache/druid/hooks/druid.py 
b/providers/src/airflow/providers/apache/druid/hooks/druid.py
index ca315b3a2c..c865adef41 100644
--- a/providers/src/airflow/providers/apache/druid/hooks/druid.py
+++ b/providers/src/airflow/providers/apache/druid/hooks/druid.py
@@ -86,7 +86,10 @@ class DruidHook(BaseHook):
         """Get Druid connection url."""
         host = self.conn.host
         port = self.conn.port
-        conn_type = self.conn.conn_type or "http"
+        if self.conn.schema:
+            conn_type = self.conn.schema
+        else:
+            conn_type = self.conn.conn_type or "http"
         if ingestion_type == IngestionType.BATCH:
             endpoint = self.conn.extra_dejson.get("endpoint", "")
         else:
diff --git a/providers/tests/apache/druid/hooks/test_druid.py 
b/providers/tests/apache/druid/hooks/test_druid.py
index 9befbf37f0..f01175942b 100644
--- a/providers/tests/apache/druid/hooks/test_druid.py
+++ b/providers/tests/apache/druid/hooks/test_druid.py
@@ -42,9 +42,14 @@ class TestDruidSubmitHook:
             self.is_sql_based_ingestion = False
 
             def get_conn_url(self, ingestion_type: IngestionType = 
IngestionType.BATCH):
+                if self.conn.schema:
+                    conn_type = self.conn.schema
+                else:
+                    conn_type = "http"
+
                 if ingestion_type == IngestionType.MSQ:
-                    return "http://druid-overlord:8081/druid/v2/sql/task";
-                return "http://druid-overlord:8081/druid/indexer/v1/task";
+                    return 
f"{conn_type}://druid-overlord:8081/druid/v2/sql/task"
+                return 
f"{conn_type}://druid-overlord:8081/druid/indexer/v1/task"
 
         self.db_hook = TestDRuidhook()
 
@@ -257,7 +262,8 @@ class TestDruidHook:
     def test_conn_property(self, mock_get_connection):
         get_conn_value = MagicMock()
         get_conn_value.host = "test_host"
-        get_conn_value.conn_type = "https"
+        get_conn_value.conn_type = "http"
+        get_conn_value.schema = None
         get_conn_value.port = "1"
         get_conn_value.extra_dejson = {"endpoint": "ingest"}
         mock_get_connection.return_value = get_conn_value
@@ -268,8 +274,22 @@ class TestDruidHook:
     def test_get_conn_url(self, mock_get_connection):
         get_conn_value = MagicMock()
         get_conn_value.host = "test_host"
-        get_conn_value.conn_type = "https"
+        get_conn_value.conn_type = "http"
+        get_conn_value.schema = None
+        get_conn_value.port = "1"
+        get_conn_value.extra_dejson = {"endpoint": "ingest"}
+        mock_get_connection.return_value = get_conn_value
+        hook = DruidHook(timeout=1, max_ingestion_time=5)
+        assert hook.get_conn_url() == "http://test_host:1/ingest";
+
+    
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+    def test_get_conn_url_with_schema(self, mock_get_connection):
+        get_conn_value = MagicMock()
+        get_conn_value.host = "test_host"
+        get_conn_value.conn_type = "http"
+        get_conn_value.schema = None
         get_conn_value.port = "1"
+        get_conn_value.schema = "https"
         get_conn_value.extra_dejson = {"endpoint": "ingest"}
         mock_get_connection.return_value = get_conn_value
         hook = DruidHook(timeout=1, max_ingestion_time=5)
@@ -279,8 +299,21 @@ class TestDruidHook:
     def test_get_conn_url_with_ingestion_type(self, mock_get_connection):
         get_conn_value = MagicMock()
         get_conn_value.host = "test_host"
-        get_conn_value.conn_type = "https"
+        get_conn_value.conn_type = "http"
+        get_conn_value.schema = None
+        get_conn_value.port = "1"
+        get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": 
"sql_ingest"}
+        mock_get_connection.return_value = get_conn_value
+        hook = DruidHook(timeout=1, max_ingestion_time=5)
+        assert hook.get_conn_url(IngestionType.MSQ) == 
"http://test_host:1/sql_ingest";
+
+    
@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
+    def test_get_conn_url_with_ingestion_type_and_schema(self, 
mock_get_connection):
+        get_conn_value = MagicMock()
+        get_conn_value.host = "test_host"
+        get_conn_value.conn_type = "http"
         get_conn_value.port = "1"
+        get_conn_value.schema = "https"
         get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": 
"sql_ingest"}
         mock_get_connection.return_value = get_conn_value
         hook = DruidHook(timeout=1, max_ingestion_time=5)
@@ -343,6 +376,7 @@ class TestDruidDbApiHook:
         self.conn = conn = MagicMock()
         self.conn.host = "host"
         self.conn.port = "1000"
+        self.conn.schema = None
         self.conn.conn_type = "druid"
         self.conn.extra_dejson = {"endpoint": "druid/v2/sql"}
         self.conn.cursor.return_value = self.cur

Reply via email to