This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit b7ff51b020fe405ca829c3ed7cdbbe411bc6365e
Author: Nikhil Kak <n...@pivotal.io>
AuthorDate: Wed May 22 12:07:54 2019 -0700

    DL: Rename Format class to be more meaningful.
    
    JIRA: MADLIB-1348
    
    In addition, renamed a variable.
    
    Closes #399
    
    Co-authored-by: Orhan Kislal <okis...@apache.org>
---
 .../deep_learning/keras_model_arch_table.py_in     | 24 +++++++++++-----------
 .../modules/deep_learning/madlib_keras.py_in       | 20 +++++++++---------
 .../deep_learning/madlib_keras_serializer.py_in    | 11 +++++-----
 .../deep_learning/madlib_keras_validator.py_in     |  6 +++---
 .../deep_learning/predict_input_params.py_in       |  4 ++--
 5 files changed, 33 insertions(+), 32 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in 
b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
index e35b568..28ab753 100644
--- a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
+++ b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
@@ -34,7 +34,7 @@ from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import quote_ident
 from utilities.validate_args import table_exists
 
-class Format:
+class ModelArchSchema:
     """Expected format of keras_model_arch_table.
        Example uses:
 
@@ -63,8 +63,8 @@ def load_keras_model(keras_model_arch_table, model_arch, 
model_weights,
                      name, description, **kwargs):
     model_arch_table = quote_ident(keras_model_arch_table)
     if not table_exists(model_arch_table):
-        col_defs = get_col_name_type_sql_string(Format.col_names,
-                                                Format.col_types)
+        col_defs = get_col_name_type_sql_string(ModelArchSchema.col_names,
+                                                ModelArchSchema.col_types)
 
         sql = "CREATE TABLE {model_arch_table} ({col_defs})" \
               .format(**locals())
@@ -74,7 +74,7 @@ def load_keras_model(keras_model_arch_table, model_arch, 
model_weights,
             .format(model_arch_table))
     else:
         missing_cols = columns_missing_from_table(model_arch_table,
-                                                  Format.col_names)
+                                                  ModelArchSchema.col_names)
         if len(missing_cols) > 0:
             plpy.error("Keras Model Arch: Invalid keras model arch table {0},"
                        " missing columns: {1}".format(model_arch_table,
@@ -83,34 +83,34 @@ def load_keras_model(keras_model_arch_table, model_arch, 
model_weights,
     unique_str = unique_string(prefix_has_temp=False)
     insert_query = plpy.prepare("INSERT INTO {model_arch_table} "
                                 "VALUES(DEFAULT, $1, $2, $3, $4, 
$5);".format(**locals()),
-                                Format.col_types[1:])
+                                ModelArchSchema.col_types[1:])
     insert_res = plpy.execute(insert_query,[model_arch, model_weights, name, 
description,
                                unique_str], 0)
 
     select_query = """SELECT {model_id_col}, {model_arch_col} FROM 
{model_arch_table}
                    WHERE {internal_id_col} = '{unique_str}'""".format(
-                    model_id_col=Format.MODEL_ID,
-                    model_arch_col=Format.MODEL_ARCH,
+                    model_id_col=ModelArchSchema.MODEL_ID,
+                    model_arch_col=ModelArchSchema.MODEL_ARCH,
                     model_arch_table=model_arch_table,
-                    internal_id_col=Format.__INTERNAL_MADLIB_ID__,
+                    internal_id_col=ModelArchSchema.__INTERNAL_MADLIB_ID__,
                     unique_str=unique_str)
     select_res = plpy.execute(select_query,1)
 
     plpy.info("Keras Model Arch: Added model id {0} to {1} table".
-        format(select_res[0][Format.MODEL_ID], model_arch_table))
+              format(select_res[0][ModelArchSchema.MODEL_ID], 
model_arch_table))
 
 def delete_keras_model(keras_model_arch_table, model_id, **kwargs):
     model_arch_table = quote_ident(keras_model_arch_table)
     input_tbl_valid(model_arch_table, "Keras Model Arch")
 
-    missing_cols = columns_missing_from_table(model_arch_table, 
Format.col_names)
+    missing_cols = columns_missing_from_table(model_arch_table, 
ModelArchSchema.col_names)
     if len(missing_cols) > 0:
         plpy.error("Keras Model Arch: Invalid keras model arch table {0},"
                    " missing columns: {1}".format(model_arch_table, 
missing_cols))
 
     sql = """
            DELETE FROM {model_arch_table} WHERE {model_id_col}={model_id}
-          """.format(model_arch_table=model_arch_table, 
model_id_col=Format.MODEL_ID,
+          """.format(model_arch_table=model_arch_table, 
model_id_col=ModelArchSchema.MODEL_ID,
                      model_id=model_id)
     res = plpy.execute(sql, 0)
 
@@ -120,7 +120,7 @@ def delete_keras_model(keras_model_arch_table, model_id, 
**kwargs):
     else:
         plpy.error("Keras Model Arch: Model id {0} not found".format(model_id))
 
