Repository: incubator-airflow
Updated Branches:
  refs/heads/master c3730650c -> 7a880a7e9


[AIRFLOW-2183] Refactor DruidHook to enable sql

Refactor DruidHook to be able to issue druid sql query to druid broker

Closes #3105 from feng-tao/airflow-2183


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/7a880a7e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/7a880a7e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/7a880a7e

Branch: refs/heads/master
Commit: 7a880a7e987f423d4baca67eaa9d20451fa8fa87
Parents: c373065
Author: Tao feng <tf...@lyft.com>
Authored: Wed Mar 14 09:20:16 2018 +0100
Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com>
Committed: Wed Mar 14 09:20:20 2018 +0100

----------------------------------------------------------------------
 .gitignore                     |  1 +
 airflow/hooks/__init__.py      |  5 ++-
 airflow/hooks/druid_hook.py    | 64 +++++++++++++++++++++++++++++++++++--
 airflow/utils/db.py            |  4 +++
 setup.py                       |  6 ++--
 tests/hooks/test_druid_hook.py | 52 ++++++++++++++++++++++++++++--
 6 files changed, 123 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7a880a7e/.gitignore
----------------------------------------------------------------------
diff --git a/.gitignore b/.gitignore
index a29c9ad..f5ed5ad 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,6 +19,7 @@ logs/
 __pycache__/
 *.py[cod]
 *$py.class
