[
https://issues.apache.org/jira/browse/MADLIB-1222?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16426074#comment-16426074
]
Frank McQuillan commented on MADLIB-1222:
-----------------------------------------
For IGD (without minibatch) this seems to work OK. e.g., using a modified
version user docs example:
[http://madlib.apache.org/docs/latest/group__grp__nn.html#example]
{code:sql}
DROP TABLE IF EXISTS iris_data;
CREATE TABLE iris_data(
id serial,
attributes numeric[],
class_integer integer[],
class integer,
state varchar
);
INSERT INTO iris_data(id, attributes, class_integer, class, state) VALUES
(1,ARRAY[5.0,3.2,1.2,0.2], ARRAY[1,0],1,'Alaska'),
(2,ARRAY[5.5,3.5,1.3,0.2], ARRAY[1,0],1,'Alaska'),
(3,ARRAY[4.9,3.1,1.5,0.1], ARRAY[1,0],1,'Alaska'),
(4,ARRAY[4.4,3.0,1.3,0.2], ARRAY[1,0],1,'Alaska'),
(5,ARRAY[5.1,3.4,1.5,0.2], ARRAY[1,0],1,'Alaska'),
(6,ARRAY[5.0,3.5,1.3,0.3], ARRAY[1,0],1,'Alaska'),
(7,ARRAY[4.5,2.3,1.3,0.3], ARRAY[1,0],1,'Alaska'),
(8,ARRAY[4.4,3.2,1.3,0.2], ARRAY[1,0],1,'Alaska'),
(9,ARRAY[5.0,3.5,1.6,0.6], ARRAY[1,0],1,'Alaska'),
(10,ARRAY[5.1,3.8,1.9,0.4], ARRAY[1,0],1,'Alaska'),
(11,ARRAY[4.8,3.0,1.4,0.3], ARRAY[1,0],1,'Alaska'),
(12,ARRAY[5.1,3.8,1.6,0.2], ARRAY[1,0],1,'Alaska'),
(13,ARRAY[5.7,2.8,4.5,1.3], ARRAY[0,1],2,'Alaska'),
(14,ARRAY[6.3,3.3,4.7,1.6], ARRAY[0,1],2,'Alaska'),
(15,ARRAY[4.9,2.4,3.3,1.0], ARRAY[0,1],2,'Alaska'),
(16,ARRAY[6.6,2.9,4.6,1.3], ARRAY[0,1],2,'Alaska'),
(17,ARRAY[5.2,2.7,3.9,1.4], ARRAY[0,1],2,'Alaska'),
(18,ARRAY[5.0,2.0,3.5,1.0], ARRAY[0,1],2,'Alaska'),
(19,ARRAY[5.9,3.0,4.2,1.5], ARRAY[0,1],2,'Alaska'),
(20,ARRAY[6.0,2.2,4.0,1.0], ARRAY[0,1],2,'Alaska'),
(21,ARRAY[6.1,2.9,4.7,1.4], ARRAY[0,1],2,'Alaska'),
(22,ARRAY[5.6,2.9,3.6,1.3], ARRAY[0,1],2,'Alaska'),
(23,ARRAY[6.7,3.1,4.4,1.4], ARRAY[0,1],2,'Alaska'),
(24,ARRAY[5.6,3.0,4.5,1.5], ARRAY[0,1],2,'Alaska'),
(25,ARRAY[5.8,2.7,4.1,1.0], ARRAY[0,1],2,'Alaska'),
(26,ARRAY[6.2,2.2,4.5,1.5], ARRAY[0,1],2,'Alaska'),
(27,ARRAY[5.6,2.5,3.9,1.1], ARRAY[0,1],2,'Alaska'),
(28,ARRAY[5.0,3.4,1.5,0.2], ARRAY[1,0],1,'Tennessee'),
(29,ARRAY[4.4,2.9,1.4,0.2], ARRAY[1,0],1,'Tennessee'),
(30,ARRAY[4.9,3.1,1.5,0.1], ARRAY[1,0],1,'Tennessee'),
(31,ARRAY[5.4,3.7,1.5,0.2], ARRAY[1,0],1,'Tennessee'),
(32,ARRAY[4.8,3.4,1.6,0.2], ARRAY[1,0],1,'Tennessee'),
(33,ARRAY[4.8,3.0,1.4,0.1], ARRAY[1,0],1,'Tennessee'),
(34,ARRAY[4.3,3.0,1.1,0.1], ARRAY[1,0],1,'Tennessee'),
(35,ARRAY[5.8,4.0,1.2,0.2], ARRAY[1,0],1,'Tennessee'),
(36,ARRAY[5.7,4.4,1.5,0.4], ARRAY[1,0],1,'Tennessee'),
(37,ARRAY[5.4,3.9,1.3,0.4], ARRAY[1,0],1,'Tennessee'),
(38,ARRAY[6.0,2.9,4.5,1.5], ARRAY[0,1],2,'Tennessee'),
(39,ARRAY[5.7,2.6,3.5,1.0], ARRAY[0,1],2,'Tennessee'),
(40,ARRAY[5.5,2.4,3.8,1.1], ARRAY[0,1],2,'Tennessee'),
(41,ARRAY[5.5,2.4,3.7,1.0], ARRAY[0,1],2,'Tennessee'),
(42,ARRAY[5.8,2.7,3.9,1.2], ARRAY[0,1],2,'Tennessee'),
(43,ARRAY[6.0,2.7,5.1,1.6], ARRAY[0,1],2,'Tennessee'),
(44,ARRAY[5.4,3.0,4.5,1.5], ARRAY[0,1],2,'Tennessee'),
(45,ARRAY[6.0,3.4,4.5,1.6], ARRAY[0,1],2,'Tennessee'),
(46,ARRAY[6.7,3.1,4.7,1.5], ARRAY[0,1],2,'Tennessee'),
(47,ARRAY[6.3,2.3,4.4,1.3], ARRAY[0,1],2,'Tennessee'),
(48,ARRAY[5.6,3.0,4.1,1.3], ARRAY[0,1],2,'Tennessee'),
(49,ARRAY[5.5,2.5,4.0,1.3], ARRAY[0,1],2,'Tennessee'),
(50,ARRAY[5.5,2.6,4.4,1.2], ARRAY[0,1],2,'Tennessee'),
(51,ARRAY[6.1,3.0,4.6,1.4], ARRAY[0,1],2,'Tennessee'),
(52,ARRAY[5.8,2.6,4.0,1.2], ARRAY[0,1],2,'Tennessee');
{code}
{code:sql}
DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization;
-- Set seed so results are reproducible
SELECT setseed(0);
SELECT madlib.mlp_classification(
'iris_data', -- Source table
'mlp_model', -- Destination table
'attributes', -- Input features
'class_integer', -- Label
ARRAY[5], -- Number of units per layer
'learning_rate_init=0.003,
n_iterations=500,
tolerance=0', -- Optimizer params
'tanh', -- Activation function
NULL, -- Default weight (1)
FALSE, -- No warm start
FALSE -- Not verbose
);
{code}
{code:sql}
DROP TABLE IF EXISTS mlp_prediction;
SELECT madlib.mlp_predict(
'mlp_model', -- Model table
'iris_data', -- Test data table
'id', -- Id column in test table
'mlp_prediction', -- Output table for predictions
'prob' -- Output classes, not probabilities
);
SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
{code}
produces
{code}
id | estimated_class_integer | attributes | class_integer | class |
state
----+-------------------------+-------------------+---------------+-------+-----------
1 | {1,0} | {5.0,3.2,1.2,0.2} | {1,0} | 1 |
Alaska
2 | {1,0} | {5.5,3.5,1.3,0.2} | {1,0} | 1 |
Alaska
3 | {1,0} | {4.9,3.1,1.5,0.1} | {1,0} | 1 |
Alaska
4 | {1,0} | {4.4,3.0,1.3,0.2} | {1,0} | 1 |
Alaska
5 | {1,0} | {5.1,3.4,1.5,0.2} | {1,0} | 1 |
Alaska
6 | {1,0} | {5.0,3.5,1.3,0.3} | {1,0} | 1 |
Alaska
7 | {1,0} | {4.5,2.3,1.3,0.3} | {1,0} | 1 |
Alaska
8 | {1,0} | {4.4,3.2,1.3,0.2} | {1,0} | 1 |
Alaska
9 | {1,0} | {5.0,3.5,1.6,0.6} | {1,0} | 1 |
Alaska
10 | {1,0} | {5.1,3.8,1.9,0.4} | {1,0} | 1 |
Alaska
11 | {1,0} | {4.8,3.0,1.4,0.3} | {1,0} | 1 |
Alaska
12 | {1,0} | {5.1,3.8,1.6,0.2} | {1,0} | 1 |
Alaska
13 | {0,1} | {5.7,2.8,4.5,1.3} | {0,1} | 2 |
Alaska
14 | {0,1} | {6.3,3.3,4.7,1.6} | {0,1} | 2 |
Alaska
15 | {0,1} | {4.9,2.4,3.3,1.0} | {0,1} | 2 |
Alaska
16 | {0,1} | {6.6,2.9,4.6,1.3} | {0,1} | 2 |
Alaska
17 | {0,1} | {5.2,2.7,3.9,1.4} | {0,1} | 2 |
Alaska
18 | {0,1} | {5.0,2.0,3.5,1.0} | {0,1} | 2 |
Alaska
19 | {0,1} | {5.9,3.0,4.2,1.5} | {0,1} | 2 |
Alaska
20 | {0,1} | {6.0,2.2,4.0,1.0} | {0,1} | 2 |
Alaska
21 | {0,1} | {6.1,2.9,4.7,1.4} | {0,1} | 2 |
Alaska
22 | {0,1} | {5.6,2.9,3.6,1.3} | {0,1} | 2 |
Alaska
23 | {0,1} | {6.7,3.1,4.4,1.4} | {0,1} | 2 |
Alaska
24 | {0,1} | {5.6,3.0,4.5,1.5} | {0,1} | 2 |
Alaska
25 | {0,1} | {5.8,2.7,4.1,1.0} | {0,1} | 2 |
Alaska
26 | {0,1} | {6.2,2.2,4.5,1.5} | {0,1} | 2 |
Alaska
27 | {0,1} | {5.6,2.5,3.9,1.1} | {0,1} | 2 |
Alaska
28 | {1,0} | {5.0,3.4,1.5,0.2} | {1,0} | 1 |
Tennessee
29 | {1,0} | {4.4,2.9,1.4,0.2} | {1,0} | 1 |
Tennessee
30 | {1,0} | {4.9,3.1,1.5,0.1} | {1,0} | 1 |
Tennessee
31 | {1,0} | {5.4,3.7,1.5,0.2} | {1,0} | 1 |
Tennessee
32 | {1,0} | {4.8,3.4,1.6,0.2} | {1,0} | 1 |
Tennessee
33 | {1,0} | {4.8,3.0,1.4,0.1} | {1,0} | 1 |
Tennessee
34 | {1,0} | {4.3,3.0,1.1,0.1} | {1,0} | 1 |
Tennessee
35 | {1,0} | {5.8,4.0,1.2,0.2} | {1,0} | 1 |
Tennessee
36 | {1,0} | {5.7,4.4,1.5,0.4} | {1,0} | 1 |
Tennessee
37 | {1,0} | {5.4,3.9,1.3,0.4} | {1,0} | 1 |
Tennessee
38 | {0,1} | {6.0,2.9,4.5,1.5} | {0,1} | 2 |
Tennessee
39 | {0,1} | {5.7,2.6,3.5,1.0} | {0,1} | 2 |
Tennessee
40 | {0,1} | {5.5,2.4,3.8,1.1} | {0,1} | 2 |
Tennessee
41 | {0,1} | {5.5,2.4,3.7,1.0} | {0,1} | 2 |
Tennessee
42 | {0,1} | {5.8,2.7,3.9,1.2} | {0,1} | 2 |
Tennessee
43 | {0,1} | {6.0,2.7,5.1,1.6} | {0,1} | 2 |
Tennessee
44 | {0,1} | {5.4,3.0,4.5,1.5} | {0,1} | 2 |
Tennessee
45 | {0,1} | {6.0,3.4,4.5,1.6} | {0,1} | 2 |
Tennessee
46 | {0,1} | {6.7,3.1,4.7,1.5} | {0,1} | 2 |
Tennessee
47 | {0,1} | {6.3,2.3,4.4,1.3} | {0,1} | 2 |
Tennessee
48 | {0,1} | {5.6,3.0,4.1,1.3} | {0,1} | 2 |
Tennessee
49 | {0,1} | {5.5,2.5,4.0,1.3} | {0,1} | 2 |
Tennessee
50 | {0,1} | {5.5,2.6,4.4,1.2} | {0,1} | 2 |
Tennessee
51 | {0,1} | {6.1,3.0,4.6,1.4} | {0,1} | 2 |
Tennessee
52 | {0,1} | {5.8,2.6,4.0,1.2} | {0,1} | 2 |
Tennessee
(52 rows)
{code}
{code:sql}
DROP TABLE IF EXISTS mlp_prediction;
SELECT madlib.mlp_predict(
'mlp_model', -- Model table
'iris_data', -- Test data table
'id', -- Id column in test table
'mlp_prediction', -- Output table for predictions
'response' -- Output classes, not probabilities
);
SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
{code}
produces
{code}
id | estimated_prob | attributes |
class_integer | class | state
----+-----------------------------------------+-------------------+---------------+-------+-----------
1 | {0.994917710195306,0.00508228980469395} | {5.0,3.2,1.2,0.2} | {1,0}
| 1 | Alaska
2 | {0.995159885858941,0.00484011414105887} | {5.5,3.5,1.3,0.2} | {1,0}
| 1 | Alaska
3 | {0.994318306218412,0.00568169378158784} | {4.9,3.1,1.5,0.1} | {1,0}
| 1 | Alaska
4 | {0.994618026320835,0.00538197367916474} | {4.4,3.0,1.3,0.2} | {1,0}
| 1 | Alaska
5 | {0.994870293358612,0.00512970664138773} | {5.1,3.4,1.5,0.2} | {1,0}
| 1 | Alaska
6 | {0.99531633079363,0.00468366920636989} | {5.0,3.5,1.3,0.3} | {1,0}
| 1 | Alaska
7 | {0.98418299604323,0.0158170039567704} | {4.5,2.3,1.3,0.3} | {1,0}
| 1 | Alaska
8 | {0.995196642987703,0.00480335701229707} | {4.4,3.2,1.3,0.2} | {1,0}
| 1 | Alaska
9 | {0.992720652969015,0.00727934703098487} | {5.0,3.5,1.6,0.6} | {1,0}
| 1 | Alaska
10 | {0.994520287818222,0.00547971218177782} | {5.1,3.8,1.9,0.4} | {1,0}
| 1 | Alaska
11 | {0.99301979697995,0.00698020302005045} | {4.8,3.0,1.4,0.3} | {1,0}
| 1 | Alaska
12 | {0.995776047105716,0.00422395289428424} | {5.1,3.8,1.6,0.2} | {1,0}
| 1 | Alaska
13 | {0.00223106616944025,0.99776893383056} | {5.7,2.8,4.5,1.3} | {0,1}
| 2 | Alaska
14 | {0.00182750530822185,0.998172494691778} | {6.3,3.3,4.7,1.6} | {0,1}
| 2 | Alaska
15 | {0.0232861025970201,0.97671389740298} | {4.9,2.4,3.3,1.0} | {0,1}
| 2 | Alaska
16 | {0.00174866608296875,0.998251333917031} | {6.6,2.9,4.6,1.3} | {0,1}
| 2 | Alaska
17 | {0.0038505116920885,0.996149488307912} | {5.2,2.7,3.9,1.4} | {0,1}
| 2 | Alaska
18 | {0.0049907726168201,0.99500922738318} | {5.0,2.0,3.5,1.0} | {0,1}
| 2 | Alaska
19 | {0.00236757504327753,0.997632424956722} | {5.9,3.0,4.2,1.5} | {0,1}
| 2 | Alaska
20 | {0.0022022382208288,0.997797761779171} | {6.0,2.2,4.0,1.0} | {0,1}
| 2 | Alaska
21 | {0.00172433217610708,0.998275667823893} | {6.1,2.9,4.7,1.4} | {0,1}
| 2 | Alaska
22 | {0.00698273385715221,0.993017266142848} | {5.6,2.9,3.6,1.3} | {0,1}
| 2 | Alaska
23 | {0.00193627184803815,0.998063728151962} | {6.7,3.1,4.4,1.4} | {0,1}
| 2 | Alaska
24 | {0.00227660985655827,0.997723390143442} | {5.6,3.0,4.5,1.5} | {0,1}
| 2 | Alaska
25 | {0.00422604973235051,0.995773950267649} | {5.8,2.7,4.1,1.0} | {0,1}
| 2 | Alaska
26 | {0.0012841752883848,0.998715824711615} | {6.2,2.2,4.5,1.5} | {0,1}
| 2 | Alaska
27 | {0.00356897386123549,0.996431026138765} | {5.6,2.5,3.9,1.1} | {0,1}
| 2 | Alaska
28 | {0.994986190795732,0.00501380920426748} | {5.0,3.4,1.5,0.2} | {1,0}
| 1 | Tennessee
29 | {0.99394398852259,0.0060560114774095} | {4.4,2.9,1.4,0.2} | {1,0}
| 1 | Tennessee
30 | {0.994318306218412,0.00568169378158784} | {4.9,3.1,1.5,0.1} | {1,0}
| 1 | Tennessee
31 | {0.995521219175118,0.00447878082488187} | {5.4,3.7,1.5,0.2} | {1,0}
| 1 | Tennessee
32 | {0.994970147701769,0.00502985229823091} | {4.8,3.4,1.6,0.2} | {1,0}
| 1 | Tennessee
33 | {0.994286918345676,0.00571308165432426} | {4.8,3.0,1.4,0.1} | {1,0}
| 1 | Tennessee
34 | {0.995346091954601,0.00465390804539922} | {4.3,3.0,1.1,0.1} | {1,0}
| 1 | Tennessee
35 | {0.996278913102386,0.00372108689761437} | {5.8,4.0,1.2,0.2} | {1,0}
| 1 | Tennessee
36 | {0.996223751852219,0.00377624814778116} | {5.7,4.4,1.5,0.4} | {1,0}
| 1 | Tennessee
37 | {0.995721034181034,0.00427896581896577} | {5.4,3.9,1.3,0.4} | {1,0}
| 1 | Tennessee
38 | {0.00177592816347139,0.998224071836529} | {6.0,2.9,4.5,1.5} | {0,1}
| 2 | Tennessee
39 | {0.00828060448377906,0.991719395516221} | {5.7,2.6,3.5,1.0} | {0,1}
| 2 | Tennessee
40 | {0.0036452266897908,0.996354773310209} | {5.5,2.4,3.8,1.1} | {0,1}
| 2 | Tennessee
41 | {0.0051056630637088,0.994894336936291} | {5.5,2.4,3.7,1.0} | {0,1}
| 2 | Tennessee
42 | {0.00343988163885057,0.996560118361149} | {5.8,2.7,3.9,1.2} | {0,1}
| 2 | Tennessee
43 | {0.00129331249512001,0.99870668750488} | {6.0,2.7,5.1,1.6} | {0,1}
| 2 | Tennessee
44 | {0.00256724016890917,0.997432759831091} | {5.4,3.0,4.5,1.5} | {0,1}
| 2 | Tennessee
45 | {0.00255900504882963,0.99744099495117} | {6.0,3.4,4.5,1.6} | {0,1}
| 2 | Tennessee
46 | {0.00161519300379958,0.9983848069962} | {6.7,3.1,4.7,1.5} | {0,1}
| 2 | Tennessee
47 | {0.00146581473536153,0.998534185264638} | {6.3,2.3,4.4,1.3} | {0,1}
| 2 | Tennessee
48 | {0.00452800962952035,0.99547199037048} | {5.6,3.0,4.1,1.3} | {0,1}
| 2 | Tennessee
49 | {0.00249834272514067,0.997501657274859} | {5.5,2.5,4.0,1.3} | {0,1}
| 2 | Tennessee
50 | {0.00242578006971178,0.997574219930288} | {5.5,2.6,4.4,1.2} | {0,1}
| 2 | Tennessee
51 | {0.00193399291827211,0.998066007081728} | {6.1,3.0,4.6,1.4} | {0,1}
| 2 | Tennessee
52 | {0.00275581131696433,0.997244188683036} | {5.8,2.6,4.0,1.2} | {0,1}
| 2 | Tennessee
(52 rows)
{code}
> Support already encoded arrays for dependent var in MLP classification
> ----------------------------------------------------------------------
>
> Key: MADLIB-1222
> URL: https://issues.apache.org/jira/browse/MADLIB-1222
> Project: Apache MADlib
> Issue Type: New Feature
> Components: Module: Neural Networks
> Reporter: Nandish Jayaram
> Priority: Major
> Fix For: v1.14
>
>
> MLP currently only supports scalar dependent variables for MLP
> classification. If a user has already one-hot encoded categorical variables
> the dependent variable will be an array, and hence unusable with
> mlp_classification. This feature request is to allow the use of one-hot
> encoded array for dependent vars in MLP classification.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)