This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 96885c8  DL: Split madlib_keras devcheck file
96885c8 is described below

commit 96885c895b6b75456fbe448aa786814bf72421a9
Author: Domino Valdano <dvald...@pivotal.io>
AuthorDate: Mon Jul 29 15:46:04 2019 -0700

    DL: Split madlib_keras devcheck file
    
    Prior to this commit, the dev-check test file for deep_learning module
    was too big and there was no way of running a subset of the tests for
    e.g. predict.
    This commit separates madlib_keras.sql_in test file into:
    1. Setup files: for data creation used in the tests
    2. Specific test files: for testing fit, predict, evaluate and transfer
    learning
    
    Additionally, madpack.py is updated to ignore any setup
    files(`*.setup.sql_in`) in the test dir and not run it as a test in
    dev-check/install-check.
    
    Co-authored-by: Ekta Khanna <ekha...@pivotal.io>
    Co-authored-by: Orhan Kislal <okis...@apache.org>
---
 src/madpack/madpack.py                             |    1 +
 .../modules/deep_learning/test/madlib_keras.sql_in | 1287 --------------------
 .../test/madlib_keras_cifar.setup.sql_in           |  152 +++
 .../test/madlib_keras_evaluate.sql_in              |   61 +
 .../deep_learning/test/madlib_keras_fit.sql_in     |  379 ++++++
 .../test/madlib_keras_iris.setup.sql_in            |  266 ++++
 .../deep_learning/test/madlib_keras_predict.sql_in |  316 +++++
 .../test/madlib_keras_predict_byom.sql_in          |  137 +++
 .../test/madlib_keras_transfer_learning.sql_in     |  116 ++
 9 files changed, 1428 insertions(+), 1287 deletions(-)

diff --git a/src/madpack/madpack.py b/src/madpack/madpack.py
index e735526..74fff37 100755
--- a/src/madpack/madpack.py
+++ b/src/madpack/madpack.py
@@ -674,6 +674,7 @@ def _process_py_sql_files_in_modules(modset, args_dict):
 
         # Loop through all SQL files for this module
         source_files = glob.glob(mask)
+        source_files = [s for s in source_files if '.setup' not in s]
         if calling_operation == INSTALL_DEV_CHECK and madpack_cmd != 
