This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch log_queries
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8b2f9c79e311515f5bb82627cf739d99d04d8f6c
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Mon Apr 8 11:47:34 2024 -0700

    feat: improve event logging for queries + refactor
    
    The driver for this PR was to enrich event logging around database engine 
and database drivers for events that interact directly with analytics databases.
    
    Digging a bit into the logging code:
    - I realized that `request.form` is empty when pushing JSON payload, and
      `request.json` should be used. This should automatically capture more
      automated logging that parses context out of the request object
      proactively
    - Adding an event `execute_sql` that targets just that, that **isn't**
      called when hitting the cache for instance. Using the context manager
      insures capturing the duration of the call as well.
    - a bit of refactor here and there
---
 superset/config.py      |  3 ++
 superset/models/core.py | 83 +++++++++++++++++++++++++------------------------
 superset/sql_lab.py     | 37 ++++++++++++----------
 superset/utils/core.py  |  7 +++++
 superset/utils/log.py   | 65 ++++++++++++++++++++++++++++----------
 5 files changed, 123 insertions(+), 72 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index 1b06f96db8..90a31568e1 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -74,6 +74,9 @@ if TYPE_CHECKING:
 
 # Realtime stats logger, a StatsD implementation exists
 STATS_LOGGER = DummyStatsLogger()
+
+# By default will log events to the metadata database with `DBEventLogger`
+# Note that you can use `StdOutEventLogger` for debugging
 EVENT_LOGGER = DBEventLogger()
 
 SUPERSET_LOG_VIEW = True
diff --git a/superset/models/core.py b/superset/models/core.py
index 92f6946f1e..69fc2bc8ac 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -66,6 +66,7 @@ from superset.db_engine_specs.base import MetricType, 
TimeGrain
 from superset.extensions import (
     cache_manager,
     encrypted_field_factory,
+    event_logger,
     security_manager,
     ssh_manager_factory,
 )
