This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new df03bc2  DL: Add function for predict byom
df03bc2 is described below

commit df03bc2029fffa081685c9ffd9471386666bc9a6
Author: Nikhil Kak <n...@pivotal.io>
AuthorDate: Fri Jul 19 17:39:27 2019 -0700

    DL: Add function for predict byom
    
    JIRA: MADLIB-1371 , MADLIB-1359
    
    Previously a user would have to train a deep learning model in madlib
    and only then they could use that model to predict.
    This commit adds a new function called `madlib_keras_predict_byom` which 
allows the
    user to run prediction on their own model which doesn't have to be
    trained on madlib.
    
    * Refactored the code to reuse the logic between predict and
    predict_byom
    * user doc changes
    
    Co-authored-by: Nandish Jayaram <njaya...@apache.org>
    Co-authored-by: Orhan Kislal <okis...@apache.org>
---
 .../deep_learning/input_data_preprocessor.py_in    |   2 +-
 .../modules/deep_learning/madlib_keras.py_in       |  71 +++--
 .../modules/deep_learning/madlib_keras.sql_in      | 337 +++++++++++++++++++-
 .../deep_learning/madlib_keras_helper.py_in        |   2 +
 .../deep_learning/madlib_keras_predict.py_in       | 312 +++++++++++++++----
 .../deep_learning/madlib_keras_validator.py_in     | 346 ++++++++++-----------
 .../model_arch_info.py_in                          |  36 +++
 .../modules/deep_learning/test/madlib_keras.sql_in | 196 +++++++++---
 .../test/unit_tests/test_madlib_keras.py_in        | 203 +++++++++---
 9 files changed, 1140 insertions(+), 365 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in 
b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
index 82bec97..6a03eca 100644
--- a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
+++ b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
@@ -58,7 +58,7 @@ class InputDataPreprocessorDL(object):
         self.dependent_varname = dependent_varname
         self.independent_varname = independent_varname
         self.buffer_size = buffer_size
-        self.normalizing_const = normalizing_const if normalizing_const is not 
None else 1.0
+        self.normalizing_const = normalizing_const if normalizing_const is not 
None else DEFAULT_NORMALIZING_CONST
         self.num_classes = num_classes
         self.module_name = module_name
         self.output_summary_table = None
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 47ce306..fa55093 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -29,25 +29,19 @@ import time
 import keras
 
 from keras import backend as K
-from keras import utils as keras_utils
 from keras.layers import *
 from keras.models import *
 from keras.optimizers import *
 from keras.regularizers import *
-import madlib_keras_serializer
 from madlib_keras_helper import *
 from madlib_keras_validator import *
 from madlib_keras_wrapper import *
-from keras_model_arch_table import ModelArchSchema
+from model_arch_info import *
 
-from utilities.control import MinWarning
-from utilities.model_arch_info import get_input_shape
-from utilities.model_arch_info import get_num_classes
 from utilities.utilities import _assert
 from utilities.utilities import is_platform_pg
 from utilities.utilities import get_segments_per_host
 from utilities.utilities import madlib_version
-from utilities.validate_args import get_col_value_and_type
 from utilities.validate_args import get_expr_type
 from utilities.validate_args import quote_ident
 from utilities.control import MinWarning
@@ -68,7 +62,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
 
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table,
-        mb_dep_var_col, mb_indep_var_col,
+        model_arch_id, mb_dep_var_col, mb_indep_var_col,
         num_iterations, metrics_compute_frequency, warm_start)
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
@@ -88,23 +82,13 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
 
     # Get the serialized master model
     start_deserialization = time.time()
-    model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
-        ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS,
-        model_arch_table, ModelArchSchema.MODEL_ID,
-        model_arch_id)
-    model_arch_result = plpy.execute(model_arch_query)
-    if not  model_arch_result:
-        plpy.error("no model arch found in table {0} with id {1}".format(
-            model_arch_table, model_arch_id))
-    model_arch_result = model_arch_result[0]
-    model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH]
-    input_shape = get_input_shape(model_arch)
+    model_arch, model_weights = get_model_arch_weights(model_arch_table, 
model_arch_id)
     num_classes = get_num_classes(model_arch)
+    input_shape = get_input_shape(model_arch)
     fit_validator.validate_input_shapes(input_shape)
-
     gp_segment_id_col = '0' if is_platform_pg() else 'gp_segment_id'
 