'install-check':
             source_files = [s for s in source_files if '.ic' not in s]
 
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
deleted file mode 100644
index 28a500e..0000000
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ /dev/null
@@ -1,1287 +0,0 @@
-/* ---------------------------------------------------------------------*//**
- *
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- *
- *//* ---------------------------------------------------------------------*/
-drop table if exists cifar_10_sample;
-create table cifar_10_sample(id INTEGER, y SMALLINT, y_text TEXT, imgpath 
TEXT, x  REAL[]);
-copy cifar_10_sample from stdin delimiter '|';
-1|0|'cat'|'0/img0.jpg'|{{{202,204,199},{202,204,199},{204,206,201},{206,208,203},{208,210,205},{209,211,206},{210,212,207},{212,214,210},{213,215,212},{215,217,214},{216,218,215},{216,218,215},{215,217,214},{216,218,215},{216,218,215},{216,218,214},{217,219,214},{217,219,214},{218,220,215},{218,219,214},{216,217,212},{217,218,213},{218,219,214},{214,215,209},{213,214,207},{212,213,206},{211,212,205},{209,210,203},{208,209,202},{207,208,200},{205,206,199},{203,204,198}},{{206,208,203},{20
 [...]
-2|1|'dog'|'0/img2.jpg'|{{{126,118,110},{122,115,108},{126,119,111},{127,119,109},{130,122,111},{130,122,111},{132,124,113},{133,125,114},{130,122,111},{132,124,113},{134,126,115},{131,123,112},{131,123,112},{134,126,115},{133,125,114},{136,128,117},{137,129,118},{137,129,118},{136,128,117},{131,123,112},{130,122,111},{132,124,113},{132,124,113},{132,124,113},{129,122,110},{127,121,109},{127,121,109},{125,119,107},{124,118,106},{124,118,106},{120,114,102},{117,111,99}},{{122,115,107},{119
 [...]
-\.
-
-
--- normalize the indep variable
--- TODO Calling this function makes keras.fit fail with the exception 
(investigate later)
--- NOTICE:  Releasing segworker groups to finish aborting the transaction.
--- ERROR:  could not connect to segment: initialization of segworker group 
failed (cdbgang.c:237)
--- update cifar_10_sample_val SET independent_var = 
array_scalar_mult(independent_var::real[], (1/255.0)::real);
-
--- Prepare the minibatched data manually instead of calling
--- training_preprocessor_dl since it internally calls array_scalar_mult.
--- Please refer to MADLIB-1326 for more details on the issue.
-
-DROP TABLE IF EXISTS cifar_10_sample_batched;
-CREATE TABLE cifar_10_sample_batched(
-    buffer_id smallint,
-    dependent_var smallint[],
-    dependent_var_text_with_null smallint[],
-    independent_var real[]);
-copy cifar_10_sample_batched from stdin delimiter '|';
-0|{{0,1}}|{{0,0,1,0,0}}|{{{{0.494118,0.462745,0.431373},{0.478431,0.45098,0.423529},{0.494118,0.466667,0.435294},{0.498039,0.466667,0.427451},{0.509804,0.478431,0.435294},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.521569,0.490196,0.447059},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.52549,0.494118,0.45098},{0.513726,0.482353,0.439216},{0.513726,0.482353,0.439216},{0.52549,0.494118,0.45098},{0.521569,0.490196,0.447059},{0.533333,0.501961,0.458824},{0.537
 [...]
-1|{{1,0}}|{{0,1,0,0,0}}|{{{{0.792157,0.8,0.780392},{0.792157,0.8,0.780392},{0.8,0.807843,0.788235},{0.807843,0.815686,0.796079},{0.815686,0.823529,0.803922},{0.819608,0.827451,0.807843},{0.823529,0.831373,0.811765},{0.831373,0.839216,0.823529},{0.835294,0.843137,0.831373},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.839216},{0.85098,0.858824,
 [...]
-\.
-
-DROP TABLE IF EXISTS cifar_10_sample_batched_summary;
-CREATE TABLE cifar_10_sample_batched_summary(
-    source_table text,
-    output_table text,
-    dependent_varname text,
-    independent_varname text,
-    dependent_vartype text,
-    class_values smallint[],
-    buffer_size integer,
-    normalizing_const numeric);
-INSERT INTO cifar_10_sample_batched_summary values (
-    'cifar_10_sample',
-    'cifar_10_sample_batched',
-    'y',
-    'x',
-    'smallint',
-    ARRAY[0,1],
-    1,
-    255.0);
-
-drop table if exists cifar_10_sample_val;
-create table cifar_10_sample_val(independent_var REAL[], dependent_var 
SMALLINT[], buffer_id SMALLINT);
-copy cifar_10_sample_val from stdin delimiter '|';
-{{{{0.494118,0.462745,0.431373},{0.478431,0.45098,0.423529},{0.494118,0.466667,0.435294},{0.498039,0.466667,0.427451},{0.509804,0.478431,0.435294},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.521569,0.490196,0.447059},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.52549,0.494118,0.45098},{0.513726,0.482353,0.439216},{0.513726,0.482353,0.439216},{0.52549,0.494118,0.45098},{0.521569,0.490196,0.447059},{0.533333,0.501961,0.458824},{0.537255,0.505882,0.462745},{
 [...]
-{{{{0.792157,0.8,0.780392},{0.792157,0.8,0.780392},{0.8,0.807843,0.788235},{0.807843,0.815686,0.796079},{0.815686,0.823529,0.803922},{0.819608,0.827451,0.807843},{0.823529,0.831373,0.811765},{0.831373,0.839216,0.823529},{0.835294,0.843137,0.831373},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.839216},{0.85098,0.858824,0.839216},{0.85098,0.858
 [...]
-\.
-
-DROP TABLE IF EXISTS cifar_10_sample_val_summary;
-CREATE TABLE cifar_10_sample_val_summary AS
-       SELECT * FROM cifar_10_sample_batched_summary;
-
---- NOTE:  In order to test fit_merge, we need at least 2 rows in the batched 
table (1 on each segment).
---- ALSO NOTE: As part of supporting Postgres, an issue was reported JIRA 
MADLIB-1326.
---- Once this bug is fixed, we should uncomment these 2 lines, which was used 
to generate
---- the 4 tables above (cifar_10_sample{|_val}_batched{|_summary}).  Only the 
original
----- cifar_10_sample table should be hard-coded, so we don't have to keep 
re-generating
----- all of these tables by hand every time something changes.
---- SELECT 
minibatch_preprocessor_dl('cifar_10_sample','cifar_10_sample_batched','y','x', 
1, 255);
-
-DROP TABLE IF EXISTS model_arch;
-SELECT load_keras_model('model_arch',
-  $${
-  "class_name": "Sequential",
-  "keras_version": "2.1.6",
-  "config": [{
-       "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}},
-       "name": "conv2d_1",
-       "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": 
null,
-       "dtype": "float32", "activation": "relu", "trainable": true,
-       "data_format": "channels_last", "filters": 32, "padding": "valid",
-       "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
-       "bias_initializer": {"class_name": "Zeros", "config": {}},
-       "batch_input_shape": [null, 32, 32, 3], "use_bias": true,
-       "activity_regularizer": null, "kernel_size": [3, 3]}},
-       {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", 
"trainable": true, "data_format": "channels_last", "pool_size": [2, 2], 
"padding": "valid", "strides": [2, 2]}},
-       {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, 
"trainable": true, "seed": null, "name": "dropout_1"}},
-       {"class_name": "Flatten", "config": {"trainable": true, "name": 
"flatten_1", "data_format": "channels_last"}},
-       {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
-       {"class_name": "Zeros", "config": {}}, "units": 2, "use_bias": true, 
"activity_regularizer": null}
-       }], "backend": "tensorflow"}$$);
-
--- -- Please do not break up the compile_params string
--- -- It might break the assertion
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_saved_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    3,
-    NULL,
-    'cifar_10_sample_val');
-
-SELECT assert(
-        model_arch_table = 'model_arch' AND
-        model_arch_id = 1 AND
-        model_type = 'madlib_keras' AND
-        start_training_time         < now() AND
-        end_training_time > start_training_time AND
-        source_table = 'cifar_10_sample_batched' AND
-        validation_table = 'cifar_10_sample_val' AND
-        model = 'keras_saved_out' AND
-        dependent_varname = 'y' AND
-        dependent_vartype = 'smallint' AND
-        independent_varname = 'x' AND
-        normalizing_const = 255.0 AND
-        pg_typeof(normalizing_const) = 'real'::regtype AND
-        name is NULL AND
-        description is NULL AND
-        model_size > 0 AND
-        madlib_version is NOT NULL AND
-        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text AND
-        fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
-        num_iterations = 3 AND
-        metrics_compute_frequency = 3 AND
-        num_classes = 2 AND
-        class_values = '{0,1}' AND
-        metrics_type = '{mae}' AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 1 AND
-        array_upper(training_loss, 1) = 1 AND
-        array_upper(metrics_elapsed_time, 1) = 1 AND
-        validation_metrics_final >= 0 AND
-        validation_loss_final  >= 0  AND
-        array_upper(validation_metrics, 1) = 1 AND
-        array_upper(validation_loss, 1) = 1 ,
-        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_saved_out_summary) summary;
-
-SELECT assert(
-        model_data IS NOT NULL AND
-        model_arch IS NOT NULL, 'Keras model output validation failed. 
Actual:' || __to_char(k))
-FROM (SELECT * FROM keras_saved_out) k;
-
--- Test that evaluate works as expected:
-DROP TABLE IF EXISTS evaluate_out;
-SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out', 0);
-
-SELECT assert(loss IS NOT NULL AND
-        metric IS NOT NULL AND
-        metrics_type = '{mae}', 'Evaluate output validation failed.  Actual:' 
|| __to_char(evaluate_out))
-FROM evaluate_out;
-
--- Test that passing NULL / None instead of 0 for gpus_per_host works
-DROP TABLE IF EXISTS evaluate_out;
-SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out');
-SELECT assert(loss IS NOT NULL AND
-        metric IS NOT NULL AND
-        metrics_type = '{mae}', 'Evaluate output validation failed.  Actual:' 
|| __to_char(evaluate_out))
-FROM evaluate_out;
-
--- Test that evaluate errors out correctly if model_arch field missing from 
fit output
-DROP TABLE IF EXISTS evaluate_out;
-ALTER TABLE keras_saved_out DROP COLUMN model_arch;
-SELECT assert(trap_error($TRAP$
-       SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out');
-       $TRAP$) = 1, 'Should error out if model_arch column is missing from 
model_table');
-
--- Verify number of iterations for which metrics and loss are computed
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_saved_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    7,
-    NULL,
-    'cifar_10_sample_val',
-    4);
-SELECT assert(
-        num_iterations = 7 AND
-        metrics_compute_frequency = 4 AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        metrics_type = '{accuracy}' AND
-        array_upper(training_metrics, 1) = 2 AND
-        array_upper(training_loss, 1) = 2 AND
-        array_upper(metrics_elapsed_time, 1) = 2 AND
-        validation_metrics_final >= 0 AND
-        validation_loss_final  >= 0  AND
-        array_upper(validation_metrics, 1) = 2 AND
-        array_upper(validation_loss, 1) = 2 ,
-        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_saved_out_summary) summary;
--- Fit with gpus_per_host set to 2 must error out on machines
--- that don't have GPUs. Since Jenkins builds are run on docker containers
--- that don't have GPUs, these queries must error out.
-DROP TABLE IF EXISTS keras_saved_out_gpu, keras_saved_out_gpu_summary;
-SELECT assert(trap_error($TRAP$madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_saved_out_gpu',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    3,
-    2,
-    'cifar_10_sample_val');$TRAP$) = 1,
-       'Fit with gpus_per_host=2 must error out.');
-
--- Prediction with gpus_per_host set to 2 must error out on machines
--- that don't have GPUs. Since Jenkins builds are run on docker containers
--- that don't have GPUs, these queries must error out.
-
--- IMPORTANT: The following test must be run when we have a valid
--- keras_saved_out model table. Otherwise, it will fail because of a
--- non-existent model table, while we want to trap failure due to
--- gpus_per_host=2
-DROP TABLE IF EXISTS cifar10_predict_gpu;
-SELECT assert(trap_error($TRAP$madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample',
-    'id',
-    'x',
-    'cifar10_predict_gpu',
-    NULL,
-    2);$TRAP$) = 1,
-    'Prediction with gpus_per_host=2 must error out.');
-
--- Test for
-  -- Non null name and description columns
-       -- Null validation table
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    2,
-    NULL,
-    NULL,
-    1,
-    NULL,
-    'model name',
-    'model desc');
-
-SELECT assert(
-    source_table = 'cifar_10_sample_batched' AND
-    model = 'keras_out' AND
-    dependent_varname = 'y' AND
-    independent_varname = 'x' AND
-    model_arch_table = 'model_arch' AND
-    model_arch_id = 1 AND
-    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
-    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
-    num_iterations = 2 AND
-    validation_table is NULL AND
-    metrics_compute_frequency = 1 AND
-    name = 'model name' AND
-    description = 'model desc' AND
-    model_type = 'madlib_keras' AND
-    model_size > 0 AND
-    start_training_time         < now() AND
-    end_training_time > start_training_time AND
-    array_upper(metrics_elapsed_time, 1) = 2 AND
-    dependent_vartype = 'smallint' AND
-    madlib_version is NOT NULL AND
-    num_classes = 2 AND
-    class_values = '{0,1}' AND
-    metrics_type = '{accuracy}' AND
-    normalizing_const = 255.0 AND
-    training_metrics_final is not NULL AND
-    training_loss_final is not NULL AND
-    array_upper(training_metrics, 1) = 2 AND
-    array_upper(training_loss, 1) = 2 AND
-    validation_metrics_final is  NULL AND
-    validation_loss_final is  NULL AND
-    validation_metrics is NULL AND
-    validation_loss is NULL,
-    'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_out_summary) summary;
-
-SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_out) k;
-
--- Validate metrics=NULL works with fit
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-'cifar_10_sample_batched',
-'keras_saved_out',
-'model_arch',
-1,
-$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy'$$::text,
-$$ batch_size=2, epochs=1, verbose=0 $$::text,
-1);
-
-SELECT assert(
-        metrics_type is NULL AND
-        training_metrics IS NULL AND
-        array_upper(training_loss, 1) = 1 AND
-        array_upper(metrics_elapsed_time, 1) = 1 AND
-        validation_metrics_final IS NULL AND
-        validation_loss_final  >= 0  AND
-        validation_metrics IS NULL AND
-        array_upper(validation_loss, 1) = 1,
-        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_saved_out_summary) summary;
-
--- Validate that metrics=NULL works with evaluate
-DROP TABLE IF EXISTS evaluate_out;
-SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out', 0);
-
-SELECT assert(loss IS NOT NULL AND
-        metric IS NULL AND
-        metrics_type IS NULL, 'Evaluate output validation for NULL metric 
failed.  Actual:' || __to_char(evaluate_out))
-FROM evaluate_out;
-
--- Validate metrics=[] works with fit
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-'cifar_10_sample_batched',
-'keras_saved_out',
-'model_arch',
-1,
-$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=[]$$::text,
-$$ batch_size=2, epochs=1, verbose=0 $$::text,
-1);
-
-SELECT assert(
-        metrics_type IS NULL AND
-        training_metrics IS NULL AND
-        array_upper(training_loss, 1) = 1 AND
-        array_upper(metrics_elapsed_time, 1) = 1 AND
-        validation_metrics_final IS NULL AND
-        validation_loss_final  >= 0  AND
-        validation_metrics IS NULL AND
-        array_upper(validation_loss, 1) = 1,
-        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_saved_out_summary) summary;
-
--- Validate metrics=[] works with evaluate
-DROP TABLE IF EXISTS evaluate_out;
-SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out', 0);
-
-SELECT assert(loss IS NOT NULL AND
-        metric IS NULL AND
-        metrics_type IS NULL, 'Evaluate output validation for [] metric 
failed.  Actual:' || __to_char(evaluate_out))
-FROM evaluate_out;
-
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample',
-    'id',
-    'x',
-    'cifar10_predict',
-    NULL,
-    0);
-
--- Validate that prediction output table exists and has correct schema
-SELECT assert(UPPER(pg_typeof(id)::TEXT )= 'INTEGER',
-    'id column should be INTEGER type') FROM cifar10_predict;
-
-SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
-    'SMALLINT', 'prediction column should be SMALLINT type')
-FROM cifar10_predict;
-
--- Validate correct number of rows returned.
-SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have 
two rows')
-FROM cifar10_predict;
-
--- First test that all values are in set of class values; if this breaks, it's 
definitely a problem.
-SELECT assert(estimated_y IN (0,1),
-    'Predicted value not in set of defined class values for model')
-FROM cifar10_predict;
-
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT assert(trap_error($TRAP$madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample_batched',
-    'id',
-    'x',
-    'cifar10_predict',
-    NULL,
-    0);$TRAP$) = 1,
-    'Passing batched image table to predict should error out.');
-
--- Compile and fit parameter tests
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_out',
-    'model_arch',
-    1,
-    $$ optimizer='SGD', loss='categorical_crossentropy', 
metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    1,
-    NULL,
-    NULL,
-    NULL,
-    NULL, 'model name', 'model desc');
-
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_out',
-    'model_arch',
-    1,
-    $$ optimizer='Adam()', loss='categorical_crossentropy', 
metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    1,
-    NULL,
-    NULL,
-    NULL,
-    NULL, 'model name', 'model desc');
-
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_out',
-    'model_arch',
-    1,
-    $$ optimizer=Adam(epsilon=None), loss='categorical_crossentropy', 
metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    1,
-    0,
-    NULL,
-    NULL,
-    NULL, 'model name', 'model desc');
-
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, 
loss='categorical_crossentropy' $$::text,
-    $$ epochs=10, verbose=0, shuffle=True, initial_epoch=1, steps_per_epoch=2 
$$::text,
-    1,
-    NULL,
-    NULL,
-    NULL,
-    False, 'model name', 'model desc');
-
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, 
loss='categorical_crossentropy' $$::text,
-    NULL,
-    1,
-    NULL,
-    NULL,
-    NULL,
-    False, 'model name', 'model desc');
-
--- -- negative test case for passing non numeric y to fit
--- induce failure by passing a non numeric column
-DROP TABLE IF EXISTS cifar_10_sample_val_failure;
-CREATE TABLE cifar_10_sample_val_failure AS SELECT * FROM cifar_10_sample_val;
-ALTER TABLE cifar_10_sample_val_failure rename dependent_var to 
dependent_var_original;
-ALTER TABLE cifar_10_sample_val_failure rename buffer_id to dependent_var;
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT assert(trap_error($TRAP$madlib_keras_fit(
-           'cifar_10_sample_batched',
-           'keras_out',
-           'model_arch',
-           1,
-           $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-           $$ batch_size=2, epochs=1, verbose=0 $$::text,
-           2,
-           NULL,
-          'cifar_10_sample_val_failure');$TRAP$) = 1,
-       'Passing y of type non numeric array to fit should error out.');
-
--- Test with pred_type=prob
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample',
-    'id',
-    'x',
-    'cifar10_predict',
-    'prob',
-    0);
-
-SELECT assert(UPPER(pg_typeof(prob_0)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_0 should be double precision type')
-FROM cifar10_predict;
-
-SELECT assert(UPPER(pg_typeof(prob_1)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_1 should be double precision type')
-FROM cifar10_predict;
-
-SELECT assert(COUNT(*)=3, 'Predict out table must have exactly three cols.')
-FROM pg_attribute
-WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
-
--- Tests with text class values:
--- Modify input data to have text classes, and mini-batch it.
-DROP TABLE IF EXISTS cifar_10_sample_text_batched;
--- Create a new table using the text based column for dep var.
-CREATE TABLE cifar_10_sample_text_batched AS
-    SELECT buffer_id, independent_var, dependent_var_text_with_null AS 
dependent_var
-    FROM cifar_10_sample_batched;
--- Insert a new row with NULL as the dependent var (one-hot encoded)
-INSERT INTO cifar_10_sample_text_batched(buffer_id, independent_var, 
dependent_var)
-    SELECT 2, independent_var, ARRAY[[0,1,0,0,0]]
-    FROM cifar_10_sample_batched
-    WHERE cifar_10_sample_batched.buffer_id=0;
--- Create the necessary summary table for the batched input.
-DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
-CREATE TABLE cifar_10_sample_text_batched_summary(
-    source_table text,
-    output_table text,
-    dependent_varname text,
-    independent_varname text,
-    dependent_vartype text,
-    class_values text[],
-    buffer_size integer,
-    normalizing_const numeric);
-INSERT INTO cifar_10_sample_text_batched_summary values (
-    'cifar_10_sample',
-    'cifar_10_sample_text_batched',
-    'y_text',
-    'x',
-    'text',
-    ARRAY[NULL,'cat','dog',NULL,NULL],
-    1,
-    255.0);
-
--- Change model_arch to reflect 5 num_classes
-DROP TABLE IF EXISTS model_arch;
-SELECT load_keras_model('model_arch',
-  $${
-  "class_name": "Sequential",
-  "keras_version": "2.1.6",
-  "config": [{
-    "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}},
-    "name": "conv2d_1",
-    "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": 
null,
-    "dtype": "float32", "activation": "relu", "trainable": true,
-    "data_format": "channels_last", "filters": 32, "padding": "valid",
-    "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
-    "bias_initializer": {"class_name": "Zeros", "config": {}},
-    "batch_input_shape": [null, 32, 32, 3], "use_bias": true,
-    "activity_regularizer": null, "kernel_size": [3, 3]}},
-    {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", 
"trainable": true, "data_format": "channels_last", "pool_size": [2, 2], 
"padding": "valid", "strides": [2, 2]}},
-    {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, 
"trainable": true, "seed": null, "name": "dropout_1"}},
-    {"class_name": "Flatten", "config": {"trainable": true, "name": 
"flatten_1", "data_format": "channels_last"}},
-    {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
-    {"class_name": "Zeros", "config": {}}, "units": 5, "use_bias": true, 
"activity_regularizer": null}
-    }], "backend": "tensorflow"}$$);
-
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_text_batched',
-    'keras_saved_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    3);
--- Assert fit has correct class_values
-SELECT assert(
-    dependent_vartype = 'text' AND
-    class_values = '{NULL,cat,dog,NULL,NULL}',
-    'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_saved_out_summary) summary;
-
--- Predict with pred_type=prob
-DROP TABLE IF EXISTS cifar_10_sample_text;
-CREATE TABLE cifar_10_sample_text AS
-    SELECT id, x, y_text
-    FROM cifar_10_sample;
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample_text',
-    'id',
-    'x',
-    'cifar10_predict',
-    'prob',
-    0);
-
--- Validate the output datatype of newly created prediction columns
--- for prediction type = 'prob' and class_values 'TEXT' with NULL as a valid
--- class_values
-SELECT assert(UPPER(pg_typeof(prob_cat)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_cat should be double precision type')
-FROM cifar10_predict;
-
-SELECT assert(UPPER(pg_typeof(prob_dog)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_dog should be double precision type')
-FROM cifar10_predict;
-
-SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
-    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
-FROM cifar10_predict;
-
--- Must have exactly 4 cols (3 for class_values and 1 for id)
-SELECT assert(COUNT(*)=4, 'Predict out table must have exactly four cols.')
-FROM pg_attribute
-WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
-
--- Predict with pred_type=response
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample_text',
-    'id',
-    'x',
-    'cifar10_predict',
-    'response',
-    0);
-
--- Validate the output datatype of newly created prediction columns
--- for prediction type = 'response' and class_values 'TEXT' with NULL
--- as a valid class_values
-SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) = 'TEXT',
-       'prediction column should be TEXT type')
-FROM  cifar10_predict LIMIT 1;
-
-
--- Tests where the assumption is user has one-hot encoded, so class_values
--- in input summary table will be NULL.
-UPDATE keras_saved_out_summary SET class_values=NULL;
-
--- Predict with pred_type=prob
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample_text',
-    'id',
-    'x',
-    'cifar10_predict',
-    'prob',
-    0);
-
--- Validate the output datatype of newly created prediction column
--- for prediction type = 'response' and class_value = NULL
--- Returns: Array of probabilities for user's one-hot encoded data
-SELECT assert(UPPER(pg_typeof(prob)::TEXT) = 'DOUBLE PRECISION[]',
-       'column prob should be double precision[] type')
-FROM  cifar10_predict LIMIT 1;
-
--- Predict with pred_type=response
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample_text',
-    'id',
-    'x',
-    'cifar10_predict',
-    'response',
-    0);
-
--- Validate the output datatype of newly created prediction column
--- for prediction type = 'response' and class_value = NULL
--- Returns: Index of class value in user's one-hot encoded data with
--- highest probability
-SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) = 'TEXT',
-       'column estimated_y_text should be text type')
-FROM  cifar10_predict LIMIT 1;
-
-SELECT assert(
-  estimated_y_text IN ('0', '1'),
-  'Predict failure for null class value and response pred_type.')
-FROM cifar10_predict;
-
--- Test predict with INTEGER class_values
--- with NULL as a valid class value
-INSERT INTO cifar_10_sample(id, x, y, imgpath)
-SELECT 3, x, NULL, '0/img3.jpg' FROM cifar_10_sample
-WHERE y = 1;
-INSERT INTO cifar_10_sample(id, x, y, imgpath)
-SELECT 4, x, 4, '0/img4.jpg' FROM cifar_10_sample
-WHERE y = 0;
-INSERT INTO cifar_10_sample(id, x, y, imgpath)
-SELECT 5, x, 5, '0/img5.jpg' FROM cifar_10_sample
-WHERE y = 1;
-
-DROP TABLE IF EXISTS cifar_10_sample_int_batched;
-DROP TABLE IF EXISTS cifar_10_sample_int_batched_summary;
-SELECT 
training_preprocessor_dl('cifar_10_sample','cifar_10_sample_int_batched','y','x',
 2, 255, 5);
