This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 352e41f9c004a86147d041fafe533ec4bd7394e2
Author: Nikhil Kak <n...@pivotal.io>
AuthorDate: Mon May 20 14:29:07 2019 -0700

    DL: Add model_weights, name and desc to the model arch table.
    
    JIRA: MADLIB-1348
    
    This commit adds three optional params to the model arch interface.
    1. model weights in bytea format
    2. name
    3. description
    
    The model_weights param will allow the user to load pre trained weights
    to enable transfer learning
    
    1. Remove image count from the serialized model weights that get stored
    in the model output table. This image count was unnecessary and also
    caused inconsistency in the model weights format in model_arch and
    keras fit model out table.
    2. Use bytea to store/read model weights. This includes the input param
    and output col type. Here is why
      a. we found that in python, using double precision[] or real[] was
      almost double in size as compared to bytea.
      b. It also helps keep the internal state of the model and the output to
      the user consistent.
    
    Additionally
    1. Use plpy prepare to format the model weights as bytea.
    2. Modify the deserialize code to accept a bytea string
    instead of a double precision[]
    3. Modify madpack code so that install check user can create python
    UDFs in the madlib keras dev check sql.
    4. Add dev check test to test for transfer learning by creating a UDF
    that calls the load_arch_table with some pre defined weights so that the
    first iteration of fit always returns the same loss and metric.
    5. Rename serializer functions to be more reflective of their purpose.
    6. Move model_shapes to wrapper and remove model_shapes from SD
    7. Rename model_state to either state or serialized_weights depending on
    its content.
    
    Closes #399
    
    Co-authored-by: Orhan Kislal <okis...@apache.org>
---
 src/madpack/madpack.py                             |   5 +-
 .../deep_learning/keras_model_arch_table.py_in     |  67 ++++----
 .../deep_learning/keras_model_arch_table.sql_in    |  50 +++++-
 .../modules/deep_learning/madlib_keras.py_in       |  65 ++++----
 .../modules/deep_learning/madlib_keras.sql_in      |   6 +-
 .../deep_learning/madlib_keras_predict.py_in       |   7 +-
 .../deep_learning/madlib_keras_serializer.py_in    | 181 +++++++++++----------
 .../deep_learning/madlib_keras_wrapper.py_in       |  19 ++-
 .../test/keras_model_arch_table.sql_in             |  26 ++-
 .../modules/deep_learning/test/madlib_keras.sql_in |  98 +++++++++++
 .../test/unit_tests/test_madlib_keras.py_in        | 166 ++++++++-----------
 11 files changed, 409 insertions(+), 281 deletions(-)

diff --git a/src/madpack/madpack.py b/src/madpack/madpack.py
index d5eb3da..e735526 100755
--- a/src/madpack/madpack.py
+++ b/src/madpack/madpack.py
@@ -1027,8 +1027,9 @@ def run_install_check(args, testcase, madpack_cmd):
         _internal_run_query("DROP OWNED BY %s CASCADE;" % (test_user), True)
         _internal_run_query("DROP USER IF EXISTS %s;" % (test_user), True)
 
-    _internal_run_query("CREATE USER %s;" % (test_user), True)
+    _internal_run_query("CREATE USER %s WITH SUPERUSER NOINHERIT;" % 
(test_user), True)
     _internal_run_query("GRANT USAGE ON SCHEMA %s TO %s;" % (schema, 
test_user), True)
+    _internal_run_query("GRANT ALL PRIVILEGES ON DATABASE %s TO %s;" % 
(db_name, test_user), True)
 
     # 2) Run test SQLs
     info_(this, "> Running %s scripts for:" % madpack_cmd, verbose)
@@ -1053,6 +1054,8 @@ def run_install_check(args, testcase, madpack_cmd):
             from time import sleep
             sleep(1)
             _internal_run_query("DROP OWNED BY %s CASCADE;" % (test_user), 
show_error=True)
+
+        _internal_run_query("REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s;" % 
(db_name, test_user), True)
         _internal_run_query("DROP USER %s;" % (test_user), True)
 
 
diff --git 
a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in 
b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
index 8ed3ad6..e35b568 100644
--- a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
+++ b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
@@ -53,16 +53,14 @@ class Format:
            arch = plpy.execute(sql)[0]
 
     """
-    col_names = ('model_id', 'model_arch', 'model_weights', 
'__internal_madlib_id__')
-    col_types = ('SERIAL PRIMARY KEY', 'JSON', 'DOUBLE PRECISION[]', 'TEXT')
-    (MODEL_ID, MODEL_ARCH, MODEL_WEIGHTS, __INTERNAL_MADLIB_ID__) = col_names
-
-@MinWarning("warning")
-def _execute(sql,max_rows=0):
-    return plpy.execute(sql,max_rows)
-
-def load_keras_model(schema_madlib, keras_model_arch_table,
-                     model_arch, **kwargs):
+    col_names = ('model_id', 'model_arch', 'model_weights', 'name', 
'description',
+                 '__internal_madlib_id__')
+    col_types = ('SERIAL PRIMARY KEY', 'JSON', 'bytea', 'TEXT', 'TEXT', 'TEXT')
+    (MODEL_ID, MODEL_ARCH, MODEL_WEIGHTS, NAME, DESCRIPTION,
+     __INTERNAL_MADLIB_ID__) = col_names
+
+def load_keras_model(keras_model_arch_table, model_arch, model_weights,
+                     name, description, **kwargs):
     model_arch_table = quote_ident(keras_model_arch_table)
     if not table_exists(model_arch_table):
         col_defs = get_col_name_type_sql_string(Format.col_names,
@@ -71,7 +69,7 @@ def load_keras_model(schema_madlib, keras_model_arch_table,
         sql = "CREATE TABLE {model_arch_table} ({col_defs})" \
               .format(**locals())
 
-        _execute(sql)
+        plpy.execute(sql, 0)
         plpy.info("Keras Model Arch: Created new keras model arch table {0}." \
             .format(model_arch_table))
     else:
@@ -83,27 +81,25 @@ def load_keras_model(schema_madlib, keras_model_arch_table,
                                                       missing_cols))
 
     unique_str = unique_string(prefix_has_temp=False)
+    insert_query = plpy.prepare("INSERT INTO {model_arch_table} "
+                                "VALUES(DEFAULT, $1, $2, $3, $4, 
$5);".format(**locals()),
+                                Format.col_types[1:])
+    insert_res = plpy.execute(insert_query,[model_arch, model_weights, name, 
description,
+                               unique_str], 0)
+
+    select_query = """SELECT {model_id_col}, {model_arch_col} FROM 
{model_arch_table}
+                   WHERE {internal_id_col} = '{unique_str}'""".format(
+                    model_id_col=Format.MODEL_ID,
+                    model_arch_col=Format.MODEL_ARCH,
+                    model_arch_table=model_arch_table,
+                    internal_id_col=Format.__INTERNAL_MADLIB_ID__,
+                    unique_str=unique_str)
+    select_res = plpy.execute(select_query,1)
 
-    sql = """INSERT INTO {model_arch_table} ({model_arch_col}, 
{internal_id_col})
-                                    VALUES({model_arch}, '{unique_str}');
-             SELECT {model_id_col}, {model_arch_col}
-                 FROM {model_arch_table} WHERE {internal_id_col} = 
'{unique_str}'
-    """.format(model_arch_table=model_arch_table,
-               model_arch_col=Format.MODEL_ARCH,
-               unique_str=unique_str,
-               model_arch=quote_literal(model_arch),
-               model_id_col=Format.MODEL_ID,
-               internal_id_col=Format.__INTERNAL_MADLIB_ID__)
-    res = _execute(sql,1)
-
-    if len(res) != 1 or res[0][Format.MODEL_ARCH] != model_arch:
-        raise Exception("Failed to insert new row in {0} table--try again?"
-                       .format(model_arch_table))
     plpy.info("Keras Model Arch: Added model id {0} to {1} table".
-        format(res[0][Format.MODEL_ID], model_arch_table))
+        format(select_res[0][Format.MODEL_ID], model_arch_table))
 
-def delete_keras_model(schema_madlib, keras_model_arch_table,
-                       model_id, **kwargs):
+def delete_keras_model(keras_model_arch_table, model_id, **kwargs):
     model_arch_table = quote_ident(keras_model_arch_table)
     input_tbl_valid(model_arch_table, "Keras Model Arch")
 
@@ -116,7 +112,7 @@ def delete_keras_model(schema_madlib, 
keras_model_arch_table,
            DELETE FROM {model_arch_table} WHERE {model_id_col}={model_id}
           """.format(model_arch_table=model_arch_table, 
model_id_col=Format.MODEL_ID,
                      model_id=model_id)
