This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 374145f  DL: Add model_arch column to model data table
374145f is described below

commit 374145f2214cb36b5e328f9d98779e91d45c0eaf
Author: Ekta Khanna <ekha...@pivotal.io>
AuthorDate: Fri May 17 15:35:13 2019 -0700

    DL: Add model_arch column to model data table
    
    JIRA: MADLIB-1347
    
    This commit adds the 'model_arch' column to the model data table so that
    predict and evaluate can directly get the model_arch from the model data
    table instead of the model arch table.
    
    We have made changes to the predict code in this commit but the changes
    for evaluate will be done in a future PR.
    
    closes #394
    
    Co-authored-by: Nikhil Kak <n...@pivotal.io>
---
 .../modules/deep_learning/madlib_keras.py_in       |  7 +++--
 .../deep_learning/madlib_keras_validator.py_in     | 11 ++++++--
 .../deep_learning/predict_input_params.py_in       | 32 ++++------------------
 .../modules/deep_learning/test/madlib_keras.sql_in | 17 +++++++++++-
 4 files changed, 34 insertions(+), 33 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 82384d4..087838f 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -320,9 +320,10 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
                   description, aggregate_runtime, class_values])
 
     create_output_table = plpy.prepare("""
-        CREATE TABLE {0} AS
-        SELECT $1 as model_data""".format(model), ["bytea"])
-    plpy.execute(create_output_table, [model_state])
+        CREATE TABLE {0} AS SELECT
+        $1 as model_data,
+        $2 as {1}""".format(model, Format.MODEL_ARCH), ["bytea", "json"])
+    plpy.execute(create_output_table, [model_state, model_arch])
 
     if is_platform_pg():
         clear_keras_session()
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 9210550..00426dd 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -18,6 +18,7 @@
 # under the License.
 
 import plpy
+from keras_model_arch_table import Format
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import COMPILE_PARAMS_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
@@ -107,11 +108,17 @@ class PredictInputValidator:
 
     def _validate_model_data_col(self):
         _assert(is_var_valid(self.model_table, MODEL_DATA_COLNAME),
-                "{module_name} error: invalid model_data "
-                "('{model_data}') in model table ({table}).".format(
+                "{module_name} error: column '{model_data}' "
+                "does not exist in model table '{table}'.".format(
                     module_name=self.module_name,
                     model_data=MODEL_DATA_COLNAME,
                     table=self.model_table))
+        _assert(is_var_valid(self.model_table, Format.MODEL_ARCH),
+                "{module_name} error: column '{model_arch}' "
+                "does not exist in model table '{table}'.".format(
+                    module_name=self.module_name,
+                    model_arch=Format.MODEL_ARCH,
+                    table=self.model_table))
 
     def _validate_test_tbl_cols(self):
         _assert(is_var_valid(self.test_table, self.independent_varname),
diff --git 
a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in 
b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
index aba6dce..7f88020 100644
--- a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
+++ b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
@@ -25,8 +25,6 @@ from utilities.validate_args import input_tbl_valid
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
 from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
-from madlib_keras_helper import MODEL_ARCH_ID_COLNAME
-from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
 from madlib_keras_helper import MODEL_DATA_COLNAME
 from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 
@@ -36,28 +34,11 @@ class PredictParamsProcessor:
         self.model_table = model_table
         self.model_summary_table = add_postfix(self.model_table, '_summary')
         input_tbl_valid(self.model_summary_table, self.module_name)
-        self.model_summary_dict = self._get_model_summary_dict()
-        self.model_arch_dict = self._get_model_arch_dict()
+        self.model_summary_dict = 
self._get_dict_for_table(self.model_summary_table)
+        self.model_arch_dict = self._get_dict_for_table(self.model_table)
 
-    def _get_model_summary_dict(self):
-        return plpy.execute("SELECT * FROM {0}".format(
-            self.model_summary_table))[0]
-
-    def _get_model_arch_dict(self):
-        model_arch_table = self.model_summary_dict[MODEL_ARCH_TABLE_COLNAME]
-        model_arch_id = self.model_summary_dict[MODEL_ARCH_ID_COLNAME]
-        input_tbl_valid(model_arch_table, self.module_name)
-        model_arch_query = """
-            SELECT {0}
-            FROM {1}
-            WHERE {2} = {3}
-        """.format(Format.MODEL_ARCH, model_arch_table, Format.MODEL_ID,
-                   model_arch_id)
-        query_result = plpy.execute(model_arch_query)
-        if not query_result or len(query_result) == 0:
-            plpy.error("{0}: No model arch found in table {1} with id 
{2}".format(
-                self.module_name, model_arch_table, model_arch_id))
-        return query_result[0]
+    def _get_dict_for_table(self, table_name):
+        return plpy.execute("SELECT * FROM {0}".format(table_name), 1)[0]
 
     def get_class_values(self):
         return self.model_summary_dict[CLASS_VALUES_COLNAME]
@@ -72,10 +53,7 @@ class PredictParamsProcessor:
         return self.model_arch_dict[Format.MODEL_ARCH]
 
     def get_model_data(self):
-        return plpy.execute("""
-                SELECT {0} FROM {1}
-            """.format(MODEL_DATA_COLNAME, self.model_table)
-                            )[0][MODEL_DATA_COLNAME]
+        return self.model_arch_dict[MODEL_DATA_COLNAME]
 
     def get_normalizing_const(self):
         return self.model_summary_dict[NORMALIZING_CONST_COLNAME]
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index b5aaa6d..93ca513 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -151,7 +151,10 @@ SELECT assert(
         'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
 FROM (SELECT * FROM keras_saved_out_summary) summary;
 
-SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_saved_out) k;
+SELECT assert(
+        model_data IS NOT NULL AND
+        model_arch IS NOT NULL, 'Keras model output validation failed. 
Actual:' || __to_char(k))
+FROM (SELECT * FROM keras_saved_out) k;
 
 
 -- Verify number of iterations for which metrics and loss are computed
@@ -743,3 +746,15 @@ SELECT assert(trap_error($TRAP$madlib_keras_predict(
         'prob',
         0);$TRAP$) = 1,
     'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should 
have failed.');
+
+-- Test model_arch is retrieved from model data table and not model 
architecture
+DROP TABLE IF EXISTS model_arch;
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_test_shape',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob',
+    0);

Reply via email to