[AIRFLOW-858] Configurable database name for DB operators

Closes #2063 from s7anley/configurable-schema

(cherry picked from commit 94dc7fb0a6bb3c563d9df6566cd52a59bd0c4629)


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/5eb33358
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/5eb33358
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/5eb33358

Branch: refs/heads/v1-8-stable
Commit: 5eb33358f62a13192e537296becc315476112afb
Parents: eb12f01
Author: Ján Koščo <3k.stan...@gmail.com>
Authored: Sun Feb 12 15:43:41 2017 -0500
Committer: Chris Riccomini <criccom...@apache.org>
Committed: Wed Mar 29 14:19:19 2017 -0700

----------------------------------------------------------------------
 airflow/hooks/mssql_hook.py            | 10 +++++--
 airflow/hooks/mysql_hook.py            | 15 ++++++----
 airflow/hooks/postgres_hook.py         |  4 +--
 airflow/operators/mssql_operator.py    | 11 ++++++--
 airflow/operators/mysql_operator.py    |  8 ++++--
 airflow/operators/postgres_operator.py |  7 ++++-
 tests/operators/operators.py           | 43 +++++++++++++++++++++++++++++
 7 files changed, 81 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/hooks/mssql_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/mssql_hook.py b/airflow/hooks/mssql_hook.py
index 1450967..99a4c82 100644
--- a/airflow/hooks/mssql_hook.py
+++ b/airflow/hooks/mssql_hook.py
@@ -18,14 +18,18 @@ from airflow.hooks.dbapi_hook import DbApiHook
 
 
 class MsSqlHook(DbApiHook):
-    '''
+    """
     Interact with Microsoft SQL Server.
-    '''
+    """
 
     conn_name_attr = 'mssql_conn_id'
     default_conn_name = 'mssql_default'
     supports_autocommit = True
 
+    def __init__(self, *args, **kwargs):
+        super(MsSqlHook, self).__init__(*args, **kwargs)
+        self.schema = kwargs.pop("schema", None)
+
     def get_conn(self):
         """
         Returns a mssql connection object
@@ -35,7 +39,7 @@ class MsSqlHook(DbApiHook):
             server=conn.host,
             user=conn.login,
             password=conn.password,
-            database=conn.schema,
+            database=self.schema or conn.schema,
             port=conn.port)
         return conn
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/hooks/mysql_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py
index e4f9533..bf1a721 100644
--- a/airflow/hooks/mysql_hook.py
+++ b/airflow/hooks/mysql_hook.py
@@ -19,18 +19,22 @@ from airflow.hooks.dbapi_hook import DbApiHook
 
 
 class MySqlHook(DbApiHook):
-    '''
+    """
     Interact with MySQL.
 
     You can specify charset in the extra field of your connection
     as ``{"charset": "utf8"}``. Also you can choose cursor as
     ``{"cursor": "SSCursor"}``. Refer to the MySQLdb.cursors for more details.
-    '''
+    """
 
     conn_name_attr = 'mysql_conn_id'
     default_conn_name = 'mysql_default'
     supports_autocommit = True
 
+    def __init__(self, *args, **kwargs):
+        super(MySqlHook, self).__init__(*args, **kwargs)
+        self.schema = kwargs.pop("schema", None)
+
     def get_conn(self):
         """
         Returns a mysql connection object
@@ -38,17 +42,16 @@ class MySqlHook(DbApiHook):
         conn = self.get_connection(self.mysql_conn_id)
         conn_config = {
             "user": conn.login,
-            "passwd": conn.password or ''
+            "passwd": conn.password or '',
+            "host": conn.host or 'localhost',
+            "db": self.schema or conn.schema or ''
         }
 
-        conn_config["host"] = conn.host or 'localhost'
         if not conn.port:
             conn_config["port"] = 3306
         else:
             conn_config["port"] = int(conn.port)
 
-        conn_config["db"] = conn.schema or ''
-
         if conn.extra_dejson.get('charset', False):
             conn_config["charset"] = conn.extra_dejson["charset"]
             if (conn_config["charset"]).lower() == 'utf8' or\

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/hooks/postgres_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py
index 584930d..4b460c1 100644
--- a/airflow/hooks/postgres_hook.py
+++ b/airflow/hooks/postgres_hook.py
@@ -19,11 +19,11 @@ from airflow.hooks.dbapi_hook import DbApiHook
 
 
 class PostgresHook(DbApiHook):
-    '''
+    """
     Interact with Postgres.
     You can specify ssl parameters in the extra field of your connection
     as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``.
-    '''
+    """
     conn_name_attr = 'postgres_conn_id'
     default_conn_name = 'postgres_default'
     supports_autocommit = True

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/operators/mssql_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/mssql_operator.py 
b/airflow/operators/mssql_operator.py
index 0590454..0f0cd63 100644
--- a/airflow/operators/mssql_operator.py
+++ b/airflow/operators/mssql_operator.py
@@ -27,6 +27,8 @@ class MsSqlOperator(BaseOperator):
     :param sql: the sql code to be executed
     :type sql: string or string pointing to a template file.
     File must have a '.sql' extensions.
+    :param database: name of database which overwrite defined one in connection
+    :type database: string
     """
 
     template_fields = ('sql',)