-    serialized_weights = get_initial_weights(model, model_arch_result,
+    serialized_weights = get_initial_weights(model, model_arch, model_weights,
                                              warm_start, gpus_per_host)
     # Compute total images on each segment
     seg_ids_train, images_per_seg_train = 
get_image_count_per_seg_for_minibatched_data_from_db(source_table)
@@ -289,7 +273,7 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
-def get_initial_weights(model_table, model_arch_result, warm_start, 
gpus_per_host):
+def get_initial_weights(model_table, model_arch, serialized_weights, 
warm_start, gpus_per_host):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
@@ -317,10 +301,8 @@ def get_initial_weights(model_table, model_arch_result, 
warm_start, gpus_per_hos
             SELECT model_data FROM {0}
         """.format(model_table))[0]['model_data']
     else:
-        serialized_weights = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
         if not serialized_weights:
-            model = model_from_json(
-                model_arch_result[ModelArchSchema.MODEL_ARCH])
+            model = model_from_json(model_arch)
             serialized_weights = madlib_keras_serializer.serialize_nd_weights(
                 model.get_weights())
     return serialized_weights
@@ -518,10 +500,12 @@ def get_segments_and_gpus(gpus_per_host):
 
 def evaluate(schema_madlib, model_table, test_table, output_table, 
gpus_per_host, **kwargs):
     module_name = 'madlib_keras_evaluate'
-    input_validator = EvaluateInputValidator(test_table, model_table, 
output_table, module_name)
-
-    model_summary_table = input_validator.model_summary_table
-    test_summary_table = input_validator.test_summary_table
+    if test_table:
+        test_summary_table = add_postfix(test_table, "_summary")
+    model_summary_table = None
+    if model_table:
+        model_summary_table = add_postfix(model_table, "_summary")
+    validate_evaluate(module_name, model_table, model_summary_table, 
test_table, test_summary_table, output_table)
 
     segments_per_host, gpus_per_host = get_segments_and_gpus(gpus_per_host)
 
@@ -531,7 +515,8 @@ def evaluate(schema_madlib, model_table, test_table, 
output_table, gpus_per_host
     model_arch = res['model_arch']
 
     input_shape = get_input_shape(model_arch)
-    input_validator.validate_input_shape(input_shape)
+    InputValidator.validate_input_shape(
+        test_table, MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, input_shape, 2)
 
     compile_params_query = "SELECT compile_params, metrics_type FROM 
{0}".format(model_summary_table)
     res = plpy.execute(compile_params_query)[0]
@@ -540,10 +525,11 @@ def evaluate(schema_madlib, model_table, test_table, 
output_table, gpus_per_host
 
     seg_ids, images_per_seg = 
get_image_count_per_seg_for_minibatched_data_from_db(test_table)
 
-    loss, metric =\
-        get_loss_metric_from_keras_eval(schema_madlib, test_table, 
compile_params, model_arch,
-                                        model_data, gpus_per_host, 
segments_per_host,
-                                        seg_ids, images_per_seg)
+    loss, metric = \
+        get_loss_metric_from_keras_eval(
+            schema_madlib, test_table, compile_params, model_arch,
+            model_data, gpus_per_host, segments_per_host,
+            seg_ids, images_per_seg)
 
     if not metrics_type:
         metrics_type = None
@@ -555,6 +541,23 @@ def evaluate(schema_madlib, model_table, test_table, 
output_table, gpus_per_host
             SELECT $1 as loss, $2 as metric, $3 as 
metrics_type""".format(output_table), ["FLOAT", "FLOAT", "TEXT[]"])
         plpy.execute(create_output_table, [loss, metric, metrics_type])
 
+def validate_evaluate(module_name, model_table, model_summary_table, 
test_table, test_summary_table, output_table):
+    def _validate_test_summary_tbl():
+        input_tbl_valid(test_summary_table, module_name,
+                error_suffix_str="Please ensure that the test table ({0}) "
+                                 "has been preprocessed by "
+                                 "the image preprocessor.".format(test_table))
+        cols_in_tbl_valid(test_summary_table, [CLASS_VALUES_COLNAME,
+            NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
+            DEPENDENT_VARNAME_COLNAME, INDEPENDENT_VARNAME_COLNAME], 
module_name)
+
+    InputValidator.validate_predict_evaluate_tables(
+        module_name, model_table, model_summary_table,
+        test_table, output_table, MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL)
+    _validate_test_summary_tbl()
+    validate_dependent_var_for_minibatch(test_table,
+                                         MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL)
+
 def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                     model_arch, serialized_weights, 
gpus_per_host,
                                     segments_per_host, seg_ids, 
images_per_seg):
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index c5d8d35..c69f158 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -35,6 +35,7 @@ m4_include(`SQLCommon.m4')
 <li class="level1"><a href="#keras_fit">Fit</a></li>
 <li class="level1"><a href="#keras_evaluate">Evaluate</a></li>
 <li class="level1"><a href="#keras_predict">Predict</a></li>
+<li class="level1"><a href="#keras_predict_byom">Predict BYOM</a></li>
 <li class="level1"><a href="#example">Examples</a></li>
 <li class="level1"><a href="#notes">Notes</a></li>
 <li class="level1"><a href="#background">Technical Background</a></li>
@@ -616,6 +617,127 @@ madlib_keras_predict(
   </DD>
 </DL>
 
+
+
+@anchor keras_predict_byom
+@par Predict BYOM (Bring your own model)
+The predict byom function has the following format:
+<pre class="syntax">
+madlib_keras_predict_byom(
+    model_arch_table,
+    model_arch_id,
+    test_table,
+    id_col,
+    independent_varname,
+    output_table,
+    pred_type,
+    gpus_per_host,
+    class_values,
+    normalizing_const
+    )
+</pre>
+
+
+\b Arguments
+<dl class="arglist">
+
+<DT>model_arch_table</DT>
+  <DD>TEXT. Name of the architecture table containing the model
+  to use for prediction. The model weights and architecture can be loaded to
+  this table by using the
+  <a href="group__grp__keras__model__arch.html">load_keras_model</a> function
+  </DD>
+
+  <DT>model_arch_id</DT>
+  <DD>INTEGER. This is the id in 'model_arch_table'containing the model
+  architecture and model weights to use for prediction.
+  </DD>
+
+  <DT>test_table</DT>
+  <DD>TEXT. Name of the table containing the dataset to
+  predict on.  Note that test data is not preprocessed (unlike
+  fit and evaluate) so put one test image per row for prediction.
+  Also see the comment below for the 'independent_varname' parameter
+  regarding normalization.
+
+  </DD>
+
+  <DT>id_col</DT>
+  <DD>TEXT. Name of the id column in the test data table.
+  </DD>
+
+  <DT>independent_varname</DT>
+  <DD>TEXT. Column with independent variables in the test table.
+  If a 'normalizing_const' is specified when preprocessing the
+  training dataset, this same normalization will be applied to
+  the independent variables used in predict.
+  </DD>
+
+  <DT>output_table</DT>
+  <DD>TEXT. Name of the table that prediction output will be
+  written to. Table contains:</DD>
+    <table class="output">
+      <tr>
+        <th>id</th>
+        <td>Gives the 'id' for each prediction, corresponding to each row from 
the test_table.</td>
+      </tr>
+      <tr>
+        <th>estimated_dependent_var</th>
+        <td>
+        (For pred_type='response') The estimated class for classification. If
+        class_values is passed in as NULL, then we assume that the class
+        labels are [0,1,2...,n] where n in the num of classes in the model
+        architecture.
+        </td>
+      </tr>
+      <tr>
+        <th>prob_CLASS</th>
+        <td>
+         (For pred_type='prob' for classification)
+         The probability of a given class.
+         If class_values is passed in as NULL, we create just one column called
+         'prob' which is an array of probabilities of all the classes.
+         Otherwise if class_values is not NULL, then there will be one
+         column for each class in the training data.
+        </td>
+      </tr>
+
+  <DT>pred_type (optional)</DT>
+  <DD>TEXT, default: 'response'. The type of output desired, where 'response'
+   gives the actual prediction and 'prob' gives the probability value for each 
class.
+  </DD>
+
+  <DT>gpus_per_host (optional)</DT>
+  <DD>INTEGER, default: 0 (i.e., CPU).
+    Number of GPUs per segment host to be used for training the neural network.
+    For example, if you specify 4 for this parameter and your database cluster
+    is set up to have 4 segments per segment host, it means that each segment
+    will have a dedicated GPU. A value of 0 means that CPUs, not GPUs, will
+    be used for training.
+
+    @note
+    We have seen some memory related issues when segments share GPU resources.
+    For example, if you specify 1 for this parameter and your database cluster
+    is set up to have 4 segments per segment host, it means that all 4 segments
+     on a segment host will share the same GPU. The current recommended
+     configuration is 1 GPU per segment.
+  </DD>
+
+  <DT>class_values (optional)</DT>
+  <DD>TEXT[], default: NULL.
+    List of class labels that were used while training the model. See the
+    output_table column for more details.
+  </DD>
+
+  <DT>normalizing_const (optional)</DT>
+  <DD>DOUBLE PRECISION, default: 1.0.
+  The normalizing constant to divide each value in the 'independent_varname'
+  array by. For example, you would use 255 for this value if the image data is
+  in the form 0-255.
+  </DD>
+</DL>
+
+
 @anchor example
 @par Examples
 
@@ -814,7 +936,6 @@ SELECT COUNT(*) FROM iris_train;
 
 -# Call the preprocessor for deep learning.  For the training dataset:
 <pre class="example">
-DROP TABLE IF EXISTS mlp_prediction;
 \\x off
 DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
 SELECT madlib.training_preprocessor_dl('iris_train',         -- Source table
@@ -1049,6 +1170,142 @@ WHERE q.actual=q.estimated;
 (1 row)
 </pre>
 
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+-# Predict BYOM.
+We will use the validation dataset for prediction
+as well, which is not usual but serves to show the
+syntax. See <a href="group__grp__keras__model__arch.html">load_keras_model</a>
+for details on how to load the model architecture and weights.
+
+
+The prediction is in the 'estimated_dependent_var'
+column:
+<pre class="example">
+UPDATE model_arch_library set model_weights = (select model_data from 
iris_model) WHERE model_id = 1;
+
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib.madlib_keras_predict_byom('model_arch_library', -- model arch 
table
+                                   1, -- model arch id
+                                   'iris_test',  -- test_table
+                                   'id',  -- id column
+                                   'attributes', -- independent var
+                                   'iris_predict_byom',  -- output table
+                                   'response', -- pred_type
+                                   0, -- gpus_per_host
+                                   ARRAY['Iris-setosa', 'Iris-versicolor',
+                                   'Iris-virginica'], -- class_values
+                                   1.0 -- normalizing_const
+                                   );
+SELECT * FROM iris_predict_byom ORDER BY id;
+</pre>
+<pre class="result">
+ id  | estimated_dependent_var
+-----+-------------------------
+   1 | Iris-setosa
+   4 | Iris-setosa
+   9 | Iris-setosa
+  27 | Iris-setosa
+  32 | Iris-setosa
+  35 | Iris-setosa
+  40 | Iris-setosa
+  41 | Iris-setosa
+  44 | Iris-setosa
+  46 | Iris-setosa
+  55 | Iris-versicolor
+  56 | Iris-versicolor
+  66 | Iris-versicolor
+  69 | Iris-versicolor
+  75 | Iris-versicolor
+  76 | Iris-versicolor
+ 102 | Iris-virginica
+ 105 | Iris-virginica
+ 108 | Iris-virginica
+ 113 | Iris-virginica
+ 115 | Iris-virginica
+ 116 | Iris-virginica
+ 118 | Iris-virginica
+ 119 | Iris-virginica
+ 122 | Iris-virginica
+ 125 | Iris-virginica
+ 133 | Iris-virginica
+ 134 | Iris-virginica
+ 135 | Iris-virginica
+ 138 | Iris-virginica
+ </pre>
+Count missclassifications:
+<pre class="example">
+SELECT COUNT(*) FROM iris_predict_byom JOIN iris_test USING (id)
+WHERE iris_predict_byom.estimated_dependent_var != iris_test.class_text;
+</pre>
+<pre class="result">
+ count
+-------+
+     6
+(1 row)
+</pre>
+Percent missclassifications:
+<pre class="example">
+SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from
+    (select iris_test.class_text as actual, 
iris_predict_byom.estimated_dependent_var as estimated
+     from iris_predict_byom inner join iris_test
+     on iris_test.id=iris_predict_byom.id) q
+WHERE q.actual=q.estimated;
+</pre>
+<pre class="result">
+ test_accuracy_percent
+-----------------------+
+                 80.00
+(1 row)
+</pre>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
 <h4>Classification with Other Parameters</h4>
 
 -# Validation dataset.  Now use a validation dataset
@@ -1571,7 +1828,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.madlib_keras_predict(
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     with AOControl(False):
-        madlib_keras_predict.predict(schema_madlib,
+        madlib_keras_predict.Predict(schema_madlib,
                model_table,
                test_table,
                id_col,
@@ -1622,6 +1879,82 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.internal_keras_predict(
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+-------------------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    test_table              VARCHAR,
+    id_col                  VARCHAR,
+    independent_varname     VARCHAR,
+    output_table            VARCHAR,
+    pred_type               VARCHAR,
+    gpus_per_host           INTEGER,
+    class_values            TEXT[],
+    normalizing_const       DOUBLE PRECISION
+) RETURNS VOID AS $$
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
+    with AOControl(False):
+        madlib_keras_predict.PredictBYOM(**globals())
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    test_table              VARCHAR,
+    id_col                  VARCHAR,
+    independent_varname     VARCHAR,
+    output_table            VARCHAR,
+    pred_type               VARCHAR,
+    gpus_per_host           INTEGER,
+    class_values            TEXT[]
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, $7, 
$8, $9, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    test_table              VARCHAR,
+    id_col                  VARCHAR,
+    independent_varname     VARCHAR,
+    output_table            VARCHAR,
+    pred_type               VARCHAR,
+    gpus_per_host           INTEGER
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, $7, 
$8, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    test_table              VARCHAR,
+    id_col                  VARCHAR,
+    independent_varname     VARCHAR,
+    output_table            VARCHAR,
+    pred_type               VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, $7, 
NULL, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    test_table              VARCHAR,
+    id_col                  VARCHAR,
+    independent_varname     VARCHAR,
+    output_table            VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, 
NULL, NULL, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+-------------------------------------------------------------------------------
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_evaluate(
     model_table             VARCHAR,
     test_table              VARCHAR,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 17bdda4..e8218a6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -45,6 +45,8 @@ MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL = "independent_var"
 FLOAT32_SQL_TYPE = 'REAL'
 SMALLINT_SQL_TYPE = 'SMALLINT'
 
+DEFAULT_NORMALIZING_CONST = 1.0
+
 #####################################################################
 
 # Prepend a dimension to np arrays using expand_dims.
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index ca7a9ad..819ff98 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -18,7 +18,6 @@
 # under the License.
 
 import plpy
-import os
 
 import keras
 from keras import backend as K
@@ -26,82 +25,190 @@ from keras.layers import *
 from keras.models import *
 from keras.optimizers import *
 
+from model_arch_info import *
 from madlib_keras_helper import *
-from madlib_keras_validator import PredictInputValidator
+from madlib_keras_validator import *
 from predict_input_params import PredictParamsProcessor
 from utilities.control import MinWarning
-from utilities.model_arch_info import get_input_shape
+from utilities.utilities import _assert
 from utilities.utilities import add_postfix
 from utilities.utilities import create_cols_from_array_sql_string
 from utilities.utilities import get_segments_per_host
-from utilities.utilities import is_platform_pg
 from utilities.utilities import unique_string
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import quote_ident
 
 from madlib_keras_wrapper import *
 
-MODULE_NAME = 'madlib_keras_predict'
+class BasePredict():
+    def __init__(self, schema_madlib, table_to_validate, test_table, id_col,
+                 independent_varname, output_table, pred_type, gpus_per_host):
+        self.schema_madlib = schema_madlib
+        self.table_to_validate = table_to_validate
+        self.test_table = test_table
+        self.id_col = id_col
+        self.independent_varname = independent_varname
+        self.output_table = output_table
+        self.pred_type = pred_type
+        self.gpus_per_host = gpus_per_host
+        self._set_default_gpus_pred_type()
+
+    def _set_default_gpus_pred_type(self):
+        self.pred_type =  'response' if not self.pred_type else self.pred_type
+        self.is_response = True if self.pred_type == 'response' else False
+        self.gpus_per_host = 0 if self.gpus_per_host is None else 
self.gpus_per_host
+
+
+    def call_internal_keras(self):
+        if self.is_response:
+            pred_col_name = add_postfix("estimated_", self.dependent_varname)
+            pred_col_type = self.dependent_vartype
+        else:
+            pred_col_name = "prob"
+            pred_col_type = 'double precision'
+
+        intermediate_col = unique_string()
+        class_values = 
strip_trailing_nulls_from_class_values(self.class_values)
+
+        prediction_select_clause = create_cols_from_array_sql_string(
+            class_values, intermediate_col, pred_col_name,
+            pred_col_type, self.is_response, self.module_name)
+        gp_segment_id_col, seg_ids_test, \
+        images_per_seg_test = 
get_image_count_per_seg_for_non_minibatched_data_from_db(
+            self.test_table)
+        segments_per_host = get_segments_per_host()
+
+        predict_query = plpy.prepare("""
+            CREATE TABLE {self.output_table} AS
+            SELECT {self.id_col}, {prediction_select_clause}
+            FROM (
+                SELECT {self.test_table}.{self.id_col},
+                       ({self.schema_madlib}.internal_keras_predict
+                           ({self.independent_varname},
+                            $1,
+                            $2,
+                            {self.is_response},
+                            {self.normalizing_const},
+                            {gp_segment_id_col},
+                            ARRAY{seg_ids_test},
+                            ARRAY{images_per_seg_test},
+                            {self.gpus_per_host},
+                            {segments_per_host})
+                       ) AS {intermediate_col}
+            FROM {self.test_table}
+            ) q
+            """.format(self=self, 
prediction_select_clause=prediction_select_clause,
+                       seg_ids_test=seg_ids_test,
+                       images_per_seg_test=images_per_seg_test,
+                       gp_segment_id_col=gp_segment_id_col,
+                       segments_per_host=segments_per_host,
+                       intermediate_col=intermediate_col),
+                                     ["text", "bytea"])
+        plpy.execute(predict_query, [self.model_arch, self.model_weights])
+
+    def set_default_class_values(self, class_values):
+        self.class_values = class_values
+        if self.pred_type == 'prob':
+            return
+        if self.class_values is None:
+            num_classes = get_num_classes(self.model_arch)
+            self.class_values = range(0, num_classes)
 
 @MinWarning("warning")
-def predict(schema_madlib, model_table, test_table, id_col,
-            independent_varname, output_table, pred_type, gpus_per_host, 
**kwargs):
-    if not pred_type:
-        pred_type = 'response'
-    input_validator = PredictInputValidator(
-        test_table, model_table, id_col, independent_varname,
-        output_table, pred_type, MODULE_NAME)
-
-    param_proc = PredictParamsProcessor(model_table, MODULE_NAME)
-    class_values = param_proc.get_class_values()
-    input_validator.validate_pred_type(class_values)
-    dependent_varname = param_proc.get_dependent_varname()
-    dependent_vartype = param_proc.get_dependent_vartype()
-    model_data = param_proc.get_model_data()
-    model_arch = param_proc.get_model_arch()
-    normalizing_const = param_proc.get_normalizing_const()
-    input_shape = get_input_shape(model_arch)
-    input_validator.validate_input_shape(input_shape)
-
-    is_response = True if pred_type == 'response' else False
-    intermediate_col = unique_string()
-    if is_response:
-        pred_col_name = add_postfix("estimated_", dependent_varname)
-        pred_col_type = dependent_vartype
-    else:
-        pred_col_name = "prob"
-        pred_col_type = 'double precision'
-
-    class_values = strip_trailing_nulls_from_class_values(class_values)
-
-    prediction_select_clause = create_cols_from_array_sql_string(
-        class_values, intermediate_col, pred_col_name,
-        pred_col_type, is_response, MODULE_NAME)
-
-    gp_segment_id_col, seg_ids_test, \
-    images_per_seg_test = 
get_image_count_per_seg_for_non_minibatched_data_from_db(test_table)
-    segments_per_host = get_segments_per_host()
-
-    predict_query = plpy.prepare("""
-        CREATE TABLE {output_table} AS
-        SELECT {id_col}, {prediction_select_clause}
-        FROM (
-            SELECT {test_table}.{id_col},
-                   ({schema_madlib}.internal_keras_predict
-                       ({independent_varname},
-                        $1,
-                        $2,
-                        {is_response},
-                        {normalizing_const},
-                        {gp_segment_id_col},
-                        ARRAY{seg_ids_test},
-                        ARRAY{images_per_seg_test},
-                        {gpus_per_host},
-                        {segments_per_host})
-                   ) AS {intermediate_col}
-        FROM {test_table}
-        ) q
-        """.format(**locals()), ["text", "bytea"])
-    plpy.execute(predict_query, [model_arch, model_data])
+class Predict(BasePredict):
+    def __init__(self, schema_madlib, model_table,
+                 test_table, id_col, independent_varname,
+                 output_table, pred_type, gpus_per_host,
+                 **kwargs):
+
+        self.module_name = 'madlib_keras_predict'
+        self.model_table = model_table
+        if self.model_table:
+            self.model_summary_table = add_postfix(self.model_table, 
"_summary")
+
+        BasePredict.__init__(self, schema_madlib, model_table, test_table,
+                              id_col, independent_varname,
+                              output_table, pred_type,
+                              gpus_per_host)
+        param_proc = PredictParamsProcessor(model_table, self.module_name)
+        self.dependent_vartype = param_proc.get_dependent_vartype()
+        self.model_weights = param_proc.get_model_data()
+        self.model_arch = param_proc.get_model_arch()
+        class_values = param_proc.get_class_values()
+        self.set_default_class_values(class_values)
+        self.normalizing_const = param_proc.get_normalizing_const()
+        self.dependent_varname = param_proc.get_dependent_varname()
+
+        self.validate()
+        BasePredict.call_internal_keras(self)
+
+    def validate(self):
+        InputValidator.validate_predict_evaluate_tables(
+            self.module_name, self.model_table, self.model_summary_table,
+            self.test_table, self.output_table, self.independent_varname)
+
+        InputValidator.validate_id_in_test_tbl(
+            self.module_name, self.test_table, self.id_col)
+
+        InputValidator.validate_class_values(
+            self.module_name, self.class_values, self.pred_type, 
self.model_arch)
+        input_shape = get_input_shape(self.model_arch)
+        InputValidator.validate_pred_type(
+            self.module_name, self.pred_type, self.class_values)
+        InputValidator.validate_input_shape(
+            self.test_table, self.independent_varname, input_shape, 1)
+
+@MinWarning("warning")
+class PredictBYOM(BasePredict):
+    def __init__(self, schema_madlib, model_arch_table, model_arch_id,
+                 test_table, id_col, independent_varname, output_table,
+                 pred_type, gpus_per_host, class_values, normalizing_const,
+                 **kwargs):
 
+        self.module_name='madlib_keras_predict_byom'
+        self.model_arch_table = model_arch_table
+        self.model_arch_id = model_arch_id
+        self.class_values = class_values
+        self.normalizing_const = normalizing_const
+        self.dependent_varname = 'dependent_var'
+        BasePredict.__init__(self, schema_madlib, model_arch_table,
+                             test_table, id_col, independent_varname,
+                             output_table, pred_type, gpus_per_host)
+        if self.is_response:
+            self.dependent_vartype = 'text'
+        else:
+            self.dependent_vartype = 'double precision'
+        ## Set default values for norm const and class_values
+        # gpus_per_host and pred_type are defaulted in base_predict's init
+        self.normalizing_const = normalizing_const
+        if self.normalizing_const is None:
+            self.normalizing_const = DEFAULT_NORMALIZING_CONST
+        InputValidator.validate_predict_byom_tables(
+            self.module_name, self.model_arch_table, self.model_arch_id,
+            self.test_table, self.id_col, self.output_table,
+            self.independent_varname)
+        self.validate_and_set_defaults()
+        BasePredict.call_internal_keras(self)
+
+    def validate_and_set_defaults(self):
+        # Set some defaults first and then validate and then set some more 
defaults
+        self.model_arch, self.model_weights = get_model_arch_weights(
+            quote_ident(self.model_arch_table), self.model_arch_id)
+        # Assert model_weights and model_arch are not empty.
+        _assert(self.model_weights and self.model_arch,
+                "{0}: Model weights and architecture should not be 
NULL.".format(
+                    self.module_name))
+        self.set_default_class_values(self.class_values)
+
+        InputValidator.validate_pred_type(
+            self.module_name, self.pred_type, self.class_values)
+        InputValidator.validate_normalizing_const(
+            self.module_name, self.normalizing_const)
+        InputValidator.validate_class_values(
+            self.module_name, self.class_values, self.pred_type, 
self.model_arch)
+        InputValidator.validate_input_shape(
+            self.test_table, self.independent_varname,
+            get_input_shape(self.model_arch), 1)
 
 def internal_keras_predict(independent_var, model_architecture, model_data,
                            is_response, normalizing_const, current_seg_id, 
seg_ids,
@@ -216,9 +323,86 @@ estimated_COL_NAME: (For pred_type='response') The 
estimated class for
 prob_CLASS:         (For pred_type='prob' for classification) The
                     probability of a given class. There will be one
                     column for each class in the training data.
+                    TODO change this
 """
     else:
         help_string = "No such option. Use 
{schema_madlib}.madlib_keras_predict()"
 
     return help_string.format(schema_madlib=schema_madlib)
+
+def predict_byom_help(schema_madlib, message, **kwargs):
+    """
+    Help function for keras predict
+
+    Args:
+        @param schema_madlib
+        @param message: string, Help message string
+        @param kwargs
+
+    Returns:
+        String. Help/usage information
+    """
+    if not message:
+        help_string = """
+-----------------------------------------------------------------------
+                            SUMMARY
+-----------------------------------------------------------------------
+This function allows the user to predict with their own pre trained model (note
+that this model doesn't have to be trained using MADlib.)
+
+For more details on function usage:
+    SELECT {schema_madlib}.madlib_keras_predict_byom('usage')
+            """
+    elif message in ['usage', 'help', '?']:
+        help_string = """
+-----------------------------------------------------------------------
+                            USAGE
+-----------------------------------------------------------------------
+ SELECT {schema_madlib}.madlib_keras_predict_byom(
+    model_arch_table, -- Name of the table containing the model architecture
+                            and the pre trained model weights
+    model_arch_id,    -- This is the id in 'model_arch_table' containing the
+                         model architecture
+    test_table,     --  Name of the table containing the evaluation dataset
+    id_col,         --  Name of the id column in the test data table
+    independent_varname,    --  Name of the column with independent
+                                variables in the test table
+    output_table,   --  Name of the output table
+    pred_type,      --  The type of the desired output
+    gpus_per_host,   --  Number of GPUs per segment host to
+                        be used for training
+    class_values,     -- List of class labels that were used while training 
the 
+                         model. If class_values is passed in as NULL, the 
output
+                         table will have a column named 'prob' which is an 
array
+                         of probabilities of all the classes.
+                         Otherwise if class_values is not NULL, then the output
+                         table will contain a column for each class/label from
+                         the training data
+    normalizing_const -- Normalizing constant used for standardizing arrays in 
+                         independent_varname
+    )
+ );
+
+-----------------------------------------------------------------------
+                            OUTPUT
+-----------------------------------------------------------------------
+The output table ('output_table' above) contains the following columns:
+
+id:                 Gives the 'id' for each prediction, corresponding
+                    to each row from the test_table.
+estimated_dependent_var: (For pred_type='response') The estimated class for
+                    classification. If class_values is passed in as NULL, then 
we
+                    assume that the class labels are [0,1,2...,n] where n in 
the
+                    num of classes in the model architecture.
+prob_CLASS:         (For pred_type='prob' for classification) The
+                    probability of a given class.
+                    If class_values is passed in as NULL, we create just one 
column
+                    called 'prob' which is an array of probabilites of all the 
classes
+                    Otherwise if class_values is not NULL, then there will be 
one
+                    column for each class in the training data.
+"""
+    else:
+        help_string = "No such option. Use 
{schema_madlib}.madlib_keras_predict_byom()"
+
+    return help_string.format(schema_madlib=schema_madlib)
 # ---------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index b111fc4..e9d7d14 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -19,6 +19,7 @@
 
 import plpy
 from keras_model_arch_table import ModelArchSchema
+from model_arch_info import get_input_shape, get_num_classes
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import COMPILE_PARAMS_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
@@ -45,182 +46,177 @@ from utilities.validate_args import get_expr_type
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
 
-
-def _validate_input_shapes(table, independent_varname, input_shape, offset):
-    """
-    Validate if the input shape specified in model architecture is the same
-    as the shape of the image specified in the indepedent var of the input
-    table.
-    offset: This offset is the index of the start of the image array. We also
-    need to consider that sql array indexes start from 1
-    For ex if the image is of shape [32,32,3] and is minibatched, the image 
will
-    look like [10, 32, 32, 3]. The offset in this case is 1 (start the index 
at 1) +
-    1 (ignore the buffer size 10) = 2.
-    If the image is not batched then it will look like [32, 32 ,3] and the 
offset in
-    this case is 1 (start the index at 1).
-    """
-    array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
-        independent_varname, i+offset, i) for i in range(len(input_shape)))
-    query = """
-        SELECT {0}
-        FROM {1}
-        LIMIT 1
-    """.format(array_upper_query, table)
-    # This query will fail if an image in independent var does not have the
-    # same number of dimensions as the input_shape.
-    result = plpy.execute(query)[0]
-    _assert(len(result) == len(input_shape),
-        "model_keras error: The number of dimensions ({0}) of each image"
-        " in model architecture and {1} in {2} ({3}) do not match.".format(
-            len(input_shape), independent_varname, table, len(result)))
-    for i in range(len(input_shape)):
-        key_name = "n_{0}".format(i)
-        if result[key_name] != input_shape[i]:
-            # Construct the shape in independent varname to display
-            # meaningful error msg.
-            input_shape_from_table = [result["n_{0}".format(i)]
-                for i in range(len(input_shape))]
-            plpy.error("model_keras error: Input shape {0} in the model"
-                " architecture does not match the input shape {1} of column"
-                " {2} in table {3}.".format(
-                    input_shape, input_shape_from_table,
-                    independent_varname, table))
-
 class InputValidator:
-    def __init__(self, test_table, model_table, independent_varname,
-                 output_table, module_name):
-        self.test_table = test_table
-        self.model_table = model_table
-        self.independent_varname = independent_varname
-        self.output_table = output_table
-        if self.model_table:
-            self.model_summary_table = add_postfix(
-                self.model_table, "_summary")
-        self.module_name = module_name
-        self._validate_input_args()
-
-    def _validate_input_args(self):
-        input_tbl_valid(self.model_table, self.module_name)
-        self._validate_model_data_cols()
-        input_tbl_valid(self.model_summary_table, self.module_name)
-        self._validate_model_summary_tbl_cols()
-        input_tbl_valid(self.test_table, self.module_name)
-        self._validate_test_tbl_cols()
-        output_tbl_valid(self.output_table, self.module_name)
-
-
-    def _validate_model_data_cols(self):
-        _assert(is_var_valid(self.model_table, MODEL_DATA_COLNAME),
+    @staticmethod
+    def validate_predict_evaluate_tables(
+        module_name, model_table, model_summary_table, test_table, 
output_table,
+        independent_varname):
+        InputValidator._validate_model_data_tbl(module_name, model_table)
+        InputValidator._validate_model_summary_tbl(
+            module_name, model_summary_table)
+        InputValidator._validate_test_tbl(
+            module_name, test_table, independent_varname)
+        output_tbl_valid(output_table, module_name)
+
+    @staticmethod
+    def validate_id_in_test_tbl(module_name, test_table, id_col):
+        _assert(is_var_valid(test_table, id_col),
+                "{module_name} error: invalid id column "
+                "('{id_col}') for test table ({table}).".format(
+                    module_name=module_name,
+                    id_col=id_col,
+                    table=test_table))
+
+    @staticmethod
+    def validate_predict_byom_tables(module_name, model_arch_table, 
model_arch_id,
+                                     test_table, id_col, output_table,
+                                     independent_varname):
+        InputValidator.validate_model_arch_table(
+            module_name, model_arch_table, model_arch_id)
+        InputValidator._validate_test_tbl(
+            module_name, test_table, independent_varname)
+        InputValidator.validate_id_in_test_tbl(module_name, test_table, id_col)
+
+        output_tbl_valid(output_table, module_name)
+
+
+    @staticmethod
+    def validate_pred_type(module_name, pred_type, class_values):
+        if not pred_type in ['prob', 'response']:
+            plpy.error("{0}: Invalid value for pred_type param ({1}). Must be 
"\
+                "either response or prob.".format(module_name, pred_type))
+
+
+    @staticmethod
+    def validate_input_shape(table, independent_varname, input_shape, offset):
+        """
+        Validate if the input shape specified in model architecture is the same
+        as the shape of the image specified in the indepedent var of the input
+        table.
+        offset: This offset is the index of the start of the image array. We 
also
+        need to consider that sql array indexes start from 1
+        For ex if the image is of shape [32,32,3] and is minibatched, the 
image will
+        look like [10, 32, 32, 3]. The offset in this case is 1 (start the 
index at 1) +
+        1 (ignore the buffer size 10) = 2.
+        If the image is not batched then it will look like [32, 32 ,3] and the 
offset in
+        this case is 1 (start the index at 1).
+        """
+        array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
+            independent_varname, i+offset, i) for i in range(len(input_shape)))
+        query = """
+            SELECT {0}
+            FROM {1}
+            LIMIT 1
+        """.format(array_upper_query, table)
+        # This query will fail if an image in independent var does not have the
+        # same number of dimensions as the input_shape.
+        result = plpy.execute(query)[0]
+        _assert(len(result) == len(input_shape),
+            "model_keras error: The number of dimensions ({0}) of each image"
+            " in model architecture and {1} in {2} ({3}) do not match.".format(
+                len(input_shape), independent_varname, table, len(result)))
+        for i in range(len(input_shape)):
+            key_name = "n_{0}".format(i)
+            if result[key_name] != input_shape[i]:
+                # Construct the shape in independent varname to display
+                # meaningful error msg.
+                input_shape_from_table = [result["n_{0}".format(i)]
+                    for i in range(len(input_shape))]
+                plpy.error("model_keras error: Input shape {0} in the model"
+                    " architecture does not match the input shape {1} of 
column"
+                    " {2} in table {3}.".format(
+                        input_shape, input_shape_from_table,
+                        independent_varname, table))
+
+    @staticmethod
+    def validate_model_arch_table(module_name, model_arch_table, 
model_arch_id):
+        input_tbl_valid(model_arch_table, module_name)
+        _assert(model_arch_id is not None,
+            "{0}: Invalid model architecture ID.".format(module_name))
+
+
+    @staticmethod
+    def validate_normalizing_const(module_name, normalizing_const):
+        _assert(normalizing_const > 0,
+                "{0} error: Normalizing constant has to be greater than 0.".
+                format(module_name))
+
+    @staticmethod
+    def validate_class_values(module_name, class_values, pred_type, 
model_arch):
+        if not class_values:
+            return
+        num_classes = len(class_values)
+        _assert(num_classes == get_num_classes(model_arch),
+                "{0}: The number of class values do not match the " \
+                "provided architecture.".format(module_name))
+        if pred_type == 'prob' and num_classes+1 >= 1600:
+            plpy.error({"{0}: The output will have {1} columns, exceeding the 
"\
+                " max number of columns that can be created (1600)".format(
+                    module_name, num_classes+1)})
+
+    @staticmethod
+    def validate_model_weights(module_name, model_arch, model_weights):
+        _assert(model_weights and model_arch,
+                "{0}: Model weights and architecture must be valid.".format(
+                    module_name))
+
+    @staticmethod
+    def _validate_model_data_tbl(module_name, model_table):
+        input_tbl_valid(model_table, module_name)
+        _assert(is_var_valid(model_table, MODEL_DATA_COLNAME),
                 "{module_name} error: column '{model_data}' "
                 "does not exist in model table '{table}'.".format(
-                    module_name=self.module_name,
+                    module_name=module_name,
                     model_data=MODEL_DATA_COLNAME,
-                    table=self.model_table))
-        _assert(is_var_valid(self.model_table, ModelArchSchema.MODEL_ARCH),
+                    table=model_table))
+        _assert(is_var_valid(model_table, ModelArchSchema.MODEL_ARCH),
                 "{module_name} error: column '{model_arch}' "
                 "does not exist in model table '{table}'.".format(
-                    module_name=self.module_name,
+                    module_name=module_name,
                     model_arch=ModelArchSchema.MODEL_ARCH,
-                    table=self.model_table))
+                    table=model_table))
 
-    def _validate_test_tbl_cols(self):
-        _assert(is_var_valid(self.test_table, self.independent_varname),
+    @staticmethod
+    def _validate_test_tbl(module_name, test_table, independent_varname):
+        input_tbl_valid(test_table, module_name)
+        _assert(is_var_valid(test_table, independent_varname),
                 "{module_name} error: invalid independent_varname "
                 "('{independent_varname}') for test table "
                 "({table}).".format(
-                    module_name=self.module_name,
-                    independent_varname=self.independent_varname,
-                    table=self.test_table))
+                    module_name=module_name,
+                    independent_varname=independent_varname,
+                    table=test_table))
 
-    def _validate_model_summary_tbl_cols(self):
+    @staticmethod
+    def _validate_model_summary_tbl(module_name, model_summary_table):
+        input_tbl_valid(model_summary_table, module_name)
         cols_to_check_for = [CLASS_VALUES_COLNAME,
                              DEPENDENT_VARNAME_COLNAME,
                              DEPENDENT_VARTYPE_COLNAME,
                              MODEL_ARCH_ID_COLNAME,
                              MODEL_ARCH_TABLE_COLNAME,
-                             NORMALIZING_CONST_COLNAME]
-        _assert(columns_exist_in_table(
-            self.model_summary_table, cols_to_check_for),
-            "{0} error: One or more expected columns missing in model "
-            "summary table ('{1}'). The expected columns are {2}.".format(
-                self.module_name, self.model_summary_table, cols_to_check_for))
-
-class EvaluateInputValidator(InputValidator):
-    def __init__(self, test_table, model_table, output_table, module_name):
-        self.test_summary_table = None
-        if test_table:
-            self.test_summary_table = add_postfix(test_table, "_summary")
-
-        self.independent_varname = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
-        InputValidator.__init__(self, test_table, model_table,
-                                self.independent_varname,
-                                output_table, module_name)
-
-    def _validate_input_args(self):
-        input_tbl_valid(self.test_summary_table, self.module_name,
-                        error_suffix_str="Please ensure that the test table 
({0}) "
-                                         "has been preprocessed by "
-                                         "the image 
preprocessor.".format(self.test_table))
-        self._validate_test_summary_tbl_cols()
-        InputValidator._validate_input_args(self)
-        validate_dependent_var_for_minibatch(self.test_table,
-                                             
MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL)
-
-    def _validate_model_summary_tbl_cols(self):
-        cols_to_check_for = [COMPILE_PARAMS_COLNAME, METRIC_TYPE_COLNAME]
+                             NORMALIZING_CONST_COLNAME,
+                             COMPILE_PARAMS_COLNAME,
+                             METRIC_TYPE_COLNAME]
         _assert(columns_exist_in_table(
-            self.model_summary_table, cols_to_check_for),
+            model_summary_table, cols_to_check_for),
             "{0} error: One or more expected columns missing in model "
             "summary table ('{1}'). The expected columns are {2}.".format(
-                self.module_name, self.model_summary_table, cols_to_check_for))
+                module_name, model_summary_table, cols_to_check_for))
 
-    def _validate_test_summary_tbl_cols(self):
-        cols_in_tbl_valid(self.test_summary_table, [CLASS_VALUES_COLNAME,
-            NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
-            DEPENDENT_VARNAME_COLNAME, INDEPENDENT_VARNAME_COLNAME], 
self.module_name)
-
-    def validate_input_shape(self, input_shape_from_arch):
-        _validate_input_shapes(self.test_table, self.independent_varname,
-                               input_shape_from_arch, 2)
-
-class PredictInputValidator(InputValidator):
-    def __init__(self, test_table, model_table, id_col, independent_varname,
-                 output_table, pred_type, module_name):
-        self.id_col = id_col
-        self.pred_type = pred_type
-        InputValidator.__init__(self, test_table, model_table, 
independent_varname,
-                               output_table, module_name)
-
-    def validate_pred_type(self, class_values):
-        if not self.pred_type in ['prob', 'response']:
-            plpy.error("{0}: Invalid value for pred_type param ({1}). Must be 
"\
-                "either response or prob.".format(self.module_name, 
self.pred_type))
-        if self.pred_type == 'prob' and class_values and len(class_values)+1 
>= 1600:
-            plpy.error({"{0}: The output will have {1} columns, exceeding the 
"\
-                " max number of columns that can be created (1600)".format(
-                    self.module_name, len(class_values)+1)})
 
-    def validate_input_shape(self, input_shape_from_arch):
-        _validate_input_shapes(self.test_table, self.independent_varname,
-                               input_shape_from_arch, 1)
 
-    def _validate_test_tbl_cols(self):
-        InputValidator._validate_test_tbl_cols(self)
-        _assert(is_var_valid(self.test_table, self.id_col),
-                "{module_name} error: invalid id column "
-                "('{id_col}') for test table ({table}).".format(
-                    module_name=self.module_name,
-                    id_col=self.id_col,
-                    table=self.test_table))
 
 class FitInputValidator:
     def __init__(self, source_table, validation_table, output_model_table,
-                 model_arch_table, dependent_varname, independent_varname,
-                 num_iterations, metrics_compute_frequency, warm_start):
+                 model_arch_table, model_arch_id, dependent_varname,
+                 independent_varname, num_iterations,
+                 metrics_compute_frequency, warm_start):
         self.source_table = source_table
         self.validation_table = validation_table
         self.output_model_table = output_model_table
         self.model_arch_table = model_arch_table
+        self.model_arch_id = model_arch_id
         self.dependent_varname = dependent_varname
         self.independent_varname = independent_varname
         self.metrics_compute_frequency = metrics_compute_frequency
@@ -236,30 +232,6 @@ class FitInputValidator:
         self.module_name = 'madlib_keras_fit'
         self._validate_input_args()
 
-    def _validate_input_table(self, table):
-        _assert(is_var_valid(table, self.independent_varname),
-                "{module_name}: invalid independent_varname "
-                "('{independent_varname}') for table ({table}). "
-                "Please ensure that the input table ({table}) "
-                "has been preprocessed by the image preprocessor.".format(
-                    module_name=self.module_name,
-                    independent_varname=self.independent_varname,
-                    table=table))
-
-        _assert(is_var_valid(table, self.dependent_varname),
-                "{module_name}: invalid dependent_varname "
-                "('{dependent_varname}') for table ({table}). "
-                "Please ensure that the input table ({table}) "
-                "has been preprocessed by the image preprocessor.".format(
-                    module_name=self.module_name,
-                    dependent_varname=self.dependent_varname,
-                    table=table))
-
-    def _is_valid_metrics_compute_frequency(self):
-        return self.metrics_compute_frequency is None or \
-               (self.metrics_compute_frequency >= 1 and \
-               self.metrics_compute_frequency <= self.num_iterations)
-
     def _validate_input_args(self):
         _assert(self.num_iterations > 0,
             "{0}: Number of iterations cannot be < 
1.".format(self.module_name))
@@ -281,8 +253,8 @@ class FitInputValidator:
                                              self.dependent_varname)
 
         self._validate_validation_table()
-
-        input_tbl_valid(self.model_arch_table, self.module_name)
+        InputValidator.validate_model_arch_table(self.module_name, 
self.model_arch_table,
+            self.model_arch_id)
         if self.warm_start:
             input_tbl_valid(self.output_model_table, self.module_name)
             input_tbl_valid(self.output_summary_model_table, self.module_name)
@@ -290,6 +262,31 @@ class FitInputValidator:
             output_tbl_valid(self.output_model_table, self.module_name)
             output_tbl_valid(self.output_summary_model_table, self.module_name)
 
+    def _validate_input_table(self, table):
+        _assert(is_var_valid(table, self.independent_varname),
+                "{module_name}: invalid independent_varname "
+                "('{independent_varname}') for table ({table}). "
+                "Please ensure that the input table ({table}) "
+                "has been preprocessed by the image preprocessor.".format(
+                    module_name=self.module_name,
+                    independent_varname=self.independent_varname,
+                    table=table))
+
+        _assert(is_var_valid(table, self.dependent_varname),
+                "{module_name}: invalid dependent_varname "
+                "('{dependent_varname}') for table ({table}). "
+                "Please ensure that the input table ({table}) "
+                "has been preprocessed by the image preprocessor.".format(
+                    module_name=self.module_name,
+                    dependent_varname=self.dependent_varname,
+                    table=table))
+
+    def _is_valid_metrics_compute_frequency(self):
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+               self.metrics_compute_frequency <= self.num_iterations)
+
+
 
     def _validate_validation_table(self):
         if self.validation_table and self.validation_table.strip() != '':
@@ -305,9 +302,10 @@ class FitInputValidator:
 
 
     def validate_input_shapes(self, input_shape):
-        _validate_input_shapes(self.source_table, self.independent_varname,
+        InputValidator.validate_input_shape(self.source_table, 
self.independent_varname,
                                input_shape, 2)
         if self.validation_table:
-            _validate_input_shapes(
+            InputValidator.validate_input_shape(
                 self.validation_table, self.independent_varname,
                 input_shape, 2)
+
diff --git a/src/ports/postgres/modules/utilities/model_arch_info.py_in 
b/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
similarity index 66%
rename from src/ports/postgres/modules/utilities/model_arch_info.py_in
rename to src/ports/postgres/modules/deep_learning/model_arch_info.py_in
index a03594a..c749144 100644
--- a/src/ports/postgres/modules/utilities/model_arch_info.py_in
+++ b/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
@@ -22,6 +22,7 @@ m4_changequote(`<!', `!>')
 import sys
 import json
 import plpy
