kaknikhil commented on a change in pull request #360: Deep Learning: Add
support for one-hot encoded dep var
URL: https://github.com/apache/madlib/pull/360#discussion_r270604601
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
##########
@@ -75,9 +89,24 @@ def internal_keras_predict(x_test, model_arch, model_data,
input_shape, compile_
model_shapes = []
for weight_arr in model.get_weights():
model_shapes.append(weight_arr.shape)
- _,_,_, model_weights = deserialize_weights(model_data, model_shapes)
+ _,_,_, model_weights = KerasWeightsSerializer.deserialize_weights(
+ model_data, model_shapes)
model.set_weights(model_weights)
x_test = np.array(x_test).reshape(1, *input_shape)
x_test /= 255
- res = model.predict_classes(x_test)
- return res
+ proba_argmax = model.predict_classes(x_test)
+ # proba_argmax is a list with exactly one element in it. That element
+ # refers to the index containing the largest probability value in the
+ # output of Keras' predict function.
+ return _get_class_label(class_values, proba_argmax[0])
+
+def _get_class_label(class_values, class_index):
+ if class_values:
Review comment:
Also maybe consider adding unit tests for this function
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services