This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 90e7b3fd05 Fix `total_entries` count on the event logs endpoint 
(#38625)
90e7b3fd05 is described below

commit 90e7b3fd057be1c014d0dd94eafbea14d46536b2
Author: Jed Cunningham <66968678+jedcunning...@users.noreply.github.com>
AuthorDate: Sat Mar 30 02:17:59 2024 -0400

    Fix `total_entries` count on the event logs endpoint (#38625)
    
    The `total_entries` count should reflect how many log entries match the
    filters provided, not simply how many rows are in the table total.
---
 airflow/api_connexion/endpoints/event_log_endpoint.py  | 10 ++++++----
 .../api_connexion/endpoints/test_event_log_endpoint.py | 18 +++++++++++++-----
 2 files changed, 19 insertions(+), 9 deletions(-)

diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py 
b/airflow/api_connexion/endpoints/event_log_endpoint.py
index 68321d3324..3b3dbe6efd 100644
--- a/airflow/api_connexion/endpoints/event_log_endpoint.py
+++ b/airflow/api_connexion/endpoints/event_log_endpoint.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING
 
-from sqlalchemy import func, select
+from sqlalchemy import select
 
 from airflow.api_connexion import security
 from airflow.api_connexion.exceptions import NotFound
@@ -31,6 +31,7 @@ from airflow.api_connexion.schemas.event_log_schema import (
 from airflow.auth.managers.models.resource_details import DagAccessEntity
 from airflow.models import Log
 from airflow.utils import timezone
+from airflow.utils.db import get_query_count
 from airflow.utils.session import NEW_SESSION, provide_session
 
 if TYPE_CHECKING:
@@ -70,7 +71,7 @@ def get_event_logs(
 ) -> APIResponse:
     """Get all log entries from event log."""
     to_replace = {"event_log_id": "id", "when": "dttm"}
-    allowed_filter_attrs = [
+    allowed_sort_attrs = [
         "event_log_id",
         "when",
         "dag_id",
@@ -81,7 +82,6 @@ def get_event_logs(
         "owner",
         "extra",
     ]
-    total_entries = session.scalars(func.count(Log.id)).one()
     query = select(Log)
 
     if dag_id:
@@ -105,7 +105,9 @@ def get_event_logs(
     if after:
         query = query.where(Log.dttm > timezone.parse(after))
 
-    query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
+    total_entries = get_query_count(query, session=session)
+
+    query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)
     event_logs = session.scalars(query.offset(offset).limit(limit)).all()
     return event_log_collection_schema.dump(
         EventLogCollection(event_logs=event_logs, total_entries=total_entries)
diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py 
b/tests/api_connexion/endpoints/test_event_log_endpoint.py
index 1bb79dd4b3..6e71a86b94 100644
--- a/tests/api_connexion/endpoints/test_event_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py
@@ -281,23 +281,27 @@ class TestGetEventLogs(TestEventLogEndpoint):
                 f"/api/v1/eventLogs?{attr}={attr_value}", 
environ_overrides={"REMOTE_USER": "test_granular"}
             )
             assert response.status_code == 200
-            assert {eventlog[attr] for eventlog in 
response.json["event_logs"]} == {attr_value}
+            assert response.json["total_entries"] == 1
+            assert len(response.json["event_logs"]) == 1
+            assert response.json["event_logs"][0][attr] == attr_value
 
     def test_should_filter_eventlogs_by_when(self, create_log_model, session):
         eventlog1 = create_log_model(event="TEST_EVENT_1", 
when=self.default_time)
         eventlog2 = create_log_model(event="TEST_EVENT_2", 
when=self.default_time_2)
         session.add_all([eventlog1, eventlog2])
         session.commit()
-        for when_attr, expected_eventlogs in {
-            "before": {"TEST_EVENT_1"},
-            "after": {"TEST_EVENT_2"},
+        for when_attr, expected_eventlog_event in {
+            "before": "TEST_EVENT_1",
+            "after": "TEST_EVENT_2",
         }.items():
             response = self.client.get(
                 
f"/api/v1/eventLogs?{when_attr}=2020-06-10T20%3A00%3A01%2B00%3A00",  # 
self.default_time + 1s
                 environ_overrides={"REMOTE_USER": "test"},
             )
             assert response.status_code == 200
-            assert {eventlog["event"] for eventlog in 
response.json["event_logs"]} == expected_eventlogs
+            assert response.json["total_entries"] == 1
+            assert len(response.json["event_logs"]) == 1
+            assert response.json["event_logs"][0]["event"] == 
expected_eventlog_event
 
     def test_should_filter_eventlogs_by_run_id(self, create_log_model, 
session):
         eventlog1 = create_log_model(event="TEST_EVENT_1", 
when=self.default_time, run_id="run_1")
@@ -314,6 +318,8 @@ class TestGetEventLogs(TestEventLogEndpoint):
                 environ_overrides={"REMOTE_USER": "test"},
             )
             assert response.status_code == 200
+            assert response.json["total_entries"] == len(expected_eventlogs)
+            assert len(response.json["event_logs"]) == len(expected_eventlogs)
             assert {eventlog["event"] for eventlog in 
response.json["event_logs"]} == expected_eventlogs
             assert all({eventlog["run_id"] == run_id for eventlog in 
response.json["event_logs"]})
 
@@ -327,6 +333,7 @@ class TestGetEventLogs(TestEventLogEndpoint):
         assert response.status_code == 200
         response_data = response.json
         assert len(response_data["event_logs"]) == 2
+        assert response_data["total_entries"] == 2
         assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in 
response_data["event_logs"]}
 
     def test_should_filter_eventlogs_by_excluded_events(self, 
create_log_model):
@@ -339,6 +346,7 @@ class TestGetEventLogs(TestEventLogEndpoint):
         assert response.status_code == 200
         response_data = response.json
         assert len(response_data["event_logs"]) == 1
+        assert response_data["total_entries"] == 1
         assert {"cli_scheduler"} == {x["event"] for x in 
response_data["event_logs"]}
 
 

Reply via email to