This is an automated email from the ASF dual-hosted git repository. nkak pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 9933c1dd9d70bc9a48d98d3a1aa2cbcfd0beabc9 Author: Nikhil Kak <n...@pivotal.io> AuthorDate: Mon Jun 3 12:11:20 2019 -0700 DL: Add dev check test for loading model weights This commit adds a dev check test for calling the load_keras_model UDF from within a python UDF. --- .../test/keras_model_arch_table.sql_in | 29 ++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in index e3d589c..1f2009b 100644 --- a/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in @@ -123,3 +123,32 @@ SELECT assert(name IS NULL AND description IS NULL, 'Name or description is not FROM test_keras_model_arch_table WHERE model_id = 1; SELECT assert(name = 'my name' AND description = 'my desc', 'Incorrect name or description in the model arch table.') FROM test_keras_model_arch_table WHERE model_id = 2; + + +--------------------------- Test calling the UDF from python --------------------------------- +CREATE OR REPLACE FUNCTION create_model_arch_transfer_learning() RETURNS VOID AS $$ +from keras.layers import * +from keras import Sequential +import numpy as np +import plpy + +model = Sequential() +model.add(Conv2D(1, kernel_size=(1, 1), activation='relu', input_shape=(1,1,1,))) +weights = model.get_weights() +weights_flat = [ w.flatten() for w in weights ] +weights1d = np.array([j for sub in weights_flat for j in sub]) +weights1d = np.ones_like(weights1d) +weights_bytea = weights1d.tostring() + +load_query = plpy.prepare("""SELECT load_keras_model( + 'test_keras_model_arch_table', + $1, $2) + """, ['json','bytea']) +plpy.execute(load_query, [model.to_json(), weights_bytea]) +$$ LANGUAGE plpythonu VOLATILE; + +DROP TABLE IF EXISTS test_keras_model_arch_table; +SELECT create_model_arch_transfer_learning(); + +select assert(model_weights = '\000\000\200?\000\000\200?', 'loading weights from udf failed') +from test_keras_model_arch_table;