-    res = _execute(sql)
+    res = plpy.execute(sql, 0)
 
     if res.nrows() > 0:
         plpy.info("Keras Model Arch: Model id {0} has been deleted from {1}.".
@@ -125,16 +121,12 @@ def delete_keras_model(schema_madlib, 
keras_model_arch_table,
         plpy.error("Keras Model Arch: Model id {0} not found".format(model_id))
 
     sql = "SELECT {0} FROM {1}".format(Format.MODEL_ID, model_arch_table)
-    res = _execute(sql)
+    res = plpy.execute(sql, 0)
     if not res:
         plpy.info("Keras Model Arch: Dropping empty keras model arch "\
             "table 
{model_arch_table}".format(model_arch_table=model_arch_table))
         sql = "DROP TABLE {0}".format(model_arch_table)
-        try:
-            _execute(sql)
-        except plpy.SPIError, e:
-            plpy.warning("Keras Model Arch: Unable to drop empty keras model "\
-                "arch table {0}".format(model_arch_table))
+        plpy.execute(sql, 0)
 
 class KerasModelArchDocumentation:
     @staticmethod
@@ -185,8 +177,7 @@ class KerasModelArchDocumentation:
 
         'model_id'                -- SERIAL PRIMARY KEY. Model ID.
         'model_arch'              -- JSON. JSON blob of the model architecture.
-        'model_weights'           -- DOUBLE PRECISION[]. weights of the model 
for warm start.
-                                  -- This is currently NULL.
+        'model_weights'           -- bytea. weights of the model for warm 
start.
         '__internal_madlib_id__'  -- TEXT. Unique id for model arch.
 
         """.format(**locals())
diff --git 
a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.sql_in 
b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.sql_in
index 45dcaa7..70f8369 100644
--- a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.sql_in
+++ b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.sql_in
@@ -302,14 +302,55 @@ SELECT * FROM model_arch_library;
 </pre>
 */
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model(
+    keras_model_arch_table VARCHAR,
+    model_arch             JSON,
+    model_weights          bytea,
+    name                   TEXT,
+    description            TEXT
+)
+    RETURNS VOID AS $$
+    PythonFunctionBodyOnlyNoSchema(`deep_learning', `keras_model_arch_table')
+    from utilities.control import AOControl
+    with AOControl(False):
+        keras_model_arch_table.load_keras_model(**globals())
+$$ LANGUAGE plpythonu VOLATILE;
+
 -- Function to add a keras model to arch table
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model(
     keras_model_arch_table VARCHAR,
     model_arch             JSON
 )
 RETURNS VOID AS $$
-    PythonFunction(`deep_learning',`keras_model_arch_table',`load_keras_model')
-$$ LANGUAGE plpythonu VOLATILE;
+    SELECT MADLIB_SCHEMA.load_keras_model($1, $2, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model(
+    keras_model_arch_table VARCHAR,
+    model_arch             JSON,
+    model_weights          bytea
+)
+    RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.load_keras_model($1, $2, $3, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model(
+    keras_model_arch_table VARCHAR,
+    model_arch             JSON,
+    model_weights          bytea,
+    name                   TEXT
+)
+    RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.load_keras_model($1, $2, $3, $4, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+
+
+
+
 
 -- Functions for online help
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model(
@@ -333,7 +374,10 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.delete_keras_model(
     model_id INTEGER
 )
 RETURNS VOID AS $$
-    
PythonFunction(`deep_learning',`keras_model_arch_table',`delete_keras_model')
+    PythonFunctionBodyOnlyNoSchema(`deep_learning',`keras_model_arch_table')
+    from utilities.control import AOControl
+    with AOControl(False):
+        keras_model_arch_table.delete_keras_model(**globals())
 $$ LANGUAGE plpythonu VOLATILE;
 
 -- Functions for online help
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 16cf7af..3a912e1 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -140,7 +140,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
 
     if model_weights_serialized:
         # If warm start from previously trained model, set weights
-        model_weights = madlib_keras_serializer.deserialize_weights_orig(
+        model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
             model_weights_serialized, model_shapes)
         master_model.set_weights(model_weights)
 
@@ -169,23 +169,23 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
         """.format(**locals()), ["bytea"])
 
     # Define the state for the model and loss/metric storage lists
-    model_state = madlib_keras_serializer.serialize_weights(0, model_weights)
+    serialized_weights = 
madlib_keras_serializer.serialize_nd_weights(model_weights)
     training_loss, training_metrics, metrics_elapsed_time = [], [], []
     metrics_iters = []
 
     # get the size of serialized model weights string in KB
-    model_size = sys.getsizeof(model_state)/1024.0
+    model_size = sys.getsizeof(serialized_weights)/1024.0
 
     # Run distributed training for specified number of iterations
     for i in range(1, num_iterations+1):
         start_iteration = time.time()
         iteration_result = plpy.execute(run_training_iteration,
-                                        [model_state])[0]['iteration_result']
+                                        
[serialized_weights])[0]['iteration_result']
         end_iteration = time.time()
         plpy.info("Time for training in iteration {0}: {1} sec".
                   format(i, end_iteration - start_iteration))
-        model_state = madlib_keras_serializer.deserialize_iteration_state(
-            iteration_result)
+        serialized_weights = madlib_keras_serializer.\
+            get_serialized_1d_weights_from_state(iteration_result)
 
         if should_compute_metrics_this_iter(i, metrics_compute_frequency,
                                             num_iterations):
@@ -193,7 +193,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
             compute_loss_and_metrics(
                 schema_madlib, source_table, dependent_varname,
                 independent_varname, compile_params_to_pass, model_arch,
-                model_state, gpus_per_host, segments_per_host, seg_ids_train,
+                serialized_weights, gpus_per_host, segments_per_host, 
seg_ids_train,
                 images_per_seg_train, gp_segment_id_col,
                 training_metrics, training_loss,
                 i, "Training")
@@ -203,7 +203,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
                 compute_loss_and_metrics(
                     schema_madlib, validation_table, dependent_varname,
                     independent_varname, compile_params_to_pass, model_arch,
-                    model_state, gpus_per_host, segments_per_host, seg_ids_val,
+                    serialized_weights, gpus_per_host, segments_per_host, 
seg_ids_val,
                     images_per_seg_val, gp_segment_id_col,
                     validation_metrics, validation_loss,
                     i, "Validation")
@@ -301,7 +301,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
         CREATE TABLE {0} AS SELECT
         $1 as model_data,
         $2 as {1}""".format(model, Format.MODEL_ARCH), ["bytea", "json"])
-    plpy.execute(create_output_table, [model_state, model_arch])
+    plpy.execute(create_output_table, [serialized_weights, model_arch])
 
     if is_platform_pg():
         clear_keras_session()
@@ -343,12 +343,12 @@ def get_metrics_sql_string(metrics_list, 
is_metrics_specified):
 
 def compute_loss_and_metrics(schema_madlib, table, dependent_varname,
                              independent_varname, compile_params, model_arch,
-                             model_state, gpus_per_host, segments_per_host,
+                             serialized_weights, gpus_per_host, 
segments_per_host,
                              seg_ids, rows_per_seg,
                              gp_segment_id_col, metrics_list, loss_list,
                              curr_iter, dataset_name):
     """
-    Compute the loss and metric using a given model (model_state) on the
+    Compute the loss and metric using a given model (serialized_weights) on the
     given dataset (table.)
     """
     start_val = time.time()
@@ -358,7 +358,7 @@ def compute_loss_and_metrics(schema_madlib, table, 
dependent_varname,
                                                       independent_varname,
                                                       compile_params,
                                                       model_arch,
-                                                      model_state,
+                                                      serialized_weights,
                                                       gpus_per_host,
                                                       segments_per_host,
                                                       seg_ids,
@@ -434,7 +434,7 @@ def get_images_per_seg(source_table, dependent_varname):
 def fit_transition(state, dependent_var, independent_var, model_architecture,
                    compile_params, fit_params, current_seg_id, seg_ids,
                    images_per_seg, gpus_per_host, segments_per_host,
-                   previous_state, **kwargs):
+                   prev_serialized_weights, **kwargs):
     if not independent_var or not dependent_var:
         return state
 
@@ -447,12 +447,10 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
         if not is_platform_pg():
             set_keras_session(gpus_per_host, segments_per_host)
         segment_model = model_from_json(model_architecture)
-        SD['model_shapes'] = 
madlib_keras_serializer.get_model_shapes(segment_model)
-        # Configure GPUs/CPUs
         compile_and_set_weights(segment_model, compile_params, device_name,
-                                previous_state, SD['model_shapes'])
+                                prev_serialized_weights)
+
         SD['segment_model'] = segment_model
-        image_count = 0
         agg_image_count = 0
     else:
         segment_model = SD['segment_model']
@@ -500,7 +498,7 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
         plpy.error('Processed {0} images, but there were supposed to be only 
{1}!'
             .format(agg_image_count, total_images))
 
-    new_model_state = madlib_keras_serializer.serialize_weights(
+    new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
         agg_image_count, updated_weights)
 
     del x_train
@@ -510,7 +508,7 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
     plpy.info("Processed {0} images: Fit took {1} sec, Total was {2} 
sec".format(
         image_count, end_fit - start_fit, end_transition - start_transition))
 
-    return new_model_state
+    return new_state
 
 def fit_merge(state1, state2, **kwargs):
 
@@ -519,8 +517,8 @@ def fit_merge(state1, state2, **kwargs):
         return state1 or state2
 
     # Deserialize states
-    image_count1, weights1 = 
madlib_keras_serializer.deserialize_weights_merge(state1)
-    image_count2, weights2 = 
madlib_keras_serializer.deserialize_weights_merge(state2)
+    image_count1, weights1 = 
madlib_keras_serializer.deserialize_as_image_1d_weights(state1)
+    image_count2, weights2 = 
madlib_keras_serializer.deserialize_as_image_1d_weights(state2)
 
     # Compute total image counts
     image_count = (image_count1 + image_count2) * 1.0
@@ -529,7 +527,7 @@ def fit_merge(state1, state2, **kwargs):
     total_weights = weights1 + weights2
 
     # Return the merged state
-    return madlib_keras_serializer.serialize_weights_merge(
+    return madlib_keras_serializer.serialize_state_with_1d_weights(
         image_count, total_weights)
 
 def fit_final(state, **kwargs):
@@ -537,13 +535,13 @@ def fit_final(state, **kwargs):
     if not state:
         return state
 
-    image_count, weights = 
madlib_keras_serializer.deserialize_weights_merge(state)
+    image_count, weights = 
madlib_keras_serializer.deserialize_as_image_1d_weights(state)
     if image_count == 0:
         plpy.error("fit_final: Total images processed is 0")
 
     # Averaging the weights
     weights /= image_count
-    return madlib_keras_serializer.serialize_weights_merge(
+    return madlib_keras_serializer.serialize_state_with_1d_weights(
         image_count, weights)
 
 def evaluate1(schema_madlib, model_table, test_table, id_col, model_arch_table,
@@ -557,7 +555,7 @@ def evaluate1(schema_madlib, model_table, test_table, 
id_col, model_arch_table,
     # _validate_input_args(test_table, model_arch_table, output_table)
 
     model_data_query = "SELECT model_data from {0}".format(model_table)
-    model_data = plpy.execute(model_data_query)[0]['model_data']
+    serialized_weights = plpy.execute(model_data_query)[0]['model_data']
 
     model_arch_query = "SELECT model_arch, model_weights FROM {0} " \
                        "WHERE id = {1}".format(model_arch_table, model_arch_id)
@@ -572,7 +570,7 @@ def evaluate1(schema_madlib, model_table, test_table, 
id_col, model_arch_table,
     loss_metric = get_loss_metric_from_keras_eval(
                     schema_madlib, test_table, dependent_varname,
                     independent_varname, compile_params, model_arch,
-                    model_data, False, None)
+                    serialized_weights, False, None)
 
     #TODO remove these infos after adding create table command
     plpy.info('len of evaluate result is {}'.format(len(loss_metric)))
@@ -581,7 +579,7 @@ def evaluate1(schema_madlib, model_table, test_table, 
id_col, model_arch_table,
 
 def get_loss_metric_from_keras_eval(schema_madlib, table, dependent_varname,
                                  independent_varname, compile_params,
-                                 model_arch, model_data, gpus_per_host,
+                                 model_arch, serialized_weights, gpus_per_host,
                                  segments_per_host, seg_ids, images_per_seg,
                                  gp_segment_id_col):
     """
@@ -607,12 +605,12 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, 
dependent_varname,
                 )) AS loss_metric
     FROM {table}
     """.format(**locals()), ["bytea"])
-    res = plpy.execute(evaluate_query, [model_data])
+    res = plpy.execute(evaluate_query, [serialized_weights])
     loss_metric = res[0]['loss_metric']
     return loss_metric
 
 def internal_keras_eval_transition(state, dependent_var, independent_var,
-                                   model_architecture, model_data, 
compile_params,
+                                   model_architecture, serialized_weights, 
compile_params,
                                    current_seg_id, seg_ids, images_per_seg,
                                    gpus_per_host, segments_per_host, **kwargs):
     SD = kwargs['SD']
@@ -624,12 +622,9 @@ def internal_keras_eval_transition(state, dependent_var, 
independent_var,
         if not is_platform_pg():
             set_keras_session(gpus_per_host, segments_per_host)
         model = model_from_json(model_architecture)
-        model_shapes = madlib_keras_serializer.get_model_shapes(model)
-        _, model_weights = madlib_keras_serializer.deserialize_weights(
-            model_data, model_shapes)
-        model.set_weights(model_weights)
-        with K.tf.device(device_name):
-            compile_model(model, compile_params)
+        compile_and_set_weights(model, compile_params, device_name,
+                                serialized_weights)
+
         SD['segment_model'] = model
         # These should already be 0, but just in case make sure
         agg_metric = 0
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 12bcb39..4db18b8 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -136,7 +136,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     images_per_seg             INTEGER[],
     gpus_per_host              INTEGER,
     segments_per_host          INTEGER,
-    previous_state             BYTEA
+    prev_serialized_weights    BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_transition(**globals())
@@ -183,7 +183,7 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* images_per_seg*/          INTEGER[],
     /* gpus_per_host  */         INTEGER,
     /* segments_per_host  */     INTEGER,
-    /* previous_state */         BYTEA
+    /* serialized_weights */     BYTEA
 )(
     STYPE=BYTEA,
     SFUNC=MADLIB_SCHEMA.fit_transition,
@@ -284,7 +284,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.internal_keras_eval_transition(
     dependent_var                      SMALLINT[],
     independent_var                    REAL[],
     model_architecture                 TEXT,
-    model_data                         BYTEA,
+    serialized_weights                 BYTEA,
     compile_params                     TEXT,
     current_seg_id                     INTEGER,
     seg_ids                            INTEGER[],
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index dbf29e9..b4d21fb 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -25,14 +25,10 @@ from keras import backend as K
 from keras.layers import *
 from keras.models import *
 from keras.optimizers import *
-import numpy as np
 
 from madlib_keras_helper import expand_input_dims
-from madlib_keras_helper import MODEL_DATA_COLNAME
 from madlib_keras_helper import strip_trailing_nulls_from_class_values
 from madlib_keras_validator import PredictInputValidator
-from madlib_keras_wrapper import get_device_name_and_set_cuda_env
-from madlib_keras_wrapper import set_model_weights
 from predict_input_params import PredictParamsProcessor
 from utilities.control import MinWarning
 from utilities.model_arch_info import get_input_shape
@@ -163,8 +159,9 @@ def internal_keras_predict(independent_var, 
model_architecture, model_data,
             if not is_platform_pg():
                 set_keras_session(gpus_per_host, segments_per_host)
             model = model_from_json(model_architecture)
-            model_shapes = madlib_keras_serializer.get_model_shapes(model)
+            model_shapes = get_model_shapes(model)
             set_model_weights(model, device_name, model_data, model_shapes)
+
             SD[model_key] = model
             SD[row_count_key] = 0
         else:
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
index 341b154..c92b3c6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
@@ -16,68 +16,67 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 import numpy as np
 
-def get_model_shapes(model):
-    model_shapes = []
-    for a in model.get_weights():
-        model_shapes.append(a.shape)
-    return model_shapes
-
 # TODO
-# Current serializing logic
-# serialized string -> byte string
-# np.array(np.array(loss, acc, 
buff_count).concatenate(weights_np_array)).tostring()
+# 1. Current serializing logic
+    # serialized string -> byte string
+    # np.array(np.array(image_count).concatenate(weights_np_array)).tostring()
+    # Proposed logic
+    # image_count can be a separate value
+    # weights -> np.array(weights).tostring()
+    # combine these 2 into one string by a random splitter
+    # serialized string -> imagecount_splitter_weights
+# 2. combine the serialize_state_with_nd_weights and 
serialize_state_with_1d_weights
+    # into one function called serialize_state. This function can infer the 
shape
+    # of the model weights and then flatten if they are nd weights.
+# 3. Same as 2 for deserialize
 
-# Proposed logic
-# loss , accuracy and image_count can be comma separated values
-# weights -> np.array.tostring()
-# combine these 2 into one string by a random splitter
-# serialized string -> loss_splitter_acc_splitter_buffer_splitter_weights
 
-def deserialize_weights(model_state, model_shapes):
+"""
+1. Set initial weights in madlib keras fit function.
+1. Serialize these initial model weights as a byte string and pass it to keras 
step
+1. Deserialize the state passed from the previous step into a list of nd 
weights
+that will be passed on to model.set_weights()
+1. At the end of each buffer in fit transition, serialize the image count and
+the model weights into a bytestring that will be passed on to the fit merge 
function.
+1. In fit merge, deserialize the state as image and 1d np arrays. Do some 
averaging
+operations and serialize them again into a state which contains the image
+and the 1d state. same for fit final 
+1. Return the final state from fit final to fit which will then be 
deserialized 
+as 1d weights to be passed on to the evaluate function
+"""
+def get_image_count_from_state(state):
     """
