njayaram2 commented on a change in pull request #363: Add tests for
madlib_keras_predict
URL: https://github.com/apache/madlib/pull/363#discussion_r271967957
##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
##########
@@ -173,3 +172,36 @@ SELECT madlib_keras_predict(
'x',
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True),
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
'cifar10_predict');
+
+-- Validate that prediction output table exists and has correct schema
+SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be
INTEGER type')
+ FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+ AND attname = 'id';
+
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+ 'DOUBLE PRECISION', 'prediction column should be DOUBLE PRECISION type')
+ FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+ AND attname = 'prediction';
+
+-- Validate correct number of rows returned.
+SELECT assert(COUNT(*) = 2, "Output table of madlib_keras_predict should have
two rows") FROM cifar10_predict;
+
+-- First test that all values are in set of class values; if this breaks, it's
definitely a problem.
+SELECT assert(prediction in (0,1),'Predicted value not in set of defined class
values for model') FROM cifar10_predict;
+
+-- Then test that each of the two images is correctly predicted. If this
breaks, it's likely a different problem, or
+ -- possibly due to the test model we chose not being trained well enough.
Based on current testing, it looks
+ -- reliable.
+SELECT assert(prediction=0,'Predicted value not in set of defined class values
for model') FROM cifar10_predict WHERE id=1;
+SELECT assert(prediction=1,'Predicted value not in set of defined class values
for model') FROM cifar10_predict WHERE id=2;
Review comment:
Might not be a great idea to assert for exact prediction values. The model
learnt in this test scenario is most likely something random, and we may end up
seeing different results for different runs. We can instead assert that it is
one of 0 or 1, like in the test in line 190.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services