+from keras_model_arch_table import ModelArchSchema
 
 def _get_layers(model_arch):
     d = json.loads(model_arch)
@@ -41,6 +42,22 @@ def get_input_shape(model_arch):
     plpy.error('Unable to get input shape from model architecture.')
 
 def get_num_classes(model_arch):
+    """
+     We assume that the last dense layer in the model architecture contains 
the num_classes (units)
+     An example can be:
+     ```
+     ...
+     model.add(Flatten())
+     model.add(Dense(512))
+     model.add(Activation('relu'))
+     model.add(Dropout(0.5))
+     model.add(Dense(num_classes))
+     model.add(Activation('softmax'))
+     ```
+     where activation can be after the dense layer.
+    :param model_arch:
+    :return:
+    """
     arch_layers = _get_layers(model_arch)
     i = len(arch_layers) - 1
     while i >= 0:
@@ -66,3 +83,22 @@ def get_model_arch_layers_str(model_arch):
         else:
             layers += "{1}\n".format(class_name)
     return layers
+
+def get_model_arch_weights(model_arch_table, model_arch_id):
+
+    #assume validation is already called
+    model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
+        ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS,
+        model_arch_table, ModelArchSchema.MODEL_ID,
+        model_arch_id)
+    model_arch_result = plpy.execute(model_arch_query)
+    if not model_arch_result:
+        plpy.error("no model arch found in table {0} with id {1}".format(
+            model_arch_table, model_arch_id))
+
+    model_arch_result = model_arch_result[0]
+
+    model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH]
+    model_weights = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
+
+    return model_arch, model_weights
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index dacf236..28a500e 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -383,21 +383,20 @@ SELECT madlib_keras_predict(
     0);
 
 -- Validate that prediction output table exists and has correct schema
-SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be 
INTEGER type')
-    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
-        AND attname = 'id';
+SELECT assert(UPPER(pg_typeof(id)::TEXT )= 'INTEGER',
+    'id column should be INTEGER type') FROM cifar10_predict;
 
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
     'SMALLINT', 'prediction column should be SMALLINT type')
-    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
-        AND attname = 'estimated_y';
+FROM cifar10_predict;
 
 -- Validate correct number of rows returned.
-SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have 
two rows') FROM cifar10_predict;
+SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have 
two rows')
+FROM cifar10_predict;
 
 -- First test that all values are in set of class values; if this breaks, it's 
definitely a problem.
 SELECT assert(estimated_y IN (0,1),
-              'Predicted value not in set of defined class values for model')
+    'Predicted value not in set of defined class values for model')
 FROM cifar10_predict;
 
 DROP TABLE IF EXISTS cifar10_predict;
@@ -512,15 +511,13 @@ SELECT madlib_keras_predict(
     'prob',
     0);
 
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof(prob_0)::TEXT) =
     'DOUBLE PRECISION', 'column prob_0 should be double precision type')
-    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
-        AND attname = 'prob_0';
+FROM cifar10_predict;
 
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof(prob_1)::TEXT) =
     'DOUBLE PRECISION', 'column prob_1 should be double precision type')
