Advitya17 commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r479873861



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as 
Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate 
multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, 
model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, 
automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, 
object_table=None,
+                 use_gpus=False, validation_table=None, 
metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, 
'_info')
+            self.model_summary_table = add_postfix(self.model_output_table, 
'_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER 
PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, 
ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER 
PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, 
ModelSelectionSchema=ModelSelectionSchema,
+                                                  
ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 
6, 'eta': 3, 'skip_last': 0},
+                                                         
lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 
3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, 
"DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per 
execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the 
automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check 
validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) 
# initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations 
to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, 
self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, 
self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in 
initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM 
{random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & 
round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' 
with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, 
configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if 
self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, 
self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,

Review comment:
       That may not be possible, as this function call takes in a different 
(reconstructed) mst table for each diagonal. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to