reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537825924



##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same 
slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = 
model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT 
NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the 
cache on the very first hop
             # For the very_first_hop, we want to run the transition function 
on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore 
we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore 
we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the 
independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var 
as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT 
{dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED 
BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT 
{self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, 
{self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: 
{}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS 
NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is 
False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in 
model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = 
{self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = 
{self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights 
IS NULL".format(self=self))

Review comment:
       I was also thinking at first that we could do it that way, using a WHERE 
clause.  But unfortunately, that can only be used to filter the input rows of a 
query, not the output rows.  The only way to filter output rows I can think of 
is to make this query a sub-query of an outer query that does the filtering.
   
   In terms of performance, adding a subquery would be pretty similar to the 
way it is now, I think, except that it would run the filtering operation as a 
separate slice... that mostly still has to wait for the first one to complete 
in order to run.  And I think we want to avoid any extra slices for this one, 
in case it messes up SD/GD.
   
   I also wonder if we could use a GROUP BY with a HAVING clause.  But that 
might also add another slice, and/or slightly reduce performance... possibly 
for the same reason as the UDA ran slightly slower.
   
   The DELETE operation is really fast, since it just scans through the rows 
and marks some as deleted... the only undesirable effect of this is that it 
leaves the rows fragmented on disk, but since we're about to TRUNCATE AND DROP 
it anyway, seems like it shouldn't matter.
   
   We could also change the hop query to filter them out with a WHERE clause 
when it's copying from `model_output_tbl` to `model_input_tbl`.  But then I 
guess we'd have to do something similar for eval.  And essentially, this would 
just be another way of doing the same thing as the DELETE accomplishes, telling 
them both "ignore these rows".
   
   If you have any other ideas, let me know and I can try it.  (Perhaps using 
WITH keyword?  My impression has always been that WITH is just another way to 
introduce a separate sub-query / slice, but not sure.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to