[ https://issues.apache.org/jira/browse/MADLIB-1357?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Nandish Jayaram reassigned MADLIB-1357: --------------------------------------- Assignee: Nandish Jayaram > Predict for deep learning not handling NULL class values properly > ----------------------------------------------------------------- > > Key: MADLIB-1357 > URL: https://issues.apache.org/jira/browse/MADLIB-1357 > Project: Apache MADlib > Issue Type: Bug > Components: Module: Neural Networks > Reporter: Frank McQuillan > Assignee: Nandish Jayaram > Priority: Major > Fix For: v1.16 > > > If you set a class to NULL, predict does not seem to work. > See below. > Can you also check "NULL" and 'NULL' and fix them too > if there is a problem. > Load iris data set with NULL for one of the classes: > {code} > DROP TABLE IF EXISTS iris_data; > CREATE TABLE iris_data( > id serial, > attributes numeric[], > class_text varchar > ); > INSERT INTO iris_data(id, attributes, class_text) VALUES > (1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'), > (2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'), > (3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'), > (4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'), > (5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'), > (6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'), > (7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'), > (8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'), > (9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'), > (10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'), > (11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'), > (12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'), > (13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'), > (14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'), > (15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'), > (16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'), > (17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'), > (18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'), > (19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'), > (20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'), > (21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'), > (22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'), > (23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'), > (24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'), > (25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'), > (26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'), > (27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'), > (28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'), > (29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'), > (30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'), > (31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'), > (32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'), > (33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'), > (34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'), > (35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'), > (36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'), > (37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'), > (38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'), > (39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'), > (40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'), > (41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'), > (42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'), > (43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'), > (44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'), > (45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'), > (46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'), > (47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'), > (48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'), > (49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'), > (50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'), > (51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'), > (52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'), > (53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'), > (54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'), > (55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'), > (56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'), > (57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'), > (58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'), > (59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'), > (60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'), > (61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'), > (62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'), > (63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'), > (64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'), > (65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'), > (66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'), > (67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'), > (68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'), > (69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'), > (70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'), > (71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'), > (72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'), > (73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'), > (74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'), > (75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'), > (76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'), > (77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'), > (78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'), > (79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'), > (80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'), > (81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'), > (82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'), > (83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'), > (84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'), > (85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'), > (86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'), > (87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'), > (88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'), > (89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'), > (90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'), > (91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'), > (92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'), > (93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'), > (94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'), > (95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'), > (96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'), > (97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'), > (98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'), > (99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'), > (100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'), > (101,ARRAY[6.3,3.3,6.0,2.5],NULL), > (102,ARRAY[5.8,2.7,5.1,1.9],NULL), > (103,ARRAY[7.1,3.0,5.9,2.1],NULL), > (104,ARRAY[6.3,2.9,5.6,1.8],NULL), > (105,ARRAY[6.5,3.0,5.8,2.2],NULL), > (106,ARRAY[7.6,3.0,6.6,2.1],NULL), > (107,ARRAY[4.9,2.5,4.5,1.7],NULL), > (108,ARRAY[7.3,2.9,6.3,1.8],NULL), > (109,ARRAY[6.7,2.5,5.8,1.8],NULL), > (110,ARRAY[7.2,3.6,6.1,2.5],NULL), > (111,ARRAY[6.5,3.2,5.1,2.0],NULL), > (112,ARRAY[6.4,2.7,5.3,1.9],NULL), > (113,ARRAY[6.8,3.0,5.5,2.1],NULL), > (114,ARRAY[5.7,2.5,5.0,2.0],NULL), > (115,ARRAY[5.8,2.8,5.1,2.4],NULL), > (116,ARRAY[6.4,3.2,5.3,2.3],NULL), > (117,ARRAY[6.5,3.0,5.5,1.8],NULL), > (118,ARRAY[7.7,3.8,6.7,2.2],NULL), > (119,ARRAY[7.7,2.6,6.9,2.3],NULL), > (120,ARRAY[6.0,2.2,5.0,1.5],NULL), > (121,ARRAY[6.9,3.2,5.7,2.3],NULL), > (122,ARRAY[5.6,2.8,4.9,2.0],NULL), > (123,ARRAY[7.7,2.8,6.7,2.0],NULL), > (124,ARRAY[6.3,2.7,4.9,1.8],NULL), > (125,ARRAY[6.7,3.3,5.7,2.1],NULL), > (126,ARRAY[7.2,3.2,6.0,1.8],NULL), > (127,ARRAY[6.2,2.8,4.8,1.8],NULL), > (128,ARRAY[6.1,3.0,4.9,1.8],NULL), > (129,ARRAY[6.4,2.8,5.6,2.1],NULL), > (130,ARRAY[7.2,3.0,5.8,1.6],NULL), > (131,ARRAY[7.4,2.8,6.1,1.9],NULL), > (132,ARRAY[7.9,3.8,6.4,2.0],NULL), > (133,ARRAY[6.4,2.8,5.6,2.2],NULL), > (134,ARRAY[6.3,2.8,5.1,1.5],NULL), > (135,ARRAY[6.1,2.6,5.6,1.4],NULL), > (136,ARRAY[7.7,3.0,6.1,2.3],NULL), > (137,ARRAY[6.3,3.4,5.6,2.4],NULL), > (138,ARRAY[6.4,3.1,5.5,1.8],NULL), > (139,ARRAY[6.0,3.0,4.8,1.8],NULL), > (140,ARRAY[6.9,3.1,5.4,2.1],NULL), > (141,ARRAY[6.7,3.1,5.6,2.4],NULL), > (142,ARRAY[6.9,3.1,5.1,2.3],NULL), > (143,ARRAY[5.8,2.7,5.1,1.9],NULL), > (144,ARRAY[6.8,3.2,5.9,2.3],NULL), > (145,ARRAY[6.7,3.3,5.7,2.5],NULL), > (146,ARRAY[6.7,3.0,5.2,2.3],NULL), > (147,ARRAY[6.3,2.5,5.0,1.9],NULL), > (148,ARRAY[6.5,3.0,5.2,2.0],NULL), > (149,ARRAY[6.2,3.4,5.4,2.3],NULL), > (150,ARRAY[5.9,3.0,5.1,1.8],NULL); > SELECT * FROM iris_data ORDER BY id; > {code} > Create a test/validation dataset from the training data > {code} > DROP TABLE IF EXISTS iris_train, iris_test; > -- Set seed so results are reproducible > SELECT setseed(0); > SELECT madlib.train_test_split('iris_data', -- Source table > 'iris', -- Output table root name > 0.8, -- Train proportion > NULL, -- Test proportion (0.2) > NULL, -- Strata definition > NULL, -- Output all columns > NULL, -- Sample without replacement > TRUE -- Separate output tables > ); > SELECT COUNT(*) FROM iris_train; > {code} > Preprocess > {code} > DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary; > SELECT madlib.training_preprocessor_dl('iris_train', -- Source table > 'iris_train_packed', -- Output table > 'class_text', -- Dependent > variable > 'attributes' -- Independent > variable > ); > DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary; > SELECT madlib.validation_preprocessor_dl('iris_test', -- Source table > 'iris_test_packed', -- Output table > 'class_text', -- Dependent > variable > 'attributes', -- Independent > variable > 'iris_train_packed' -- From > training preprocessor step > ); > {code} > Load model arch > {code} > DROP TABLE IF EXISTS model_arch_library; > SELECT madlib.load_keras_model('model_arch_library', -- Output table, > > $$ > {"class_name": "Sequential", "keras_version": "2.1.6", "config": > [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": > "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, > "seed": null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": > null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", > "activation": "relu", "trainable": true, "kernel_regularizer": null, > "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, > "batch_input_shape": [null, 4], "use_bias": true, "activity_regularizer": > null}}, {"class_name": "Dense", "config": {"kernel_initializer": > {"class_name": "VarianceScaling", "config": {"distribution": "uniform", > "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_2", > "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, > "activation": "relu", "trainable": true, "kernel_regularizer": null, > "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, > "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", > "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": > {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, > "name": "dense_3", "kernel_constraint": null, "bias_regularizer": null, > "bias_constraint": null, "activation": "softmax", "trainable": true, > "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", > "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": null}}], > "backend": "tensorflow"} > $$ > ::json, -- JSON blob > NULL, -- Weights > 'Sophie', -- Name > 'A simple model' -- Descr > ); > SELECT * FROM model_arch_library; > {code} > Train > {code} > DROP TABLE IF EXISTS iris_model, iris_model_summary; > SELECT madlib.madlib_keras_fit('iris_train_packed', -- source table > 'iris_model', -- model output table > 'model_arch_library', -- model arch table > 1, -- model arch id > $$ loss='categorical_crossentropy', > optimizer='adam', metrics=['accuracy'] $$, -- compile_params > $$ batch_size=5, epochs=3 $$, -- fit_params > 10 -- num_iterations > ); > {code} > Evalute > {code} > DROP TABLE IF EXISTS iris_validate; > SELECT madlib.madlib_keras_evaluate('iris_model', -- model > 'iris_test_packed', -- test table > 'iris_validate' -- output table > ); > SELECT * FROM iris_validate; > loss | metric | metrics_type > -------------------+-------------------+-------------- > 0.589864671230316 | 0.600000023841858 | {accuracy} > (1 row) > {code} > Predict > {code} > DROP TABLE IF EXISTS iris_predict; > SELECT madlib.madlib_keras_predict('iris_model', -- model > 'iris_test', -- test_table > 'id', -- id column > 'attributes', -- independent var > 'iris_predict' -- output table > ); > SELECT * FROM iris_predict ORDER BY id; > id | estimated_class_text > -----+---------------------- > 1 | Iris-setosa > 3 | Iris-setosa > 10 | Iris-setosa > 11 | Iris-setosa > 18 | Iris-setosa > 19 | Iris-setosa > 34 | Iris-setosa > 35 | Iris-setosa > 37 | Iris-setosa > 42 | Iris-setosa > 48 | Iris-setosa > 58 | > 60 | > 61 | > 63 | > 65 | > 73 | > 77 | > 81 | > 84 | > 85 | > 88 | > 99 | > 117 | > 123 | > 130 | > 134 | > 137 | > 141 | > 149 | > (30 rows) > {code} > This looks wrong ^^^ if you look at the test dataset: > {code} > select * from iris_test order by id; > id | attributes | class_text > -----+-------------------+----------------- > 1 | {5.1,3.5,1.4,0.2} | Iris-setosa > 3 | {4.7,3.2,1.3,0.2} | Iris-setosa > 10 | {4.9,3.1,1.5,0.1} | Iris-setosa > 11 | {5.4,3.7,1.5,0.2} | Iris-setosa > 18 | {5.1,3.5,1.4,0.3} | Iris-setosa > 19 | {5.7,3.8,1.7,0.3} | Iris-setosa > 34 | {5.5,4.2,1.4,0.2} | Iris-setosa > 35 | {4.9,3.1,1.5,0.1} | Iris-setosa > 37 | {5.5,3.5,1.3,0.2} | Iris-setosa > 42 | {4.5,2.3,1.3,0.3} | Iris-setosa > 48 | {4.6,3.2,1.4,0.2} | Iris-setosa > 58 | {4.9,2.4,3.3,1.0} | Iris-versicolor > 60 | {5.2,2.7,3.9,1.4} | Iris-versicolor > 61 | {5.0,2.0,3.5,1.0} | Iris-versicolor > 63 | {6.0,2.2,4.0,1.0} | Iris-versicolor > 65 | {5.6,2.9,3.6,1.3} | Iris-versicolor > 73 | {6.3,2.5,4.9,1.5} | Iris-versicolor > 77 | {6.8,2.8,4.8,1.4} | Iris-versicolor > 81 | {5.5,2.4,3.8,1.1} | Iris-versicolor > 84 | {6.0,2.7,5.1,1.6} | Iris-versicolor > 85 | {5.4,3.0,4.5,1.5} | Iris-versicolor > 88 | {6.3,2.3,4.4,1.3} | Iris-versicolor > 99 | {5.1,2.5,3.0,1.1} | Iris-versicolor > 117 | {6.5,3.0,5.5,1.8} | > 123 | {7.7,2.8,6.7,2.0} | > 130 | {7.2,3.0,5.8,1.6} | > 134 | {6.3,2.8,5.1,1.5} | > 137 | {6.3,3.4,5.6,2.4} | > 141 | {6.7,3.1,5.6,2.4} | > 149 | {6.2,3.4,5.4,2.3} | > {code} > Percent error > {code} > SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from > (select iris_test.class_text as actual, iris_predict.estimated_class_text > as estimated > from iris_predict inner join iris_test > on iris_test.id=iris_predict.id) q > WHERE q.actual=q.estimated; > test_accuracy_percent > ----------------------- > 36.67 > (1 row) > {code} > Expected 60% accuracy. -- This message was sent by Atlassian JIRA (v7.6.3#76005)