-
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_int_batched',
-    'keras_saved_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    3);
-
--- Assert fit has correct class_values
-SELECT assert(
-    dependent_vartype = 'smallint' AND
-    class_values = '{NULL,0,1,4,5}',
-    'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_saved_out_summary) summary;
-
--- Predict with pred_type=prob
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample',
-    'id',
-    'x',
-    'cifar10_predict',
-    'prob',
-    0);
-
--- Validate the output datatype of newly created prediction column
--- for prediction type = 'prob' and class_values 'INT' with NULL
--- as a valid class_values
-SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
-    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
-FROM cifar10_predict;
--- Must have exactly 6 cols (5 for class_values and 1 for id)
-SELECT assert(COUNT(*)=6, 'Predict out table must have exactly six cols.')
-FROM pg_attribute
-WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
-
--- Predict with pred_type=response
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample',
-    'id',
-    'x',
-    'cifar10_predict',
-    'response',
-    0);
-
--- Validate the output datatype of newly created prediction column
--- for prediction type = 'response' and class_values 'TEXT' with NULL
--- as a valid class_values
--- Returns: class_value with highest probability
-SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
-    'SMALLINT', 'prediction column should be smallint type')
-FROM cifar10_predict;
-
--- Test case with a different input shape (3, 32, 32) instead of (32, 32, 3).
--- Create a new table with image shape 3, 32, 32
-drop table if exists cifar_10_sample_test_shape;
-create table cifar_10_sample_test_shape(id INTEGER, y SMALLINT, x  REAL[] );
-copy cifar_10_sample_test_shape from stdin delimiter '|';
-1|0|{{{248,248,250,245,245,246,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245},{247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245},{245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247},{248,248,250,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247
 [...]
-\.
-
-DROP TABLE IF EXISTS cifar_10_sample_test_shape_batched;
-DROP TABLE IF EXISTS cifar_10_sample_test_shape_batched_summary;
-SELECT 
training_preprocessor_dl('cifar_10_sample_test_shape','cifar_10_sample_test_shape_batched','y','x',
 NULL, 255, 3);
-
--- Change model_arch to reflect channels_first
-DROP TABLE IF EXISTS model_arch;
-SELECT load_keras_model('model_arch',
-  $${
-  "class_name": "Sequential",
-  "keras_version": "2.1.6",
-  "config": [{
-    "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}},
-    "name": "conv2d_1",
-    "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": 
null,
-    "dtype": "float32", "activation": "relu", "trainable": true,
-    "data_format": "channels_first", "filters": 32, "padding": "valid",
-    "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
-    "bias_initializer": {"class_name": "Zeros", "config": {}},
-    "batch_input_shape": [null, 3, 32, 32], "use_bias": true,
-    "activity_regularizer": null, "kernel_size": [3, 3]}},
-    {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", 
"trainable": true, "data_format": "channels_first", "pool_size": [2, 2], 
"padding": "valid", "strides": [2, 2]}},
-    {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, 
"trainable": true, "seed": null, "name": "dropout_1"}},
-    {"class_name": "Flatten", "config": {"trainable": true, "name": 
"flatten_1", "data_format": "channels_first"}},
-    {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
-    {"class_name": "Zeros", "config": {}}, "units": 3, "use_bias": true, 
"activity_regularizer": null}
-    }], "backend": "tensorflow"}$$);
-
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_test_shape_batched',
-    'keras_saved_out',
-    'model_arch',
-    1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    3);
-
--- Predict with correctly shaped data, must go thru.
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample_test_shape',
-    'id',
-    'x',
-    'cifar10_predict',
-    'prob',
-    0);
-
--- Prediction with incorrectly shaped data must error out.
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT assert(trap_error($TRAP$madlib_keras_predict(
-        'keras_saved_out',
-        'cifar_10_sample',
-        'id',
-        'x',
-        'cifar10_predict',
-        'prob',
-        0);$TRAP$) = 1,
-    'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should 
have failed.');
-
--- Test model_arch is retrieved from model data table and not model 
architecture
-DROP TABLE IF EXISTS model_arch;
-DROP TABLE IF EXISTS cifar10_predict;
-SELECT madlib_keras_predict(
-    'keras_saved_out',
-    'cifar_10_sample_test_shape',
-    'id',
-    'x',
-    'cifar10_predict',
-    'prob',
-    0);
-
--------------------- TRANSFER LEARNING and WARM START -----------------
-
-DROP TABLE IF EXISTS iris_data;
-CREATE TABLE iris_data(
-    id serial,
-    attributes numeric[],
-    class_text varchar
-);
-INSERT INTO iris_data(id, attributes, class_text) VALUES
-(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),
-(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),
-(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),
-(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),
-(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),
-(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),
-(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),
-(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),
-(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),
-(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
-(11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),
-(12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),
-(13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),
-(14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),
-(15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),
-(16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),
-(17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),
-(18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),
-(19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),
-(20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),
-(21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),
-(22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),
-(23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),
-(24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),
-(25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),
-(26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),
-(27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),
-(28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),
-(29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),
-(30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),
-(31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),
-(32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),
-(33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),
-(34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),
-(35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
-(36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),
-(37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),
-(38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
-(39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),
-(40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),
-(41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),
-(42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),
-(43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),
-(44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),
-(45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),
-(46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),
-(47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),
-(48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),
-(49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),
-(50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),
-(51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),
-(52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),
-(53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),
-(54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),
-(55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),
-(56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),
-(57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),
-(58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),
-(59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),
-(60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),
-(61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),
-(62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),
-(63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),
-(64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),
-(65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),
-(66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),
-(67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),
-(68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),
-(69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),
-(70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),
-(71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),
-(72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),
-(73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),
-(74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),
-(75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),
-(76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),
-(77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),
-(78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),
-(79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),
-(80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),
-(81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),
-(82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),
-(83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),
-(84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),
-(85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),
-(86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),
-(87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),
-(88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),
-(89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),
-(90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),
-(91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),
-(92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),
-(93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),
-(94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),
-(95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),
-(96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),
-(97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),
-(98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),
-(99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),
-(100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),
-(101,ARRAY[6.3,3.3,6.0,2.5],'Iris-virginica'),
-(102,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
-(103,ARRAY[7.1,3.0,5.9,2.1],'Iris-virginica'),
-(104,ARRAY[6.3,2.9,5.6,1.8],'Iris-virginica'),
-(105,ARRAY[6.5,3.0,5.8,2.2],'Iris-virginica'),
-(106,ARRAY[7.6,3.0,6.6,2.1],'Iris-virginica'),
-(107,ARRAY[4.9,2.5,4.5,1.7],'Iris-virginica'),
-(108,ARRAY[7.3,2.9,6.3,1.8],'Iris-virginica'),
-(109,ARRAY[6.7,2.5,5.8,1.8],'Iris-virginica'),
-(110,ARRAY[7.2,3.6,6.1,2.5],'Iris-virginica'),
-(111,ARRAY[6.5,3.2,5.1,2.0],'Iris-virginica'),
-(112,ARRAY[6.4,2.7,5.3,1.9],'Iris-virginica'),
-(113,ARRAY[6.8,3.0,5.5,2.1],'Iris-virginica'),
-(114,ARRAY[5.7,2.5,5.0,2.0],'Iris-virginica'),
-(115,ARRAY[5.8,2.8,5.1,2.4],'Iris-virginica'),
-(116,ARRAY[6.4,3.2,5.3,2.3],'Iris-virginica'),
-(117,ARRAY[6.5,3.0,5.5,1.8],'Iris-virginica'),
-(118,ARRAY[7.7,3.8,6.7,2.2],'Iris-virginica'),
-(119,ARRAY[7.7,2.6,6.9,2.3],'Iris-virginica'),
-(120,ARRAY[6.0,2.2,5.0,1.5],'Iris-virginica'),
-(121,ARRAY[6.9,3.2,5.7,2.3],'Iris-virginica'),
-(122,ARRAY[5.6,2.8,4.9,2.0],'Iris-virginica'),
-(123,ARRAY[7.7,2.8,6.7,2.0],'Iris-virginica'),
-(124,ARRAY[6.3,2.7,4.9,1.8],'Iris-virginica'),
-(125,ARRAY[6.7,3.3,5.7,2.1],'Iris-virginica'),
-(126,ARRAY[7.2,3.2,6.0,1.8],'Iris-virginica'),
-(127,ARRAY[6.2,2.8,4.8,1.8],'Iris-virginica'),
-(128,ARRAY[6.1,3.0,4.9,1.8],'Iris-virginica'),
-(129,ARRAY[6.4,2.8,5.6,2.1],'Iris-virginica'),
-(130,ARRAY[7.2,3.0,5.8,1.6],'Iris-virginica'),
-(131,ARRAY[7.4,2.8,6.1,1.9],'Iris-virginica'),
-(132,ARRAY[7.9,3.8,6.4,2.0],'Iris-virginica'),
-(133,ARRAY[6.4,2.8,5.6,2.2],'Iris-virginica'),
-(134,ARRAY[6.3,2.8,5.1,1.5],'Iris-virginica'),
-(135,ARRAY[6.1,2.6,5.6,1.4],'Iris-virginica'),
-(136,ARRAY[7.7,3.0,6.1,2.3],'Iris-virginica'),
-(137,ARRAY[6.3,3.4,5.6,2.4],'Iris-virginica'),
-(138,ARRAY[6.4,3.1,5.5,1.8],'Iris-virginica'),
-(139,ARRAY[6.0,3.0,4.8,1.8],'Iris-virginica'),
-(140,ARRAY[6.9,3.1,5.4,2.1],'Iris-virginica'),
-(141,ARRAY[6.7,3.1,5.6,2.4],'Iris-virginica'),
-(142,ARRAY[6.9,3.1,5.1,2.3],'Iris-virginica'),
-(143,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
-(144,ARRAY[6.8,3.2,5.9,2.3],'Iris-virginica'),
-(145,ARRAY[6.7,3.3,5.7,2.5],'Iris-virginica'),
-(146,ARRAY[6.7,3.0,5.2,2.3],'Iris-virginica'),
-(147,ARRAY[6.3,2.5,5.0,1.9],'Iris-virginica'),
-(148,ARRAY[6.5,3.0,5.2,2.0],'Iris-virginica'),
-(149,ARRAY[6.2,3.4,5.4,2.3],'Iris-virginica'),
-(150,ARRAY[5.9,3.0,5.1,1.8],'Iris-virginica');
-
-DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_summary;
-SELECT training_preprocessor_dl('iris_data',         -- Source table
-                                'iris_data_packed',  -- Output table
-                                'class_text',        -- Dependent variable
-                                'attributes'         -- Independent variable
-                                );
-
-DROP TABLE IF EXISTS iris_model_arch;
--- NOTE: The seed is set to 0 for every layer.
-SELECT load_keras_model('iris_model_arch',  -- Output table,
-$$
-{
-"class_name": "Sequential",
-"keras_version": "2.1.6",
-"config":
-    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling",
-    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
-    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
-    "bias_constraint": null, "dtype": "float32", "activation": "relu", 
"trainable": true,
-    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
-    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": 
true,
-    "activity_regularizer": null}}, {"class_name": "Dense",
-    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
-    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
-    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
-    "bias_constraint": null, "activation": "relu", "trainable": true, 
"kernel_regularizer": null,
-    "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, 
"use_bias": true,
-    "activity_regularizer": null}}, {"class_name": "Dense", "config": 
{"kernel_initializer":
-    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", 
"scale": 1.0,
-    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": 
null,
-    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
-    "trainable": true, "kernel_regularizer": null, "bias_initializer": 
{"class_name": "Zeros",
-    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": 
null}}],
-    "backend": "tensorflow"}
-$$
-);
-
-DROP TABLE IF EXISTS iris_model, iris_model_summary;
-SELECT madlib_keras_fit('iris_data_packed',   -- source table
-                        'iris_model',          -- model output table
-                        'iris_model_arch',  -- model arch table
-                         1,                    -- model arch id
-                         $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
-                         $$ batch_size=5, epochs=3 $$,  -- fit_params
-                         5,                    -- num_iterations
-                         NULL, NULL,
-                         1 -- metrics_compute_frequency
-                        );
-
-DROP TABLE IF EXISTS iris_train, iris_test;
--- Set seed so results are reproducible
-SELECT setseed(0);
-SELECT train_test_split('iris_data',     -- Source table
-                        'iris',          -- Output table root name
-                        0.8,            -- Train proportion
-                        NULL,           -- Test proportion (0.2)
-                        NULL,           -- Strata definition
-                        NULL,           -- Output all columns
-                        NULL,           -- Sample without replacement
-                        TRUE            -- Separate output tables
-                        );
-
-DROP TABLE IF EXISTS iris_predict;
-SELECT madlib_keras_predict('iris_model', -- model
-                            'iris_test',  -- test_table
-                            'id',  -- id column
-                            'attributes', -- independent var
-                            'iris_predict'  -- output table
-                            );
-
--- Test that our code is indeed learning something and not broken. The loss
--- from the first iteration should be less than the 5th, while the accuracy
--- must be greater.
-SELECT assert(
-  array_upper(training_loss, 1) = 5 AND
-  array_upper(training_metrics, 1) = 5,
-  'metrics compute frequency must be 1.')
-FROM iris_model_summary;
-
-SELECT assert(
-  training_loss[5]-training_loss[1] < 0 AND
-  training_metrics[5]-training_metrics[1] > 0,
-    'The loss and accuracy should have improved with more iterations.'
-)
-FROM iris_model_summary;
-
--- Make a copy of the loss and metrics array, to compare it with runs after
--- warm start and transfer learning.
-DROP TABLE IF EXISTS iris_model_first_run;
-CREATE TABLE iris_model_first_run AS
-SELECT training_loss_final, training_metrics_final
-FROM iris_model_summary;
-
--- Duplicate the architecture, but note that trainable is set to FALSE.
--- This is to ensure we don't learn anything new, that would help us
--- deterministically assert the accuracy and loss after transfer learning
--- and warm start.
-SELECT load_keras_model('iris_model_arch',  -- Output table,
-$$
-{
-"class_name": "Sequential",
-"keras_version": "2.1.6",
-"config":
-    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling",
-    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
-    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
-    "bias_constraint": null, "dtype": "float32", "activation": "relu",
-    "trainable": false,
-    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
-    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": 
true,
-    "activity_regularizer": null}}, {"class_name": "Dense",
-    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
-    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
-    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
-    "bias_constraint": null, "activation": "relu",
-    "trainable": false,
-    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
-    "config": {}}, "units": 10, "use_bias": true, "activity_regularizer": 
null}},
-    {"class_name": "Dense", "config": {"kernel_initializer":
-    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", 
"scale": 1.0,
-    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": 
null,
-    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
-    "trainable": false,
-    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
-    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": 
null}}],
-    "backend": "tensorflow"}
-$$
-);
--- Copy weights that were learnt from the previous run, for transfer
--- learning. Copy it now, because using warm_start will overwrite it.
-UPDATE iris_model_arch set model_weights = (select model_data from iris_model) 
 WHERE model_id = 2;