-    Parameters:
-        model_state: a stringified (serialized) state containing
-        image_count and model_weights, passed from postgres
-        model_shapes: a list of tuples containing the shapes of
-        each element in keras.get_weights()
-    Returns:
-        image_count: the buffer count from state
-        model_weights: a list of numpy arrays that can be inputted into 
keras.set_weights()
+    :param state: bytestring serialized model state containing image count
+    and weights
+    :return: image count as float
     """
-    if not model_state or not model_shapes:
-        return None
-    state = np.fromstring(model_state, dtype=np.float32)
+    image_count , _  = deserialize_as_image_1d_weights(state)
+    return image_count
 
-    model_weights_serialized = state[1:]
-    i, j, model_weights = 0, 0, []
-    while j < len(model_shapes):
-        next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
-        weight_arr_portion = model_weights_serialized[i:next_pointer]
-        model_weights.append(weight_arr_portion.reshape(model_shapes[j]))
-        i, j = next_pointer, j + 1
-    #TODO: float(state[0]) is the image_count, which can be get from
-    # get_image_count_from_state() we defined below, we should check if
-    # we still need to return it here when refactoring
-    return float(state[0]), model_weights
+def get_serialized_1d_weights_from_state(state):
+    """
+    Output of this function is used to deserialize the output of each iteration
+    of madlib keras step UDA.
 
