This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 6694eaa  Show the location of the queries when the 
assert_queries_count fails. (#11186)
6694eaa is described below

commit 6694eaa8313f2709f5712c4bf7f03355e843e517
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Mon Sep 28 19:39:21 2020 +0100

    Show the location of the queries when the assert_queries_count fails. 
(#11186)
    
    Example output (I forced one of the existing tests to fail)
    
    ```
    E   AssertionError: The expected number of db queries is 3. The current 
number is 2.
    E
    E   Recorded query locations:
    E           
scheduler_job.py:_run_scheduler_loop>scheduler_job.py:_emit_pool_metrics>pool.py:slots_stats:94:
        1
    E           
scheduler_job.py:_run_scheduler_loop>scheduler_job.py:_emit_pool_metrics>pool.py:slots_stats:101:
       1
    ```
    
    This makes it a bit easier to see what the queries are, without having
    to re-run with full query tracing and then analyze the logs.
---
 tests/test_utils/asserts.py | 39 ++++++++++++++++++++++++++-------------
 1 file changed, 26 insertions(+), 13 deletions(-)

diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py
index ca3cf2f..220331d 100644
--- a/tests/test_utils/asserts.py
+++ b/tests/test_utils/asserts.py
@@ -17,6 +17,8 @@
 
 import logging
 import re
+import traceback
+from collections import Counter
 from contextlib import contextmanager
 
 from sqlalchemy import event
@@ -33,11 +35,6 @@ def assert_equal_ignore_multiple_spaces(case, first, second, 
msg=None):
     return case.assertEqual(_trim(first), _trim(second), msg)
 
 
-class CountQueriesResult:
-    def __init__(self):
-        self.count = 0
-
-
 class CountQueries:
     """
     Counts the number of queries sent to Airflow Database in a given context.
@@ -46,18 +43,26 @@ class CountQueries:
     not be included.
     """
     def __init__(self):
-        self.result = CountQueriesResult()
+        self.result = Counter()
 
     def __enter__(self):
         event.listen(airflow.settings.engine, "after_cursor_execute", 
self.after_cursor_execute)
         return self.result
 
-    def __exit__(self, type_, value, traceback):
+    def __exit__(self, type_, value, tb):
         event.remove(airflow.settings.engine, "after_cursor_execute", 
self.after_cursor_execute)
-        log.debug("Queries count: %d", self.result.count)
+        log.debug("Queries count: %d", sum(self.result.values()))
 
     def after_cursor_execute(self, *args, **kwargs):
-        self.result.count += 1
+        stack = [
+            f for f in traceback.extract_stack()
+            if 'sqlalchemy' not in f.filename and
+               __file__ != f.filename and
+               ('session.py' not in f.filename and f.name != 'wrapper')
+        ]
+        stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" 
for f in stack][-3:])
+        lineno = stack[-1].lineno
+        self.result[f"{stack_info}:{lineno}"] += 1
 
 
 count_queries = CountQueries  # pylint: disable=invalid-name
@@ -67,7 +72,15 @@ count_queries = CountQueries  # pylint: disable=invalid-name
 def assert_queries_count(expected_count, message_fmt=None):
     with count_queries() as result:
         yield None
-    message_fmt = message_fmt or "The expected number of db queries is 
{expected_count}. " \
-                                 "The current number is {current_count}."
-    message = message_fmt.format(current_count=result.count, 
expected_count=expected_count)
-    assert expected_count == result.count, message
+
+    count = sum(result.values())
+    if expected_count != count:
+        message_fmt = message_fmt or "The expected number of db queries is 
{expected_count}. " \
+                                     "The current number is 
{current_count}.\n\n" \
+                                     "Recorded query locations:"
+        message = message_fmt.format(current_count=count, 
expected_count=expected_count)
+
+        for location, count in result.items():
+            message += f'\n\t{location}:\t{count}'
+
+        raise AssertionError(message)

Reply via email to