This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 29dce5ea8ff8b22a7f1f97a9407d2dc11a1b06d2
Author: Domino Valdano <dom...@apache.org>
AuthorDate: Thu Oct 7 10:26:44 2021 -0700

    DEBUG: Adds WithTracebackForwarding() macro and report_segment_tracebacks 
param
    
    This adds a new macro WithTracebackForwarding() to SQLCommon.m4_in,
    which can be used to cause a UDF intended to run on the segments to
    forward any traceback information attached to an exception back to
    the coordinator.
    
    In order to have the coordinator intercept the forwarded traceback
    message and attach it to the DETAILS of the exception thrown on
    coordinator, there is also a new optinoal flag for DEBUG.plpy_execute(),
    which must be set to True:
    
    DEBUG.plpy_execute(..., report_segment_tracebacks=True)
    
    WithTracebackForwarding() should wrap any python statement which might
    raise an exception on a segment.  It will enclose it in an appropriate
    try/except block and handle the exception.
    
    For example, see definition of plpython UDF __dbscan_leaf() which
    is called from DEBUG.plpy_execute() on coordinator and gets run on the
    segments:
    
    CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._dbscan_leaf(
        ...
    ) RETURNS SETOF MADLIB_SCHEMA.dbscan_record AS
    $$
    ...
        PythonFunctionBodyOnlyNoSchema(dbscan,dbscan)
        WithTracebackForwarding(return dbscan.dbscan_leaf(*args))
    ...
    $$ LANGUAGE plpythonu VOLATILE
---
 src/ports/postgres/madpack/SQLCommon.m4_in       | 41 +++++++++++++++++++++++
 src/ports/postgres/modules/utilities/debug.py_in | 42 +++++++++++++++++++-----
 2 files changed, 75 insertions(+), 8 deletions(-)

diff --git a/src/ports/postgres/madpack/SQLCommon.m4_in 
b/src/ports/postgres/madpack/SQLCommon.m4_in
index cc58ea2..4a7c042 100644
--- a/src/ports/postgres/madpack/SQLCommon.m4_in
+++ b/src/ports/postgres/madpack/SQLCommon.m4_in
@@ -14,6 +14,47 @@
 m4_changequote(<!,!>)
 
 /*
+ * WithTracebackForwarding
+ *
+ * @param $1 python statement which might raise an exception
+ *
+ * Use this macro in the sql definition of a plpythonu function
+ *   that runs on the segments.  If the function raises an exception,
+ *   traceback information will be attached to the exception message
+ *   which gets forwarded back to the coordinator.
+ *
+ * On the coordinator side, to attach the message to the DETAIL of the
+ *   exception before displaying, you must call the segment UDF
+ *   or UDA like this:
+ *
+ *   DEBUG.plpy_execute(sql, ..., segment_traceback_reporting=True)
+ */
+m4_define(<!WithTracebackForwarding!>, <!
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        $1
+    except Exception as e:
+        global SD
+        global GD
+
+        for k in SD.keys():
+            del SD[k]
+        del SD
+        for k in GD.keys():
+            del GD[k]
+        del GD
+
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.message + 'SegmentTraceback' + detail
+        e.message = message
+        e.args = (message,)
+        raise e
+!>)
+
+/*
  * PythonFunction
  *
  * @param $1 directory
diff --git a/src/ports/postgres/modules/utilities/debug.py_in 
b/src/ports/postgres/modules/utilities/debug.py_in
index 05fc880..33fb21b 100644
--- a/src/ports/postgres/modules/utilities/debug.py_in
+++ b/src/ports/postgres/modules/utilities/debug.py_in
@@ -19,11 +19,9 @@
 
 import plpy as plpy_orig
 import time
-from deep_learning.madlib_keras_model_selection import ModelSelectionSchema
-from deep_learning.madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
 
-mst_key_col = ModelSelectionSchema.MST_KEY
-dist_key_col = DISTRIBUTION_KEY_COLNAME
+mst_key_col = 'mst_key'
+dist_key_col = '__dist_id__'
 
 start_times = dict()
 timings_enabled = False
@@ -120,13 +118,18 @@ def plpy_prepare(*args, **kwargs):
 
 plpy_execute_enabled = False
 def plpy_execute(*args, **kwargs):
-    """ debug.plpy.execute(q, ..., force=False)
+    """ debug.plpy.execute(q, ..., force=False, 
report_segment_tracebacks=False)
 
         Replace plpy.execute(q, ...) with
         debug.plpy.execute(q, ...) to debug
-        a query.  Shows the query itself, the
-        EXPLAIN of it, and how long the query
+        a query.  If enabled, shows the query itself,
+        the EXPLAIN of it, and how long the query
         takes to execute.
+
+        If report_segment_tracebacks=True, any tracebacks forwarded from
+        WithTracebackForwarding() on the segment will be attached to
+        the DETAILS of the ERROR message
+
     """
 
     force = False
@@ -134,6 +137,11 @@ def plpy_execute(*args, **kwargs):
         force = kwargs['force']
         del kwargs['force']
 
+    report_segment_tracebacks=False
+    if 'report_segment_tracebacks' in kwargs:
+        report_segment_tracebacks = kwargs['report_segment_tracebacks']
+        del kwargs['report_segment_tracebacks']
+
     plpy = plpy_orig # override global plpy,
                      # to avoid infinite recursion
 
@@ -164,13 +172,31 @@ def plpy_execute(*args, **kwargs):
     explain_query = "EXPLAIN" + sql
     if prep:
         explain_query = plpy.prepare(explain_query, *prep.args, **prep.kwargs)
+
     res = plpy.execute(explain_query, *args[1:], **kwargs)
     for r in res:
         plpy.info(r['QUERY PLAN'])
 
     # Run actual sql command, with timing
     start = time.time()
-    res = plpy.execute(*args, **kwargs)
+    if report_segment_tracebacks:
+        try:
+            res = plpy.execute(*args, **kwargs)
+        except plpy.SPIError as e:
+            msg = e.message
+            if 'SegmentTraceback' in msg:
+                e.message, detail = msg.split('SegmentTraceback')
+            else:
+                raise e
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+    else:
+        res = plpy.execute(*args, **kwargs)
 
     # Print how long execution of query took
     plpy.info("Query took {0}s".format(time.time() - start))

Reply via email to