@@ -559,6 +560,20 @@ class Database(
         """
         return self.db_engine_spec.get_default_schema_for_query(self, query)
 
+    @staticmethod
+    def post_process_df(df: pd.DataFrame) -> pd.DataFrame:
+        def column_needs_conversion(df_series: pd.Series) -> bool:
+            return (
+                not df_series.empty
+                and isinstance(df_series, pd.Series)
+                and isinstance(df_series[0], (list, dict))
+            )
+
+        for col, coltype in df.dtypes.to_dict().items():
+            if coltype == numpy.object_ and column_needs_conversion(df[col]):
+                df[col] = df[col].apply(utils.json_dumps_w_dates)
+        return df
+
     @property
     def quote_identifier(self) -> Callable[[str], str]:
         """Add quotes to potential identifier expressions if needed"""
@@ -576,15 +591,15 @@ class Database(
         sqls = self.db_engine_spec.parse_sql(sql)
         with self.get_sqla_engine_with_context(schema) as engine:
             engine_url = engine.url
-        mutate_after_split = config["MUTATE_AFTER_SPLIT"]
-        sql_query_mutator = config["SQL_QUERY_MUTATOR"]
 
-        def needs_conversion(df_series: pd.Series) -> bool:
-            return (
-                not df_series.empty
-                and isinstance(df_series, pd.Series)
-                and isinstance(df_series[0], (list, dict))
-            )
+        def _mutate_sql_if_needed(sql_: str) -> str:
+            if config["MUTATE_AFTER_SPLIT"]:
+                return config["SQL_QUERY_MUTATOR"](
+                    sql_,
+                    security_manager=security_manager,
+                    database=None,
+                )
+            return sql_
 
         def _log_query(sql: str) -> None:
             if log_query:
@@ -598,42 +613,30 @@ class Database(
 
         with self.get_raw_connection(schema=schema) as conn:
             cursor = conn.cursor()
-            for sql_ in sqls[:-1]:
-                if mutate_after_split:
-                    sql_ = sql_query_mutator(
-                        sql_,
-                        security_manager=security_manager,
-                        database=None,
-                    )
+            df = None
+            for i, sql_ in enumerate(sqls):
+                sql_ = _mutate_sql_if_needed(sql_)
                 _log_query(sql_)
-                self.db_engine_spec.execute(cursor, sql_, self)
-                cursor.fetchall()
-
-            if mutate_after_split:
-                last_sql = sql_query_mutator(
-                    sqls[-1],
-                    security_manager=security_manager,
-                    database=None,
-                )
-                _log_query(last_sql)
-                self.db_engine_spec.execute(cursor, last_sql, self)
-            else:
-                _log_query(sqls[-1])
-                self.db_engine_spec.execute(cursor, sqls[-1], self)
-
-            data = self.db_engine_spec.fetch_data(cursor)
-            result_set = SupersetResultSet(
-                data, cursor.description, self.db_engine_spec
-            )
-            df = result_set.to_pandas_df()
+                with event_logger.log_context(
+                    action="execute_sql",
+                    database=self,
+                    object_ref=__name__,
+                ):
+                    self.db_engine_spec.execute(cursor, sql_, self)
+                    if i < len(sqls) - 1:
+                        # If it's not the last, we don't keep the results
+                        cursor.fetchall()
+                    else:
+                        # Last query, fetch and process the results
+                        data = self.db_engine_spec.fetch_data(cursor)
+                        result_set = SupersetResultSet(
+                            data, cursor.description, self.db_engine_spec
+                        )
+                        df = result_set.to_pandas_df()
             if mutator:
                 df = mutator(df)
 
-            for col, coltype in df.dtypes.to_dict().items():
-                if coltype == numpy.object_ and needs_conversion(df[col]):
-                    df[col] = df[col].apply(utils.json_dumps_w_dates)
-
-            return df
+            return self.post_process_df(df)
 
     def compile_sqla_query(self, qry: Select, schema: str | None = None) -> 
str:
         with self.get_sqla_engine_with_context(schema) as engine:
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index e34f7e2fde..cffd6c457a 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -46,7 +46,7 @@ from superset.exceptions import (
     SupersetErrorException,
     SupersetErrorsException,
 )
-from superset.extensions import celery_app
+from superset.extensions import celery_app, event_logger
 from superset.models.core import Database
 from superset.models.sql_lab import Query
 from superset.result_set import SupersetResultSet
@@ -281,21 +281,26 @@ def execute_sql_statement(  # pylint: 
disable=too-many-statements
                 log_params,
             )
         db.session.commit()
-        with stats_timing("sqllab.query.time_executing_query", stats_logger):
-            db_engine_spec.execute_with_cursor(cursor, sql, query)
-
-        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
-            logger.debug(
-                "Query %d: Fetching data for query object: %s",
-                query.id,
-                str(query.to_dict()),
-            )
-            data = db_engine_spec.fetch_data(cursor, increased_limit)
-            if query.limit is None or len(data) <= query.limit:
-                query.limiting_factor = LimitingFactor.NOT_LIMITED
-            else:
-                # return 1 row less than increased_query
-                data = data[:-1]
+        with event_logger.log_context(
+            action="execute_sql",
+            database=database,
+            object_ref=__name__,
+        ):
+            with stats_timing("sqllab.query.time_executing_query", 
stats_logger):
+                db_engine_spec.execute_with_cursor(cursor, sql, query)
+
+            with stats_timing("sqllab.query.time_fetching_results", 
stats_logger):
+                logger.debug(
+                    "Query %d: Fetching data for query object: %s",
+                    query.id,
+                    str(query.to_dict()),
+                )
+                data = db_engine_spec.fetch_data(cursor, increased_limit)
+                if query.limit is None or len(data) <= query.limit:
+                    query.limiting_factor = LimitingFactor.NOT_LIMITED
+                else:
+                    # return 1 row less than increased_query
+                    data = data[:-1]
     except SoftTimeLimitExceeded as ex:
         query.status = QueryStatus.TIMED_OUT
 
diff --git a/superset/utils/core.py b/superset/utils/core.py
index de1034ddb0..84544e69ec 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1907,3 +1907,10 @@ def remove_extra_adhoc_filters(form_data: dict[str, 
Any]) -> None:
         form_data[key] = [
             filter_ for filter_ in value or [] if not filter_.get("isExtra")
         ]
+
+
+def to_int(v: Any, value_if_invalid: int = 0) -> int:
+    try:
+        return int(v)
+    except (ValueError, TypeError):
+        return value_if_invalid
diff --git a/superset/utils/log.py b/superset/utils/log.py
index 1de599bf08..fece0aa8a8 100644
--- a/superset/utils/log.py
+++ b/superset/utils/log.py
@@ -32,7 +32,7 @@ from flask_appbuilder.const import API_URI_RIS_KEY
 from sqlalchemy.exc import SQLAlchemyError
 
 from superset.extensions import stats_logger_manager
-from superset.utils.core import get_user_id, LoggerLevel
+from superset.utils.core import get_user_id, LoggerLevel, to_int
 
 if TYPE_CHECKING:
     from superset.stats_logger import BaseStatsLogger
@@ -47,10 +47,11 @@ def collect_request_payload() -> dict[str, Any]:
 
     payload: dict[str, Any] = {
         "path": request.path,
-        **request.form.to_dict(),
-        # url search params can overwrite POST body
-        **request.args.to_dict(),
     }
+    payload.update(**request.form.to_dict())
+    payload.update(**request.args.to_dict())
+    if request.is_json:
+        payload.update(request.json)
 
     # save URL match pattern in addition to the request path
     url_rule = str(request.url_rule)
@@ -136,6 +137,7 @@ class AbstractEventLogger(ABC):
         duration: timedelta | None = None,
         object_ref: str | None = None,
         log_to_statsd: bool = True,
+        database: Any | None = None,
         **payload_override: dict[str, Any] | None,
     ) -> None:
         # pylint: disable=import-outside-toplevel
@@ -161,15 +163,19 @@ class AbstractEventLogger(ABC):
 
         payload = collect_request_payload()
         if object_ref:
-            payload["object_ref"] = object_ref
+            payload["object_ref"] = str(object_ref)
         if payload_override:
             payload.update(payload_override)
 
-        dashboard_id: int | None = None
-        try:
-            dashboard_id = int(payload.get("dashboard_id"))  # type: ignore
-        except (TypeError, ValueError):
-            dashboard_id = None
+        dashboard_id = to_int(payload.get("dashboard_id"))
+
+        database_params = {"database_id": payload.get("database_id")}
+        if database:
+            database_params = {
+                "database_id": database.id,
+                "engine": database.backend,
+                "database_driver": database.driver,
+            }
 
         if "form_data" in payload:
             form_data, _ = get_form_data()
@@ -178,10 +184,7 @@ class AbstractEventLogger(ABC):
         else:
             slice_id = payload.get("slice_id")
 
-        try:
-            slice_id = int(slice_id)  # type: ignore
-        except (TypeError, ValueError):
-            slice_id = 0
+        slice_id = to_int(slice_id)
 
         if log_to_statsd:
             stats_logger_manager.instance.incr(action)
@@ -196,11 +199,13 @@ class AbstractEventLogger(ABC):
         self.log(
             user_id,
             action,
-            records=records,
             dashboard_id=dashboard_id,
+            records=records,
+            object_ref=object_ref,
             slice_id=slice_id,
             duration_ms=duration_ms,
             referrer=referrer,
+            **database_params,
         )
 
     @contextmanager
@@ -209,6 +214,7 @@ class AbstractEventLogger(ABC):
         action: str,
         object_ref: str | None = None,
         log_to_statsd: bool = True,
+        **kwargs: Any,
     ) -> Iterator[Callable[..., None]]:
         """
         Log an event with additional information from the request context.
@@ -216,7 +222,7 @@ class AbstractEventLogger(ABC):
         :param object_ref: reference to the Python object that triggered this 
action
         :param log_to_statsd: whether to update statsd counter for the action
         """
-        payload_override = {}
+        payload_override = kwargs.copy()
         start = datetime.now()
         # yield a helper to add additional payload
         yield lambda **kwargs: payload_override.update(kwargs)
@@ -359,3 +365,30 @@ class DBEventLogger(AbstractEventLogger):
         except SQLAlchemyError as ex:
             logging.error("DBEventLogger failed to log event(s)")
             logging.exception(ex)
+
+
+class StdOutEventLogger(AbstractEventLogger):
+    """Event logger that prints to stdout for debugging purposes"""
+
+    def log(  # pylint: disable=too-many-arguments,too-many-locals
+        self,
+        user_id: int | None,
+        action: str,
+        dashboard_id: int | None,
+        duration_ms: int | None,
+        slice_id: int | None,
+        referrer: str | None,
+        *args: Any,
+        **kwargs: Any,
+    ) -> None:
+        print("-=" * 20)
+        d = dict(
+            user_id=user_id,
+            action=action,
+            dashboard_id=dashboard_id,
+            duration_ms=duration_ms,
+            slice_id=slice_id,
+            referrer=referrer,
+            **kwargs,
+        )
+        print("StdOutEventLogger: ", d)

Reply via email to