+.pytest_cache/
 
 # C extensions
 *.so

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7a880a7e/airflow/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py
index 6372b2f..3d75f9b 100644
--- a/airflow/hooks/__init__.py
+++ b/airflow/hooks/__init__.py
@@ -50,7 +50,10 @@ _hooks = {
     'S3_hook': ['S3Hook'],
     'zendesk_hook': ['ZendeskHook'],
     'http_hook': ['HttpHook'],
-    'druid_hook': ['DruidHook'],
+    'druid_hook': [
+        'DruidHook',
+        'DruidDbApiHook',
+    ],
     'jdbc_hook': ['JdbcHook'],
     'dbapi_hook': ['DbApiHook'],
     'mssql_hook': ['MsSqlHook'],

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7a880a7e/airflow/hooks/druid_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py
index 9ce1f9a..f08cbc2 100644
--- a/airflow/hooks/druid_hook.py
+++ b/airflow/hooks/druid_hook.py
@@ -17,17 +17,22 @@ from __future__ import print_function
 import requests
 import time
 
+from pydruid.db import connect
+
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
+from airflow.hooks.dbapi_hook import DbApiHook
 
 
 class DruidHook(BaseHook):
     """
-    Connection to Druid
+    Connection to Druid overlord for ingestion
 
-    :param druid_ingest_conn_id: The connection id to the Druid overlord 
machine which accepts index jobs
+    :param druid_ingest_conn_id: The connection id to the Druid overlord 
machine
+                                 which accepts index jobs
     :type druid_ingest_conn_id: string
-    :param timeout: The interval between polling the Druid job for the status 
of the ingestion job
+    :param timeout: The interval between polling
+                    the Druid job for the status of the ingestion job
     :type timeout: int
     :param max_ingestion_time: The maximum ingestion time before assuming the 
job failed
     :type max_ingestion_time: int
@@ -90,3 +95,56 @@ class DruidHook(BaseHook):
                 raise AirflowException('Could not get status of the job, got 
%s', status)
 
         self.log.info('Successful index')
+
+
+class DruidDbApiHook(DbApiHook):
+    """
+    Interact with Druid broker
+
+    This hook is purely for users to query druid broker.
+    For ingestion, please use druidHook.
+    """
+    conn_name_attr = 'druid_broker_conn_id'
+    default_conn_name = 'druid_broker_default'
+    supports_autocommit = False
+
+    def __init__(self, *args, **kwargs):
+        super(DruidDbApiHook, self).__init__(*args, **kwargs)
+
+    def get_conn(self):
+        """
+        Establish a connection to druid broker.
+        """
+        conn = self.get_connection(self.druid_broker_conn_id)
+        druid_broker_conn = connect(
+            host=conn.host,
+            port=conn.port,
+            path=conn.extra_dejson.get('endpoint', '/druid/v2/sql'),
+            scheme=conn.extra_dejson.get('schema', 'http')
+        )
+        self.log('Get the connection to druid broker on 
{host}'.format(host=conn.host))
+        return druid_broker_conn
+
+    def get_uri(self):
+        """
+        Get the connection uri for druid broker.
+
+        e.g: druid://localhost:8082/druid/v2/sql/
+        """
+        conn = self.get_connection(getattr(self, self.conn_name_attr))
+        host = conn.host
+        if conn.port is not None:
+            host += ':{port}'.format(port=conn.port)
+        conn_type = 'druid' if not conn.conn_type else conn.conn_type
+        endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql')
+        return '{conn_type}://{host}/{endpoint}'.format(
+            conn_type=conn_type, host=host, endpoint=endpoint)
+
+    def set_autocommit(self, conn, autocommit):
+        raise NotImplementedError()
+
+    def get_pandas_df(self, sql, parameters=None):
+        raise NotImplementedError()
+
+    def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
+        raise NotImplementedError()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7a880a7e/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 6acab4f..6c7f3c0 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -193,6 +193,10 @@ def initdb():
             host='yarn', extra='{"queue": "root.default"}'))
     merge_conn(
         models.Connection(
+            conn_id='druid_broker_default', conn_type='druid',
+            host='druid-broker', port=8082, extra='{"endpoint": 
"druid/v2/sql"}'))
+    merge_conn(
+        models.Connection(
             conn_id='druid_ingest_default', conn_type='druid',
             host='druid-overlord', port=8081, extra='{"endpoint": 
"druid/indexer/v1/task"}'))
     merge_conn(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7a880a7e/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 20742d4..83f199a 100644
--- a/setup.py
+++ b/setup.py
@@ -118,6 +118,7 @@ doc = [
     'Sphinx-PyPI-upload>=0.2.1'
 ]
 docker = ['docker-py>=1.6.0']
+druid = ['pydruid>=0.4.1']
 emr = ['boto3>=1.0.0']
 gcp_api = [
     'httplib2',
@@ -168,7 +169,7 @@ kubernetes = ['kubernetes>=3.0.0',
 
 zendesk = ['zdesk']
 
-all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant
+all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid
 devel = [
     'click',
     'freezegun',
@@ -190,7 +191,7 @@ devel_minreq = devel + kubernetes + mysql + doc + password 
+ s3 + cgroups
 devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
 devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + 
oracle +
              docker + ssh + kubernetes + celery + azure + redis + gcp_api + 
datadog +
-             zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins)
+             zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + 
druid)
 
 # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
 if PY3:
@@ -269,6 +270,7 @@ def do_setup():
             'devel_hadoop': devel_hadoop,
             'doc': doc,
             'docker': docker,
+            'druid': druid,
             'emr': emr,
             'gcp_api': gcp_api,
             'github_enterprise': github_enterprise,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7a880a7e/tests/hooks/test_druid_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_druid_hook.py b/tests/hooks/test_druid_hook.py
index ddab369..606cbb7 100644
--- a/tests/hooks/test_druid_hook.py
+++ b/tests/hooks/test_druid_hook.py
@@ -13,12 +13,13 @@
 # limitations under the License.
 #
 
+import mock
 import requests
 import requests_mock
 import unittest
 
 from airflow.exceptions import AirflowException
-from airflow.hooks.druid_hook import DruidHook
+from airflow.hooks.druid_hook import DruidDbApiHook, DruidHook
 
 
 class TestDruidHook(unittest.TestCase):
@@ -111,6 +112,51 @@ class TestDruidHook(unittest.TestCase):
         self.assertTrue(shutdown_post.called_once)
 
 
+class TestDruidDbApiHook(unittest.TestCase):
 
-
-
+    def setUp(self):
+        super(TestDruidDbApiHook, self).setUp()
+        self.cur = mock.MagicMock()
+        self.conn = conn = mock.MagicMock()
+        self.conn.host = 'host'
+        self.conn.port = '1000'
+        self.conn.conn_type = 'druid'
+        self.conn.extra_dejson = {'endpoint': 'druid/v2/sql'}
+        self.conn.cursor.return_value = self.cur
+
+        class TestDruidDBApiHook(DruidDbApiHook):
+            def get_conn(self):
+                return conn
+
+            def get_connection(self, conn_id):
+                return conn
+
+        self.db_hook = TestDruidDBApiHook
+
+    def test_get_uri(self):
+        db_hook = self.db_hook()
+        self.assertEquals('druid://host:1000/druid/v2/sql', db_hook.get_uri())
+
+    def test_get_first_record(self):
+        statement = 'SQL'
+        result_sets = [('row1',), ('row2',)]
+        self.cur.fetchone.return_value = result_sets[0]
+
+        self.assertEqual(result_sets[0], self.db_hook().get_first(statement))
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+        self.cur.execute.assert_called_once_with(statement)
+
+    def test_get_records(self):
+        statement = 'SQL'
+        result_sets = [('row1',), ('row2',)]
+        self.cur.fetchall.return_value = result_sets
+
+        self.assertEqual(result_sets, self.db_hook().get_records(statement))
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+        self.cur.execute.assert_called_once_with(statement)
+
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to