[
https://issues.apache.org/jira/browse/MADLIB-1222?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16426112#comment-16426112
]
Frank McQuillan commented on MADLIB-1222:
-----------------------------------------
For minibatch this seems to work OK. e.g., continuing modified version user
docs example from above:
{code:sql}
DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_standardization,
iris_data_packed_summary;
SELECT madlib.minibatch_preprocessor(
'iris_data',
'iris_data_packed',
'class_integer',
'attributes',
10
);
{code}
{code:sql}
DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization;
-- Set seed so results are reproducible
SELECT setseed(0);
SELECT madlib.mlp_classification(
'iris_data_packed', -- Source table
'mlp_model', -- Destination table
'independent_varname', -- Input features
'dependent_varname', -- Label
ARRAY[5], -- Number of units per layer
'learning_rate_init=0.003,
n_iterations=500,
tolerance=0', -- Optimizer params
'tanh', -- Activation function
NULL, -- Default weight (1)
FALSE, -- No warm start
FALSE -- Not verbose
);
{code}
{code:sql}
DROP TABLE IF EXISTS mlp_prediction;
SELECT madlib.mlp_predict(
'mlp_model', -- Model table
'iris_data', -- Test data table
'id', -- Id column in test table
'mlp_prediction', -- Output table for predictions
'response' -- Output classes, not probabilities
);
SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
{code}
produces
{code}
id | estimated_class_integer | attributes | class_integer | class |
state
----+-------------------------+-------------------+---------------+-------+-----------
1 | {1,0} | {5.0,3.2,1.2,0.2} | {1,0} | 1 |
Alaska
2 | {1,0} | {5.5,3.5,1.3,0.2} | {1,0} | 1 |
Alaska
3 | {1,0} | {4.9,3.1,1.5,0.1} | {1,0} | 1 |
Alaska
4 | {1,0} | {4.4,3.0,1.3,0.2} | {1,0} | 1 |
Alaska
5 | {1,0} | {5.1,3.4,1.5,0.2} | {1,0} | 1 |
Alaska
6 | {1,0} | {5.0,3.5,1.3,0.3} | {1,0} | 1 |
Alaska
7 | {1,0} | {4.5,2.3,1.3,0.3} | {1,0} | 1 |
Alaska
8 | {1,0} | {4.4,3.2,1.3,0.2} | {1,0} | 1 |
Alaska
9 | {1,0} | {5.0,3.5,1.6,0.6} | {1,0} | 1 |
Alaska
10 | {1,0} | {5.1,3.8,1.9,0.4} | {1,0} | 1 |
Alaska
11 | {1,0} | {4.8,3.0,1.4,0.3} | {1,0} | 1 |
Alaska
12 | {1,0} | {5.1,3.8,1.6,0.2} | {1,0} | 1 |
Alaska
13 | {0,1} | {5.7,2.8,4.5,1.3} | {0,1} | 2 |
Alaska
14 | {0,1} | {6.3,3.3,4.7,1.6} | {0,1} | 2 |
Alaska
15 | {0,1} | {4.9,2.4,3.3,1.0} | {0,1} | 2 |
Alaska
16 | {0,1} | {6.6,2.9,4.6,1.3} | {0,1} | 2 |
Alaska
17 | {0,1} | {5.2,2.7,3.9,1.4} | {0,1} | 2 |
Alaska
18 | {0,1} | {5.0,2.0,3.5,1.0} | {0,1} | 2 |
Alaska
19 | {0,1} | {5.9,3.0,4.2,1.5} | {0,1} | 2 |
Alaska
20 | {0,1} | {6.0,2.2,4.0,1.0} | {0,1} | 2 |
Alaska
21 | {0,1} | {6.1,2.9,4.7,1.4} | {0,1} | 2 |
Alaska
22 | {0,1} | {5.6,2.9,3.6,1.3} | {0,1} | 2 |
Alaska
23 | {0,1} | {6.7,3.1,4.4,1.4} | {0,1} | 2 |
Alaska
24 | {0,1} | {5.6,3.0,4.5,1.5} | {0,1} | 2 |
Alaska
25 | {0,1} | {5.8,2.7,4.1,1.0} | {0,1} | 2 |
Alaska
26 | {0,1} | {6.2,2.2,4.5,1.5} | {0,1} | 2 |
Alaska
27 | {0,1} | {5.6,2.5,3.9,1.1} | {0,1} | 2 |
Alaska
28 | {1,0} | {5.0,3.4,1.5,0.2} | {1,0} | 1 |
Tennessee
29 | {1,0} | {4.4,2.9,1.4,0.2} | {1,0} | 1 |
Tennessee
30 | {1,0} | {4.9,3.1,1.5,0.1} | {1,0} | 1 |
Tennessee
31 | {1,0} | {5.4,3.7,1.5,0.2} | {1,0} | 1 |
Tennessee
32 | {1,0} | {4.8,3.4,1.6,0.2} | {1,0} | 1 |
Tennessee
33 | {1,0} | {4.8,3.0,1.4,0.1} | {1,0} | 1 |
Tennessee
34 | {1,0} | {4.3,3.0,1.1,0.1} | {1,0} | 1 |
Tennessee
35 | {1,0} | {5.8,4.0,1.2,0.2} | {1,0} | 1 |
Tennessee
36 | {1,0} | {5.7,4.4,1.5,0.4} | {1,0} | 1 |
Tennessee
37 | {1,0} | {5.4,3.9,1.3,0.4} | {1,0} | 1 |
Tennessee
38 | {0,1} | {6.0,2.9,4.5,1.5} | {0,1} | 2 |
Tennessee
39 | {0,1} | {5.7,2.6,3.5,1.0} | {0,1} | 2 |
Tennessee
40 | {0,1} | {5.5,2.4,3.8,1.1} | {0,1} | 2 |
Tennessee
41 | {0,1} | {5.5,2.4,3.7,1.0} | {0,1} | 2 |
Tennessee
42 | {0,1} | {5.8,2.7,3.9,1.2} | {0,1} | 2 |
Tennessee
43 | {0,1} | {6.0,2.7,5.1,1.6} | {0,1} | 2 |
Tennessee
44 | {0,1} | {5.4,3.0,4.5,1.5} | {0,1} | 2 |
Tennessee
45 | {0,1} | {6.0,3.4,4.5,1.6} | {0,1} | 2 |
Tennessee
46 | {0,1} | {6.7,3.1,4.7,1.5} | {0,1} | 2 |
Tennessee
47 | {0,1} | {6.3,2.3,4.4,1.3} | {0,1} | 2 |
Tennessee
48 | {0,1} | {5.6,3.0,4.1,1.3} | {0,1} | 2 |
Tennessee
49 | {0,1} | {5.5,2.5,4.0,1.3} | {0,1} | 2 |
Tennessee
50 | {0,1} | {5.5,2.6,4.4,1.2} | {0,1} | 2 |
Tennessee
51 | {0,1} | {6.1,3.0,4.6,1.4} | {0,1} | 2 |
Tennessee
52 | {0,1} | {5.8,2.6,4.0,1.2} | {0,1} | 2 |
Tennessee
(52 rows)
{code}
{code:sql}
DROP TABLE IF EXISTS mlp_prediction;
SELECT madlib.mlp_predict(
'mlp_model', -- Model table
'iris_data', -- Test data table
'id', -- Id column in test table
'mlp_prediction', -- Output table for predictions
'prob' -- Output classes, not probabilities
);
SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
{code}
produces
{code}
id | estimated_prob | attributes |
class_integer | class | state
----+----------------------------------------+-------------------+---------------+-------+-----------
1 | {0.930759252672095,0.069240747327905} | {5.0,3.2,1.2,0.2} | {1,0}
| 1 | Alaska
2 | {0.929395372110727,0.0706046278892731} | {5.5,3.5,1.3,0.2} | {1,0}
| 1 | Alaska
3 | {0.92275296493747,0.0772470350625298} | {4.9,3.1,1.5,0.1} | {1,0}
| 1 | Alaska
4 | {0.92923853862346,0.0707614613765397} | {4.4,3.0,1.3,0.2} | {1,0}
| 1 | Alaska
5 | {0.930203943536138,0.0697960564638618} | {5.1,3.4,1.5,0.2} | {1,0}
| 1 | Alaska
6 | {0.937097480813401,0.062902519186599} | {5.0,3.5,1.3,0.3} | {1,0}
| 1 | Alaska
7 | {0.809864020154205,0.190135979845795} | {4.5,2.3,1.3,0.3} | {1,0}
| 1 | Alaska
8 | {0.938492444302248,0.0615075556977523} | {4.4,3.2,1.3,0.2} | {1,0}
| 1 | Alaska
9 | {0.909421618572682,0.090578381427318} | {5.0,3.5,1.6,0.6} | {1,0}
| 1 | Alaska
10 | {0.927170837453955,0.0728291625460452} | {5.1,3.8,1.9,0.4} | {1,0}
| 1 | Alaska
11 | {0.907769148253907,0.0922308517460933} | {4.8,3.0,1.4,0.3} | {1,0}
| 1 | Alaska
12 | {0.943518017066475,0.0564819829335253} | {5.1,3.8,1.6,0.2} | {1,0}
| 1 | Alaska
13 | {0.0529094443610184,0.947090555638982} | {5.7,2.8,4.5,1.3} | {0,1}
| 2 | Alaska
14 | {0.0529742392448023,0.947025760755198} | {6.3,3.3,4.7,1.6} | {0,1}
| 2 | Alaska
15 | {0.154232916835593,0.845767083164407} | {4.9,2.4,3.3,1.0} | {0,1}
| 2 | Alaska
16 | {0.0432082742886866,0.956791725711313} | {6.6,2.9,4.6,1.3} | {0,1}
| 2 | Alaska
17 | {0.0848279782808559,0.915172021719144} | {5.2,2.7,3.9,1.4} | {0,1}
| 2 | Alaska
18 | {0.0757044751883623,0.924295524811638} | {5.0,2.0,3.5,1.0} | {0,1}
| 2 | Alaska
19 | {0.0611931643454561,0.938806835654544} | {5.9,3.0,4.2,1.5} | {0,1}
| 2 | Alaska
20 | {0.0449649419417731,0.955035058058227} | {6.0,2.2,4.0,1.0} | {0,1}
| 2 | Alaska
21 | {0.0430757587622325,0.956924241237768} | {6.1,2.9,4.7,1.4} | {0,1}
| 2 | Alaska
22 | {0.111330143272174,0.888669856727826} | {5.6,2.9,3.6,1.3} | {0,1}
| 2 | Alaska
23 | {0.0517875328297457,0.948212467170254} | {6.7,3.1,4.4,1.4} | {0,1}
| 2 | Alaska
24 | {0.0610712779633371,0.938928722036663} | {5.6,3.0,4.5,1.5} | {0,1}
| 2 | Alaska
25 | {0.0697058912971787,0.930294108702821} | {5.8,2.7,4.1,1.0} | {0,1}
| 2 | Alaska
26 | {0.0300465449544714,0.969953455045529} | {6.2,2.2,4.5,1.5} | {0,1}
| 2 | Alaska
27 | {0.0641965800166526,0.935803419983347} | {5.6,2.5,3.9,1.1} | {0,1}
| 2 | Alaska
28 | {0.932679530162975,0.0673204698370254} | {5.0,3.4,1.5,0.2} | {1,0}
| 1 | Tennessee
29 | {0.91913460018541,0.0808653998145895} | {4.4,2.9,1.4,0.2} | {1,0}
| 1 | Tennessee
30 | {0.92275296493747,0.0772470350625298} | {4.9,3.1,1.5,0.1} | {1,0}
| 1 | Tennessee
31 | {0.936685220371634,0.0633147796283663} | {5.4,3.7,1.5,0.2} | {1,0}
| 1 | Tennessee
32 | {0.934032404740506,0.0659675952594938} | {4.8,3.4,1.6,0.2} | {1,0}
| 1 | Tennessee
33 | {0.922226922202426,0.0777730777975738} | {4.8,3.0,1.4,0.1} | {1,0}
| 1 | Tennessee
34 | {0.94042684548622,0.0595731545137805} | {4.3,3.0,1.1,0.1} | {1,0}
| 1 | Tennessee
35 | {0.943820498537346,0.0561795014626537} | {5.8,4.0,1.2,0.2} | {1,0}
| 1 | Tennessee
36 | {0.942322282886469,0.0576777171135306} | {5.7,4.4,1.5,0.4} | {1,0}
| 1 | Tennessee
37 | {0.938684928938641,0.0613150710613592} | {5.4,3.9,1.3,0.4} | {1,0}
| 1 | Tennessee
38 | {0.0457032748591934,0.954296725140807} | {6.0,2.9,4.5,1.5} | {0,1}
| 2 | Tennessee
39 | {0.0944184541813754,0.905581545818625} | {5.7,2.6,3.5,1.0} | {0,1}
| 2 | Tennessee
40 | {0.0645243589381724,0.935475641061828} | {5.5,2.4,3.8,1.1} | {0,1}
| 2 | Tennessee
41 | {0.0739865590316946,0.926013440968305} | {5.5,2.4,3.7,1.0} | {0,1}
| 2 | Tennessee
42 | {0.0665047837634499,0.93349521623655} | {5.8,2.7,3.9,1.2} | {0,1}
| 2 | Tennessee
43 | {0.0315714539349891,0.968428546065011} | {6.0,2.7,5.1,1.6} | {0,1}
| 2 | Tennessee
44 | {0.0700314082679038,0.929968591732096} | {5.4,3.0,4.5,1.5} | {0,1}
| 2 | Tennessee
45 | {0.0769778072718228,0.923022192728177} | {6.0,3.4,4.5,1.6} | {0,1}
| 2 | Tennessee
46 | {0.0432280654691233,0.956771934530877} | {6.7,3.1,4.7,1.5} | {0,1}
| 2 | Tennessee
47 | {0.0340971244056786,0.965902875594321} | {6.3,2.3,4.4,1.3} | {0,1}
| 2 | Tennessee
48 | {0.0911111408301112,0.908888859169889} | {5.6,3.0,4.1,1.3} | {0,1}
| 2 | Tennessee
49 | {0.0556041814589958,0.944395818541004} | {5.5,2.5,4.0,1.3} | {0,1}
| 2 | Tennessee
50 | {0.0542022706221145,0.945797729377886} | {5.5,2.6,4.4,1.2} | {0,1}
| 2 | Tennessee
51 | {0.0490083738977815,0.950991626102219} | {6.1,3.0,4.6,1.4} | {0,1}
| 2 | Tennessee
52 | {0.0567530672794966,0.943246932720503} | {5.8,2.6,4.0,1.2} | {0,1}
| 2 | Tennessee
(52 rows)
{code}
> Support already encoded arrays for dependent var in MLP classification
> ----------------------------------------------------------------------
>
> Key: MADLIB-1222
> URL: https://issues.apache.org/jira/browse/MADLIB-1222
> Project: Apache MADlib
> Issue Type: New Feature
> Components: Module: Neural Networks
> Reporter: Nandish Jayaram
> Priority: Major
> Fix For: v1.14
>
>
> MLP currently only supports scalar dependent variables for MLP
> classification. If a user has already one-hot encoded categorical variables
> the dependent variable will be an array, and hence unusable with
> mlp_classification. This feature request is to allow the use of one-hot
> encoded array for dependent vars in MLP classification.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)