-    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
-        AND attname = 'prob_1';
+FROM cifar10_predict;
 
 SELECT assert(COUNT(*)=3, 'Predict out table must have exactly three cols.')
 FROM pg_attribute
@@ -616,20 +613,17 @@ SELECT madlib_keras_predict(
 -- Validate the output datatype of newly created prediction columns
 -- for prediction type = 'prob' and class_values 'TEXT' with NULL as a valid
 -- class_values
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof(prob_cat)::TEXT) =
     'DOUBLE PRECISION', 'column prob_cat should be double precision type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_cat';
+FROM cifar10_predict;
 
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof(prob_dog)::TEXT) =
     'DOUBLE PRECISION', 'column prob_dog should be double precision type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_dog';
+FROM cifar10_predict;
 
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
     'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_NULL';
+FROM cifar10_predict;
 
 -- Must have exactly 4 cols (3 for class_values and 1 for id)
 SELECT assert(COUNT(*)=4, 'Predict out table must have exactly four cols.')
@@ -650,11 +644,10 @@ SELECT madlib_keras_predict(
 -- Validate the output datatype of newly created prediction columns
 -- for prediction type = 'response' and class_values 'TEXT' with NULL
 -- as a valid class_values
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
-    'TEXT', 'prediction column should be TEXT type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass
-      AND attname = 'estimated_y';
+SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) = 'TEXT',
+       'prediction column should be TEXT type')
+FROM  cifar10_predict LIMIT 1;
+
 
 -- Tests where the assumption is user has one-hot encoded, so class_values
 -- in input summary table will be NULL.
