Repository: incubator-airflow Updated Branches: refs/heads/master b65dc43d2 -> acc9a3617
[AIRFLOW-2228] Enhancements in ValueCheckOperator Allow ValueCheckOperator to accept a tolerance of 1. Modify pass_value to be a template field, so that its value can be determined at runtime. Add tolerance value in airflow exception. This gives an idea about the allowed range for resultant records. Closes #3149 from sakshi2894/AIRFLOW-2228 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/acc9a361 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/acc9a361 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/acc9a361 Branch: refs/heads/master Commit: acc9a3617e7d4520bb95191c80f4d3d2e64f622d Parents: b65dc43 Author: Sakshi Bansal <saks...@qubole.com> Authored: Mon Mar 26 21:15:22 2018 +0200 Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com> Committed: Mon Mar 26 21:15:22 2018 +0200 ---------------------------------------------------------------------- airflow/operators/check_operator.py | 30 ++++++---- tests/operators/test_check_operator.py | 90 +++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/acc9a361/airflow/operators/check_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index ff82539..9994671 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -115,7 +115,7 @@ class ValueCheckOperator(BaseOperator): __mapper_args__ = { 'polymorphic_identity': 'ValueCheckOperator' } - template_fields = ('sql',) + template_fields = ('sql', 'pass_value',) template_ext = ('.hql', '.sql',) ui_color = '#fff7e6' @@ -127,10 +127,9 @@ class ValueCheckOperator(BaseOperator): super(ValueCheckOperator, self).__init__(*args, **kwargs) self.sql = sql self.conn_id = conn_id - self.pass_value = _convert_to_float_if_possible(pass_value) + self.pass_value = str(pass_value) tol = _convert_to_float_if_possible(tolerance) self.tol = tol if isinstance(tol, float) else None - self.is_numeric_value_check = isinstance(self.pass_value, float) self.has_tolerance = self.tol is not None def execute(self, context=None): @@ -138,23 +137,32 @@ class ValueCheckOperator(BaseOperator): records = self.get_db_hook().get_first(self.sql) if not records: raise AirflowException("The query returned None") - test_results = [] - except_temp = ("Test failed.\nPass value:{self.pass_value}\n" + + pass_value_conv = _convert_to_float_if_possible(self.pass_value) + is_numeric_value_check = isinstance(pass_value_conv, float) + + tolerance_pct_str = None + if (self.tol is not None): + tolerance_pct_str = str(self.tol * 100) + '%' + + except_temp = ("Test failed.\nPass value:{pass_value_conv}\n" + "Tolerance:{tolerance_pct_str}\n" "Query:\n{self.sql}\nResults:\n{records!s}") - if not self.is_numeric_value_check: - tests = [str(r) == self.pass_value for r in records] - elif self.is_numeric_value_check: + if not is_numeric_value_check: + tests = [str(r) == pass_value_conv for r in records] + elif is_numeric_value_check: try: num_rec = [float(r) for r in records] except (ValueError, TypeError) as e: cvestr = "Converting a result to float failed.\n" - raise AirflowException(cvestr+except_temp.format(**locals())) + raise AirflowException(cvestr + except_temp.format(**locals())) if self.has_tolerance: tests = [ - r / (1 + self.tol) <= self.pass_value <= r / (1 - self.tol) + pass_value_conv * (1 - self.tol) <= + r <= pass_value_conv * (1 + self.tol) for r in num_rec] else: - tests = [r == self.pass_value for r in num_rec] + tests = [r == pass_value_conv for r in num_rec] if not all(tests): raise AirflowException(except_temp.format(**locals())) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/acc9a361/tests/operators/test_check_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py new file mode 100644 index 0000000..903d547 --- /dev/null +++ b/tests/operators/test_check_operator.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime +from airflow.models import DAG +from airflow.exceptions import AirflowException +from airflow.operators.check_operator import ValueCheckOperator + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +class ValueCheckOperatorTest(unittest.TestCase): + + def setUp(self): + self.task_id = 'test_task' + self.conn_id = 'default_conn' + + def __construct_operator(self, sql, pass_value, tolerance=None): + + dag = DAG('test_dag', start_date=datetime(2017, 1, 1)) + + return ValueCheckOperator( + dag=dag, + task_id=self.task_id, + conn_id=self.conn_id, + sql=sql, + pass_value=pass_value, + tolerance=tolerance) + + def test_pass_value_template_string(self): + pass_value_str = "2018-03-22" + operator = self.__construct_operator('select date from tab1;', "{{ ds }}") + result = operator.render_template('pass_value', operator.pass_value, + {'ds': pass_value_str}) + + self.assertEqual(operator.task_id, self.task_id) + self.assertEqual(result, pass_value_str) + + def test_pass_value_template_string_float(self): + pass_value_float = 4.0 + operator = self.__construct_operator('select date from tab1;', pass_value_float) + result = operator.render_template('pass_value', operator.pass_value, {}) + + self.assertEqual(operator.task_id, self.task_id) + self.assertEqual(result, str(pass_value_float)) + + @mock.patch.object(ValueCheckOperator, 'get_db_hook') + def test_execute_pass(self, mock_get_db_hook): + + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [10] + mock_get_db_hook.return_value = mock_hook + + sql = 'select value from tab1 limit 1;' + + operator = self.__construct_operator(sql, 5, 1) + + operator.execute(None) + + mock_hook.get_first.assert_called_with(sql) + + @mock.patch.object(ValueCheckOperator, 'get_db_hook') + def test_execute_fail(self, mock_get_db_hook): + + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [11] + mock_get_db_hook.return_value = mock_hook + + operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1) + + with self.assertRaisesRegexp(AirflowException, 'Tolerance:100.0%'): + operator.execute()