-def get_image_count_from_state(model_state):
-    if not model_state:
-        return None
-    state = np.fromstring(model_state, dtype=np.float32)
-    return float(state[0])
+    :param state: bytestring serialized model state containing image count
+    and weights
+    :return: model weights serialized as bytestring
+    """
+    _ , weights = deserialize_as_image_1d_weights(state)
+    return weights.tostring()
 
-def serialize_weights(image_count, model_weights):
+def serialize_state_with_nd_weights(image_count, model_weights):
     """
-    Parameters:
-        image_count: float values
-        model_weights: a list of numpy arrays, what you get from
+    This function is called when the output of keras.get_weights() (list of nd
+    np arrays) has to be converted into a serialized model state.
+
+    :param image_count: float value
+    :param model_weights: a list of numpy arrays, what you get from
         keras.get_weights()
-    Returns:
-        A stringified (serialized) state containing all these values, to be
-        passed to postgres
+    :return: Image count and model weights serialized into a bytestring format
+
     """
     if model_weights is None:
         return None
@@ -88,36 +87,37 @@ def serialize_weights(image_count, model_weights):
     new_model_string = np.float32(new_model_string)
     return new_model_string.tostring()
 
-def deserialize_iteration_state(iteration_result):
+
+def serialize_state_with_1d_weights(image_count, model_weights):
     """
