kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r560539401



##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same 
slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = 
model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT 
NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the 
cache on the very first hop
             # For the very_first_hop, we want to run the transition function 
on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore 
we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore 
we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the 
independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var 
as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT 
{dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED 
BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT 
{self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, 
{self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: 
{}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS 
NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is 
False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in 
model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = 
{self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = 
{self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights 
IS NULL".format(self=self))

Review comment:
       It would be interesting to figure out how much extra time a subquery 
would add and if there is a better and efficient way to write this query. But 
that can be done in a future PR. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to