@@ -674,10 +667,9 @@ SELECT madlib_keras_predict(
 -- Validate the output datatype of newly created prediction column
 -- for prediction type = 'response' and class_value = NULL
 -- Returns: Array of probabilities for user's one-hot encoded data
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
-    'DOUBLE PRECISION[]', 'column prob should be double precision[] type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob';
+SELECT assert(UPPER(pg_typeof(prob)::TEXT) = 'DOUBLE PRECISION[]',
+       'column prob should be double precision[] type')
+FROM  cifar10_predict LIMIT 1;
 
 -- Predict with pred_type=response
 DROP TABLE IF EXISTS cifar10_predict;
@@ -694,11 +686,14 @@ SELECT madlib_keras_predict(
 -- for prediction type = 'response' and class_value = NULL
 -- Returns: Index of class value in user's one-hot encoded data with
 -- highest probability
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
-    'DOUBLE PRECISION', 'prediction column should be double precision type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass
-      AND attname = 'estimated_y';
+SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) = 'TEXT',
+       'column estimated_y_text should be text type')
+FROM  cifar10_predict LIMIT 1;
+
+SELECT assert(
+  estimated_y_text IN ('0', '1'),
+  'Predict failure for null class value and response pred_type.')
+FROM cifar10_predict;
 
 -- Test predict with INTEGER class_values
 -- with NULL as a valid class value
