kaknikhil commented on a change in pull request #370: DL: Support response and 
prob prediction outputs
URL: https://github.com/apache/madlib/pull/370#discussion_r276755216
 
 

 ##########
 File path: src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
 ##########
 @@ -27,107 +27,113 @@ from keras.models import *
 from keras.optimizers import *
 import numpy as np
 
+from madlib_keras_helper import expand_input_dims
+from madlib_keras_helper import PredictParamsProcessor
+from madlib_keras_helper import MODEL_DATA_CNAME
+from madlib_keras_wrapper import compile_and_set_weights
 from utilities.model_arch_info import get_input_shape
 from utilities.utilities import add_postfix
-from utilities.validate_args import get_col_value_and_type
+from utilities.utilities import create_cols_from_array_sql_string
+from utilities.utilities import unique_string
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
-from madlib_keras_validator import CLASS_VALUES_COLNAME
-from keras_model_arch_table import Format
 
-from madlib_keras_wrapper import compile_and_set_weights
 import madlib_keras_serializer
 
 MODULE_NAME = 'madlib_keras_predict'
+
+def validate_pred_type(pred_type, class_values):
+    if not pred_type in ['prob', 'response']:
+        plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
+            "either response or prob.".format(MODULE_NAME, pred_type))
+    if pred_type == 'prob' and class_values and len(class_values)+1 >= 1600:
+        plpy.error({"{0}: The output will have {1} columns, exceeding the "\
+            " max number of columns that can be created (1600)".format(
+                MODULE_NAME, len(class_values)+1)})
+
 def predict(schema_madlib, model_table, test_table, id_col,
-            independent_varname, output_table, **kwargs):
+            independent_varname, output_table, pred_type, **kwargs):
+    # Refactor and add more validation as part of MADLIB-1312.
     input_tbl_valid(model_table, MODULE_NAME)
-    model_summary_table = add_postfix(model_table, '_summary')
-    input_tbl_valid(model_summary_table, MODULE_NAME)
     input_tbl_valid(test_table, MODULE_NAME)
     output_tbl_valid(output_table, MODULE_NAME)
-    model_summary_dict = plpy.execute("SELECT * FROM {0}".format(
-        model_summary_table))[0]
-    model_arch_table = model_summary_dict['model_arch_table']
-    model_arch_id = model_summary_dict['model_arch_id']
-    compile_params = model_summary_dict['compile_params']
-    input_tbl_valid(model_arch_table, MODULE_NAME)
-
-    model_data_query = "SELECT model_data from {0}".format(model_table)
-    model_data = plpy.execute(model_data_query)[0]['model_data']
+    param_proc = PredictParamsProcessor(model_table, MODULE_NAME)
 
-    model_arch_query = """
-        SELECT {0}, {1}
-        FROM {2}
-        WHERE {3} = {4}
-        """.format(Format.MODEL_ARCH, Format.MODEL_WEIGHTS,model_arch_table,
-                   Format.MODEL_ID, model_arch_id)
-    query_result = plpy.execute(model_arch_query)
-    if not  query_result or len(query_result) == 0:
-        plpy.error("{0}: No model arch found in table {1} with id {2}".format(
-            MODULE_NAME, model_arch_table, model_arch_id))
-    query_result = query_result[0]
-    model_arch = query_result[Format.MODEL_ARCH]
+    class_values = param_proc.get_class_values()
+    compile_params = param_proc.get_compile_params()
+    dependent_varname = param_proc.get_dependent_varname()
+    dependent_vartype = param_proc.get_dependent_vartype()
+    model_data = param_proc.get_model_data()
+    model_arch = param_proc.get_model_arch()
+    normalizing_const = param_proc.get_normalizing_const()
+    # TODO: Validate input shape as part of MADLIB-1312
     input_shape = get_input_shape(model_arch)
     compile_params = "$madlib$" + compile_params + "$madlib$"