@@ -36,14 +38,17 @@ class MsSqlOperator(BaseOperator):
     @apply_defaults
     def __init__(
             self, sql, mssql_conn_id='mssql_default', parameters=None,
-            autocommit=False, *args, **kwargs):
+            autocommit=False, database=None, *args, **kwargs):
         super(MsSqlOperator, self).__init__(*args, **kwargs)
         self.mssql_conn_id = mssql_conn_id
         self.sql = sql
         self.parameters = parameters
         self.autocommit = autocommit
+        self.database = database
 
     def execute(self, context):
         logging.info('Executing: ' + str(self.sql))
-        hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
-        hook.run(self.sql, autocommit=self.autocommit, 
parameters=self.parameters)
+        hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id,
+                         schema=self.database)
+        hook.run(self.sql, autocommit=self.autocommit,
+                 parameters=self.parameters)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/operators/mysql_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/mysql_operator.py 
b/airflow/operators/mysql_operator.py
index b3a3c73..156ada8 100644
--- a/airflow/operators/mysql_operator.py
+++ b/airflow/operators/mysql_operator.py
@@ -29,6 +29,8 @@ class MySqlOperator(BaseOperator):
     :type sql: Can receive a str representing a sql statement,
         a list of str (sql statements), or reference to a template file.
         Template reference are recognized by str ending in '.sql'
+    :param database: name of database which overwrite defined one in connection
+    :type database: string
     """
 
     template_fields = ('sql',)
@@ -38,16 +40,18 @@ class MySqlOperator(BaseOperator):
     @apply_defaults
     def __init__(
             self, sql, mysql_conn_id='mysql_default', parameters=None,
-            autocommit=False, *args, **kwargs):
+            autocommit=False, database=None, *args, **kwargs):
         super(MySqlOperator, self).__init__(*args, **kwargs)
         self.mysql_conn_id = mysql_conn_id
         self.sql = sql
         self.autocommit = autocommit
         self.parameters = parameters
+        self.database = database
 
     def execute(self, context):
         logging.info('Executing: ' + str(self.sql))
-        hook = MySqlHook(mysql_conn_id=self.mysql_conn_id)
+        hook = MySqlHook(mysql_conn_id=self.mysql_conn_id,
+                         schema=self.database)
         hook.run(
             self.sql,
             autocommit=self.autocommit,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/operators/postgres_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/postgres_operator.py 
b/airflow/operators/postgres_operator.py
index c4f56a4..0de5aa5 100644
--- a/airflow/operators/postgres_operator.py
+++ b/airflow/operators/postgres_operator.py
@@ -29,6 +29,8 @@ class PostgresOperator(BaseOperator):
     :type sql: Can receive a str representing a sql statement,
         a list of str (sql statements), or reference to a template file.
         Template reference are recognized by str ending in '.sql'
+    :param database: name of database which overwrite defined one in connection
+    :type database: string
     """
 
     template_fields = ('sql',)
@@ -40,14 +42,17 @@ class PostgresOperator(BaseOperator):
             self, sql,
             postgres_conn_id='postgres_default', autocommit=False,
             parameters=None,
+            database=None,
             *args, **kwargs):
         super(PostgresOperator, self).__init__(*args, **kwargs)
         self.sql = sql
         self.postgres_conn_id = postgres_conn_id
         self.autocommit = autocommit
         self.parameters = parameters
+        self.database = database
 
     def execute(self, context):
         logging.info('Executing: ' + str(self.sql))
-        self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
+        self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id,
+                                 schema=self.database)
         self.hook.run(self.sql, self.autocommit, parameters=self.parameters)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/tests/operators/operators.py
----------------------------------------------------------------------
diff --git a/tests/operators/operators.py b/tests/operators/operators.py
index 7aaf12e..19901ae 100644
--- a/tests/operators/operators.py
+++ b/tests/operators/operators.py
@@ -114,6 +114,27 @@ class MySqlTest(unittest.TestCase):
             dag=self.dag)
         t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
+    def test_overwrite_schema(self):
+        """
+        Verifies option to overwrite connection schema
+        """
+        import airflow.operators.mysql_operator
+
+        sql = "SELECT 1;"
+        t = operators.mysql_operator.MySqlOperator(
+            task_id='test_mysql_operator_test_schema_overwrite',
+            sql=sql,
+            dag=self.dag,
+            database="foobar",
+        )
+
+        from _mysql_exceptions import OperationalError
+        try:
+            t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+                  ignore_ti_state=True)
+        except OperationalError as e:
+            assert "Unknown database 'foobar'" in str(e)
+
 
 @skipUnlessImported('airflow.operators.postgres_operator', 'PostgresOperator')
 class PostgresTest(unittest.TestCase):
@@ -193,6 +214,28 @@ class PostgresTest(unittest.TestCase):
             autocommit=True)
         t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
+    def test_overwrite_schema(self):
+        """
+        Verifies option to overwrite connection schema
+        """
+        import airflow.operators.postgres_operator
+
+        sql = "SELECT 1;"
+        t = operators.postgres_operator.PostgresOperator(
+            task_id='postgres_operator_test_schema_overwrite',
+            sql=sql,
+            dag=self.dag,
+            autocommit=True,
+            database="foobar",
+        )
+
+        from psycopg2._psycopg import OperationalError
+        try:
+            t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+                  ignore_ti_state=True)
+        except OperationalError as e:
+            assert 'database "foobar" does not exist' in str(e)
+
 
 @skipUnlessImported('airflow.operators.hive_operator', 'HiveOperator')
 @skipUnlessImported('airflow.operators.postgres_operator', 'PostgresOperator')

Reply via email to