khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r480247892
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
r_i_col=AutoMLSchema.RESOURCES,
**locals())
plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+ """The core AutoML function for running AutoML algorithms such as
Hyperband.
+ This function executes the hyperband rounds 'diagonally' to evaluate
multiple configurations together
+ and leverage the compute power of MPP databases such as Greenplum.
+ """
+ def __init__(self, schema_madlib, source_table, model_output_table,
model_arch_table, model_selection_table,
+ model_id_list, compile_params_grid, fit_params_grid,
automl_method='hyperband',
+ automl_params='R=6, eta=3, skip_last=0', random_state=None,
object_table=None,
+ use_gpus=False, validation_table=None,
metrics_compute_frequency=None,
+ name=None, description=None, **kwargs):
+ self.schema_madlib = schema_madlib
+ self.source_table = source_table
+ self.model_output_table = model_output_table
+ if self.model_output_table:
+ self.model_info_table = add_postfix(self.model_output_table,
'_info')
+ self.model_summary_table = add_postfix(self.model_output_table,
'_summary')
+ self.model_arch_table = model_arch_table
+ self.model_selection_table = model_selection_table
+ self.model_selection_summary_table = add_postfix(
+ model_selection_table, "_summary")
+ self.model_id_list = sorted(list(set(model_id_list)))
+ self.compile_params_grid = compile_params_grid
+ self.fit_params_grid = fit_params_grid
+
+ MstLoaderInputValidator(
+ model_arch_table=self.model_arch_table,
+ model_selection_table=self.model_selection_table,
+ model_selection_summary_table=self.model_selection_summary_table,
+ model_id_list=self.model_id_list,
+ compile_params_list=compile_params_grid,
+ fit_params_list=fit_params_grid,
+ object_table=object_table,
+ module_name='madlib_keras_automl'
+ )
+
+ self.automl_method = automl_method
+ self.automl_params = automl_params
+ self.random_state = random_state
+ self.validate_and_define_inputs()
+
+ self.object_table = object_table
+ self.use_gpus = use_gpus
+ self.validation_table = validation_table
+ self.metrics_compute_frequency = metrics_compute_frequency
+ self.name = name
+ self.description = description
+
+ if self.validation_table:
+ AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+ self.create_model_output_table()
+ self.create_model_output_info_table()
+
+ if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+ self.find_hyperband_config()
+
+ def create_model_output_table(self):
+ output_table_create_query = """
+ CREATE TABLE {self.model_output_table}
+ ({ModelSelectionSchema.MST_KEY} INTEGER
PRIMARY KEY,
+ {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+ {ModelArchSchema.MODEL_ARCH} JSON)
+ """.format(self=self,
ModelSelectionSchema=ModelSelectionSchema,
+ ModelArchSchema=ModelArchSchema)
+ with MinWarning('warning'):
+ plpy.execute(output_table_create_query)
+
+ def create_model_output_info_table(self):
+ info_table_create_query = """
+ CREATE TABLE {self.model_info_table}
+ ({ModelSelectionSchema.MST_KEY} INTEGER
PRIMARY KEY,
+ {ModelArchSchema.MODEL_ID} INTEGER,
+ {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+ {ModelSelectionSchema.FIT_PARAMS} TEXT,
+ model_type TEXT,
+ model_size DOUBLE PRECISION,
+ metrics_elapsed_time DOUBLE PRECISION[],
+ metrics_type TEXT[],
+ training_metrics_final DOUBLE PRECISION,
+ training_loss_final DOUBLE PRECISION,
+ training_metrics DOUBLE PRECISION[],
+ training_loss DOUBLE PRECISION[],
+ validation_metrics_final DOUBLE PRECISION,
+ validation_loss_final DOUBLE PRECISION,
+ validation_metrics DOUBLE PRECISION[],
+ validation_loss DOUBLE PRECISION[],
+ {AutoMLSchema.METRICS_ITERS} INTEGER[])
+ """.format(self=self,
ModelSelectionSchema=ModelSelectionSchema,
+
ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+ with MinWarning('warning'):
+ plpy.execute(info_table_create_query)
+
+ def validate_and_define_inputs(self):
+
+ if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+ automl_params_dict = extract_keyvalue_params(self.automl_params,
+ default_values={'R':
6, 'eta': 3, 'skip_last': 0},
+
lower_case_names=False)
+ # casting dict values to int
+ for i in automl_params_dict:
+ automl_params_dict[i] = int(automl_params_dict[i])
+ _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <=
3,
+ "DL: Only R, eta, and skip_last may be specified")
+ for i in automl_params_dict:
+ if i == AutoMLSchema.R:
+ self.R = automl_params_dict[AutoMLSchema.R]
+ elif i == AutoMLSchema.ETA:
+ self.eta = automl_params_dict[AutoMLSchema.ETA]
+ elif i == AutoMLSchema.SKIP_LAST:
+ self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+ else:
+ plpy.error("DL: {0} is an invalid param".format(i))
+ _assert(self.eta > 1, "DL: eta must be greater than 1")
+ _assert(self.R >= self.eta, "DL: R should not be less than eta")
+ self.s_max = int(math.floor(math.log(self.R, self.eta)))
+ _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1,
"DL: skip_last must be " +
+ "non-negative and less than {0}".format(self.s_max))
+ # total number of resources/iterations (without reuse) per
execution of Succesive Halving (n,r)
+ self.B = (self.s_max + 1) * self.R
+ else:
+ plpy.error("DL: Only hyperband is currently supported as the
automl method")
+
+ def _is_valid_metrics_compute_frequency(self, num_iterations):
+ """
+ Utility function (same as that in the Fit Multiple function) to check
validity of mcf value for computing
+ metrics during an AutoML algorithm run.
+ :param num_iterations: interations/resources to allocate for training.
+ :return: boolean on validity of the mcf value.
+ """
+ return self.metrics_compute_frequency is None or \
+ (self.metrics_compute_frequency >= 1 and \
+ self.metrics_compute_frequency <= num_iterations)
+
+ def find_hyperband_config(self):
+ """
+ Runs the diagonal hyperband algorithm.
+ """
+ initial_vals = {}
+
+ # get hyper parameter configs for each s
+ for s in reversed(range(self.s_max+1)):
+ n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s)))
# initial number of configurations
+ r = self.R * math.pow(self.eta, -s) # initial number of iterations
to run configurations for
+ initial_vals[s] = (n, int(round(r)))
+ self.start_training_time = self.get_current_timestamp()
+ random_search = MstSearch(self.model_arch_table,
self.model_selection_table, self.model_id_list,
+ self.compile_params_grid,
self.fit_params_grid, 'random',
+ sum([initial_vals[k][0] for k in
initial_vals][self.skip_last:]), self.random_state,
+ self.object_table)
+ random_search.load() # for populating mst tables
+
+ # for creating the summary table for usage in fit multiple
+ plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+ "SELECT * FROM
{random_search.model_selection_summary_table}".format(
+ AutoMLSchema=AutoMLSchema, random_search=random_search))
+ ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+ # outer loop on diagonal
+ for i in range((self.s_max+1) - int(self.skip_last)):
+ # inner loop on s desc
+ temp_lst = []
+ configs_prune_lookup = {}
+ for s in range(self.s_max, self.s_max-i-1, -1):
+ n = initial_vals[s][0]
+ n_i = n * math.pow(self.eta, -i+self.s_max-s)
+ configs_prune_lookup[s] = int(round(n_i))
+ temp_lst.append("{0} configs under bracket={1} &
round={2}".format(int(n_i), s, s-self.s_max+i))
+ plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + '
with {0} iterations ***'.format(
+ int(initial_vals[self.s_max-i][1])))
+
+ self.reconstruct_temp_mst_table(i, ranges_dict,
configs_prune_lookup)
+ self.warm_start = int(i != 0)
+ num_iterations = int(initial_vals[self.s_max-i][1])
+ mcf = self.metrics_compute_frequency if
self._is_valid_metrics_compute_frequency(num_iterations) else None
+ model_training = FitMultipleModel(self.schema_madlib,
self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+ AutoMLSchema.TEMP_MST_TABLE,
num_iterations, self.use_gpus,
+ self.validation_table, mcf,
self.warm_start, self.name, self.description)
+ self.update_model_output_table(model_training)
+ self.update_model_output_info_table(i, model_training,initial_vals)
+ self.end_training_time = self.get_current_timestamp()
+ self.update_model_selection_table()
+ self.generate_model_output_summary_table(model_training)
+ self.remove_temp_tables(model_training)
+
+ def get_current_timestamp(self):
+ """for start and end times for the chosen AutoML algorithm. Showcased
in the output summary table"""
+ return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+ def mst_key_ranges_dict(self, initial_vals):
+ """
+ Extracts the ranges of model configs (using mst_keys) belonging to /
sampled as part of
+ executing a particular SHA bracket.
+ """
+ d = {}
+ for s_val in sorted(initial_vals.keys(), reverse=True): # going from
s_max to 0
+ if s_val == self.s_max:
+ d[s_val] = (1, initial_vals[s_val][0])
+ else:
+ d[s_val] = (d[s_val+1][1]+1,
d[s_val+1][1]+initial_vals[s_val][0])
+ return d
+
+ def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+ """
+ Drops and Reconstructs a temp mst table for evaluation along
particular diagonals of hyperband.
+ :param i: outer diagonal loop iteration.
+ :param ranges_dict: model config ranges to group by bracket number.
+ :param configs_prune_lookup: Lookup dictionary for configs to evaluate
for a diagonal.
+ :return:
+ """
+ if i == 0:
+ _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+ lower_bound, upper_bound = ranges_dict[self.s_max]
+ plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT
* FROM {self.model_selection_table} "
+ "WHERE mst_key >= {lower_bound} AND mst_key <=
{upper_bound}".format(self=self,
+
AutoMLSchema=AutoMLSchema,
+
lower_bound=lower_bound,
+
upper_bound=upper_bound,))
+ return
+ # dropping and repopulating temp_mst_table
+ drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+ # {mst_key} changed from SERIAL to INTEGER for safe insertions and
preservation of mst_key values
+ create_query = """
+ CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+ {mst_key} INTEGER,
+ {model_id} INTEGER,
+ {compile_params} VARCHAR,
+ {fit_params} VARCHAR,
+ unique ({model_id}, {compile_params}, {fit_params})
+ );
+ """.format(AutoMLSchema=AutoMLSchema,
+ mst_key=ModelSelectionSchema.MST_KEY,
+ model_id=ModelSelectionSchema.MODEL_ID,
+
compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+ fit_params=ModelSelectionSchema.FIT_PARAMS)
+ with MinWarning('warning'):
+ plpy.execute(create_query)
+
+ query = ""
+ new_configs = True
+ for s_val in configs_prune_lookup:
+ lower_bound, upper_bound = ranges_dict[s_val]
+ if new_configs:
+ query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT
mst_key, model_id, compile_params, fit_params " \
+ "FROM {self.model_selection_table} WHERE mst_key >=
{lower_bound} " \
+ "AND mst_key <= {upper_bound};".format(self=self,
AutoMLSchema=AutoMLSchema,
+
lower_bound=lower_bound, upper_bound=upper_bound)
+ new_configs = False
+ else:
+ query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT
mst_key, model_id, compile_params, fit_params " \
+ "FROM {self.model_info_table} WHERE mst_key >=
{lower_bound} " \
+ "AND mst_key <= {upper_bound} ORDER BY
{AutoMLSchema.LOSS_METRIC} " \
+ "LIMIT {configs_prune_lookup_val};".format(self=self,
AutoMLSchema=AutoMLSchema,
+
lower_bound=lower_bound, upper_bound=upper_bound,
+
configs_prune_lookup_val=configs_prune_lookup[s_val])
+ plpy.execute(query)
+
+ def update_model_output_table(self, model_training):
+ """
+ Updates gathered information of a hyperband diagonal run to the
overall model output table.
+ :param model_training: Fit Multiple function call object.
+ """
+ # updates model weights for any previously trained configs
+ plpy.execute("UPDATE {self.model_output_table} a SET model_weights="
Review comment:
Yes, this is still applicable
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]