This is an automated email from the ASF dual-hosted git repository. dimberman pushed a commit to branch v1-10-stable in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-stable by this push: new 2e1f813 Add DBApiHook check for 2.0 migration (#12730) 2e1f813 is described below commit 2e1f813c35e60d9e13575639bc913d1cbafcd1ff Author: Daniel Imberman <daniel.imber...@gmail.com> AuthorDate: Fri Dec 11 09:19:16 2020 -0800 Add DBApiHook check for 2.0 migration (#12730) * Add DBApiHook check for 2.0 migration Adds a check that ensures that any hook that uses the run, get_pandas_df or get_records functions does not import from the base_hook * exception for grpc_hook * fix plugin * fix plugin * fix plugin * py2 compliance and add full lineage * black * fix --- airflow/upgrade/rules/db_api_functions.py | 97 ++++++++++++++++++++++++++++ tests/upgrade/rules/test_db_api_functions.py | 71 ++++++++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/airflow/upgrade/rules/db_api_functions.py b/airflow/upgrade/rules/db_api_functions.py new file mode 100644 index 0000000..1801c36 --- /dev/null +++ b/airflow/upgrade/rules/db_api_functions.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.hooks.base_hook import BaseHook +from airflow.upgrade.rules.base_rule import BaseRule + + +def check_get_pandas_df(cls): + try: + cls.__new__(cls).get_pandas_df("fake SQL") + return return_error_string(cls, "get_pandas_df") + except NotImplementedError: + pass + except Exception: + return return_error_string(cls, "get_pandas_df") + + +def check_run(cls): + try: + cls.__new__(cls).run("fake SQL") + return return_error_string(cls, "run") + except NotImplementedError: + pass + except Exception: + return return_error_string(cls, "run") + + +def check_get_records(cls): + try: + cls.__new__(cls).get_records("fake SQL") + return return_error_string(cls, "get_records") + except NotImplementedError: + pass + except Exception: + return return_error_string(cls, "get_records") + + +def return_error_string(cls, method): + return ( + "Class {} incorrectly implements the function {} while inheriting from BaseHook. " + "Please make this class inherit from airflow.hooks.db_api_hook.DbApiHook instead".format( + cls, method + ) + ) + + +def get_all_non_dbapi_children(): + basehook_children = [ + child for child in BaseHook.__subclasses__() if child.__name__ != "DbApiHook" + ] + res = basehook_children[:] + while basehook_children: + next_generation = [] + for child in basehook_children: + subclasses = child.__subclasses__() + if subclasses: + next_generation.extend(subclasses) + res.extend(next_generation) + basehook_children = next_generation + return res + + +class DbApiRule(BaseRule): + title = "Hooks that run DB functions must inherit from DBApiHook" + + description = ( + "Hooks that run DB functions must inherit from DBApiHook instead of BaseHook" + ) + + def check(self): + basehook_subclasses = get_all_non_dbapi_children() + incorrect_implementations = [] + for child in basehook_subclasses: + pandas_df = check_get_pandas_df(child) + if pandas_df: + incorrect_implementations.append(pandas_df) + run = check_run(child) + if run: + incorrect_implementations.append(run) + get_records = check_get_records(child) + if get_records: + incorrect_implementations.append(get_records) + return incorrect_implementations diff --git a/tests/upgrade/rules/test_db_api_functions.py b/tests/upgrade/rules/test_db_api_functions.py new file mode 100644 index 0000000..d73a041 --- /dev/null +++ b/tests/upgrade/rules/test_db_api_functions.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import TestCase + +from airflow.hooks.base_hook import BaseHook +from airflow.hooks.dbapi_hook import DbApiHook +from airflow.upgrade.rules.db_api_functions import DbApiRule + + +class MyHook(BaseHook): + def run(self, sql): + pass + + def get_pandas_df(self, sql): + pass + + def get_conn(self): + pass + + +class GrandChildHook(MyHook): + def __init__(self, foo, bar): + self.foo = foo + self.bar = bar + + def get_records(self, sql): + pass + + +class ProperDbApiHook(DbApiHook): + def bulk_dump(self, table, tmp_file): + pass + + def bulk_load(self, table, tmp_file): + pass + + def get_records(self, sql, *kwargs): + pass + + def run(self, sql, *kwargs): + pass + + def get_pandas_df(self, sql, *kwargs): + pass + + +class TestSqlHookCheck(TestCase): + def test_fails_on_incorrect_hook(self): + db_api_rule_failures = DbApiRule().check() + myhook_errors = [d for d in db_api_rule_failures if "MyHook" in d] + grandchild_errors = [d for d in db_api_rule_failures if "GrandChild" in d] + self.assertEqual(len(myhook_errors), 2) + self.assertEqual(len(grandchild_errors), 3) + proper_db_api_hook_failures = [ + failure for failure in db_api_rule_failures if "ProperDbApiHook" in failure + ] + self.assertEqual(len(proper_db_api_hook_failures), 0)