This is an automated email from the ASF dual-hosted git repository. nkak pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit b7ff51b020fe405ca829c3ed7cdbbe411bc6365e Author: Nikhil Kak <n...@pivotal.io> AuthorDate: Wed May 22 12:07:54 2019 -0700 DL: Rename Format class to be more meaningful. JIRA: MADLIB-1348 In addition, renamed a variable. Closes #399 Co-authored-by: Orhan Kislal <okis...@apache.org> --- .../deep_learning/keras_model_arch_table.py_in | 24 +++++++++++----------- .../modules/deep_learning/madlib_keras.py_in | 20 +++++++++--------- .../deep_learning/madlib_keras_serializer.py_in | 11 +++++----- .../deep_learning/madlib_keras_validator.py_in | 6 +++--- .../deep_learning/predict_input_params.py_in | 4 ++-- 5 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in index e35b568..28ab753 100644 --- a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in +++ b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in @@ -34,7 +34,7 @@ from utilities.validate_args import input_tbl_valid from utilities.validate_args import quote_ident from utilities.validate_args import table_exists -class Format: +class ModelArchSchema: """Expected format of keras_model_arch_table. Example uses: @@ -63,8 +63,8 @@ def load_keras_model(keras_model_arch_table, model_arch, model_weights, name, description, **kwargs): model_arch_table = quote_ident(keras_model_arch_table) if not table_exists(model_arch_table): - col_defs = get_col_name_type_sql_string(Format.col_names, - Format.col_types) + col_defs = get_col_name_type_sql_string(ModelArchSchema.col_names, + ModelArchSchema.col_types) sql = "CREATE TABLE {model_arch_table} ({col_defs})" \ .format(**locals()) @@ -74,7 +74,7 @@ def load_keras_model(keras_model_arch_table, model_arch, model_weights, .format(model_arch_table)) else: missing_cols = columns_missing_from_table(model_arch_table, - Format.col_names) + ModelArchSchema.col_names) if len(missing_cols) > 0: plpy.error("Keras Model Arch: Invalid keras model arch table {0}," " missing columns: {1}".format(model_arch_table, @@ -83,34 +83,34 @@ def load_keras_model(keras_model_arch_table, model_arch, model_weights, unique_str = unique_string(prefix_has_temp=False) insert_query = plpy.prepare("INSERT INTO {model_arch_table} " "VALUES(DEFAULT, $1, $2, $3, $4, $5);".format(**locals()), - Format.col_types[1:]) + ModelArchSchema.col_types[1:]) insert_res = plpy.execute(insert_query,[model_arch, model_weights, name, description, unique_str], 0) select_query = """SELECT {model_id_col}, {model_arch_col} FROM {model_arch_table} WHERE {internal_id_col} = '{unique_str}'""".format( - model_id_col=Format.MODEL_ID, - model_arch_col=Format.MODEL_ARCH, + model_id_col=ModelArchSchema.MODEL_ID, + model_arch_col=ModelArchSchema.MODEL_ARCH, model_arch_table=model_arch_table, - internal_id_col=Format.__INTERNAL_MADLIB_ID__, + internal_id_col=ModelArchSchema.__INTERNAL_MADLIB_ID__, unique_str=unique_str) select_res = plpy.execute(select_query,1) plpy.info("Keras Model Arch: Added model id {0} to {1} table". - format(select_res[0][Format.MODEL_ID], model_arch_table)) + format(select_res[0][ModelArchSchema.MODEL_ID], model_arch_table)) def delete_keras_model(keras_model_arch_table, model_id, **kwargs): model_arch_table = quote_ident(keras_model_arch_table) input_tbl_valid(model_arch_table, "Keras Model Arch") - missing_cols = columns_missing_from_table(model_arch_table, Format.col_names) + missing_cols = columns_missing_from_table(model_arch_table, ModelArchSchema.col_names) if len(missing_cols) > 0: plpy.error("Keras Model Arch: Invalid keras model arch table {0}," " missing columns: {1}".format(model_arch_table, missing_cols)) sql = """ DELETE FROM {model_arch_table} WHERE {model_id_col}={model_id} - """.format(model_arch_table=model_arch_table, model_id_col=Format.MODEL_ID, + """.format(model_arch_table=model_arch_table, model_id_col=ModelArchSchema.MODEL_ID, model_id=model_id) res = plpy.execute(sql, 0) @@ -120,7 +120,7 @@ def delete_keras_model(keras_model_arch_table, model_id, **kwargs): else: plpy.error("Keras Model Arch: Model id {0} not found".format(model_id)) - sql = "SELECT {0} FROM {1}".format(Format.MODEL_ID, model_arch_table) + sql = "SELECT {0} FROM {1}".format(ModelArchSchema.MODEL_ID, model_arch_table) res = plpy.execute(sql, 0) if not res: plpy.info("Keras Model Arch: Dropping empty keras model arch "\ diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in index 3a912e1..be3b4e0 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in @@ -43,7 +43,7 @@ from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME from madlib_keras_helper import NORMALIZING_CONST_COLNAME from madlib_keras_validator import FitInputValidator from madlib_keras_wrapper import * -from keras_model_arch_table import Format +from keras_model_arch_table import ModelArchSchema from utilities.control import MinWarning from utilities.model_arch_info import get_input_shape @@ -97,19 +97,19 @@ def fit(schema_madlib, source_table, model,model_arch_table, # Get the serialized master model start_deserialization = time.time() model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format( - Format.MODEL_ARCH, Format.MODEL_WEIGHTS, - model_arch_table, Format.MODEL_ID, + ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS, + model_arch_table, ModelArchSchema.MODEL_ID, model_arch_id) - query_result = plpy.execute(model_arch_query) - if not query_result: + model_arch_result = plpy.execute(model_arch_query) + if not model_arch_result: plpy.error("no model arch found in table {0} with id {1}".format( model_arch_table, model_arch_id)) - query_result = query_result[0] - model_arch = query_result[Format.MODEL_ARCH] + model_arch_result = model_arch_result[0] + model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH] input_shape = get_input_shape(model_arch) num_classes = get_num_classes(model_arch) fit_validator.validate_input_shapes(input_shape) - model_weights_serialized = query_result[Format.MODEL_WEIGHTS] + model_weights_serialized = model_arch_result[ModelArchSchema.MODEL_WEIGHTS] #TODO: Refactor the pg related logic in a future PR when we think # about making the fit function easier to read and maintain. @@ -300,7 +300,7 @@ def fit(schema_madlib, source_table, model,model_arch_table, create_output_table = plpy.prepare(""" CREATE TABLE {0} AS SELECT $1 as model_data, - $2 as {1}""".format(model, Format.MODEL_ARCH), ["bytea", "json"]) + $2 as {1}""".format(model, ModelArchSchema.MODEL_ARCH), ["bytea", "json"]) plpy.execute(create_output_table, [serialized_weights, model_arch]) if is_platform_pg(): @@ -564,7 +564,7 @@ def evaluate1(schema_madlib, model_table, test_table, id_col, model_arch_table, plpy.error("no model arch found in table {0} with id {1}".format( model_arch_table, model_arch_id)) query_result = query_result[0] - model_arch = query_result[Format.MODEL_ARCH] + model_arch = query_result[ModelArchSchema.MODEL_ARCH] compile_params = "$madlib$" + compile_params + "$madlib$" loss_metric = get_loss_metric_from_keras_eval( diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in index c92b3c6..ba6672e 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in @@ -34,16 +34,17 @@ import numpy as np """ +workflow 1. Set initial weights in madlib keras fit function. -1. Serialize these initial model weights as a byte string and pass it to keras step -1. Deserialize the state passed from the previous step into a list of nd weights +2. Serialize these initial model weights as a byte string and pass it to keras step +3. Deserialize the state passed from the previous step into a list of nd weights that will be passed on to model.set_weights() -1. At the end of each buffer in fit transition, serialize the image count and +4. At the end of each buffer in fit transition, serialize the image count and the model weights into a bytestring that will be passed on to the fit merge function. -1. In fit merge, deserialize the state as image and 1d np arrays. Do some averaging +5. In fit merge, deserialize the state as image and 1d np arrays. Do some averaging operations and serialize them again into a state which contains the image and the 1d state. same for fit final -1. Return the final state from fit final to fit which will then be deserialized +6. Return the final state from fit final to fit which will then be deserialized as 1d weights to be passed on to the evaluate function """ def get_image_count_from_state(state): diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in index 00426dd..5892308 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in @@ -18,7 +18,7 @@ # under the License. import plpy -from keras_model_arch_table import Format +from keras_model_arch_table import ModelArchSchema from madlib_keras_helper import CLASS_VALUES_COLNAME from madlib_keras_helper import COMPILE_PARAMS_COLNAME from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME @@ -113,11 +113,11 @@ class PredictInputValidator: module_name=self.module_name, model_data=MODEL_DATA_COLNAME, table=self.model_table)) - _assert(is_var_valid(self.model_table, Format.MODEL_ARCH), + _assert(is_var_valid(self.model_table, ModelArchSchema.MODEL_ARCH), "{module_name} error: column '{model_arch}' " "does not exist in model table '{table}'.".format( module_name=self.module_name, - model_arch=Format.MODEL_ARCH, + model_arch=ModelArchSchema.MODEL_ARCH, table=self.model_table)) def _validate_test_tbl_cols(self): diff --git a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in index 7f88020..84b34ba 100644 --- a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in +++ b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in @@ -18,7 +18,7 @@ # under the License. import plpy -from keras_model_arch_table import Format +from keras_model_arch_table import ModelArchSchema from utilities.utilities import add_postfix from utilities.validate_args import input_tbl_valid @@ -50,7 +50,7 @@ class PredictParamsProcessor: return self.model_summary_dict[DEPENDENT_VARTYPE_COLNAME] def get_model_arch(self): - return self.model_arch_dict[Format.MODEL_ARCH] + return self.model_arch_dict[ModelArchSchema.MODEL_ARCH] def get_model_data(self): return self.model_arch_dict[MODEL_DATA_COLNAME]