[ 
https://issues.apache.org/jira/browse/MADLIB-1357?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Frank McQuillan closed MADLIB-1357.
-----------------------------------
    Resolution: Fixed

> Predict for deep learning not handling NULL class values properly
> -----------------------------------------------------------------
>
>                 Key: MADLIB-1357
>                 URL: https://issues.apache.org/jira/browse/MADLIB-1357
>             Project: Apache MADlib
>          Issue Type: Bug
>          Components: Module: Neural Networks
>            Reporter: Frank McQuillan
>            Assignee: Nandish Jayaram
>            Priority: Major
>             Fix For: v1.16
>
>
> If you set a class to NULL, predict does not seem to work.
> See below.
> Can you also check "NULL" and 'NULL' and fix them too
> if there is a problem.
> Load iris data set with NULL for one of the classes:
> {code}
> DROP TABLE IF EXISTS iris_data;
> CREATE TABLE iris_data(
>     id serial,
>     attributes numeric[],
>     class_text varchar
> );
> INSERT INTO iris_data(id, attributes, class_text) VALUES
> (1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),
> (2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),
> (3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),
> (4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),
> (5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),
> (6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),
> (7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),
> (8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),
> (9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),
> (10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
> (11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),
> (12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),
> (13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),
> (14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),
> (15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),
> (16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),
> (17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),
> (18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),
> (19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),
> (20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),
> (21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),
> (22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),
> (23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),
> (24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),
> (25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),
> (26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),
> (27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),
> (28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),
> (29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),
> (30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),
> (31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),
> (32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),
> (33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),
> (34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),
> (35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
> (36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),
> (37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),
> (38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
> (39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),
> (40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),
> (41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),
> (42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),
> (43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),
> (44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),
> (45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),
> (46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),
> (47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),
> (48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),
> (49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),
> (50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),
> (51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),
> (52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),
> (53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),
> (54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),
> (55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),
> (56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),
> (57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),
> (58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),
> (59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),
> (60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),
> (61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),
> (62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),
> (63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),
> (64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),
> (65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),
> (66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),
> (67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),
> (68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),
> (69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),
> (70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),
> (71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),
> (72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),
> (73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),
> (74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),
> (75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),
> (76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),
> (77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),
> (78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),
> (79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),
> (80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),
> (81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),
> (82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),
> (83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),
> (84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),
> (85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),
> (86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),
> (87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),
> (88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),
> (89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),
> (90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),
> (91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),
> (92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),
> (93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),
> (94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),
> (95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),
> (96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),
> (97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),
> (98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),
> (99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),
> (100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),
> (101,ARRAY[6.3,3.3,6.0,2.5],NULL),
> (102,ARRAY[5.8,2.7,5.1,1.9],NULL),
> (103,ARRAY[7.1,3.0,5.9,2.1],NULL),
> (104,ARRAY[6.3,2.9,5.6,1.8],NULL),
> (105,ARRAY[6.5,3.0,5.8,2.2],NULL),
> (106,ARRAY[7.6,3.0,6.6,2.1],NULL),
> (107,ARRAY[4.9,2.5,4.5,1.7],NULL),
> (108,ARRAY[7.3,2.9,6.3,1.8],NULL),
> (109,ARRAY[6.7,2.5,5.8,1.8],NULL),
> (110,ARRAY[7.2,3.6,6.1,2.5],NULL),
> (111,ARRAY[6.5,3.2,5.1,2.0],NULL),
> (112,ARRAY[6.4,2.7,5.3,1.9],NULL),
> (113,ARRAY[6.8,3.0,5.5,2.1],NULL),
> (114,ARRAY[5.7,2.5,5.0,2.0],NULL),
> (115,ARRAY[5.8,2.8,5.1,2.4],NULL),
> (116,ARRAY[6.4,3.2,5.3,2.3],NULL),
> (117,ARRAY[6.5,3.0,5.5,1.8],NULL),
> (118,ARRAY[7.7,3.8,6.7,2.2],NULL),
> (119,ARRAY[7.7,2.6,6.9,2.3],NULL),
> (120,ARRAY[6.0,2.2,5.0,1.5],NULL),
> (121,ARRAY[6.9,3.2,5.7,2.3],NULL),
> (122,ARRAY[5.6,2.8,4.9,2.0],NULL),
> (123,ARRAY[7.7,2.8,6.7,2.0],NULL),
> (124,ARRAY[6.3,2.7,4.9,1.8],NULL),
> (125,ARRAY[6.7,3.3,5.7,2.1],NULL),
> (126,ARRAY[7.2,3.2,6.0,1.8],NULL),
> (127,ARRAY[6.2,2.8,4.8,1.8],NULL),
> (128,ARRAY[6.1,3.0,4.9,1.8],NULL),
> (129,ARRAY[6.4,2.8,5.6,2.1],NULL),
> (130,ARRAY[7.2,3.0,5.8,1.6],NULL),
> (131,ARRAY[7.4,2.8,6.1,1.9],NULL),
> (132,ARRAY[7.9,3.8,6.4,2.0],NULL),
> (133,ARRAY[6.4,2.8,5.6,2.2],NULL),
> (134,ARRAY[6.3,2.8,5.1,1.5],NULL),
> (135,ARRAY[6.1,2.6,5.6,1.4],NULL),
> (136,ARRAY[7.7,3.0,6.1,2.3],NULL),
> (137,ARRAY[6.3,3.4,5.6,2.4],NULL),
> (138,ARRAY[6.4,3.1,5.5,1.8],NULL),
> (139,ARRAY[6.0,3.0,4.8,1.8],NULL),
> (140,ARRAY[6.9,3.1,5.4,2.1],NULL),
> (141,ARRAY[6.7,3.1,5.6,2.4],NULL),
> (142,ARRAY[6.9,3.1,5.1,2.3],NULL),
> (143,ARRAY[5.8,2.7,5.1,1.9],NULL),
> (144,ARRAY[6.8,3.2,5.9,2.3],NULL),
> (145,ARRAY[6.7,3.3,5.7,2.5],NULL),
> (146,ARRAY[6.7,3.0,5.2,2.3],NULL),
> (147,ARRAY[6.3,2.5,5.0,1.9],NULL),
> (148,ARRAY[6.5,3.0,5.2,2.0],NULL),
> (149,ARRAY[6.2,3.4,5.4,2.3],NULL),
> (150,ARRAY[5.9,3.0,5.1,1.8],NULL);
> SELECT * FROM iris_data ORDER BY id;
> {code}
> Create a test/validation dataset from the training data
> {code}
> DROP TABLE IF EXISTS iris_train, iris_test;
> -- Set seed so results are reproducible
> SELECT setseed(0);
> SELECT madlib.train_test_split('iris_data',     -- Source table
>                                'iris',          -- Output table root name
>                                 0.8,            -- Train proportion
>                                 NULL,           -- Test proportion (0.2)
>                                 NULL,           -- Strata definition
>                                 NULL,           -- Output all columns
>                                 NULL,           -- Sample without replacement
>                                 TRUE            -- Separate output tables
>                               );
> SELECT COUNT(*) FROM iris_train;
> {code}
> Preprocess
> {code}
> DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
> SELECT madlib.training_preprocessor_dl('iris_train',         -- Source table
>                                        'iris_train_packed',  -- Output table
>                                        'class_text',        -- Dependent 
> variable
>                                        'attributes'         -- Independent 
> variable
>                                         ); 
> DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary;
> SELECT madlib.validation_preprocessor_dl('iris_test',          -- Source table
>                                          'iris_test_packed',   -- Output table
>                                          'class_text',         -- Dependent 
> variable
>                                          'attributes',         -- Independent 
> variable
>                                          'iris_train_packed'   -- From 
> training preprocessor step
>                                           ); 
> {code}
> Load model arch
> {code}
> DROP TABLE IF EXISTS model_arch_library;
> SELECT madlib.load_keras_model('model_arch_library',  -- Output table,
>                                
> $$
> {"class_name": "Sequential", "keras_version": "2.1.6", "config": 
> [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
> "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, 
> "seed": null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": 
> null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", 
> "activation": "relu", "trainable": true, "kernel_regularizer": null, 
> "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, 
> "batch_input_shape": [null, 4], "use_bias": true, "activity_regularizer": 
> null}}, {"class_name": "Dense", "config": {"kernel_initializer": 
> {"class_name": "VarianceScaling", "config": {"distribution": "uniform", 
> "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_2", 
> "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, 
> "activation": "relu", "trainable": true, "kernel_regularizer": null, 
> "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, 
> "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", 
> "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": 
> {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, 
> "name": "dense_3", "kernel_constraint": null, "bias_regularizer": null, 
> "bias_constraint": null, "activation": "softmax", "trainable": true, 
> "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", 
> "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": null}}], 
> "backend": "tensorflow"}
> $$
> ::json,         -- JSON blob
>                                NULL,                  -- Weights
>                                'Sophie',              -- Name
>                                'A simple model'       -- Descr
> );
> SELECT * FROM model_arch_library;
> {code}
> Train
> {code}
> DROP TABLE IF EXISTS iris_model, iris_model_summary;
> SELECT madlib.madlib_keras_fit('iris_train_packed',   -- source table
>                                'iris_model',          -- model output table
>                                'model_arch_library',  -- model arch table
>                                 1,                    -- model arch id
>                                 $$ loss='categorical_crossentropy', 
> optimizer='adam', metrics=['accuracy'] $$,  -- compile_params
>                                 $$ batch_size=5, epochs=3 $$,  -- fit_params
>                                 10                    -- num_iterations
>                               );
> {code}
> Evalute
> {code}
> DROP TABLE IF EXISTS iris_validate;
> SELECT madlib.madlib_keras_evaluate('iris_model',       -- model
>                                    'iris_test_packed',  -- test table
>                                    'iris_validate'      -- output table
>                                    );
> SELECT * FROM iris_validate;
>        loss        |      metric       | metrics_type 
> -------------------+-------------------+--------------
>  0.589864671230316 | 0.600000023841858 | {accuracy}
> (1 row)
> {code}
> Predict
> {code}
> DROP TABLE IF EXISTS iris_predict;
> SELECT madlib.madlib_keras_predict('iris_model', -- model
>                                    'iris_test',  -- test_table
>                                    'id',  -- id column
>                                    'attributes', -- independent var
>                                    'iris_predict'  -- output table
>                                    );
> SELECT * FROM iris_predict ORDER BY id;
>  id  | estimated_class_text 
> -----+----------------------
>    1 | Iris-setosa
>    3 | Iris-setosa
>   10 | Iris-setosa
>   11 | Iris-setosa
>   18 | Iris-setosa
>   19 | Iris-setosa
>   34 | Iris-setosa
>   35 | Iris-setosa
>   37 | Iris-setosa
>   42 | Iris-setosa
>   48 | Iris-setosa
>   58 | 
>   60 | 
>   61 | 
>   63 | 
>   65 | 
>   73 | 
>   77 | 
>   81 | 
>   84 | 
>   85 | 
>   88 | 
>   99 | 
>  117 | 
>  123 | 
>  130 | 
>  134 | 
>  137 | 
>  141 | 
>  149 | 
> (30 rows)
> {code}
> This looks wrong ^^^ if you look at the test dataset:
> {code}
> select * from iris_test order by id;
>  id  |    attributes     |   class_text    
> -----+-------------------+-----------------
>    1 | {5.1,3.5,1.4,0.2} | Iris-setosa
>    3 | {4.7,3.2,1.3,0.2} | Iris-setosa
>   10 | {4.9,3.1,1.5,0.1} | Iris-setosa
>   11 | {5.4,3.7,1.5,0.2} | Iris-setosa
>   18 | {5.1,3.5,1.4,0.3} | Iris-setosa
>   19 | {5.7,3.8,1.7,0.3} | Iris-setosa
>   34 | {5.5,4.2,1.4,0.2} | Iris-setosa
>   35 | {4.9,3.1,1.5,0.1} | Iris-setosa
>   37 | {5.5,3.5,1.3,0.2} | Iris-setosa
>   42 | {4.5,2.3,1.3,0.3} | Iris-setosa
>   48 | {4.6,3.2,1.4,0.2} | Iris-setosa
>   58 | {4.9,2.4,3.3,1.0} | Iris-versicolor
>   60 | {5.2,2.7,3.9,1.4} | Iris-versicolor
>   61 | {5.0,2.0,3.5,1.0} | Iris-versicolor
>   63 | {6.0,2.2,4.0,1.0} | Iris-versicolor
>   65 | {5.6,2.9,3.6,1.3} | Iris-versicolor
>   73 | {6.3,2.5,4.9,1.5} | Iris-versicolor
>   77 | {6.8,2.8,4.8,1.4} | Iris-versicolor
>   81 | {5.5,2.4,3.8,1.1} | Iris-versicolor
>   84 | {6.0,2.7,5.1,1.6} | Iris-versicolor
>   85 | {5.4,3.0,4.5,1.5} | Iris-versicolor
>   88 | {6.3,2.3,4.4,1.3} | Iris-versicolor
>   99 | {5.1,2.5,3.0,1.1} | Iris-versicolor
>  117 | {6.5,3.0,5.5,1.8} | 
>  123 | {7.7,2.8,6.7,2.0} | 
>  130 | {7.2,3.0,5.8,1.6} | 
>  134 | {6.3,2.8,5.1,1.5} | 
>  137 | {6.3,3.4,5.6,2.4} | 
>  141 | {6.7,3.1,5.6,2.4} | 
>  149 | {6.2,3.4,5.4,2.3} | 
> {code}
> Percent error
> {code}
> SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from
>     (select iris_test.class_text as actual, iris_predict.estimated_class_text 
> as estimated
>      from iris_predict inner join iris_test
>      on iris_test.id=iris_predict.id) q
> WHERE q.actual=q.estimated;
>  test_accuracy_percent 
> -----------------------
>                  36.67
> (1 row)
> {code}
> Expected 60% accuracy.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to