@@ -747,13 +742,11 @@ SELECT madlib_keras_predict(
 -- Validate the output datatype of newly created prediction column
 -- for prediction type = 'prob' and class_values 'INT' with NULL
 -- as a valid class_values
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
     'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_NULL';
-
+FROM cifar10_predict;
 -- Must have exactly 6 cols (5 for class_values and 1 for id)
-SELECT assert(COUNT(*)=6, 'Predict out table must have exactly four cols.')
+SELECT assert(COUNT(*)=6, 'Predict out table must have exactly six cols.')
 FROM pg_attribute
 WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
 
@@ -772,10 +765,9 @@ SELECT madlib_keras_predict(
 -- for prediction type = 'response' and class_values 'TEXT' with NULL
 -- as a valid class_values
 -- Returns: class_value with highest probability
-SELECT assert(UPPER(atttypid::regtype::TEXT) =
+SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
     'SMALLINT', 'prediction column should be smallint type')
-FROM pg_attribute
-WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'estimated_y';
+FROM cifar10_predict;
 
 -- Test case with a different input shape (3, 32, 32) instead of (32, 32, 3).
 -- Create a new table with image shape 3, 32, 32
@@ -1066,6 +1058,27 @@ SELECT madlib_keras_fit('iris_data_packed',   -- source 
table
                          1 -- metrics_compute_frequency
                         );
 
+DROP TABLE IF EXISTS iris_train, iris_test;
+-- Set seed so results are reproducible
+SELECT setseed(0);
+SELECT train_test_split('iris_data',     -- Source table
+                        'iris',          -- Output table root name
+                        0.8,            -- Train proportion
+                        NULL,           -- Test proportion (0.2)
+                        NULL,           -- Strata definition
+                        NULL,           -- Output all columns
+                        NULL,           -- Sample without replacement
+                        TRUE            -- Separate output tables
+                        );
+
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict('iris_model', -- model
+                            'iris_test',  -- test_table
+                            'id',  -- id column
+                            'attributes', -- independent var
+                            'iris_predict'  -- output table
+                            );
+
 -- Test that our code is indeed learning something and not broken. The loss
 -- from the first iteration should be less than the 5th, while the accuracy
 -- must be greater.
@@ -1179,3 +1192,96 @@ SELECT assert(
   abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
   'Transfer learning test failed because training loss and metrics don''t 
match the expected value.')
 FROM iris_model_first_run AS first, iris_model_transfer_summary AS second;
+
+---------------------- Predict BYOM test --------------------------------
+
+-- class_values not NULL, pred_type is response
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom',
+                                 'response',
+                                 -1,
+                                 ARRAY['Iris-setosa', 'Iris-versicolor',
+                                  'Iris-virginica']
+                                 );
+
+SELECT assert(
+  p0.estimated_class_text = p1.estimated_dependent_var,
+  'Predict byom failure for non null class value and response pred_type.')
+FROM iris_predict AS p0,  iris_predict_byom AS p1
+WHERE p0.id=p1.id;
+SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
+       'Predict byom failure for non null class value and response pred_type.
+        Expeceted estimated_dependent_var to be of type TEXT')
+FROM  iris_predict_byom LIMIT 1;
+
+-- class_values NULL, pred_type is NULL (response)
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom'
+                                 );
+SELECT assert(
+  p1.estimated_dependent_var IN ('0', '1', '2'),
+  'Predict byom failure for null class value and null pred_type.')
+FROM iris_predict_byom AS p1;
+SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
+       'Predict byom failure for non null class value and response pred_type.
+        Expeceted estimated_dependent_var to be of type TEXT')
+FROM  iris_predict_byom LIMIT 1;
+
+-- class_values not NULL, pred_type is prob
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom',
+                                 'prob',
+                                 -1,
+                                 ARRAY['Iris-setosa', 'Iris-versicolor',
+                                  'Iris-virginica'],
+                                 1.0
+                                 );
+
+SELECT assert(
+  (p1."prob_Iris-setosa" + p1."prob_Iris-virginica" + 
p1."prob_Iris-versicolor") - 1 < 1e-6,
+    'Predict byom failure for non null class value and prob pred_type.')
+FROM iris_predict_byom AS p1;
+SELECT assert(UPPER(pg_typeof("prob_Iris-setosa")::TEXT) = 'DOUBLE PRECISION',
+       'Predict byom failure for non null class value and prob pred_type.
+       Expeceted "prob_Iris-setosa" to be of type DOUBLE PRECISION')
+FROM  iris_predict_byom LIMIT 1;
+
+-- class_values NULL, pred_type is prob
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom',
+                                 'prob',
+                                 0,
+                                 NULL
+                                 );
+SELECT assert(
+  (prob[1] + prob[2] + prob[3]) - 1 < 1e-6,
+    'Predict byom failure for null class value and prob pred_type.')
+FROM iris_predict_byom;
+SELECT assert(UPPER(pg_typeof(prob)::TEXT) = 'DOUBLE PRECISION[]',
+       'Predict byom failure for null class value and prob pred_type. 
Expeceted prob to
+       be of type DOUBLE PRECISION[]')
+FROM  iris_predict_byom LIMIT 1;
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 2a1c39e..9cce86a 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -301,7 +301,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(True, res)
 
 
-class MadlibKerasPredictTestCase(unittest.TestCase):
+class InternalKerasPredictTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
         patches = {
@@ -406,6 +406,90 @@ class MadlibKerasPredictTestCase(unittest.TestCase):
         self.assertEqual(False, 'row_count' in k['SD'])
         self.assertEqual(False, 'segment_model_predict' in k['SD'])
 
+
+class MadlibKerasPredictBYOMTestCase(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        self.num_classes = 5
+        self.model = Sequential()
+        self.model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
+                              input_shape=(1,1,1,), padding='same'))
+        self.model.add(Dense(self.num_classes))
+
+        self.pred_type = 'prob'
+        self.gpus_per_host = 2
+        self.class_values = ['foo', 'bar', 'baaz', 'foo2', 'bar2']
+        self.normalizing_const = 255.0
+
+        import madlib_keras_predict
+        self.module = madlib_keras_predict
+        self.module.get_model_arch_weights = Mock(return_value=(
+            self.model.to_json(), 'weights'))
+        self.module.InputValidator.validate_predict_byom_tables = Mock()
+        self.module.InputValidator.validate_input_shape = Mock()
+        self.module.BasePredict.call_internal_keras = Mock()
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_predictbyom_defaults_1(self):
+        res = self.module.PredictBYOM('schema_madlib', 'model_arch_table',
+                                 'model_arch_id', 'test_table', 'id_col',
+                                 'independent_varname', 'output_table', None,
+                                 None, None, None)
+        self.assertEqual('response', res.pred_type)
+        self.assertEqual(0, res.gpus_per_host)
+        self.assertEqual([0,1,2,3,4], res.class_values)
+        self.assertEqual(1.0, res.normalizing_const)
+        self.assertEqual('text', res.dependent_vartype)
+
+    def test_predictbyom_defaults_2(self):
+        res = self.module.PredictBYOM('schema_madlib', 'model_arch_table',
+                                       'model_arch_id', 'test_table', 'id_col',
+                                       'independent_varname', 'output_table',
+                                       self.pred_type, self.gpus_per_host,
+                                       self.class_values, 
self.normalizing_const)
+        self.assertEqual('prob', res.pred_type)
+        self.assertEqual(2, res.gpus_per_host)
+        self.assertEqual(['foo', 'bar', 'baaz', 'foo2', 'bar2'], 
res.class_values)
+        self.assertEqual(255.0, res.normalizing_const)
+        self.assertEqual('double precision', res.dependent_vartype)
+
+    def test_predictbyom_exception_invalid_params(self):
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.module.PredictBYOM('schema_madlib', 'model_arch_table',
+                                     'model_arch_id', 'test_table', 'id_col',
+                                     'independent_varname', 'output_table',
+                                     'invalid_pred_type', self.gpus_per_host,
+                                     self.class_values, self.normalizing_const)
+        self.assertIn('invalid_pred_type', str(error.exception))
+
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.module.PredictBYOM('schema_madlib', 'model_arch_table',
+                                     'model_arch_id', 'test_table', 'id_col',
+                                     'independent_varname', 'output_table',
+                                     self.pred_type, self.gpus_per_host,
+                                     ["foo", "bar", "baaz"], 
self.normalizing_const)
+        self.assertIn('class values', str(error.exception).lower())
+
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.module.PredictBYOM('schema_madlib', 'model_arch_table',
+                                     'model_arch_id', 'test_table', 'id_col',
+                                     'independent_varname', 'output_table',
+                                     self.pred_type, self.gpus_per_host,
+                                     self.class_values, 0)
+        self.assertIn('normalizing const', str(error.exception).lower())
+
+
 class MadlibKerasWrapperTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
@@ -748,58 +832,37 @@ class 
MadlibKerasFitInputValidatorTestCase(unittest.TestCase):
     def tearDown(self):
         self.module_patcher.stop()
 
-    def test_validate_input_shapes_shapes_do_not_match(self):
-        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32}]
-        self.subject._validate_input_args = Mock()
-        with self.assertRaises(plpy.PLPYException):
-            self.subject._validate_input_shapes(
-                'dummy_tbl', 'dummy_col', [32,32,3], 2)
-
-        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': 32, 'n_2': 
32}]
-        with self.assertRaises(plpy.PLPYException):
-            self.subject._validate_input_shapes(
-                'dummy_tbl', 'dummy_col', [32,32,3], 2)
-
-        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': None, 'n_2': 
None}]
-        with self.assertRaises(plpy.PLPYException):
-            self.subject._validate_input_shapes(
-                'dummy_tbl', 'dummy_col', [3,32], 2)
-
-    def test_validate_input_shapes_shapes_match(self):
-        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32, 'n_2': 
3}]
-        self.subject._validate_input_args = Mock()
-        self.subject._validate_input_shapes(
-            'dummy_tbl', 'dummy_col', [32,32,3], 1)
 
     def test_is_valid_metrics_compute_frequency_True_None(self):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table',