-    Parameters:
-        iteration_result: the output of the step function
-    Returns:
-        new_model_state: the stringified (serialized) state to pass in to next
-        iteration of step function training, represents the averaged weights
-        from the last iteration of training; zeros out image_count in this 
state
-        because the new iteration must start with
-        fresh values
+    This function is called when the weights are to be passed to the keras fit
+    merge and final functions.
+
+    :param image_count: float value
+    :param model_weights: a single flattened numpy array containing all of the
+        weights
+    :return: Image count and model weights serialized into a bytestring format
+
     """
-    if not iteration_result:
+    if model_weights is None:
         return None
-    state = np.fromstring(iteration_result, dtype=np.float32)
-    new_model_string = np.array(state)
-    new_model_string[0]= 0
-    new_model_string = np.float32(new_model_string)
-    return new_model_string.tostring()
+    merge_state = np.array([image_count])
+    merge_state = np.concatenate((merge_state, model_weights))
+    merge_state = np.float32(merge_state)
+    return merge_state.tostring()
 
 
-def deserialize_weights_merge(state):
+def deserialize_as_image_1d_weights(state):
     """
-    Parameters:
-        state: the stringified (serialized) state containing loss, accuracy, 
image_count, and
-            model_weights, passed from postgres to merge function
-    Returns:
+    This function is called when the model state needs to be deserialized in
+    the keras fit merge and final functions.
+
+    :param state: the stringified (serialized) state containing image_count and
+            model_weights
+    :return:
         image_count: total buffer counts processed
         model_weights: a single flattened numpy array containing all of the
-        weights, flattened because all we have to do is average them (so don't
-        have to reshape)
+        weights
     """
     if not state:
         return None
@@ -125,33 +125,40 @@ def deserialize_weights_merge(state):
     return float(state[0]), state[1:]
 
 
-def serialize_weights_merge(image_count, model_weights):
+def serialize_nd_weights(model_weights):
     """
-    Parameters:
-        image_count: float values
-        model_weights: a single flattened numpy array containing all of the
-        weights, averaged in merge function over the 2 states
-    Returns:
-        A stringified (serialized) state containing all these values, to be
-        passed to postgres
+    This function is called for passing the initial model weights from the 
keras
+    fit function to the keras fit transition function.
+    :param model_weights: a list of numpy arrays, what you get from
+        keras.get_weights()
+    :return: Model weights serialized into a bytestring format
     """
     if model_weights is None:
         return None
-    new_model_string = np.array([image_count])
-    new_model_string = np.concatenate((new_model_string, model_weights))
-    new_model_string = np.float32(new_model_string)
-    return new_model_string.tostring()
+    flattened_weights = [w.flatten() for w in model_weights]
+    model_weights_serialized = np.concatenate(flattened_weights)
+    return np.float32(model_weights_serialized).tostring()
 
 
-def deserialize_weights_orig(model_weights_serialized, model_shapes):
+def deserialize_as_nd_weights(model_weights_serialized, model_shapes):
     """
-    Original deserialization for warm-start, used only to parse model received
-    from query at the top of this file
+    The output of this function is used to set keras model weights using the
+    function model.set_weights()
+    :param model_weights_serialized: bytestring containing model weights
+    :param model_shapes: list containing the shapes of each layer.
+    :return: list of nd numpy arrays containing all of the
+        weights
     """
+    if not model_weights_serialized or not model_shapes:
+        return None
+
     i, j, model_weights = 0, 0, []
+    model_weights_serialized = np.fromstring(model_weights_serialized, 
dtype=np.float32)
     while j < len(model_shapes):
         next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
         weight_arr_portion = model_weights_serialized[i:next_pointer]
         
model_weights.append(np.array(weight_arr_portion).reshape(model_shapes[j]))
         i, j = next_pointer, j + 1
     return model_weights
+
+
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index eef30bf..37c1b73 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -96,21 +96,28 @@ def clear_keras_session():
     K.clear_session()
     sess.close()
 
+def get_model_shapes(model):
+    model_shapes = []
+    for a in model.get_weights():
+        model_shapes.append(a.shape)
+    return model_shapes
+
 def compile_and_set_weights(segment_model, compile_params, device_name,
-                            previous_state, model_shapes):
+                            serialized_weights):
+    model_shapes = get_model_shapes(segment_model)
     with K.tf.device(device_name):
         compile_model(segment_model, compile_params)
-        _, model_weights = madlib_keras_serializer.deserialize_weights(
-            previous_state, model_shapes)
+        model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
+            serialized_weights, model_shapes)
         segment_model.set_weights(model_weights)
 
 # TODO: This can be refactored to be part of compile_and_set_weights(),
 # by making compile_params an optional param in that function. Doing that
 # now might create more merge conflicts with other JIRAs, so get to this later.
-def set_model_weights(segment_model, device_name, state, model_shapes):
+def set_model_weights(segment_model, device_name, serialized_weights, 
model_shapes):
     with K.tf.device(device_name):
-        _, model_weights = madlib_keras_serializer.deserialize_weights(
-            state, model_shapes)
+        model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
+            serialized_weights, model_shapes)
         segment_model.set_weights(model_weights)
 
 """
diff --git 
a/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in 
b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in
index e6a51cc..e3d589c 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in
@@ -33,8 +33,8 @@ SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 
'model_id column shoul
         AND attname = 'model_id';
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'JSON', 'model_arch column 
should be JSON type' ) FROM pg_attribute WHERE attrelid = 
'test_keras_model_arch_table'::regclass
         AND attname = 'model_arch';
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
-    'DOUBLE PRECISION[]', 'model_weights column should be DOUBLE PRECISION[] 
type')
+SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA',
+    'model_weights column should be bytea type')
     FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass
         AND attname = 'model_weights';
 
