This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 9933c1dd9d70bc9a48d98d3a1aa2cbcfd0beabc9
Author: Nikhil Kak <n...@pivotal.io>
AuthorDate: Mon Jun 3 12:11:20 2019 -0700

    DL: Add dev check test for loading model weights
    
    This commit adds a dev check test for calling the load_keras_model UDF
    from within a python UDF.
---
 .../test/keras_model_arch_table.sql_in             | 29 ++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git 
a/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in 
b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in
index e3d589c..1f2009b 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in
@@ -123,3 +123,32 @@ SELECT assert(name IS NULL AND description IS NULL, 'Name 
or description is not
 FROM test_keras_model_arch_table WHERE model_id = 1;
 SELECT assert(name = 'my name' AND description = 'my desc', 'Incorrect name or 
description in the model arch table.')
 FROM test_keras_model_arch_table WHERE model_id = 2;
+
+
+--------------------------- Test calling the UDF from python 
---------------------------------
+CREATE OR REPLACE FUNCTION create_model_arch_transfer_learning() RETURNS VOID 
AS $$
+from keras.layers import *
+from keras import Sequential
+import numpy as np
+import plpy
+
+model = Sequential()
+model.add(Conv2D(1, kernel_size=(1, 1), activation='relu', 
input_shape=(1,1,1,)))
+weights = model.get_weights()
+weights_flat = [ w.flatten() for w in weights ]
+weights1d = np.array([j for sub in weights_flat for j in sub])
+weights1d = np.ones_like(weights1d)
+weights_bytea = weights1d.tostring()
+
+load_query = plpy.prepare("""SELECT load_keras_model(
+                        'test_keras_model_arch_table',
+                        $1, $2)
+                    """, ['json','bytea'])
+plpy.execute(load_query, [model.to_json(), weights_bytea])
+$$ LANGUAGE plpythonu VOLATILE;
+
+DROP TABLE IF EXISTS test_keras_model_arch_table;
+SELECT create_model_arch_transfer_learning();
+
+select assert(model_weights = '\000\000\200?\000\000\200?', 'loading weights 
from udf failed')
+from test_keras_model_arch_table;

Reply via email to