kaknikhil commented on a change in pull request #425: DL: Add training for 
multiple models
URL: https://github.com/apache/madlib/pull/425#discussion_r309943433
 
 

 ##########
 File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
 ##########
 @@ -0,0 +1,424 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+import time
+import sys
+# Do not remove `import keras` although it's not directly used in this file.
+# For ex if the user passes in the optimizer as keras.optimizers.SGD instead 
of just
+# SGD, then without this import this python file won't find the SGD module
+import keras
+
+# from keras import backend as K
+# from keras import utils as keras_utils
+from keras.layers import *
+from keras.models import *
+from keras.optimizers import *
+from keras.regularizers import *
+import madlib_keras_serializer
+from madlib_keras import compute_loss_and_metrics
+from madlib_keras import get_initial_weights
+from madlib_keras import get_segments_and_gpus
+from madlib_keras import get_source_summary_table_dict
+from madlib_keras import reset_cuda_env
+from madlib_keras_helper import *
+from madlib_keras_validator import *
+from madlib_keras_wrapper import *
+from keras_model_arch_table import ModelArchSchema
+
+from utilities.control import MinWarning
+from utilities.utilities import add_postfix
+from utilities.utilities import rotate
+from utilities.utilities import madlib_version
+from utilities.utilities import is_platform_pg
+
+import json
+from collections import defaultdict
+import random
+import datetime
+mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+
+
+class ModelSelectionSchema:
+    MST_KEY = 'mst_key'
+    MODEL_ARCH_ID = 'model_arch_id'
+    COMPILE_PARAMS = 'compile_params'
+    FIT_PARAMS = 'fit_params'
+    col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR')
+
+@MinWarning("warning")
+class FitMultipleModel():
+    def __init__(self, schema_madlib, source_table, model_output_table,
+                 model_arch_table, model_selection_table, num_iterations,
+                 gpus_per_host=0, **kwargs):
+
+        if is_platform_pg():
+            plpy.error("DL: Multiple model training is not supported on 
Postgresql.")
+        self.source_table = source_table
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(model_output_table, '_info')
+            self.model_summary_table = add_postfix(
+                model_output_table, '_summary')
+        self.num_iterations = num_iterations
+        self.module_name = 'madlib_keras_fit_multiple_model'
+        self.schema_madlib = schema_madlib
+        self.version = madlib_version(self.schema_madlib)
+        self.fit_validator = FitInputValidator(
+            self.source_table, None, self.model_output_table,
+            self.model_arch_table, mb_dep_var_col, mb_indep_var_col,
+            self.num_iterations, 1, False)
+        input_tbl_valid(self.model_selection_table, self.module_name)
+        output_tbl_valid(self.model_info_table, self.module_name)
+        self.msts = self.query_msts()
+        self.mst_key_col = ModelSelectionSchema.MST_KEY
+        self.model_arch_id_col = ModelSelectionSchema.MODEL_ARCH_ID
+        self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS
+        self.fit_params_col = ModelSelectionSchema.FIT_PARAMS
+        self.dist_keys = self.query_dist_keys()
+        self.grand_schedule = self.generate_schedule()
+
+        self.seg_ids_train, self.images_per_seg_train = \
+            get_image_count_per_seg_for_minibatched_data_from_db(
+                self.source_table)
+        self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
+            gpus_per_host)
+        self.create_model_output_table()
+        self.train_mst_metric_eval_time = defaultdict(list)
+        self.train_mst_loss = defaultdict(list)
+        self.train_mst_metric = defaultdict(list)
 
 Review comment:
   I would recommend calling the `fit_multiple_model` function here instead of 
the sql file because it makes it easier to read and navigate the code. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to