@@ -84,8 +84,8 @@ SELECT assert(COUNT(model_id) = 0, 'model id 3 should have 
been deleted!')
        *  It should archrt to the user that the model_id wasn't found but not
        *  raise an exception or change anything. */
 SELECT delete_keras_model('test_keras_model_arch_table', 1);
-SELECT assert(COUNT(relname) = 0, 'Table test_keras_model_arch_table should 
have been deleted.')
-    FROM pg_class where relname = 'test_keras_model_arch_table';
+SELECT assert(trap_error($$SELECT * from test_keras_model_arch_table$$) = 1,
+              'Table test_keras_model_arch_table should have been deleted.');
 
 SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}');
 DELETE FROM test_keras_model_arch_table;
@@ -104,6 +104,22 @@ SELECT assert(trap_error($$SELECT 
delete_keras_model('test_keras_model_arch_tabl
 SELECT assert(trap_error($$SELECT 
load_keras_model('test_keras_model_arch_table', '{"config" : 1}')$$) = 1, 
'Passing an invalid table to load_keras_model() should raise exception.');
 
 /* Test deletion where no table exists */
-DROP TABLE test_keras_model_arch_table;
+DROP TABLE IF EXISTS test_keras_model_arch_table;
 SELECT assert(trap_error($$SELECT 
delete_keras_model('test_keras_model_arch_table', 3)$$) = 1,
               'Deleting a non-existent table should raise exception.');
+
+DROP TABLE IF EXISTS test_keras_model_arch_table;
+SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', 
'dummy weights'::bytea);
+SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', 
NULL, 'my name', 'my desc');
+
+/* Test model weights */
+SELECT assert(model_weights = 'dummy weights', 'Incorrect model_weights in the 
model arch table.')
+FROM test_keras_model_arch_table WHERE model_id = 1;
+SELECT assert(model_weights IS NULL, 'model_weights is not NULL')
+FROM test_keras_model_arch_table WHERE model_id = 2;
+
+/* Test name and description */
+SELECT assert(name IS NULL AND description IS NULL, 'Name or description is 
not NULL.')
+FROM test_keras_model_arch_table WHERE model_id = 1;
+SELECT assert(name = 'my name' AND description = 'my desc', 'Incorrect name or 
description in the model arch table.')
+FROM test_keras_model_arch_table WHERE model_id = 2;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index fc96f1e..6e6065e 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -806,3 +806,101 @@ SELECT madlib_keras_predict(
     'cifar10_predict',
     'prob',
     0);
+
+
+-- Test cases for transfer learning
+-- 1. Create a model arch table with weights all set to 0.008. 0.008 is just a
+-- random number we chose after a few experiments so that we can 
deterministically
+-- assert the loss and metric values reported by madlib_keras_fit.
+-- 2. Run keras fit and then update the model arch table with the output of 
the keras
+-- fit run.
+CREATE OR REPLACE FUNCTION create_model_arch_transfer_learning() RETURNS VOID 
AS $$
+from keras.layers import *
+from keras import Sequential
+import numpy as np
+import plpy
+
+model = Sequential()
+model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', 
input_shape=(32,32,3,)))
+model.add(MaxPooling2D(pool_size=(2, 2)))
+model.add(Dropout(0.25))
+model.add(Flatten())
+model.add(Dense(2, activation='softmax'))
+
+# we don't really need to get the weights from the model and the flatten them 
since
+# we are using np.ones_like to replace all the weights with a constant.
+# We want to keep the flatten code and the concatenation code just for 
reference
+weights = model.get_weights()
+weights_flat = [ w.flatten() for w in weights ]
+weights1d = np.array([j for sub in weights_flat for j in sub])
+# Adjust weights so that the learning for the first iteration can be 
deterministic
+# 0.008 is just a random number we chose after a few experiments
+weights1d = np.ones_like(weights1d)*0.008
+weights_bytea = weights1d.tostring()
+
+model_config = model.to_json()
+
+plan1 = plpy.prepare("""SELECT load_keras_model(
+                        'test_keras_model_arch_table',
+                        $1, $2)
+                    """, ['json','bytea'])
+plpy.execute(plan1, [model_config, weights_bytea])
+
+$$ LANGUAGE plpythonu VOLATILE;
+
+DROP TABLE IF EXISTS test_keras_model_arch_table;
+SELECT create_model_arch_transfer_learning();
+
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'test_keras_model_arch_table',
+    1,
+    $$ optimizer=SGD(lr=0.001, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    1);
+SELECT training_loss_final FROM keras_saved_out_summary;
+
+-- We want to keep this select in case the assert fails and we need
+-- to know the actual values in the table without re running the entire test
+\x
+select * from keras_saved_out_summary;
+\x
+
+-- This assert is a work in progress (we are hoping that these asserts will 
not be flaky).
+-- We want to be able to assert that the loss/metric for the first iteration is
+-- deterministic if we set weights using the load_keras function. Although we
+-- have seen that the loss/metric values are different in the 3rd/4th decimal
+-- every time we run fit after loading the weights.
+
+-- TODO https://github.com/apache/madlib/pull/399#discussion_r288336557
+-- Might be a cleaner assert if we can assert the weights themselves.
+-- For instance, if we use weights1d = np.ones_like(weights1d) instead of
+-- weights1d = np.ones_like(weights1d)*0.008, and freeze the first layer,
+-- then even after multiple iterations the weights in the first layer should 
all
+-- be 1. Look at How can I "freeze" Keras layers?
+-- section in https://keras.io/getting-started/faq/ for how to freeze layers.
+SELECT assert(abs(training_loss_final - 0.6) < 0.1 AND
+              abs(training_metrics_final - 0.4) < 0.1,
+       'Transfer learning test failed.')
+FROM keras_saved_out_summary;
+DROP FUNCTION create_model_arch_transfer_learning();
+
+UPDATE test_keras_model_arch_table SET model_weights = model_data FROM 
keras_saved_out WHERE model_id = 1;
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'test_keras_model_arch_table',
+    1,
+    $$ optimizer=SGD(lr=0.001, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
+SELECT training_loss_final, training_metrics_final FROM 
keras_saved_out_summary;
+
+--assert training loss and metric deterministic
+SELECT assert(abs(training_loss_final - 0.64) < 0.01 AND
+              abs(training_metrics_final - 0.47) < 0.01,
+       'Transfer learning test failed.')
+FROM keras_saved_out_summary;
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index edfb69a..286798f 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -61,9 +61,8 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.compile_params = "optimizer=SGD(lr=0.01, decay=1e-6, 
nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']"
         self.fit_params = "batch_size=1, epochs=1"
         self.model_weights = [3,4,5,6]
-        self.model_shapes = []
-        for a in self.model.get_weights():
-            self.model_shapes.append(a.shape)
+        self.serialized_weights = np.array(self.model_weights, dtype=np.float32
+                                           ).tostring()
 
         self.all_seg_ids = [0,1,2]
 
@@ -86,17 +85,15 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.subject.is_platform_pg = Mock(return_value = True)
         starting_image_count = 0
         ending_image_count = len(self.dependent_var)
-        previous_state = [starting_image_count]
-        previous_state.extend(self.model_weights)
-        previous_state = np.array(previous_state, dtype=np.float32)
+        previous_state = np.array(self.model_weights, dtype=np.float32)
 
         k = {'SD' : {}}
 
-        new_model_state = self.subject.fit_transition(
+        new_state = self.subject.fit_transition(
             None, self.dependent_var, self.independent_var , 
self.model.to_json(),
             self.compile_params, self.fit_params, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 4, previous_state.tostring(), **k)
-        state = np.fromstring(new_model_state, dtype=np.float32)
+        state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
@@ -107,7 +104,6 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the first buffer
         self.assertEqual(0, self.subject.clear_keras_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
-        self.assertTrue(k['SD']['model_shapes'])
 
     def test_fit_transition_first_buffer_pass_gpdb(self):
         #TODO should we mock tensorflow's close_session and keras'
@@ -119,17 +115,15 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.subject.is_platform_pg = Mock(return_value = False)
         starting_image_count = 0
         ending_image_count = len(self.dependent_var)
-        previous_state = [starting_image_count]
-        previous_state.extend(self.model_weights)
-        previous_state = np.array(previous_state, dtype=np.float32)
+        previous_state = np.array(self.model_weights, dtype=np.float32)
 
         k = {'SD' : {}}
 
-        new_model_state = self.subject.fit_transition(
+        new_state = self.subject.fit_transition(
             None, self.dependent_var, self.independent_var , 
self.model.to_json(),
             self.compile_params, self.fit_params, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 4, previous_state.tostring(), **k)
-        state = np.fromstring(new_model_state, dtype=np.float32)
+        state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
@@ -140,7 +134,6 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the first buffer
         self.assertEqual(0, self.subject.clear_keras_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
-        self.assertTrue(k['SD']['model_shapes'])
 
     def test_fit_transition_middle_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
@@ -157,16 +150,15 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         state = np.array(state, dtype=np.float32)
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', state.tostring(), 
self.model_shapes)
-        k = {'SD': {'model_shapes': self.model_shapes}}
-        k['SD']['segment_model'] = self.model
+                                             '/cpu:0', self.serialized_weights)
+        k = {'SD': {'segment_model': self.model}}
 
-        new_model_state = self.subject.fit_transition(
+        new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 4, 'dummy_previous_state', **k)
 
-        state = np.fromstring(new_model_state, dtype=np.float32)
+        state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
@@ -194,15 +186,14 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         multiplied_weights = 
mult(self.total_images_per_seg[0],self.model_weights)
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', state.tostring(), 
self.model_shapes)
-        k = {'SD': { 'model_shapes': self.model_shapes}}
-        k['SD']['segment_model'] = self.model
-        new_model_state = self.subject.fit_transition(
+                                             '/cpu:0', self.serialized_weights)
+        k = {'SD': {'segment_model' :self.model}}
+        new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var , 
self.model.to_json(),
             None, self.fit_params, 0, self.all_seg_ids, 
self.total_images_per_seg,
             0, 4, 'dummy_previous_state', **k)
 
-        state = np.fromstring(new_model_state, dtype=np.float32)
+        state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
@@ -231,15 +222,14 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         multiplied_weights = 
mult(self.total_images_per_seg[0],self.model_weights)
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', state.tostring(), 
self.model_shapes)
-        k = {'SD': { 'model_shapes': self.model_shapes}}
-        k['SD']['segment_model'] = self.model
-        new_model_state = self.subject.fit_transition(
+                                             '/cpu:0', self.serialized_weights)
+        k = {'SD': {'segment_model' :self.model}}
+        new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 4, 'dummy_previous_state', **k)
 
-        state = np.fromstring(new_model_state, dtype=np.float32)
+        state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
@@ -264,7 +254,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         total_images_per_seg = [0,1,1]
 
         with self.assertRaises(plpy.PLPYException) as error:
-            new_model_state = self.subject.fit_transition(
+            new_state = self.subject.fit_transition(
                 None, self.dependent_var, self.independent_var , 
self.model.to_json(),
                 self.compile_params, self.fit_params, 0, self.all_seg_ids,
                 total_images_per_seg, 0, 4, previous_state.tostring(), **k)
@@ -283,7 +273,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         total_images_per_seg = [1,1,1]
 
         with self.assertRaises(plpy.PLPYException) as error:
-            new_model_state = self.subject.fit_transition(
+            new_state = self.subject.fit_transition(
                 None, self.dependent_var, self.independent_var , 
self.model.to_json(),
                 self.compile_params, self.fit_params, 0, self.all_seg_ids,
                 total_images_per_seg, 0, 4, previous_state.tostring(), **k)
@@ -916,73 +906,64 @@ class MadlibSerializerTestCase(unittest.TestCase):
     def tearDown(self):
         self.module_patcher.stop()
 
-    def test_deserialize_weights_merge_null_state_returns_none(self):
-        self.assertEqual(None, self.subject.deserialize_weights_merge(None))
+    def test_deserialize_image_1d_weights_null_state_returns_none(self):
+        self.assertEqual(None, 
self.subject.deserialize_as_image_1d_weights(None))
 
-    def test_deserialize_weights_merge_returns_not_none(self):
-        dummy_model_state = np.array([2,3,4,5,6], dtype=np.float32)
-        res = 
self.subject.deserialize_weights_merge(dummy_model_state.tostring())
+    def test_deserialize_image_1d_weights_returns_not_none(self):
+        dummy_state = np.array([2,3,4,5,6], dtype=np.float32)
+        res = 
self.subject.deserialize_as_image_1d_weights(dummy_state.tostring())
         self.assertEqual(2, res[0])
         self.assertEqual([3,4,5,6], res[1].tolist())
 
-    def test_deserialize_weights_null_input_returns_none(self):
-        dummy_model_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
-        self.assertEqual(None, 
self.subject.deserialize_weights(dummy_model_state.tostring(), None))
-        self.assertEqual(None, self.subject.deserialize_weights(None, [1,2,3]))
-        self.assertEqual(None, self.subject.deserialize_weights(None, None))
+    def test_deserialize_nd_weights_null_input_returns_none(self):
+        dummy_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
+        self.assertEqual(None, 
self.subject.deserialize_as_nd_weights(dummy_state.tostring(), None))
+        self.assertEqual(None, self.subject.deserialize_as_nd_weights(None, 
[1, 2, 3]))
+        self.assertEqual(None, self.subject.deserialize_as_nd_weights(None, 
None))
 
-    def test_deserialize_weights_valid_input_returns_not_none(self):
-        dummy_model_state = np.array([0,3,4,5], dtype=np.float32)
+    def test_deserialize_nd_weights_valid_input_returns_not_none(self):
+        dummy_model_weights = np.array([3,4,5], dtype=np.float32)
         dummy_model_shape = [(2, 1, 1, 1), (1,)]
-        res = self.subject.deserialize_weights(dummy_model_state.tostring(), 
dummy_model_shape)
-        self.assertEqual(0, res[0])
-        self.assertEqual([[[[3.0]]], [[[4.0]]]], res[1][0].tolist())
-        self.assertEqual([5], res[1][1].tolist())
+        res = 
self.subject.deserialize_as_nd_weights(dummy_model_weights.tostring(),
+                                                     dummy_model_shape)
+        self.assertEqual([[[[3.0]]], [[[4.0]]]], res[0].tolist())
+        self.assertEqual([5], res[1].tolist())
 
-    def test_deserialize_weights_invalid_input_fails(self):
+    def test_deserialize_nd_weights_invalid_input_fails(self):
         # pass an invalid state with missing model weights
-        invalid_model_state = np.array([0,1,2], dtype=np.float32)
+        invalid_model_weights = np.array([1,2], dtype=np.float32)
         dummy_model_shape = [(2, 1, 1, 1), (1,)]
 
         # we expect keras failure(ValueError) because we cannot reshape
         # model weights of size 0 into shape (2,2,3,1)
         with self.assertRaises(ValueError):
-            self.subject.deserialize_weights(invalid_model_state.tostring(), 
dummy_model_shape)
+            
self.subject.deserialize_as_nd_weights(invalid_model_weights.tostring(),
+                                                   dummy_model_shape)
 
-        invalid_model_state = np.array([0,1,2,3,4], dtype=np.float32)
+        invalid_model_weights = np.array([1,2,3,4], dtype=np.float32)
         dummy_model_shape = [(2, 2, 3, 1), (1,)]
         # we expect keras failure(ValueError) because we cannot reshape
         # model weights of size 2 into shape (2,2,3,1)
         with self.assertRaises(ValueError):
-            self.subject.deserialize_weights(invalid_model_state.tostring(), 
dummy_model_shape)
-
-    def test_deserialize_iteration_state_none_input_returns_none(self):
-        self.assertEqual(None, self.subject.deserialize_iteration_state(None))
+            
self.subject.deserialize_as_nd_weights(invalid_model_weights.tostring(),
+                                                   dummy_model_shape)
 
-    def test_deserialize_iteration_state_returns_valid_output(self):
-        dummy_iteration_state = np.array([2,3,4,5], dtype=np.float32)
-        res = self.subject.deserialize_iteration_state(
-            dummy_iteration_state.tostring())
-        self.assertEqual(res,
-                         np.array([0,3,4,5], dtype=np.float32).tostring())
-
-
-    def test_serialize_weights_none_weights_returns_none(self):
-        res = self.subject.serialize_weights(0,None)
+    def test_serialize_image_nd_weights_none_weights_returns_none(self):
+        res = self.subject.serialize_state_with_nd_weights(0, None)
         self.assertEqual(None , res)
 
-    def test_serialize_weights_valid_output(self):
-        res = self.subject.serialize_weights(0,[np.array([1,3]),
-                                                    np.array([4,5])])
+    def test_serialize_image_nd_weights_valid_output(self):
+        res = self.subject.serialize_state_with_nd_weights(0, [np.array([1, 
3]),
+                                                               
np.array([4,5])])
         self.assertEqual(np.array([0,1,3,4,5], dtype=np.float32).tostring(),
                          res)
 
-    def test_serialize_weights_merge_none_weights_returns_none(self):
-        res = self.subject.serialize_weights_merge(0,None)
+    def test_serialize_image_1d_weights_none_weights_returns_none(self):
+        res = self.subject.serialize_state_with_1d_weights(0, None)
         self.assertEqual(None , res)
 
-    def test_serialize_weights_merge_valid_output(self):
-        res = self.subject.serialize_weights_merge(0,np.array([1,3,4,5]))
+    def test_serialize_image_1d_weights_valid_output(self):
+        res = self.subject.serialize_state_with_1d_weights(0, np.array([1, 3, 
4, 5]))
         self.assertEqual(np.array([0,1,3,4,5], dtype=np.float32).tostring(),
                          res)
 
@@ -1093,10 +1074,8 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
 
         self.compile_params = "optimizer=SGD(lr=0.01, decay=1e-6, 
nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']"
         self.model_weights = [3,4,5,6]
-        self.model_shapes = []
-        for a in self.model.get_weights():
-            self.model_shapes.append(a.shape)
-
+        self.serialized_weights = np.array(self.model_weights, dtype='float32'
+                                           ).tostring()
         self.loss = 0.5947071313858032
         self.accuracy = 1.0
         self.all_seg_ids = [0,1,2]
@@ -1122,13 +1101,9 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         k = {'SD' : {}}
         state = [0,0,0]
 
-        serialized_weights = [0] # not used
-        serialized_weights.extend(self.model_weights)
-        serialized_weights = np.array(serialized_weights, 
dtype=np.float32).tostring()
-
         new_state = self.subject.internal_keras_eval_transition(
             state, self.dependent_var , self.independent_var, 
self.model.to_json(),
-            serialized_weights, self.compile_params, 0, self.all_seg_ids,
+            self.serialized_weights, self.compile_params, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 3, **k)
 
         agg_loss, agg_accuracy, image_count = new_state
@@ -1155,12 +1130,8 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
 
         k = {'SD' : {}}
 
-        model_state = [starting_image_count]
-        model_state.extend(self.model_weights)
-        model_state = np.array(model_state, dtype=np.float32)
-
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', model_state.tostring(), 
self.model_shapes)
+                                             '/cpu:0', self.serialized_weights)
 
         state = [self.loss * starting_image_count, self.accuracy * 
starting_image_count, starting_image_count]
         k['SD']['segment_model'] = self.model
@@ -1192,14 +1163,12 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         ending_image_count = starting_image_count + len(self.dependent_var)
         k = {'SD' : {}}
 
-        model_state = [starting_image_count]
-        model_state.extend(self.model_weights)
-        model_state = np.array(model_state, dtype=np.float32)
-
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', model_state.tostring(), 
self.model_shapes)
+                                             '/cpu:0', self.serialized_weights)
 
