denimalpaca commented on code in PR #24476:
URL: https://github.com/apache/airflow/pull/24476#discussion_r906112005


##########
airflow/providers/core/sql/operators/sql.py:
##########
@@ -0,0 +1,308 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any, Dict, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.operators.sql import BaseSQLOperator
+
+
+def parse_boolean(val: str) -> Union[str, bool]:
+    """Try to parse a string into boolean.
+
+    Raises ValueError if the input is not a valid true- or false-like string 
value.
+    """
+    val = val.lower()
+    if val in ('y', 'yes', 't', 'true', 'on', '1'):
+        return True
+    if val in ('n', 'no', 'f', 'false', 'off', '0'):
+        return False
+    raise ValueError(f"{val!r} is not a boolean-like string value")
+
+
+def _get_failed_tests(checks):
+    return [
+        f"\tCheck: {check},\n\tCheck Values: {check_values}\n"
+        for check, check_values in checks.items()
+        if not check_values["success"]
+    ]
+
+
+class SQLColumnCheckOperator(BaseSQLOperator):
+    """
+    Performs one or more of the templated checks in the column_checks 
dictionary.
+    Checks are performed on a per-column basis specified by the column_mapping.
+    Each check can take one or more of the following options:
+    - equal_to: an exact value to equal, cannot be used with other comparison 
options
+    - greater_than: value that result should be strictly greater than
+    - less_than: value that results should be strictly less than
+    - geq_than: value that results should be greater than or equal to
+    - leq_than: value that results should be less than or equal to
+    - tolerance: the percentage that the result may be off from the expected 
value
+
+    :param table: the table to run checks on
+    :param column_mapping: the dictionary of columns and their associated 
checks, e.g.
+
+    .. code-block:: python
+
+        {
+            "col_name": {
+                "null_check": {
+                    "equal_to": 0,
+                },
+                "min": {
+                    "greater_than": 5,
+                    "leq_than": 10,
+                    "tolerance": 0.2,
+                },
+                "max": {"less_than": 1000, "geq_than": 10, "tolerance": 0.01},
+            }
+        }
+
+    :param conn_id: the connection ID used to connect to the database
+    :param database: name of database which overwrite the defined one in 
connection
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SQLColumnCheckOperator`
+    """
+
+    column_checks = {
+        "null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS 
column_null_check",
+        "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check",
+        "unique_check": "COUNT(column) - COUNT(DISTINCT(column)) AS 
column_unique_check",
+        "min": "MIN(column) AS column_min",
+        "max": "MAX(column) AS column_max",
+    }
+
+    def __init__(
+        self,
+        *,
+        table: str,
+        column_mapping: Dict[str, Dict[str, Any]],
+        conn_id: Optional[str] = None,
+        database: Optional[str] = None,
+        **kwargs,
+    ):
+        super().__init__(conn_id=conn_id, database=database, **kwargs)
+        for checks in column_mapping.values():
+            for check, check_values in checks.items():
+                self._column_mapping_validation(check, check_values)
+
+        self.table = table
+        self.column_mapping = column_mapping
+        # OpenLineage needs a valid SQL query with the input/output table(s) 
to parse
+        self.sql = f"SELECT * FROM {self.table};"
+
+    def execute(self, context=None):
+        hook = self.get_db_hook()
+        failed_tests = []
+        for column in self.column_mapping:
+            checks = [*self.column_mapping[column]]
+            checks_sql = ",".join([self.column_checks[check].replace("column", 
column) for check in checks])
+
+            self.sql = f"SELECT {checks_sql} FROM {self.table};"
+            records = hook.get_first(self.sql)
+
+            if not records:
+                raise AirflowException(f"The following query returned zero 
rows: {self.sql}")
+
+            self.log.info(f"Record: {records}")
+
+            for idx, result in enumerate(records):
+                tolerance = 
self.column_mapping[column][checks[idx]].get("tolerance")
+
+                self.column_mapping[column][checks[idx]]["result"] = result
+                self.column_mapping[column][checks[idx]]["success"] = 
self._get_match(
+                    self.column_mapping[column][checks[idx]], result, tolerance
+                )
+
+            failed_tests.extend(_get_failed_tests(self.column_mapping[column]))
+        if failed_tests:
+            raise AirflowException(
+                f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
+                "The following tests have failed:"
+                f"\n{''.join(failed_tests)}"
+            )
+
+        self.log.info("All tests have passed")
+
+    def _get_match(self, check_values, record, tolerance=None) -> bool:
+        if "geq_than" in check_values:
+            if tolerance is not None:
+                return record >= check_values["geq_than"] * (1 - tolerance)
+            return record >= check_values["geq_than"]
+        elif "greater_than" in check_values:
+            if tolerance is not None:
+                return record > check_values["greater_than"] * (1 - tolerance)
+            return record > check_values["greater_than"]
+        if "leq_than" in check_values:
+            if tolerance is not None:
+                return record <= check_values["leq_than"] * (1 + tolerance)
+            return record <= check_values["leq_than"]
+        elif "less_than" in check_values:
+            if tolerance is not None:
+                return record < check_values["less_than"] * (1 + tolerance)
+            return record < check_values["less_than"]
+        if "equal_to" in check_values:
+            if tolerance is not None:
+                return (
+                    check_values["equal_to"] * (1 - tolerance)
+                    <= record
+                    <= check_values["equal_to"] * (1 + tolerance)
+                )
+        return record == check_values["equal_to"]

Review Comment:
   So, this is actually causing a static check error as the function needs a 
non-nested return statement; as I put in my commit message to make this change, 
it doesn't actually matter where the return statement is logic-wise because 
there is already a check that ensures that `check_values` contains some correct 
value, so having a catch-all return value at the end with `equal_to` should be 
fine.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to