-
--- Warm start test
-SELECT madlib_keras_fit('iris_data_packed',   -- source table
-                       'iris_model',          -- model output table
-                       'iris_model_arch',  -- model arch table
-                        2,                    -- model arch id
-                        $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
-                        $$ batch_size=5, epochs=3 $$,  -- fit_params
-                        2,                    -- num_iterations,
-                        NULL, NULL, 1,
-                        true -- warm start
-                      );
-
-SELECT assert(
-  array_upper(training_loss, 1) = 2 AND
-  array_upper(training_metrics, 1) = 2,
-  'metrics compute frequency must be 1.')
-FROM iris_model_summary;
-
-SELECT assert(
-  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
-  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
-  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
-  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
-  'warm start test failed because training loss and metrics don''t match the 
expected value from the previous run of keras fit.')
-FROM iris_model_first_run AS first, iris_model_summary AS second;
-
--- Transfer learning test
-DROP TABLE IF EXISTS iris_model_transfer, iris_model_transfer_summary;
-SELECT madlib_keras_fit('iris_data_packed',   -- source table
-                       'iris_model_transfer',          -- model output table
-                       'iris_model_arch',  -- model arch table
-                        2,                    -- model arch id
-                        $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
-                        $$ batch_size=5, epochs=3 $$,  -- fit_params
-                        2,
-                        NULL, NULL, 1
-                      );
-
-SELECT assert(
-  array_upper(training_loss, 1) = 2 AND
-  array_upper(training_metrics, 1) = 2,
-  'metrics compute frequency must be 1.')
-FROM iris_model_transfer_summary;
-
-SELECT assert(
-  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
-  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
-  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
-  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
-  'Transfer learning test failed because training loss and metrics don''t 
match the expected value.')
-FROM iris_model_first_run AS first, iris_model_transfer_summary AS second;
-
----------------------- Predict BYOM test --------------------------------
-
--- class_values not NULL, pred_type is response
-DROP TABLE IF EXISTS iris_predict_byom;
-SELECT madlib_keras_predict_byom(
-                                 'iris_model_arch',
-                                 2,
-                                 'iris_test',
-                                 'id',
-                                 'attributes',
-                                 'iris_predict_byom',
-                                 'response',
-                                 -1,
-                                 ARRAY['Iris-setosa', 'Iris-versicolor',
-                                  'Iris-virginica']
-                                 );
-
-SELECT assert(
-  p0.estimated_class_text = p1.estimated_dependent_var,
-  'Predict byom failure for non null class value and response pred_type.')
-FROM iris_predict AS p0,  iris_predict_byom AS p1
-WHERE p0.id=p1.id;
-SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
-       'Predict byom failure for non null class value and response pred_type.
-        Expeceted estimated_dependent_var to be of type TEXT')
-FROM  iris_predict_byom LIMIT 1;
-
--- class_values NULL, pred_type is NULL (response)
-DROP TABLE IF EXISTS iris_predict_byom;
-SELECT madlib_keras_predict_byom(
-                                 'iris_model_arch',
-                                 2,
-                                 'iris_test',
-                                 'id',
-                                 'attributes',
-                                 'iris_predict_byom'
-                                 );
-SELECT assert(
-  p1.estimated_dependent_var IN ('0', '1', '2'),
-  'Predict byom failure for null class value and null pred_type.')
-FROM iris_predict_byom AS p1;
-SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
-       'Predict byom failure for non null class value and response pred_type.
-        Expeceted estimated_dependent_var to be of type TEXT')
-FROM  iris_predict_byom LIMIT 1;
-
--- class_values not NULL, pred_type is prob
-DROP TABLE IF EXISTS iris_predict_byom;
-SELECT madlib_keras_predict_byom(
-                                 'iris_model_arch',
-                                 2,
-                                 'iris_test',
-                                 'id',
-                                 'attributes',
-                                 'iris_predict_byom',
-                                 'prob',
-                                 -1,
-                                 ARRAY['Iris-setosa', 'Iris-versicolor',
-                                  'Iris-virginica'],
-                                 1.0
-                                 );
-
-SELECT assert(
-  (p1."prob_Iris-setosa" + p1."prob_Iris-virginica" + 
p1."prob_Iris-versicolor") - 1 < 1e-6,
-    'Predict byom failure for non null class value and prob pred_type.')
-FROM iris_predict_byom AS p1;
-SELECT assert(UPPER(pg_typeof("prob_Iris-setosa")::TEXT) = 'DOUBLE PRECISION',
-       'Predict byom failure for non null class value and prob pred_type.
-       Expeceted "prob_Iris-setosa" to be of type DOUBLE PRECISION')
-FROM  iris_predict_byom LIMIT 1;
-
--- class_values NULL, pred_type is prob
-DROP TABLE IF EXISTS iris_predict_byom;
-SELECT madlib_keras_predict_byom(
-                                 'iris_model_arch',
-                                 2,
-                                 'iris_test',
-                                 'id',
-                                 'attributes',
-                                 'iris_predict_byom',
-                                 'prob',
-                                 0,
-                                 NULL
-                                 );
-SELECT assert(
-  (prob[1] + prob[2] + prob[3]) - 1 < 1e-6,
-    'Predict byom failure for null class value and prob pred_type.')
-FROM iris_predict_byom;
-SELECT assert(UPPER(pg_typeof(prob)::TEXT) = 'DOUBLE PRECISION[]',
-       'Predict byom failure for null class value and prob pred_type. 
Expeceted prob to
-       be of type DOUBLE PRECISION[]')
-FROM  iris_predict_byom LIMIT 1;
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
new file mode 100644
index 0000000..913921f
--- /dev/null
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
@@ -0,0 +1,152 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+-------------------- CIFAR 10 test input tables -----------------
+
+DROP TABLE IF EXISTS cifar_10_sample;
+CREATE TABLE cifar_10_sample(id INTEGER, y SMALLINT, y_text TEXT, imgpath 
TEXT, x  REAL[]);
+COPY cifar_10_sample FROM STDIN DELIMITER '|';
+1|0|'cat'|'0/img0.jpg'|{{{202,204,199},{202,204,199},{204,206,201},{206,208,203},{208,210,205},{209,211,206},{210,212,207},{212,214,210},{213,215,212},{215,217,214},{216,218,215},{216,218,215},{215,217,214},{216,218,215},{216,218,215},{216,218,214},{217,219,214},{217,219,214},{218,220,215},{218,219,214},{216,217,212},{217,218,213},{218,219,214},{214,215,209},{213,214,207},{212,213,206},{211,212,205},{209,210,203},{208,209,202},{207,208,200},{205,206,199},{203,204,198}},{{206,208,203},{20
 [...]
+2|1|'dog'|'0/img2.jpg'|{{{126,118,110},{122,115,108},{126,119,111},{127,119,109},{130,122,111},{130,122,111},{132,124,113},{133,125,114},{130,122,111},{132,124,113},{134,126,115},{131,123,112},{131,123,112},{134,126,115},{133,125,114},{136,128,117},{137,129,118},{137,129,118},{136,128,117},{131,123,112},{130,122,111},{132,124,113},{132,124,113},{132,124,113},{129,122,110},{127,121,109},{127,121,109},{125,119,107},{124,118,106},{124,118,106},{120,114,102},{117,111,99}},{{122,115,107},{119
 [...]
+\.
+
+DROP TABLE IF EXISTS cifar_10_sample_batched;
+DROP TABLE IF EXISTS cifar_10_sample_batched_summary;
+SELECT 
training_preprocessor_dl('cifar_10_sample','cifar_10_sample_batched','y','x', 
1, 255);
+
+DROP TABLE IF EXISTS cifar_10_sample_val;
+SELECT 
validation_preprocessor_dl('cifar_10_sample','cifar_10_sample_val','y','x', 
'cifar_10_sample_batched', 1);
+--- NOTE:  In order to test fit_merge, we need at least 2 rows in the batched 
table (1 on each segment).
+
+-- Text class values.
+DROP TABLE IF EXISTS cifar_10_sample_text_batched;
+-- Create a new table using the text based column for dep var.
+CREATE TABLE cifar_10_sample_text_batched AS
+    SELECT buffer_id, independent_var, dependent_var
+    FROM cifar_10_sample_batched;
+-- Insert a new row with NULL as the dependent var (one-hot encoded)
+UPDATE cifar_10_sample_text_batched set dependent_var = ARRAY[[0,0,1,0,0]] 
where buffer_id=0;
+UPDATE cifar_10_sample_text_batched set dependent_var = ARRAY[[0,1,0,0,0]] 
where buffer_id=1;
+INSERT INTO cifar_10_sample_text_batched(buffer_id, independent_var, 
dependent_var)
+    SELECT 2, independent_var, ARRAY[[0,1,0,0,0]]
+    FROM cifar_10_sample_batched
+    WHERE cifar_10_sample_batched.buffer_id=0;
+-- Create the necessary summary table for the batched input.
+DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
+CREATE TABLE cifar_10_sample_text_batched_summary(
+    source_table text,
+    output_table text,
+    dependent_varname text,
+    independent_varname text,
+    dependent_vartype text,
+    class_values text[],
+    buffer_size integer,
+    normalizing_const numeric);
+INSERT INTO cifar_10_sample_text_batched_summary values (
+    'cifar_10_sample',
+    'cifar_10_sample_text_batched',
+    'y_text',
+    'x',
+    'text',
+    ARRAY[NULL,'cat','dog',NULL,NULL],
+    1,
+    255.0);
+
+DROP TABLE IF EXISTS cifar_10_sample_int_batched;
+DROP TABLE IF EXISTS cifar_10_sample_int_batched_summary;
+SELECT 
training_preprocessor_dl('cifar_10_sample','cifar_10_sample_int_batched','y','x',
 2, 255, 5);
+
+-- This table is for testing a different input shape (3, 32, 32) instead of 
(32, 32, 3).
+-- Create a table with image shape 3, 32, 32
+drop table if exists cifar_10_sample_test_shape;
+create table cifar_10_sample_test_shape(id INTEGER, y SMALLINT, x  REAL[] );
+copy cifar_10_sample_test_shape from stdin delimiter '|';
+1|0|{{{248,248,250,245,245,246,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245},{247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245},{245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247},{248,248,250,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247
 [...]
+\.
+
+DROP TABLE IF EXISTS cifar_10_sample_test_shape_batched;
+DROP TABLE IF EXISTS cifar_10_sample_test_shape_batched_summary;
+SELECT 
training_preprocessor_dl('cifar_10_sample_test_shape','cifar_10_sample_test_shape_batched','y','x',
 NULL, 255, 3);
+
+DROP TABLE IF EXISTS model_arch;
+SELECT load_keras_model('model_arch',
+  $${
+  "class_name": "Sequential",
+  "keras_version": "2.1.6",
+  "config": [{
+       "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}},
+       "name": "conv2d_1",
+       "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": 
null,
+       "dtype": "float32", "activation": "relu", "trainable": true,
+       "data_format": "channels_last", "filters": 32, "padding": "valid",
+       "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
+       "bias_initializer": {"class_name": "Zeros", "config": {}},
+       "batch_input_shape": [null, 32, 32, 3], "use_bias": true,
+       "activity_regularizer": null, "kernel_size": [3, 3]}},
+       {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", 
"trainable": true, "data_format": "channels_last", "pool_size": [2, 2], 
"padding": "valid", "strides": [2, 2]}},
+       {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, 
"trainable": true, "seed": null, "name": "dropout_1"}},
+       {"class_name": "Flatten", "config": {"trainable": true, "name": 
"flatten_1", "data_format": "channels_last"}},
+       {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
+       {"class_name": "Zeros", "config": {}}, "units": 2, "use_bias": true, 
"activity_regularizer": null}
+       }], "backend": "tensorflow"}$$);
+
+SELECT load_keras_model('model_arch',
+  $${
+  "class_name": "Sequential",
+  "keras_version": "2.1.6",
+  "config": [{
+       "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}},
+       "name": "conv2d_1",
+       "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": 
null,
+       "dtype": "float32", "activation": "relu", "trainable": true,
+       "data_format": "channels_last", "filters": 32, "padding": "valid",
+       "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
+       "bias_initializer": {"class_name": "Zeros", "config": {}},
+       "batch_input_shape": [null, 32, 32, 3], "use_bias": true,
+       "activity_regularizer": null, "kernel_size": [3, 3]}},
+       {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", 
"trainable": true, "data_format": "channels_last", "pool_size": [2, 2], 
"padding": "valid", "strides": [2, 2]}},
+       {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, 
"trainable": true, "seed": null, "name": "dropout_1"}},
+       {"class_name": "Flatten", "config": {"trainable": true, "name": 
"flatten_1", "data_format": "channels_last"}},
+       {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
+       {"class_name": "Zeros", "config": {}}, "units": 5, "use_bias": true, 
"activity_regularizer": null}
+       }], "backend": "tensorflow"}$$);
+
+SELECT load_keras_model('model_arch',
+  $${
+  "class_name": "Sequential",
+  "keras_version": "2.1.6",
+  "config": [{
+    "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}},
+    "name": "conv2d_1",
+    "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": 
null,
+    "dtype": "float32", "activation": "relu", "trainable": true,
+    "data_format": "channels_first", "filters": 32, "padding": "valid",
+    "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
+    "bias_initializer": {"class_name": "Zeros", "config": {}},
+    "batch_input_shape": [null, 3, 32, 32], "use_bias": true,
+    "activity_regularizer": null, "kernel_size": [3, 3]}},
+    {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", 
"trainable": true, "data_format": "channels_first", "pool_size": [2, 2], 
"padding": "valid", "strides": [2, 2]}},
+    {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, 
"trainable": true, "seed": null, "name": "dropout_1"}},
+    {"class_name": "Flatten", "config": {"trainable": true, "name": 
"flatten_1", "data_format": "channels_first"}},
+    {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
+    {"class_name": "Zeros", "config": {}}, "units": 3, "use_bias": true, 
"activity_regularizer": null}
+    }], "backend": "tensorflow"}$$);
+
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
new file mode 100644
index 0000000..dfb40a0
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
@@ -0,0 +1,61 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             
`\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
+)
+
+-- -- Please do not break up the compile_params string
+-- -- It might break the assertion
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
+
+-- Test that evaluate works as expected:
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out', 0);
+
+SELECT assert(loss IS NOT NULL AND
+        metric IS NOT NULL AND
+        metrics_type = '{mae}', 'Evaluate output validation failed.  Actual:' 
|| __to_char(evaluate_out))
+FROM evaluate_out;
+
+-- Test that passing NULL / None instead of 0 for gpus_per_host works
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out');
+SELECT assert(loss IS NOT NULL AND
+        metric IS NOT NULL AND
+        metrics_type = '{mae}', 'Evaluate output validation failed.  Actual:' 
|| __to_char(evaluate_out))
+FROM evaluate_out;
+
+-- Test that evaluate errors out correctly if model_arch field missing from 
fit output
+DROP TABLE IF EXISTS evaluate_out;
+ALTER TABLE keras_saved_out DROP COLUMN model_arch;
+SELECT assert(trap_error($TRAP$
+       SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out');
+       $TRAP$) = 1, 'Should error out if model_arch column is missing from 
model_table');
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
new file mode 100644
index 0000000..c46f307
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -0,0 +1,379 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             
`\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
+)
+
+-- -- Please do not break up the compile_params string
+-- -- It might break the assertion
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    NULL,
+    'cifar_10_sample_val');
+
+SELECT assert(
+        model_arch_table = 'model_arch' AND
+        model_arch_id = 1 AND
+        model_type = 'madlib_keras' AND
+        start_training_time         < now() AND
+        end_training_time > start_training_time AND
+        source_table = 'cifar_10_sample_batched' AND
+        validation_table = 'cifar_10_sample_val' AND
+        model = 'keras_saved_out' AND
+        dependent_varname = 'y' AND
+        dependent_vartype = 'smallint' AND
+        independent_varname = 'x' AND
+        normalizing_const = 255.0 AND
+        pg_typeof(normalizing_const) = 'real'::regtype AND
+        name is NULL AND
+        description is NULL AND
+        model_size > 0 AND
+        madlib_version is NOT NULL AND
+        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text AND
+        fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+        num_iterations = 3 AND
+        metrics_compute_frequency = 3 AND
+        num_classes = 2 AND
+        class_values = '{0,1}' AND
+        metrics_type = '{mae}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1 AND
+        validation_metrics_final >= 0 AND
+        validation_loss_final  >= 0  AND
+        array_upper(validation_metrics, 1) = 1 AND
+        array_upper(validation_loss, 1) = 1 ,
+        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+
+SELECT assert(
+        model_data IS NOT NULL AND
+        model_arch IS NOT NULL, 'Keras model output validation failed. 
Actual:' || __to_char(k))
+FROM (SELECT * FROM keras_saved_out) k;
+
+-- Verify number of iterations for which metrics and loss are computed
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    NULL,
+    'cifar_10_sample_val',
+    2);
+SELECT assert(
+        num_iterations = 3 AND
+        metrics_compute_frequency = 2 AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        metrics_type = '{accuracy}' AND
+        array_upper(training_metrics, 1) = 2 AND
+        array_upper(training_loss, 1) = 2 AND
+        array_upper(metrics_elapsed_time, 1) = 2 AND
+        validation_metrics_final >= 0 AND
+        validation_loss_final  >= 0  AND
+        array_upper(validation_metrics, 1) = 2 AND
+        array_upper(validation_loss, 1) = 2 ,
+        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+-- Fit with gpus_per_host set to 2 must error out on machines
+-- that don't have GPUs. Since Jenkins builds are run on docker containers
+-- that don't have GPUs, these queries must error out.
+DROP TABLE IF EXISTS keras_saved_out_gpu, keras_saved_out_gpu_summary;
+SELECT assert(trap_error($TRAP$madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out_gpu',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    2,
+    'cifar_10_sample_val');$TRAP$) = 1,
+       'Fit with gpus_per_host=2 must error out.');
+
+-- Test for
+  -- Non null name and description columns
+       -- Null validation table
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    2,
+    NULL,
+    NULL,
+    1,
+    NULL,
+    'model name',
+    'model desc');
+
+SELECT assert(
+    source_table = 'cifar_10_sample_batched' AND
+    model = 'keras_out' AND
+    dependent_varname = 'y' AND
+    independent_varname = 'x' AND
+    model_arch_table = 'model_arch' AND
+    model_arch_id = 1 AND
+    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
+    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+    num_iterations = 2 AND
+    validation_table is NULL AND
+    metrics_compute_frequency = 1 AND
+    name = 'model name' AND
+    description = 'model desc' AND
+    model_type = 'madlib_keras' AND
+    model_size > 0 AND
+    start_training_time         < now() AND
+    end_training_time > start_training_time AND
+    array_upper(metrics_elapsed_time, 1) = 2 AND
+    dependent_vartype = 'smallint' AND
+    madlib_version is NOT NULL AND
+    num_classes = 2 AND
+    class_values = '{0,1}' AND
+    metrics_type = '{accuracy}' AND
+    normalizing_const = 255.0 AND
+    training_metrics_final is not NULL AND
+    training_loss_final is not NULL AND
+    array_upper(training_metrics, 1) = 2 AND
+    array_upper(training_loss, 1) = 2 AND
+    validation_metrics_final is  NULL AND
+    validation_loss_final is  NULL AND
+    validation_metrics is NULL AND
+    validation_loss is NULL,
+    'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_out_summary) summary;
+
+SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_out) k;
+
+-- Validate metrics=NULL works with fit
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+'cifar_10_sample_batched',
+'keras_saved_out',
+'model_arch',
+1,
+$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy'$$::text,
+$$ batch_size=2, epochs=1, verbose=0 $$::text,
+1);
+
+SELECT assert(
+        metrics_type is NULL AND
+        training_metrics IS NULL AND
+        array_upper(training_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1 AND
+        validation_metrics_final IS NULL AND
+        validation_loss_final  >= 0  AND
+        validation_metrics IS NULL AND
+        array_upper(validation_loss, 1) = 1,
+        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+
+-- Validate metrics=[] works with fit
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+'cifar_10_sample_batched',
+'keras_saved_out',
+'model_arch',
+1,
+$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=[]$$::text,
+$$ batch_size=2, epochs=1, verbose=0 $$::text,
+1);
+
+SELECT assert(
+        metrics_type IS NULL AND
+        training_metrics IS NULL AND
+        array_upper(training_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1 AND
+        validation_metrics_final IS NULL AND
+        validation_loss_final  >= 0  AND
+        validation_metrics IS NULL AND
+        array_upper(validation_loss, 1) = 1,
+        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+
+-- Compile and fit parameter tests
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'model_arch',
+    1,
+    $$ optimizer='SGD', loss='categorical_crossentropy', 
metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    1,
+    NULL,
+    NULL,
+    NULL,
+    NULL, 'model name', 'model desc');
+
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'model_arch',
+    1,
+    $$ optimizer='Adam()', loss='categorical_crossentropy', 
metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    1,
+    NULL,
+    NULL,
+    NULL,
+    NULL, 'model name', 'model desc');
+
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'model_arch',
+    1,
+    $$ optimizer=Adam(epsilon=None), loss='categorical_crossentropy', 
metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    1,
+    0,
+    NULL,
+    NULL,
+    NULL, 'model name', 'model desc');
+
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, 
loss='categorical_crossentropy' $$::text,
+    $$ epochs=10, verbose=0, shuffle=True, initial_epoch=1, steps_per_epoch=2 
$$::text,
+    1,
+    NULL,
+    NULL,
+    NULL,
+    False, 'model name', 'model desc');
+
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, 
loss='categorical_crossentropy' $$::text,
+    NULL,
+    1,
+    NULL,
+    NULL,
+    NULL,
+    False, 'model name', 'model desc');
+
+-- -- negative test case for passing non numeric y to fit
+-- induce failure by passing a non numeric column
+DROP TABLE IF EXISTS cifar_10_sample_val_failure;
+CREATE TABLE cifar_10_sample_val_failure AS SELECT * FROM cifar_10_sample_val;
+ALTER TABLE cifar_10_sample_val_failure rename dependent_var to 
dependent_var_original;
+ALTER TABLE cifar_10_sample_val_failure rename buffer_id to dependent_var;
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT assert(trap_error($TRAP$madlib_keras_fit(
+           'cifar_10_sample_batched',
+           'keras_out',
+           'model_arch',
+           1,
+           $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+           $$ batch_size=2, epochs=1, verbose=0 $$::text,
+           2,
+           NULL,
+          'cifar_10_sample_val_failure');$TRAP$) = 1,
+       'Passing y of type non numeric array to fit should error out.');
+
+-- Tests with text class values:
+-- Modify input data to have text classes, and mini-batch it.
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_text_batched',
+    'keras_saved_out',
+    'model_arch',
+    2,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
+-- Assert fit has correct class_values
+SELECT assert(
+    dependent_vartype = 'text' AND
+    class_values = '{NULL,cat,dog,NULL,NULL}',
+    'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+
+-- Test with INTEGER class_values
+-- with NULL as a valid class value
+INSERT INTO cifar_10_sample(id, x, y, imgpath)
+SELECT 3, x, NULL, '0/img3.jpg' FROM cifar_10_sample
+WHERE y = 1;
+INSERT INTO cifar_10_sample(id, x, y, imgpath)
+SELECT 4, x, 4, '0/img4.jpg' FROM cifar_10_sample
+WHERE y = 0;
+INSERT INTO cifar_10_sample(id, x, y, imgpath)
+SELECT 5, x, 5, '0/img5.jpg' FROM cifar_10_sample
+WHERE y = 1;
+
+DROP TABLE IF EXISTS cifar_10_sample_int_batched;
+DROP TABLE IF EXISTS cifar_10_sample_int_batched_summary;
+SELECT 
training_preprocessor_dl('cifar_10_sample','cifar_10_sample_int_batched','y','x',
 2, 255, 5);
+
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_int_batched',
+    'keras_saved_out',
+    'model_arch',
+    2,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
+
+-- Assert fit has correct class_values
+SELECT assert(
+    dependent_vartype = 'smallint' AND
+    class_values = '{NULL,0,1,4,5}',
+    'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_test_shape_batched',
+    'keras_saved_out',
+    'model_arch',
+    3,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
new file mode 100644
index 0000000..066adb8
--- /dev/null
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
@@ -0,0 +1,266 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+-------------------- IRIS test input tables -----------------
+
+DROP TABLE IF EXISTS iris_data;
+CREATE TABLE iris_data(
+    id serial,
+    attributes numeric[],
+    class_text varchar
+);
+INSERT INTO iris_data(id, attributes, class_text) VALUES
+(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),
+(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),
+(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),
+(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),
+(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),
+(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),
+(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),
+(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),
+(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),
+(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
+(11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),
+(12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),
+(13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),
+(14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),
+(15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),
+(16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),
+(17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),
+(18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),
+(19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),
+(20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),
+(21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),
+(22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),
+(23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),
+(24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),
+(25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),
+(26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),
+(27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),
+(28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),
+(29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),
+(30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),
+(31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),
+(32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),
+(33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),
+(34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),
+(35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
+(36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),
+(37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),
+(38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
+(39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),
+(40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),
+(41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),
+(42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),
+(43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),
+(44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),
+(45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),
+(46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),
+(47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),
+(48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),
+(49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),
+(50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),
+(51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),
+(52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),
+(53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),
+(54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),
+(55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),
+(56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),
+(57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),
+(58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),
+(59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),
+(60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),
+(61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),
+(62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),
+(63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),
+(64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),
+(65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),
+(66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),
+(67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),
+(68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),
+(69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),
+(70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),
+(71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),
+(72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),
+(73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),
+(74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),
+(75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),
+(76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),
+(77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),
+(78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),
+(79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),
+(80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),
+(81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),
+(82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),
+(83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),
+(84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),
+(85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),
+(86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),
+(87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),
+(88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),
+(89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),
+(90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),
+(91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),
+(92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),
+(93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),
+(94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),
+(95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),
+(96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),
+(97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),
+(98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),
+(99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),
+(100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),
+(101,ARRAY[6.3,3.3,6.0,2.5],'Iris-virginica'),
+(102,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
+(103,ARRAY[7.1,3.0,5.9,2.1],'Iris-virginica'),
+(104,ARRAY[6.3,2.9,5.6,1.8],'Iris-virginica'),
+(105,ARRAY[6.5,3.0,5.8,2.2],'Iris-virginica'),
+(106,ARRAY[7.6,3.0,6.6,2.1],'Iris-virginica'),
+(107,ARRAY[4.9,2.5,4.5,1.7],'Iris-virginica'),
+(108,ARRAY[7.3,2.9,6.3,1.8],'Iris-virginica'),
+(109,ARRAY[6.7,2.5,5.8,1.8],'Iris-virginica'),
+(110,ARRAY[7.2,3.6,6.1,2.5],'Iris-virginica'),
+(111,ARRAY[6.5,3.2,5.1,2.0],'Iris-virginica'),
+(112,ARRAY[6.4,2.7,5.3,1.9],'Iris-virginica'),
+(113,ARRAY[6.8,3.0,5.5,2.1],'Iris-virginica'),
+(114,ARRAY[5.7,2.5,5.0,2.0],'Iris-virginica'),
+(115,ARRAY[5.8,2.8,5.1,2.4],'Iris-virginica'),
+(116,ARRAY[6.4,3.2,5.3,2.3],'Iris-virginica'),
+(117,ARRAY[6.5,3.0,5.5,1.8],'Iris-virginica'),
+(118,ARRAY[7.7,3.8,6.7,2.2],'Iris-virginica'),
+(119,ARRAY[7.7,2.6,6.9,2.3],'Iris-virginica'),
+(120,ARRAY[6.0,2.2,5.0,1.5],'Iris-virginica'),
+(121,ARRAY[6.9,3.2,5.7,2.3],'Iris-virginica'),
+(122,ARRAY[5.6,2.8,4.9,2.0],'Iris-virginica'),
+(123,ARRAY[7.7,2.8,6.7,2.0],'Iris-virginica'),
+(124,ARRAY[6.3,2.7,4.9,1.8],'Iris-virginica'),
+(125,ARRAY[6.7,3.3,5.7,2.1],'Iris-virginica'),
+(126,ARRAY[7.2,3.2,6.0,1.8],'Iris-virginica'),
+(127,ARRAY[6.2,2.8,4.8,1.8],'Iris-virginica'),
+(128,ARRAY[6.1,3.0,4.9,1.8],'Iris-virginica'),
+(129,ARRAY[6.4,2.8,5.6,2.1],'Iris-virginica'),
+(130,ARRAY[7.2,3.0,5.8,1.6],'Iris-virginica'),
+(131,ARRAY[7.4,2.8,6.1,1.9],'Iris-virginica'),
+(132,ARRAY[7.9,3.8,6.4,2.0],'Iris-virginica'),
+(133,ARRAY[6.4,2.8,5.6,2.2],'Iris-virginica'),
+(134,ARRAY[6.3,2.8,5.1,1.5],'Iris-virginica'),
+(135,ARRAY[6.1,2.6,5.6,1.4],'Iris-virginica'),
+(136,ARRAY[7.7,3.0,6.1,2.3],'Iris-virginica'),
+(137,ARRAY[6.3,3.4,5.6,2.4],'Iris-virginica'),
+(138,ARRAY[6.4,3.1,5.5,1.8],'Iris-virginica'),
+(139,ARRAY[6.0,3.0,4.8,1.8],'Iris-virginica'),
+(140,ARRAY[6.9,3.1,5.4,2.1],'Iris-virginica'),
+(141,ARRAY[6.7,3.1,5.6,2.4],'Iris-virginica'),
+(142,ARRAY[6.9,3.1,5.1,2.3],'Iris-virginica'),
+(143,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
+(144,ARRAY[6.8,3.2,5.9,2.3],'Iris-virginica'),
+(145,ARRAY[6.7,3.3,5.7,2.5],'Iris-virginica'),
+(146,ARRAY[6.7,3.0,5.2,2.3],'Iris-virginica'),
+(147,ARRAY[6.3,2.5,5.0,1.9],'Iris-virginica'),
+(148,ARRAY[6.5,3.0,5.2,2.0],'Iris-virginica'),
+(149,ARRAY[6.2,3.4,5.4,2.3],'Iris-virginica'),
+(150,ARRAY[5.9,3.0,5.1,1.8],'Iris-virginica');
+
+DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_summary;
+SELECT training_preprocessor_dl('iris_data',         -- Source table
+                                'iris_data_packed',  -- Output table
+                                'class_text',        -- Dependent variable
+                                'attributes'         -- Independent variable
+                                );
+
+DROP TABLE IF EXISTS iris_model_arch;
+-- NOTE: The seed is set to 0 for every layer.
+SELECT load_keras_model('iris_model_arch',  -- Output table,
+$$
+{
+"class_name": "Sequential",
+"keras_version": "2.1.6",
+"config":
+    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "dtype": "float32", "activation": "relu", 
"trainable": true,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": 
true,
+    "activity_regularizer": null}}, {"class_name": "Dense",
+    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "activation": "relu", "trainable": true, 
"kernel_regularizer": null,
+    "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, 
"use_bias": true,
+    "activity_regularizer": null}}, {"class_name": "Dense", "config": 
{"kernel_initializer":
+    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", 
"scale": 1.0,
+    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": 
null,
+    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
+    "trainable": true, "kernel_regularizer": null, "bias_initializer": 
{"class_name": "Zeros",
+    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": 
null}}],
+    "backend": "tensorflow"}
+$$
+);
+
+-- Duplicate the architecture, but note that trainable is set to FALSE.
+-- This is to ensure we don't learn anything new, that would help us
+-- deterministically assert the accuracy and loss after transfer learning
+-- and warm start.
+SELECT load_keras_model('iris_model_arch',  -- Output table,
+$$
+{
+"class_name": "Sequential",
+"keras_version": "2.1.6",
+"config":
+    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "dtype": "float32", "activation": "relu",
+    "trainable": false,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": 
true,
+    "activity_regularizer": null}}, {"class_name": "Dense",
+    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "activation": "relu",
+    "trainable": false,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 10, "use_bias": true, "activity_regularizer": 
null}},
+    {"class_name": "Dense", "config": {"kernel_initializer":
+    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", 
"scale": 1.0,
+    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": 
null,
+    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
+    "trainable": false,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": 
null}}],
+    "backend": "tensorflow"}
+$$
+);
+
+DROP TABLE IF EXISTS iris_train, iris_test;
+-- Set seed so results are reproducible
+SELECT setseed(0);
+SELECT train_test_split('iris_data',     -- Source table
+                        'iris',          -- Output table root name
+                        0.8,            -- Train proportion
+                        NULL,           -- Test proportion (0.2)
+                        NULL,           -- Strata definition
+                        NULL,           -- Output all columns
+                        NULL,           -- Sample without replacement
+                        TRUE            -- Separate output tables
+                        );
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
new file mode 100644
index 0000000..7a27c7b
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
@@ -0,0 +1,316 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             
`\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
+)
+
+-- Please do not break up the compile_params string
+-- It might break the assertion
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
+
+-- Prediction with gpus_per_host set to 2 must error out on machines
+-- that don't have GPUs. Since Jenkins builds are run on docker containers
+-- that don't have GPUs, these queries must error out.
+
+-- IMPORTANT: The following test must be run when we have a valid
+-- keras_saved_out model table. Otherwise, it will fail because of a
+-- non-existent model table, while we want to trap failure due to
+-- gpus_per_host=2
+DROP TABLE IF EXISTS cifar10_predict_gpu;
+SELECT assert(trap_error($TRAP$madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict_gpu',
+    NULL,
+    2);$TRAP$) = 1,
+    'Prediction with gpus_per_host=2 must error out.');
+
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    NULL,
+    0);
+
+-- Validate that prediction output table exists and has correct schema
+SELECT assert(UPPER(pg_typeof(id)::TEXT) = 'INTEGER', 'id column should be 
INTEGER type')
+    FROM cifar10_predict;
+
+SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
+    'SMALLINT', 'prediction column should be SMALLINT type')
+    FROM cifar10_predict;
+
+-- Validate correct number of rows returned.
+SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have 
two rows')
+FROM cifar10_predict;
+
+-- First test that all values are in set of class values; if this breaks, it's 
definitely a problem.
+SELECT assert(estimated_y IN (0,1),
+    'Predicted value not in set of defined class values for model')
+FROM cifar10_predict;
+
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT assert(trap_error($TRAP$madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_batched',
+    'id',
+    'x',
+    'cifar10_predict',
+    NULL,
+    0);$TRAP$) = 1,
+    'Passing batched image table to predict should error out.');
+
+-- Test with pred_type=prob
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob',
+    0);
+
+SELECT assert(UPPER(pg_typeof(prob_0)::TEXT) =
+    'DOUBLE PRECISION', 'column prob_0 should be double precision type')
+    FROM  cifar10_predict;
+
+SELECT assert(UPPER(pg_typeof(prob_1)::TEXT) =
+    'DOUBLE PRECISION', 'column prob_1 should be double precision type')
+    FROM  cifar10_predict;
+
+SELECT assert(COUNT(*)=3, 'Predict out table must have exactly three cols.')
+FROM pg_attribute
+WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+
+-- Tests with text class values:
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_text_batched',
+    'keras_saved_out',
+    'model_arch',
+    2,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
+
+-- Predict with pred_type=prob
+DROP TABLE IF EXISTS cifar_10_sample_text;
+CREATE TABLE cifar_10_sample_text AS
+    SELECT id, x, y_text
+    FROM cifar_10_sample;
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob',
+    0);
+
+-- Validate the output datatype of newly created prediction columns
+-- for prediction type = 'prob' and class_values 'TEXT' with NULL as a valid
+-- class_values
+SELECT assert(UPPER(pg_typeof(prob_cat)::TEXT) =
+    'DOUBLE PRECISION', 'column prob_cat should be double precision type')
+FROM cifar10_predict;
+
+SELECT assert(UPPER(pg_typeof(prob_dog)::TEXT) =
+    'DOUBLE PRECISION', 'column prob_dog should be double precision type')
+FROM cifar10_predict;
+
+SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
+    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
+FROM cifar10_predict;
+
+-- Must have exactly 4 cols (3 for class_values and 1 for id)
+SELECT assert(COUNT(*)=4, 'Predict out table must have exactly four cols.')
+FROM pg_attribute
+WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+
+-- Predict with pred_type=response
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'response',
+    0);
+
+-- Validate the output datatype of newly created prediction columns
+-- for prediction type = 'response' and class_values 'TEXT' with NULL
+-- as a valid class_values
+SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) =
+    'TEXT', 'prediction column should be TEXT type')
+FROM  cifar10_predict LIMIT 1;
+
+-- Tests where the assumption is user has one-hot encoded, so class_values
+-- in input summary table will be NULL.
+UPDATE keras_saved_out_summary SET class_values=NULL;
+
+-- Predict with pred_type=prob
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob',
+    0);
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'response' and class_value = NULL
+-- Returns: Array of probabilities for user's one-hot encoded data
+SELECT assert(UPPER(pg_typeof(prob)::TEXT) =
+    'DOUBLE PRECISION[]', 'column prob should be double precision[] type')
+FROM cifar10_predict LIMIT 1;
+
+-- Predict with pred_type=response
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'response',
+    0);
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'response' and class_value = NULL
+-- Returns: Index of class value in user's one-hot encoded data with
+-- highest probability
+SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) =
+    'TEXT', 'column estimated_y_text should be text type')
+FROM cifar10_predict LIMIT 1;
+
+-- Test predict with INTEGER class_values
+-- with NULL as a valid class value
+-- Update output_summary table to reflect
+-- class_values {NULL,0,1,4,5} and dependent_vartype is SMALLINT
+UPDATE keras_saved_out_summary
+SET dependent_varname = 'y',
+    class_values = ARRAY[NULL,0,1,4,5]::INTEGER[],
+    dependent_vartype = 'smallint';
+-- Predict with pred_type=prob
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob',
+    0);
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'prob' and class_values 'INT' with NULL
+-- as a valid class_values
+SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
+    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
+FROM cifar10_predict;
+
+-- Must have exactly 6 cols (5 for class_values and 1 for id)
+SELECT assert(COUNT(*)=6, 'Predict out table must have exactly six cols.')
+FROM pg_attribute
+WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+
+-- Predict with pred_type=response
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    'response',
+    0);
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'response' and class_values 'TEXT' with NULL
+-- as a valid class_values
+-- Returns: class_value with highest probability
+SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
+    'SMALLINT', 'prediction column should be smallint type')
+FROM cifar10_predict;
+
+-- Predict with correctly shaped data, must go thru.
+-- Update output_summary table to reflect
+-- class_values, num_classes and model_arch_id for shaped data
+UPDATE keras_saved_out
+SET model_arch = (SELECT model_arch from model_arch where model_id = 3);
+UPDATE keras_saved_out_summary
+SET model_arch_id = 3,
+    num_classes = 3,
+    class_values = ARRAY[0,NULL,NULL]::INTEGER[];
+
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_test_shape',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob',
+    0);
+
+-- Prediction with incorrectly shaped data must error out.
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT assert(trap_error($TRAP$madlib_keras_predict(
+        'keras_saved_out',
+        'cifar_10_sample',
+        'id',
+        'x',
+        'cifar10_predict',
+        'prob',
+        0);$TRAP$) = 1,
+    'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should 
have failed.');
+
+-- Test model_arch is retrieved from model data table and not model 
architecture
+DROP TABLE IF EXISTS model_arch;
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_test_shape',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob',
+    0);
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
new file mode 100644
index 0000000..0e493aa
--- /dev/null
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
@@ -0,0 +1,137 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+DROP TABLE IF EXISTS iris_model, iris_model_summary;
+SELECT madlib_keras_fit('iris_data_packed',   -- source table
+                        'iris_model',          -- model output table
+                        'iris_model_arch',  -- model arch table
+                         1,                    -- model arch id
+                         $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
+                         $$ batch_size=5, epochs=3 $$,  -- fit_params
+                         3);                    -- num_iterations
+
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict('iris_model', -- model
+                            'iris_test',  -- test_table
+                            'id',  -- id column
+                            'attributes', -- independent var
+                            'iris_predict'  -- output table
+                            );
+
+-- Copy weights that were learnt from the previous run, for transfer
+-- learning. Copy it now, because using warm_start will overwrite it.
+UPDATE iris_model_arch set model_weights = (select model_data from iris_model) 
 WHERE model_id = 2;
+
+-- class_values not NULL, pred_type is response
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom',
+                                 'response',
+                                 -1,
+                                 ARRAY['Iris-setosa', 'Iris-versicolor',
+                                  'Iris-virginica']
+                                 );
+
+SELECT assert(
+  p0.estimated_class_text = p1.estimated_dependent_var,
+  'Predict byom failure for non null class value and response pred_type.')
+FROM iris_predict AS p0,  iris_predict_byom AS p1
+WHERE p0.id=p1.id;
+SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
+       'Predict byom failure for non null class value and response pred_type.
+        Expeceted estimated_dependent_var to be of type TEXT')
+FROM  iris_predict_byom LIMIT 1;
+
+-- class_values NULL, pred_type is NULL (response)
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom'
+                                 );
+SELECT assert(
+  p1.estimated_dependent_var IN ('0', '1', '2'),
+  'Predict byom failure for null class value and null pred_type.')
+FROM iris_predict_byom AS p1;
+SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
+       'Predict byom failure for non null class value and response pred_type.
+        Expeceted estimated_dependent_var to be of type TEXT')
+FROM  iris_predict_byom LIMIT 1;
+
+-- class_values not NULL, pred_type is prob
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom',
+                                 'prob',
+                                 -1,
+                                 ARRAY['Iris-setosa', 'Iris-versicolor',
+                                  'Iris-virginica'],
+                                 1.0
+                                 );
+
+SELECT assert(
+  (p1."prob_Iris-setosa" + p1."prob_Iris-virginica" + 
p1."prob_Iris-versicolor") - 1 < 1e-6,
+    'Predict byom failure for non null class value and prob pred_type.')
+FROM iris_predict_byom AS p1;
+SELECT assert(UPPER(pg_typeof("prob_Iris-setosa")::TEXT) = 'DOUBLE PRECISION',
+       'Predict byom failure for non null class value and prob pred_type.
+       Expeceted "prob_Iris-setosa" to be of type DOUBLE PRECISION')
+FROM  iris_predict_byom LIMIT 1;
+
+-- class_values NULL, pred_type is prob
+DROP TABLE IF EXISTS iris_predict_byom;
+SELECT madlib_keras_predict_byom(
+                                 'iris_model_arch',
+                                 2,
+                                 'iris_test',
+                                 'id',
+                                 'attributes',
+                                 'iris_predict_byom',
+                                 'prob',
+                                 0,
+                                 NULL
+                                 );
+SELECT assert(
+  (prob[1] + prob[2] + prob[3]) - 1 < 1e-6,
+    'Predict byom failure for null class value and prob pred_type.')
+FROM iris_predict_byom;
+SELECT assert(UPPER(pg_typeof(prob)::TEXT) = 'DOUBLE PRECISION[]',
+       'Predict byom failure for null class value and prob pred_type. 
Expeceted prob to
+       be of type DOUBLE PRECISION[]')
+FROM  iris_predict_byom LIMIT 1;
\ No newline at end of file
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
new file mode 100644
index 0000000..33f5886
--- /dev/null
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -0,0 +1,116 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+DROP TABLE IF EXISTS iris_model, iris_model_summary;
+SELECT madlib_keras_fit('iris_data_packed',   -- source table
+                        'iris_model',          -- model output table
+                        'iris_model_arch',  -- model arch table
+                         1,                    -- model arch id
+                         $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
+                         $$ batch_size=5, epochs=3 $$,  -- fit_params
+                         5,                    -- num_iterations
+                         NULL, NULL,
+                         1 -- metrics_compute_frequency
+                        );
+
+-- Test that our code is indeed learning something and not broken. The loss
+-- from the first iteration should be less than the 5th, while the accuracy
+-- must be greater.
+SELECT assert(
+  array_upper(training_loss, 1) = 5 AND
+  array_upper(training_metrics, 1) = 5,
+  'metrics compute frequency must be 1.')
+FROM iris_model_summary;
+
+SELECT assert(
+  training_loss[5]-training_loss[1] < 0 AND
+  training_metrics[5]-training_metrics[1] > 0,
+    'The loss and accuracy should have improved with more iterations.'
+)
+FROM iris_model_summary;
+
+-- Make a copy of the loss and metrics array, to compare it with runs after
+-- warm start and transfer learning.
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT training_loss_final, training_metrics_final
+FROM iris_model_summary;
+
+-- Copy weights that were learnt from the previous run, for transfer
+-- learning. Copy it now, because using warm_start will overwrite it.
+UPDATE iris_model_arch set model_weights = (select model_data from iris_model) 
 WHERE model_id = 2;
+
+-- Warm start test
+SELECT madlib_keras_fit('iris_data_packed',   -- source table
+                       'iris_model',          -- model output table
+                       'iris_model_arch',  -- model arch table
+                        2,                    -- model arch id
+                        $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
+                        $$ batch_size=5, epochs=3 $$,  -- fit_params
+                        2,                    -- num_iterations,
+                        NULL, NULL, 1,
+                        true -- warm start
+                      );
+
+SELECT assert(
+  array_upper(training_loss, 1) = 2 AND
+  array_upper(training_metrics, 1) = 2,
+  'metrics compute frequency must be 1.')
+FROM iris_model_summary;
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
+  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
+  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  'warm start test failed because training loss and metrics don''t match the 
expected value from the previous run of keras fit.')
+FROM iris_model_first_run AS first, iris_model_summary AS second;
+
+-- Transfer learning test
+DROP TABLE IF EXISTS iris_model_transfer, iris_model_transfer_summary;
+SELECT madlib_keras_fit('iris_data_packed',   -- source table
+                       'iris_model_transfer',          -- model output table
+                       'iris_model_arch',  -- model arch table
+                        2,                    -- model arch id
+                        $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
+                        $$ batch_size=5, epochs=3 $$,  -- fit_params
+                        2,
+                        NULL, NULL, 1
+                      );
+
+SELECT assert(
+  array_upper(training_loss, 1) = 2 AND
+  array_upper(training_metrics, 1) = 2,
+  'metrics compute frequency must be 1.')
+FROM iris_model_transfer_summary;
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
+  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
+  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  'Transfer learning test failed because training loss and metrics don''t 
match the expected value.')
+FROM iris_model_first_run AS first, iris_model_transfer_summary AS second;

Reply via email to