-        state = [self.loss * starting_image_count, self.accuracy * 
starting_image_count, starting_image_count]
+        state = [self.loss * starting_image_count,
+                 self.accuracy * starting_image_count,
+                 starting_image_count]
 
         k['SD']['segment_model'] = self.model
         new_state = self.subject.internal_keras_eval_transition(
@@ -1217,7 +1186,8 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         self.assertAlmostEqual(self.accuracy * ending_image_count, 
agg_accuracy, 4)
         # Clear session and sess.close must get called for the last buffer in 
gpdb,
         #  but not in postgres
-        self.assertEqual(0 if is_platform_pg else 1, 
self.subject.clear_keras_session.call_count)
+        self.assertEqual(0 if is_platform_pg else 1,
+                         self.subject.clear_keras_session.call_count)
 
     def test_internal_keras_eval_transition_first_buffer_pg(self):
         self._test_internal_keras_eval_transition_first_buffer(True)
@@ -1302,12 +1272,12 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         starting_image_count = 5
 
         k = {'SD' : {}}
-        model_state = [self.loss, self.accuracy, starting_image_count]
-        model_state.extend(self.model_weights)
-        model_state = np.array(model_state, dtype=np.float32)
+        state = [self.loss, self.accuracy, starting_image_count]
+        state.extend(self.model_weights)
+        state = np.array(state, dtype=np.float32)
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', model_state.tostring(), 
self.model_shapes)
+                                             '/cpu:0', state.tostring())
 
         state = [self.loss * starting_image_count, self.accuracy * 
starting_image_count, starting_image_count]
 

Reply via email to