kaknikhil commented on a change in pull request #360: Deep Learning: Add
support for one-hot encoded dep var
URL: https://github.com/apache/madlib/pull/360#discussion_r270604537
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
##########
@@ -75,9 +89,24 @@ def internal_keras_predict(x_test, model_arch, model_data,
input_shape, compile_
model_shapes = []
for weight_arr in model.get_weights():
model_shapes.append(weight_arr.shape)
- _,_,_, model_weights = deserialize_weights(model_data, model_shapes)
+ _,_,_, model_weights = KerasWeightsSerializer.deserialize_weights(
+ model_data, model_shapes)
model.set_weights(model_weights)
x_test = np.array(x_test).reshape(1, *input_shape)
x_test /= 255
- res = model.predict_classes(x_test)
- return res
+ proba_argmax = model.predict_classes(x_test)
+ # proba_argmax is a list with exactly one element in it. That element
+ # refers to the index containing the largest probability value in the
+ # output of Keras' predict function.
+ return _get_class_label(class_values, proba_argmax[0])
+
+def _get_class_label(class_values, class_index):
+ if class_values:
Review comment:
Might be a good idea to add docstring for this function. We should also add
a comment explaining the sorting order of class_values
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services