+            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, None, False)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_True_num(self):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table',
+            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, 3, False)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_zero(self):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table',
+            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, 0, False)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_greater(self):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table',
+            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, 6, False)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
-class PredictInputValidatorTestCases(unittest.TestCase):
+
+class InputValidatorTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
         patches = {
@@ -813,34 +876,83 @@ class PredictInputValidatorTestCases(unittest.TestCase):
         self.module_patcher.start()
         import madlib_keras_validator
         self.module = madlib_keras_validator
-        self.module.PredictInputValidator._validate_input_args = Mock()
-        self.subject = self.module.PredictInputValidator(
-            'test_table', 'model_table', 'id_col', 'independent_varname',
-            'output_table', 'pred_type', 'module_name')
+        self.subject = self.module.InputValidator
+
+        self.module_name = 'module'
+        self.test_table = 'test_table'
+        self.model_table = 'model_table'
+        self.id_col = 'id_col'
+        self.ind_var = 'ind_var'
+        self.model_arch_table = 'model_arch_table'
+        self.model_arch_id = 2
+        self.num_classes = 1598
+        self.model = Sequential()
+        self.model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
+                              input_shape=(1,1,1,), padding='same'))
+        self.model.add(Dense(self.num_classes))
         self.classes = ['train', 'boat', 'car', 'airplane']
 
     def tearDown(self):
         self.module_patcher.stop()
 
     def test_validate_pred_type_invalid_pred_type(self):
-        self.subject.pred_type = 'invalid'
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject.validate_pred_type(
+                self.module_name, 'invalid_pred_type', ['cat', 'dog'])
+        self.assertIn('type', str(error.exception).lower())
+
+    def test_validate_class_values_greater_than_1600_class_values(self):
+        self.model.add(Dense(1599))
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject.validate_class_values(
+                self.module_name, range(1599), 'prob', self.model.to_json())
+        self.assertIn('1600', str(error.exception))
+
+    def test_validate_class_values_valid_class_values_prob(self):
+        self.subject.validate_class_values(
+            self.module_name, range(self.num_classes), 'prob', 
self.model.to_json())
+        self.subject.validate_class_values(
+            self.module_name, None, 'prob', self.model.to_json())
+
+    def 
test_validate_class_values_valid_pred_type_valid_class_values_response(self):
+        self.subject.validate_class_values(
+            self.module_name, range(self.num_classes), 'response', 
self.model.to_json())
+        self.subject.validate_class_values(
+            self.module_name, None, 'response', self.model.to_json())
+
+    def test_validate_input_shape_shapes_do_not_match(self):
+        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32}]
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.validate_input_shape(
+                self.test_table, self.ind_var, [32,32,3], 2)
+
+        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': 32, 'n_2': 
32}]
         with self.assertRaises(plpy.PLPYException):
-            self.subject.validate_pred_type(['cat', 'dog'])
+            self.subject.validate_input_shape(
+                self.test_table, self.ind_var, [32,32,3], 2)
 
-    def test_validate_pred_type_valid_pred_type_invalid_num_class_values(self):
-        self.subject.pred_type = 'prob'
+        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': None, 'n_2': 
None}]
         with self.assertRaises(plpy.PLPYException):
-            self.subject.validate_pred_type(range(1599))
+            self.subject.validate_input_shape(
+                self.test_table, self.ind_var, [3,32], 2)
 
-    def test_validate_pred_type_valid_pred_type_valid_class_values_prob(self):
-        self.subject.pred_type = 'prob'
-        self.subject.validate_pred_type(range(1598))
-        self.subject.validate_pred_type(None)
+    def test_validate_input_shape_shapes_match(self):
+        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32, 'n_2': 
3}]
+        self.subject.validate_input_shape(
+            self.test_table, self.ind_var, [32,32,3], 1)
+
+    def test_validate_model_arch_table_none_values(self):
+        with self.assertRaises(plpy.PLPYException) as error:
+            obj = self.subject.validate_model_arch_table(
+                self.module_name, None, self.model_arch_id)
+        self.assertIn('null', str(error.exception).lower())
+
+        self.module.input_tbl_valid = Mock()
+        with self.assertRaises(plpy.PLPYException) as error:
+            obj = self.subject.validate_model_arch_table(
+                self.module_name, self.model_arch_table, None)
+        self.assertIn('id', str(error.exception).lower())
 
-    def 
test_validate_pred_type_valid_pred_type_valid_class_values_response(self):
-        self.subject.pred_type = 'response'
-        self.subject.validate_pred_type(range(1598))
-        self.subject.validate_pred_type(None)
 
 class MadlibSerializerTestCase(unittest.TestCase):
     def setUp(self):
@@ -921,6 +1033,7 @@ class MadlibSerializerTestCase(unittest.TestCase):
         self.assertEqual(np.array([0,1,3,4,5], dtype=np.float32).tostring(),
                          res)
 
+
 class MadlibKerasHelperTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')

Reply via email to