-    sql = "SELECT {0} FROM {1}".format(Format.MODEL_ID, model_arch_table)
+    sql = "SELECT {0} FROM {1}".format(ModelArchSchema.MODEL_ID, 
model_arch_table)
     res = plpy.execute(sql, 0)
     if not res:
         plpy.info("Keras Model Arch: Dropping empty keras model arch "\
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 3a912e1..be3b4e0 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -43,7 +43,7 @@ from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
 from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 from madlib_keras_validator import FitInputValidator
 from madlib_keras_wrapper import *
-from keras_model_arch_table import Format
+from keras_model_arch_table import ModelArchSchema
 
 from utilities.control import MinWarning
 from utilities.model_arch_info import get_input_shape
@@ -97,19 +97,19 @@ def fit(schema_madlib, source_table, model,model_arch_table,
     # Get the serialized master model
     start_deserialization = time.time()
     model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
-                                        Format.MODEL_ARCH, 
Format.MODEL_WEIGHTS,
-                                        model_arch_table, Format.MODEL_ID,
+                                        ModelArchSchema.MODEL_ARCH, 
ModelArchSchema.MODEL_WEIGHTS,
+                                        model_arch_table, 
ModelArchSchema.MODEL_ID,
                                         model_arch_id)
-    query_result = plpy.execute(model_arch_query)
-    if not  query_result:
+    model_arch_result = plpy.execute(model_arch_query)
+    if not  model_arch_result:
         plpy.error("no model arch found in table {0} with id {1}".format(
             model_arch_table, model_arch_id))
-    query_result = query_result[0]
-    model_arch = query_result[Format.MODEL_ARCH]
+    model_arch_result = model_arch_result[0]
+    model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH]
     input_shape = get_input_shape(model_arch)
     num_classes = get_num_classes(model_arch)
     fit_validator.validate_input_shapes(input_shape)
-    model_weights_serialized = query_result[Format.MODEL_WEIGHTS]
+    model_weights_serialized = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
 
     #TODO: Refactor the pg related logic in a future PR when we think
     # about making the fit function easier to read and maintain.
@@ -300,7 +300,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
     create_output_table = plpy.prepare("""
         CREATE TABLE {0} AS SELECT
         $1 as model_data,
-        $2 as {1}""".format(model, Format.MODEL_ARCH), ["bytea", "json"])
+        $2 as {1}""".format(model, ModelArchSchema.MODEL_ARCH), ["bytea", 
"json"])
     plpy.execute(create_output_table, [serialized_weights, model_arch])
 
     if is_platform_pg():
@@ -564,7 +564,7 @@ def evaluate1(schema_madlib, model_table, test_table, 
id_col, model_arch_table,
         plpy.error("no model arch found in table {0} with id {1}".format(
             model_arch_table, model_arch_id))
     query_result = query_result[0]
-    model_arch = query_result[Format.MODEL_ARCH]
+    model_arch = query_result[ModelArchSchema.MODEL_ARCH]
     compile_params = "$madlib$" + compile_params + "$madlib$"
 
     loss_metric = get_loss_metric_from_keras_eval(
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
index c92b3c6..ba6672e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
@@ -34,16 +34,17 @@ import numpy as np
 
 
 """
+workflow
 1. Set initial weights in madlib keras fit function.
-1. Serialize these initial model weights as a byte string and pass it to keras 
step
-1. Deserialize the state passed from the previous step into a list of nd 
weights
+2. Serialize these initial model weights as a byte string and pass it to keras 
step
+3. Deserialize the state passed from the previous step into a list of nd 
weights
 that will be passed on to model.set_weights()
-1. At the end of each buffer in fit transition, serialize the image count and
+4. At the end of each buffer in fit transition, serialize the image count and
 the model weights into a bytestring that will be passed on to the fit merge 
function.
-1. In fit merge, deserialize the state as image and 1d np arrays. Do some 
averaging
+5. In fit merge, deserialize the state as image and 1d np arrays. Do some 
averaging
 operations and serialize them again into a state which contains the image
 and the 1d state. same for fit final 
-1. Return the final state from fit final to fit which will then be 
deserialized 
+6. Return the final state from fit final to fit which will then be deserialized
 as 1d weights to be passed on to the evaluate function
 """
 def get_image_count_from_state(state):
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 00426dd..5892308 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -18,7 +18,7 @@
 # under the License.
 
 import plpy
-from keras_model_arch_table import Format
+from keras_model_arch_table import ModelArchSchema
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import COMPILE_PARAMS_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
@@ -113,11 +113,11 @@ class PredictInputValidator:
                     module_name=self.module_name,
                     model_data=MODEL_DATA_COLNAME,
                     table=self.model_table))
-        _assert(is_var_valid(self.model_table, Format.MODEL_ARCH),
+        _assert(is_var_valid(self.model_table, ModelArchSchema.MODEL_ARCH),
                 "{module_name} error: column '{model_arch}' "
                 "does not exist in model table '{table}'.".format(
                     module_name=self.module_name,
-                    model_arch=Format.MODEL_ARCH,
+                    model_arch=ModelArchSchema.MODEL_ARCH,
                     table=self.model_table))
 
     def _validate_test_tbl_cols(self):
diff --git 
a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in 
b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
index 7f88020..84b34ba 100644
--- a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
+++ b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
@@ -18,7 +18,7 @@
 # under the License.
 
 import plpy
-from keras_model_arch_table import Format
+from keras_model_arch_table import ModelArchSchema
 from utilities.utilities import add_postfix
 from utilities.validate_args import input_tbl_valid
 
@@ -50,7 +50,7 @@ class PredictParamsProcessor:
         return self.model_summary_dict[DEPENDENT_VARTYPE_COLNAME]
 
     def get_model_arch(self):
-        return self.model_arch_dict[Format.MODEL_ARCH]
+        return self.model_arch_dict[ModelArchSchema.MODEL_ARCH]
 
     def get_model_data(self):
         return self.model_arch_dict[MODEL_DATA_COLNAME]

Reply via email to