-    model_summary_table = add_postfix(model_table, "_summary")
-    class_values, _ = get_col_value_and_type(model_summary_table,
-                                             CLASS_VALUES_COLNAME)
-    predict_query = plpy.prepare("""
+
+    validate_pred_type(pred_type, class_values)
+    is_response = True if pred_type == 'response' else False
+    intermediate_col = unique_string()
+    if is_response:
+        pred_col_name = add_postfix("estimated_", dependent_varname)
+        pred_col_type = dependent_vartype
+    else:
+        pred_col_name = "prob"
+        pred_col_type = 'double precision'
+
+    num_of_valid_class_values = 0
+    if class_values is not None:
+        for ele in class_values:
+            if ele is None and num_of_valid_class_values > 0:
+                break
+            num_of_valid_class_values += 1
+        # Pass only the valid class_values for creating columns
+        class_values = class_values[:num_of_valid_class_values]
+
+    prediction_select_clause = create_cols_from_array_sql_string(
+        class_values, intermediate_col, pred_col_name,
+        pred_col_type, is_response, MODULE_NAME)
+
+    plpy.execute("""
         CREATE TABLE {output_table} AS
-        SELECT {id_col},
-            ({schema_madlib}.internal_keras_predict
-                ({independent_varname},
-                 $MAD${model_arch}$MAD$,
-                 $1,ARRAY{input_shape},
-                 {compile_params},
-                 ARRAY{class_values}::TEXT[])
-            )[1] as prediction
-        from {test_table}""".format(**locals()), ["bytea"])
-    plpy.execute(predict_query, [model_data])
+        SELECT {id_col}, {prediction_select_clause}
+        FROM (
+            SELECT {test_table}.{id_col},
+                   ({schema_madlib}.internal_keras_predict
+                       ({independent_varname},
+                        $MAD${model_arch}$MAD$,
+                        {0},
+                        ARRAY{input_shape},
+                        {compile_params},
+                        {is_response},
+                        {normalizing_const})
+                   ) AS {intermediate_col}
+        FROM {test_table}, {model_table}
+        ) q
+        """.format(MODEL_DATA_CNAME, **locals()))
 
 def internal_keras_predict(x_test, model_arch, model_data, input_shape,
-                           compile_params, class_values):
+                           compile_params, is_response, normalizing_const):
     model = model_from_json(model_arch)
     device_name = '/cpu:0'
     os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
     model_shapes = madlib_keras_serializer.get_model_shapes(model)
     compile_and_set_weights(model, compile_params, device_name,
                             model_data, model_shapes)
-
-    x_test = np.array(x_test).reshape(1, *input_shape)
-    x_test /= 255
-    proba_argmax = model.predict_classes(x_test)
-    # proba_argmax is a list with exactly one element in it. That element
-    # refers to the index containing the largest probability value in the
-    # output of Keras' predict function.
-    return _get_class_label(class_values, proba_argmax[0])
-
-def _get_class_label(class_values, class_index):
-    """
-    Returns back the class label associated with the index returned by Keras'
-    predict_classes function. Keras' predict_classes function returns back
-    the index of the 1-hot encoded output that has the highest probability
-    value. We should infer the exact class label corresponding to the index
-    by looking at the class_values list (which is obtained from the
-    class_values column of the model summary table). If class_values is None,
-    we return the index as is.
-    Args:
-        @param class_values: list of class labels.
-        @param class_index: integer representing the index with max
-                            probability value.
-    Returns:
-        scalar. If class_values is None, returns class_index, else returns
-        class_values[class_index].
-    """
-    if not class_values:
-        return class_index
-    elif class_index != int(class_index):
-        plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
-            " Index value must be an integer".format(MODULE_NAME, class_index))
-    elif class_index < 0 or class_index >= len(class_values):
-        plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
-            " Index value must be less than {2}".format(
-                MODULE_NAME, class_index, len(class_values)))
+    # Since the test data isn't mini-batched,
+    # we have to make sure that the test data np array has the same
+    # number of dimensions as input_shape. So we a dimension to x.
 
 Review comment:
   Missing the word `add`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to