Repository: incubator-madlib-site Updated Branches: refs/heads/asf-site d147759e6 -> 7fa3b7965
http://git-wip-us.apache.org/repos/asf/incubator-madlib-site/blob/7fa3b796/community-artifacts/mlp-v1.ipynb ---------------------------------------------------------------------- diff --git a/community-artifacts/mlp-v1.ipynb b/community-artifacts/mlp-v1.ipynb new file mode 100644 index 0000000..eaba7fb --- /dev/null +++ b/community-artifacts/mlp-v1.ipynb @@ -0,0 +1,1525 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multilayer Perceptron" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated. You should import from traitlets.config instead.\n", + " \"You should import from traitlets.config instead.\", ShimWarning)\n", + "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n", + " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n" + ] + } + ], + "source": [ + "%load_ext sql" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "u'Connected: gpdbchina@madlib'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Greenplum 4.3.10.0\n", + "%sql postgresql://gpdbchina@10.194.10.68:61000/madlib\n", + " \n", + "# PostgreSQL local\n", + "#%sql postgresql://fmcquillan@localhost:5432/madlib\n", + "\n", + "# Greenplum 4.2.3.0\n", + "#%sql postgresql://gpdbchina@10.194.10.68:55000/madlib" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>version</th>\n", + " </tr>\n", + " <tr>\n", + " <td>MADlib version: 1.12-dev, git revision: rel/v1.11-48-gff1b0f8, cmake configuration time: Tue Aug 15 00:36:04 UTC 2017, build type: Release, build system: Linux-2.6.18-238.27.1.el5.hotfix.bz516490, C compiler: gcc 4.4.0, C++ compiler: g++ 4.4.0</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(u'MADlib version: 1.12-dev, git revision: rel/v1.11-48-gff1b0f8, cmake configuration time: Tue Aug 15 00:36:04 UTC 2017, build type: Release, build system: Linux-2.6.18-238.27.1.el5.hotfix.bz516490, C compiler: gcc 4.4.0, C++ compiler: g++ 4.4.0',)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql select madlib.version();\n", + "#%sql select version();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Create input table for classification" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Done.\n", + "20 rows affected.\n", + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>attributes</th>\n", + " <th>class_text</th>\n", + " <th>class</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>[Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>[Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>[Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>[Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>[Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>[Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>[Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>[Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>[Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (2, [Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (3, [Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (4, [Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (5, [Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (6, [Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')], u'Iris-setosa', 1),\n", + " (7, [Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa', 1),\n", + " (8, [Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (9, [Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (10, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa', 1),\n", + " (11, [Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor', 2),\n", + " (12, [Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor', 2),\n", + " (13, [Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor', 2),\n", + " (14, [Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor', 2),\n", + " (15, [Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')], u'Iris-versicolor', 2),\n", + " (16, [Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')], u'Iris-versicolor', 2),\n", + " (17, [Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')], u'Iris-versicolor', 2),\n", + " (18, [Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor', 2),\n", + " (19, [Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')], u'Iris-versicolor', 2),\n", + " (20, [Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')], u'Iris-versicolor', 2)]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql \n", + "DROP TABLE IF EXISTS iris_data;\n", + "\n", + "CREATE TABLE iris_data(\n", + " id integer,\n", + " attributes numeric[],\n", + " class_text varchar,\n", + " class integer\n", + ");\n", + "\n", + "INSERT INTO iris_data VALUES\n", + "(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa',1),\n", + "(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa',1),\n", + "(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa',1),\n", + "(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa',1),\n", + "(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa',1),\n", + "(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa',1),\n", + "(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa',1),\n", + "(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa',1),\n", + "(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa',1),\n", + "(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa',1),\n", + "(11,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor',2),\n", + "(12,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor',2),\n", + "(13,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor',2),\n", + "(14,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor',2),\n", + "(15,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor',2),\n", + "(16,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor',2),\n", + "(17,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor',2),\n", + "(18,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor',2),\n", + "(19,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor',2),\n", + "(20,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor',2);\n", + "\n", + "SELECT * FROM iris_data ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Classification\n", + "\n", + "Generate a multilayer perceptron with a single hidden layer of 5 units. Use the attributes column as the independent variables, and use the class column as the classification. Set the tolerance to 0 so that 500 iterations will be run. Use a hyperbolic tangent activation function. The model will be written to mlp_model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>mlp_classification</th>\n", + " </tr>\n", + " <tr>\n", + " <td></td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[('',)]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS mlp_model, mlp_model_summary;\n", + "\n", + "-- Set seed so results are reproducible\n", + "SELECT setseed(0);\n", + "SELECT madlib.mlp_classification(\n", + " 'iris_data', -- Source table\n", + " 'mlp_model', -- Destination table\n", + " 'attributes', -- Input features\n", + " 'class_text', -- Label\n", + " ARRAY[5], -- Number of units per layer\n", + " 'learning_rate_init=0.003,\n", + " n_iterations=500,\n", + " tolerance=0', -- Optimizer params\n", + " 'tanh', -- Activation function\n", + " NULL, -- Default weight (1)\n", + " FALSE, -- No warm start\n", + " TRUE -- Verbose\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. View the classification model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coeff</th>\n", + " <th>loss</th>\n", + " <th>num_iterations</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[0.134859643584, 0.200285896402, -0.281831690249, 0.755199357168, 0.857042304782, -0.188536767412, -0.291668520498, 0.343686800435, -0.399408727166, -0.179921593947, -0.010662340824, 0.23622232339, -0.257390617236, 0.213182376685, 0.576459373081, 0.306524087619, 0.248260630252, 0.050175145813, -0.101614469022, 0.281200318932, -0.391835525435, -0.0953767781907, -0.384721111012, 0.402854448732, -0.122585128952, 0.110591514785, -1.40623002748, 0.177662074116, -0.247743897977, -0.258200774495, -0.203357963386, 0.122486844237, 1.1677668027, -0.728672776285, 0.837006898515, 0.0198429228676, 0.209390034006]</td>\n", + " <td>0.0320841110315</td>\n", + " <td>500</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([0.134859643584, 0.200285896402, -0.281831690249, 0.755199357168, 0.857042304782, -0.188536767412, -0.291668520498, 0.343686800435, -0.399408727166, -0.179921593947, -0.010662340824, 0.23622232339, -0.257390617236, 0.213182376685, 0.576459373081, 0.306524087619, 0.248260630252, 0.050175145813, -0.101614469022, 0.281200318932, -0.391835525435, -0.0953767781907, -0.384721111012, 0.402854448732, -0.122585128952, 0.110591514785, -1.40623002748, 0.177662074116, -0.247743897977, -0.258200774495, -0.203357963386, 0.122486844237, 1.1677668027, -0.728672776285, 0.837006898515, 0.0198429228676, 0.209390034006], Decimal('0.0320841110315'), 500)]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT * FROM mlp_model;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. Create input data for regression\n", + "\n", + "This dataset contains housing prices." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Done.\n", + "20 rows affected.\n", + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>x</th>\n", + " <th>grp_by_col</th>\n", + " <th>y</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>[0.00632, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98]</td>\n", + " <td>1</td>\n", + " <td>24.0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>[0.02731, 0.0, 7.07, 0.0, 0.469, 6.421, 78.9, 4.9671, 2.0, 242.0, 17.8, 396.9, 9.14]</td>\n", + " <td>1</td>\n", + " <td>21.6</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>[0.02729, 0.0, 7.07, 0.0, 0.469, 7.185, 61.1, 4.9671, 2.0, 242.0, 17.8, 392.83, 4.03]</td>\n", + " <td>1</td>\n", + " <td>34.7</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>[0.03237, 0.0, 2.18, 0.0, 0.458, 6.998, 45.8, 6.0622, 3.0, 222.0, 18.7, 394.63, 2.94]</td>\n", + " <td>1</td>\n", + " <td>33.4</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>[0.06905, 0.0, 2.18, 0.0, 0.458, 7.147, 54.2, 6.0622, 3.0, 222.0, 18.7, 396.9, 5.33]</td>\n", + " <td>1</td>\n", + " <td>36.2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>[0.02985, 0.0, 2.18, 0.0, 0.458, 6.43, 58.7, 6.0622, 3.0, 222.0, 18.7, 394.12, 5.21]</td>\n", + " <td>1</td>\n", + " <td>28.7</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>[0.08829, 12.5, 7.87, 0.0, 0.524, 6.012, 66.6, 5.5605, 5.0, 311.0, 15.2, 395.6, 12.43]</td>\n", + " <td>1</td>\n", + " <td>22.9</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>[0.14455, 12.5, 7.87, 0.0, 0.524, 6.172, 96.1, 5.9505, 5.0, 311.0, 15.2, 396.9, 19.15]</td>\n", + " <td>1</td>\n", + " <td>27.1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>[0.21124, 12.5, 7.87, 0.0, 0.524, 5.631, 100.0, 6.0821, 5.0, 311.0, 15.2, 386.63, 29.93]</td>\n", + " <td>1</td>\n", + " <td>16.5</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>[0.17004, 12.5, 7.87, 0.0, 0.524, 6.004, 85.9, 6.5921, 5.0, 311.0, 15.2, 386.71, 17.1]</td>\n", + " <td>1</td>\n", + " <td>18.9</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>[0.22489, 12.5, 7.87, 0.0, 0.524, 6.377, 94.3, 6.3467, 5.0, 311.0, 15.2, 392.52, 20.45]</td>\n", + " <td>1</td>\n", + " <td>15.0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>[0.11747, 12.5, 7.87, 0.0, 0.524, 6.009, 82.9, 6.2267, 5.0, 311.0, 15.2, 396.9, 13.27]</td>\n", + " <td>1</td>\n", + " <td>18.9</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>[0.09378, 12.5, 7.87, 0.0, 0.524, 5.889, 39.0, 5.4509, 5.0, 311.0, 15.2, 390.5, 15.71]</td>\n", + " <td>1</td>\n", + " <td>21.7</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>[0.62976, 0.0, 8.14, 0.0, 0.538, 5.949, 61.8, 4.7075, 4.0, 307.0, 21.0, 396.9, 8.26]</td>\n", + " <td>1</td>\n", + " <td>20.4</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>[0.63796, 0.0, 8.14, 0.0, 0.538, 6.096, 84.5, 4.4619, 4.0, 307.0, 21.0, 380.02, 10.26]</td>\n", + " <td>1</td>\n", + " <td>18.2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>[0.62739, 0.0, 8.14, 0.0, 0.538, 5.834, 56.5, 4.4986, 4.0, 307.0, 21.0, 395.62, 8.47]</td>\n", + " <td>1</td>\n", + " <td>19.9</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>[1.05393, 0.0, 8.14, 0.0, 0.538, 5.935, 29.3, 4.4986, 4.0, 307.0, 21.0, 386.85, 6.58]</td>\n", + " <td>1</td>\n", + " <td>23.1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>[0.7842, 0.0, 8.14, 0.0, 0.538, 5.99, 81.7, 4.2579, 4.0, 307.0, 21.0, 386.75, 14.67]</td>\n", + " <td>1</td>\n", + " <td>17.5</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>[0.80271, 0.0, 8.14, 0.0, 0.538, 5.456, 36.6, 3.7965, 4.0, 307.0, 21.0, 288.99, 11.69]</td>\n", + " <td>1</td>\n", + " <td>20.2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>[0.7258, 0.0, 8.14, 0.0, 0.538, 5.727, 69.5, 3.7965, 4.0, 307.0, 21.0, 390.95, 11.28]</td>\n", + " <td>1</td>\n", + " <td>18.2</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, [0.00632, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98], 1, 24.0),\n", + " (2, [0.02731, 0.0, 7.07, 0.0, 0.469, 6.421, 78.9, 4.9671, 2.0, 242.0, 17.8, 396.9, 9.14], 1, 21.6),\n", + " (3, [0.02729, 0.0, 7.07, 0.0, 0.469, 7.185, 61.1, 4.9671, 2.0, 242.0, 17.8, 392.83, 4.03], 1, 34.7),\n", + " (4, [0.03237, 0.0, 2.18, 0.0, 0.458, 6.998, 45.8, 6.0622, 3.0, 222.0, 18.7, 394.63, 2.94], 1, 33.4),\n", + " (5, [0.06905, 0.0, 2.18, 0.0, 0.458, 7.147, 54.2, 6.0622, 3.0, 222.0, 18.7, 396.9, 5.33], 1, 36.2),\n", + " (6, [0.02985, 0.0, 2.18, 0.0, 0.458, 6.43, 58.7, 6.0622, 3.0, 222.0, 18.7, 394.12, 5.21], 1, 28.7),\n", + " (7, [0.08829, 12.5, 7.87, 0.0, 0.524, 6.012, 66.6, 5.5605, 5.0, 311.0, 15.2, 395.6, 12.43], 1, 22.9),\n", + " (8, [0.14455, 12.5, 7.87, 0.0, 0.524, 6.172, 96.1, 5.9505, 5.0, 311.0, 15.2, 396.9, 19.15], 1, 27.1),\n", + " (9, [0.21124, 12.5, 7.87, 0.0, 0.524, 5.631, 100.0, 6.0821, 5.0, 311.0, 15.2, 386.63, 29.93], 1, 16.5),\n", + " (10, [0.17004, 12.5, 7.87, 0.0, 0.524, 6.004, 85.9, 6.5921, 5.0, 311.0, 15.2, 386.71, 17.1], 1, 18.9),\n", + " (11, [0.22489, 12.5, 7.87, 0.0, 0.524, 6.377, 94.3, 6.3467, 5.0, 311.0, 15.2, 392.52, 20.45], 1, 15.0),\n", + " (12, [0.11747, 12.5, 7.87, 0.0, 0.524, 6.009, 82.9, 6.2267, 5.0, 311.0, 15.2, 396.9, 13.27], 1, 18.9),\n", + " (13, [0.09378, 12.5, 7.87, 0.0, 0.524, 5.889, 39.0, 5.4509, 5.0, 311.0, 15.2, 390.5, 15.71], 1, 21.7),\n", + " (14, [0.62976, 0.0, 8.14, 0.0, 0.538, 5.949, 61.8, 4.7075, 4.0, 307.0, 21.0, 396.9, 8.26], 1, 20.4),\n", + " (15, [0.63796, 0.0, 8.14, 0.0, 0.538, 6.096, 84.5, 4.4619, 4.0, 307.0, 21.0, 380.02, 10.26], 1, 18.2),\n", + " (16, [0.62739, 0.0, 8.14, 0.0, 0.538, 5.834, 56.5, 4.4986, 4.0, 307.0, 21.0, 395.62, 8.47], 1, 19.9),\n", + " (17, [1.05393, 0.0, 8.14, 0.0, 0.538, 5.935, 29.3, 4.4986, 4.0, 307.0, 21.0, 386.85, 6.58], 1, 23.1),\n", + " (18, [0.7842, 0.0, 8.14, 0.0, 0.538, 5.99, 81.7, 4.2579, 4.0, 307.0, 21.0, 386.75, 14.67], 1, 17.5),\n", + " (19, [0.80271, 0.0, 8.14, 0.0, 0.538, 5.456, 36.6, 3.7965, 4.0, 307.0, 21.0, 288.99, 11.69], 1, 20.2),\n", + " (20, [0.7258, 0.0, 8.14, 0.0, 0.538, 5.727, 69.5, 3.7965, 4.0, 307.0, 21.0, 390.95, 11.28], 1, 18.2)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS lin_housing;\n", + "\n", + "CREATE TABLE lin_housing (id serial, \n", + " x float8[], \n", + " grp_by_col int, \n", + " y float8);\n", + "\n", + "INSERT INTO lin_housing VALUES\n", + "(1,ARRAY[0.00632,18.00,2.310,0,0.5380,6.5750,65.20,4.0900,1,296.0,15.30,396.90,4.98],1,24.00),\n", + "(2,ARRAY[0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14],1,21.60),\n", + "(3,ARRAY[0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03],1,34.70),\n", + "(4,ARRAY[0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94],1,33.40),\n", + "(5,ARRAY[0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33],1,36.20),\n", + "(6,ARRAY[0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21],1,28.70),\n", + "(7,ARRAY[0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43],1,22.90),\n", + "(8,ARRAY[0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15],1,27.10),\n", + "(9,ARRAY[0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93],1,16.50),\n", + "(10,ARRAY[0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10],1,18.90),\n", + "(11,ARRAY[0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45],1,15.00),\n", + "(12,ARRAY[0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27],1,18.90),\n", + "(13,ARRAY[0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71],1,21.70),\n", + "(14,ARRAY[0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26],1,20.40),\n", + "(15,ARRAY[0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26],1,18.20),\n", + "(16,ARRAY[0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47],1,19.90),\n", + "(17,ARRAY[1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58],1, 23.10),\n", + "(18,ARRAY[0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67],1,17.50),\n", + "(19,ARRAY[0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69],1,20.20),\n", + "(20,ARRAY[0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28],1,18.20);\n", + "\n", + "SELECT * FROM lin_housing ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Regression\n", + "\n", + "Now train a regression model using a multilayer perceptron with 2 hidden layers of 25 nodes each." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>mlp_regression</th>\n", + " </tr>\n", + " <tr>\n", + " <td></td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[('',)]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS mlp_regress;\n", + "DROP TABLE IF EXISTS mlp_regress_summary;\n", + "SELECT setseed(0);\n", + "SELECT madlib.mlp_regression(\n", + " 'lin_housing', -- Source table\n", + " 'mlp_regress', -- Desination table\n", + " 'x', -- Input features\n", + " 'y', -- Dependent variable\n", + " ARRAY[25,25], -- Number of units per layer\n", + " 'learning_rate_init=0.001,\n", + " n_iterations=500,\n", + " lambda=0.001,\n", + " tolerance=0',\n", + " 'relu',\n", + " NULL, -- Default weight (1)\n", + " FALSE, -- No warm start\n", + " TRUE -- Verbose\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6.0 View the regression model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coeff</th>\n", + " <th>loss</th>\n", + " <th>num_iterations</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[0.122856566817, -0.109511383377, 0.10922147183, 0.142820008065, 0.162756253894, -0.104436945183, 0.112036676092, 0.0467338442275, -0.16453372472, 0.0827992259598, 0.0137404038273, 0.0565353617298, -0.0246679212553, -0.042838429926, 0.381777598727, 0.276073801059, 0.00907788148497, 0.21191870498, -0.141702416989, 0.181936732347, -0.136061155165, -0.721381760449, -0.0560067549982, 0.203545531143, -0.00995912442451, 0.0613143120688, -0.0984596213056, -0.204436524678, 0.470767785122, -0.435311003078, -0.0424606276832, 0.160892203008, 0.0445352965619, -0.0879636030809, 0.154146037853, 0.0288780423027, -0.317282313525, 0.385922416222, -0.0664872317641, 0.146757511446, 0.133876976675, -0.116256543538, -0.0213680210249, 0.20015749221, -0.103457503483, -0.0425956948505, 0.121667463986, 0.204968623833, -0.159914333708, 0.149243857206, 0.0456901480449, -0.136975914168, -0.090899493726, 0.103816779262, 0.157303539757, -0.0813411072267, -0.186218530515, -0.18296662406, -0.01 06273344335, -0.15268221771, -0.103478313334, 0.199984721908, 0.140135789025, 0.143966006971, -0.0970102681415, 0.0261086412469, -0.0330555393771, 0.0995076031524, 0.00130931383947, 0.0778756667869, -0.0527086470232, -0.600133053147, 0.0859133658584, 0.24113974288, 0.170332657512, 0.0676027349796, -0.21148456397, 0.248468295235, -0.436384189411, 0.175863820799, 0.123778486759, -0.282581458533, 0.0191701197018, 0.280887575877, 0.21926286976, -0.24666316945, -0.0262117477515, 0.268417193439, -0.0591642733853, 0.168335924156, 0.567262570173, -0.0244295268223, -0.161746761037, 0.239464436128, 0.0927106954455, 0.108399901337, -0.00684509495191, 0.0942018933997, 0.27681768162, -0.0200546475364, -0.0178865707443, 0.0590165895996, 0.178019822826, 0.270901895834, -0.180136838054, 0.201151867097, -0.269229552029, 0.162901821652, 0.176989876451, -0.0562545583224, 0.302891594425, -0.107643708081, 0.124875778612, -0.0717199493045, -0.256875036503, -0.104361508745, 0.0223166114387, -0.07222004733 6, 0.048066243185, 0.0917036420782, -0.0401521607798, -0.166103131, -0.0598667085507, 0.226655306977, 0.196020672052, -0.0019893331732, 0.298047303143, 0.0610271888562, -0.123756854539, -0.138425394609, -0.0813844369306, -0.211061219263, 0.280266006616, -0.456958897687, 0.0957847778583, 0.10211765487, -0.280257589355, 0.31403698235, -0.0988827746221, -0.315188453188, 0.0323558432208, 0.30938173242, -0.337991301037, -0.0746337727327, -0.128019139853, -0.107135213031, 0.401895277966, -0.0856407552996, 0.338660876519, 0.182709918475, 0.00468663392494, 0.296164404708, -0.149328993889, -0.455192142697, 0.123695517097, -0.101042506058, -0.299448111606, -0.121418328349, -0.015253490999, 0.105945563874, 0.206274179032, 0.0110203190195, -0.0198463256646, -0.122433280535, 0.159641322926, 0.0111190972641, 0.146024395139, -0.3292631924, 0.383620965193, -0.407475374511, 0.124656819426, 0.00814682875215, -0.196043525121, 0.259447828442, -0.0110058837001, -0.0911647683094, -0.495353629738, 0.15832 4486446, 0.284649054294, 0.159015240494, 0.127990452507, -0.0450611236751, 0.100307630204, -0.178426360702, 0.0795949564344, -0.0935834137435, 0.0121807964771, 0.13280000068, -0.0780024566724, 0.268183360375, -0.143863232988, 0.229637060781, 0.143828509532, 0.0467453366612, -0.101445117614, 0.110812332279, 0.274695568803, 0.0886704115438, 0.0920752574129, 0.230045146137, 0.189557642894, 0.175650286388, 0.612810225963, 0.445301543545, -0.115078741068, 0.157970457258, -0.0479338968094, 0.212561173909, 0.180880194985, -0.0870642554799, 0.142339596596, -0.151003784829, 0.0688697053283, -0.128056785862, 0.0793206934645, -0.118269324067, -0.0332473245304, -0.311866015062, -0.117922286012, 0.138892683591, -0.13204744327, 0.0179514948766, -0.019965660049, 0.132229635306, -0.0313989569283, -0.130315567213, 0.0845857500725, -0.159680095558, -0.0112748210112, -0.00992759022784, 0.164452997275, 0.0322838283196, -0.270242562524, 0.162183925006, -0.0574346808757, -0.22116233726, -0.047400722579, 0.00863094243228, 0.211311350784, -0.163996331401, -0.0477683759558, -0.124856946878, -0.00714217581038, -0.266411668298, 0.313991233853, -0.225971650736, 0.161591101292, -0.188041216058, -0.214157824743, 0.118347597707, 0.00614700038529, -0.16123423931, -0.102840533896, -0.102239903301, -0.000257003830486, 0.0511787840226, 0.0912252762164, -0.0751847257127, 0.0846024895256, -0.157788057275, 0.0988942818557, 0.0264955057078, 0.0417793769197, -0.0540072711509, -0.00632082438876, -0.180617287286, -0.0526100222044, 0.118036208602, -0.000661796862675, -0.254004752387, -0.153693011234, -0.171988783244, 0.164053555891, -0.0169405183209, 0.142598359067, -0.085909030833, -0.00680028505823, -0.0510574641112, -0.10346364174, -0.25172618999, 0.457311106971, -0.873025286844, 0.183853172995, -0.124620606266, -0.123278406783, -0.191666220417, 0.159077818728, 0.284815904677, -0.735489991194, 0.351899702528, -0.235562757335, -0.0727504017353, 0.040328813679, -0.125064567205, 0.314741875925, 0.03317 54869107, 0.0884781096995, -0.178033350198, 0.116105414459, -0.0304993356093, -0.27746854463, 0.264004909184, 0.0563022180518, 0.201355642027, -0.0517272948481, 0.0748041423242, -0.0262047282685, -0.092904739373, 0.115647917976, 0.0175972187488, 0.0193950400413, 0.140792231397, -0.00159044852904, -0.00590168585226, -0.139502824903, 0.0410346350032, 0.0413657977828, -0.0633203110196, 0.111349814682, 0.00747228750879, 0.140373814921, 0.102221513338, -0.116981487287, -0.0693222177589, 0.0413462607849, -0.105398620292, -0.143032508267, -0.0178459843317, -0.0731008179726, 0.103677863645, 0.245558735393, 0.0864260843503, 0.180994293221, 0.0917188124414, -0.0121479315063, -0.0382354365149, 0.07547910454, -0.165737544578, 0.0807115289326, -0.134231077856, -0.137895944349, -0.23171311458, -0.163296061142, 0.116367060077, -0.275327082108, 0.000654881341552, -0.101458765282, -0.230568790277, 0.00328713579536, 0.121407554293, 0.270474969032, -0.0332705761859, 0.195253595642, 0.202199456653, 0.0 240241261406, -0.117138257744, 0.262847956454, 0.167488405552, 0.0791663226224, -0.0850169677771, 0.127015057297, 0.0881663075291, 0.157579936939, 0.0790931902296, -0.0350584878084, 0.0687973505113, 0.127800875594, 0.0990169958988, 0.261759911012, -0.0209958907533, 0.178141387084, 0.0802477036446, 0.0375667573926, -0.156527532652, 0.09595744853, 0.144071902168, 0.0272982692838, -0.107756789849, 0.191712610983, 0.131812031368, -0.150147659255, -0.0193064462346, 0.248057351446, 0.132636431219, -0.0656579288641, -0.0102056610199, 0.166573513737, -0.089580189987, -0.13267377489, 0.00765881272357, -0.0294600407492, 0.109945889046, 0.127286865292, -0.103663816189, -0.13701947915, -0.140194114889, -0.00273323080925, 0.193293649137, 0.111701285793, -0.00816508369269, -0.128614542133, 0.0578130660316, 0.0262319837289, 0.00944211414361, 0.0349982062763, -0.0497391325644, -0.0684296167909, 0.133638380272, -0.00513533046615, -0.113005170272, 0.0371435244753, 0.00685550220379, 0.0345504421305, 0 .0291314878, -0.0963381594242, 0.113578331339, -0.148312320923, -0.139277275139, 0.145797861416, -0.115402530254, -0.0141028671065, -0.0964248545817, 0.109089742884, -0.0746598414841, -0.0469864710904, 0.130259197266, -0.0537714811321, 0.110238509759, 0.0570749324965, -0.159819021215, -0.0835962095075, 0.0959104137424, 0.0437050342893, 0.116577502313, -0.0661250670803, -0.0961631710138, -0.103486980341, 0.0392920895058, -0.134266639255, 0.0601833600592, 0.0995456197992, 0.0780624339557, -0.103283278004, -0.0337175173846, -0.0631874494788, -0.021864548369, -0.0910979985263, -0.0393805762934, 0.0124725373216, -0.118269891962, 0.0190946263827, -0.167315972761, -0.041287897759, -0.0402292813632, -0.124074804749, 0.0743012733141, -0.0889323559481, 0.0291352210856, 0.0152283683786, 0.137832612062, 0.0299226271789, 0.0973969148803, 0.0455620046357, -0.0690846277006, 0.0463525595754, -0.170969427729, -0.00530098180334, 0.0860765294519, 0.0200025350096, 0.0408217087641, 0.000674906042282, -0 .0484522805343, -0.063904387141, 0.0680186364286, 0.0956646181429, 0.0211126410967, -0.135956703385, -0.173184730333, -0.153592609448, 0.0515420758792, -0.106504865816, 0.0427090308662, 0.0588571345969, 0.038628020038, -0.167495451323, -0.171995990523, -0.0746036469661, -0.0893353944191, 0.0769671432324, 0.129290424004, -0.115018423904, -0.0608217250046, 0.0576405168275, 0.138811008142, 0.0484048069161, -0.0648970775095, 0.12408920527, -0.116670120942, -0.11461248868, -0.000556289750656, 0.113501171408, 0.073011667822, 0.157085701474, -0.103935971103, -0.0285974926781, 0.0668750198788, 0.495352461847, 0.102622017207, 0.343493376056, 0.356233832875, -0.0474178679926, 0.143445828453, 0.276335059268, 0.396289343536, 0.131457852813, 0.0343149415343, 0.343782464698, 0.327289025174, 0.176853423997, 0.24349898742, 0.149069604566, 0.418153967013, 0.235895507356, 0.033694810931, 0.267228687328, 0.0814095665435, 0.120164220897, 0.462571446373, 0.222535013115, 0.11424010583, -0.0906197141098, 0.129458603213, 0.108931721147, -0.0890778578148, 0.191923971562, 0.17901983388, 0.151099630479, 0.0566764331123, 0.0489030212844, 0.073160549313, 0.085809192495, 0.0485886118323, 0.0215963694964, -0.0834544297183, 0.0327265042333, -0.0931091459474, 0.112069055149, 0.043039352435, 0.130235543317, 0.0299101055543, -0.149377273183, 0.00646975988585, -0.000554735025032, -0.0831731375459, 0.13563300492, -0.0165566474276, 0.0485383902634, 0.00888025767832, -0.139154446785, -0.0353843799165, -0.0871916612905, -0.04686546779, -0.166337955736, 0.0750002808334, 0.0168899088843, 0.104283523831, 0.0187089188179, -0.0251945516132, -0.0131352016429, 0.170830441379, -0.154809267796, -0.155763940099, 0.0443056586881, 0.0274827557274, -0.0977496234606, -0.0970180773028, 0.0350993020096, 0.146745061397, 0.078815475605, -0.0389907438712, -0.0335619958692, 0.118211074784, 0.152649314105, 0.147007119673, 0.103915506114, 0.0890045904999, -0.0362673918145, 0.0113676661504, -0.0975364965898, -0.0103831091 45, 0.179494474761, -0.00826024925394, 0.139510966287, 0.132164679775, -0.133918793256, -0.00635655788897, 0.0767882832999, 0.10726612898, 0.142674161454, -0.0883857724845, 0.0855403411171, 0.165220673799, -0.112104374468, -0.0380395831084, 0.00501932226928, 0.00263269004584, 0.0356991274653, -0.113333317511, -0.0124000028152, -0.0515325939146, 0.373876728674, 0.207596737771, 0.0813373842185, 0.183371106337, 0.165925307432, -0.135548407527, 0.12660809146, 0.0536288727356, 0.147716796788, 0.0568205245395, 0.197828755587, 0.176202897769, -0.101477342881, 0.159869918913, 0.0979205147731, 0.0746576092262, -0.103825525431, 0.00928188760788, -0.0504819433487, 0.191625701797, 0.0852383221501, 0.358258019876, 0.227735982591, 0.176158268625, -0.111316520456, 0.0455282389086, -0.0614696960162, 0.0843890679754, 0.0463263728094, 0.119850225612, -0.11927794611, -0.0713883054891, 0.0924825570733, -0.0624807055815, 0.115190061413, 0.0844764278027, -0.0132342347732, -0.0849093011178, 0.152923763906 , -0.148325245058, 0.105231772707, 0.0283105691055, 0.0720970891556, 0.132436487291, -0.10922576116, -0.0815515390022, -0.114722256636, -0.170915126812, 0.0175996345112, -0.124412352668, -0.13304515302, -0.106477616928, 0.120199855771, -0.174671087798, -0.122672321399, -0.144846427759, -0.0154300950821, 0.0507292688075, 0.126551952929, -0.143439054826, 0.0299006158402, -0.15005716624, -0.0401343045517, -0.00252760290587, -0.0237907281236, -0.0797753025531, -0.0991644050106, 0.134308019727, 0.0200117589182, -0.112084247031, -0.167995452395, -0.0295671374508, 0.107351765084, 0.0632551569846, -0.0582422444596, 0.16366750329, -0.162479674609, 0.00214073446262, -0.103278921502, 0.0558777777097, 0.0723732610543, -0.0613630222551, 0.135128811096, 0.050568989534, -0.0546392828733, -0.153971534091, 0.0924751592423, 0.101706791108, 0.0607050649968, 0.0596870445781, 0.136883549105, -0.0637011610473, 0.0871060796234, -0.0681990890046, 0.104202328594, -0.108179302086, 0.0223725516522, -0.1555482 26486, -0.141998314601, -0.122743124803, -0.0875520919781, -0.125286855988, 0.0307176179015, -0.153042075499, 0.23696229498, 0.177058796447, -0.082108841359, 0.20329677785, 0.00175169627491, -0.0846871909768, 0.099835299916, -0.0481316144897, 0.0661170188435, -0.00346217036623, 0.161053868398, 0.207674939326, 0.0281098351705, 0.0860739426955, 0.105192953584, -0.0285058349374, -0.0580393494223, 0.0714749978916, 0.0512850741134, -0.152503380977, 0.168472501535, 0.0485058104422, -0.0672229715629, 0.0117263745766, -0.0365743846743, -0.0412353731283, 0.0291119464053, 0.0441411422804, -0.0316612818875, -0.087567406123, 0.063411170627, -0.0589250368343, -0.107599666557, 0.081908821462, -0.0935361527796, 0.0577467210631, -0.168351581892, 0.11573638204, 0.115117344323, 0.0244735425375, -0.070952849692, 0.0485741149425, 0.00733481525825, 0.116528669326, -0.090380421955, -0.0854344432891, -0.148897840337, -0.00584468517666, 0.125584338188, 0.0303923394942, 0.00866505023759, 0.107045771132, 0.1 04507667791, 0.0759993673344, -0.0560582027212, -0.0601905601019, -0.162052677836, -0.165142098544, -0.171051382284, -0.00747175762698, -0.0725351668202, 0.0574896106535, 0.0883795833561, -0.00115905040131, -0.0208072413433, -0.146708687264, -0.106512542328, -0.0148684889684, 0.135203007488, -0.159366976488, -0.185727272921, -0.0986049216812, 0.0663781391027, 0.00337193126349, -0.143224040776, 0.151377444068, 0.122900039992, -0.123876153218, -0.0616559523895, 0.07474867887, 0.0759827169541, -0.196835030676, 0.0240963945802, 0.0815042565879, 0.0813777417882, 0.122173760283, -0.149634521852, 0.0950858936023, 0.147034831794, -0.143746034738, -0.0783443886749, -0.100895898859, 0.090389376971, -0.163667864143, 0.0799367976539, -0.080443647121, -0.112871360843, 0.147539108095, 0.0899145483935, -0.161477148966, 0.138928943053, 0.110902779082, -0.0666307526333, 0.0729760796573, 0.0290997225314, -0.0156054718084, 0.111971399191, -0.0448196144982, 0.0547285642038, -0.12753315933, -0.144334282 657, -0.0106939193609, -0.0986907128641, 0.0600266162345, -0.120230639837, 0.16669020756, 0.0439133843093, -0.0724689439423, 0.0629139032688, 0.0504403844796, -0.068271656679, 0.156640645958, 0.0919725761945, -0.127266025876, -0.159289153952, 0.00546200555327, -0.0223831026984, -0.0931915620749, -0.0163207716589, -0.12419155542, 0.116520477065, 0.0153427362007, -0.129060892518, 0.157096091442, -0.129280966926, -0.0137527695116, -0.0171380709271, 0.195707287676, 0.142599751827, 0.072152904103, 0.252011620265, 0.207324038903, -0.125405121498, -0.100013512174, 0.0583275733665, -0.0146841391113, 0.196379045132, -0.0875501697175, 0.078958660421, 0.0740925533203, 0.0902143324496, 0.29527471964, 0.073563228215, -0.0180613309394, 0.115906200451, 0.0786093503948, 1.306422046, 0.137688222641, 0.756774881139, 0.496403150097, 0.0972268426637, -0.1337106213, 0.739166706063, 0.550671589035, 0.441561444758, 0.231081592497, 0.622412508822, 0.644482827381, 0.282956642749, 0.587122887501, 0.276576675 538, 0.65481113446, 0.229873153345, 0.271528882808, 0.506621646929, 0.00565366103103, 0.282388297824, 1.02057546753, 0.217085174554, -0.0379775424244, -0.10514999404, 0.332566472108, 0.332449111037, 0.124367878135, 0.0141927018472, 0.15855771012, -0.038603288579, -0.135492381545, 0.254667963224, -0.0234027088103, 0.0593833730872, -0.0644732316305, -0.00845219720435, 0.028657566833, 0.0936179669519, 0.254740895255, -0.0461609993025, 0.120413358916, 0.138933025039, 0.196582134527, -0.0400133811542, -0.040323701499, 0.0581573904481, 0.23554382053, 0.192417275699, -0.0776195707746, 0.0592937520315, 0.089886264069, 0.0539065543803, 0.09886462098, 0.0667640588722, -0.18084306698, 0.0222659539242, 0.0403979776829, 0.0775484353974, 0.0791672960253, -0.0899907599431, -0.150064269247, 0.124833764386, -0.120770820436, -0.114178974003, -0.0639780687205, -0.0336269431702, -0.0608694034386, 0.0467895651684, -0.141267333139, -0.064375264314, 0.121018439709, 0.0366265122228, -0.154133636165, 0.0859 110581911, 0.0756296178076, -0.0424145003869, -0.0456028602463, -0.0407671725434, -0.0791181207568, 0.0288097129653, -0.151513578519, 0.138406753718, -0.0233697746503, 0.127112027349, 0.0588020231249, -0.015333084669, -0.0097756075345, -0.0715451307032, -0.0949891422502, -0.0875247946453, 0.0222796529221, -0.0696183219033, -0.128883685169, 0.0827929784601, -0.00779955596619, -0.00487182386305, -0.106664395522, 0.123694800488, -0.129353816579, -0.071810234853, -0.110089746206, -0.171366525529, 0.136882534904, -0.111206506502, 0.0982967459967, 0.0298035370415, 0.0446501099772, -0.117237913284, -0.156925357628, 0.13528309057, 0.0847468661361, -0.150103993267, 0.117264719501, -0.114274591469, 0.157554316643, 0.00546812137529, 0.038074635687, -0.0263940535714, 0.101334013433, 0.128183877645, 0.0670206279241, -0.0442752598958, -0.109288889221, 0.111823141465, -0.148855288965, 0.0475096756177, -0.0695066497011, -0.0537726255674, 0.000632083516048, 1.57413626503, 0.55559495427, 0.3146391992 95, 0.0363237197077, 0.0298504461036, -0.224219935066, -0.157329470601, 1.21592745382, 0.234224608374, -0.0740354122049, 0.165430831517, 0.66674980751, -0.12077633358, -0.0337960848896, 0.00708649021609, 0.293365487636, 0.000863544472662, -0.199669889395, -0.210671570015, 0.180015272246, 0.467349281518, 2.58471060843, 0.516690045711, -0.0632937912186, -0.15288034917, -0.0418765986663]</td>\n", + " <td>1.48283239226</td>\n", + " <td>500</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([0.122856566817, -0.109511383377, 0.10922147183, 0.142820008065, 0.162756253894, -0.104436945183, 0.112036676092, 0.0467338442275, -0.16453372472, 0.0827992259598, 0.0137404038273, 0.0565353617298, -0.0246679212553, -0.042838429926, 0.381777598727, 0.276073801059, 0.00907788148497, 0.21191870498, -0.141702416989, 0.181936732347, -0.136061155165, -0.721381760449, -0.0560067549982, 0.203545531143, -0.00995912442451, 0.0613143120688, -0.0984596213056, -0.204436524678, 0.470767785122, -0.435311003078, -0.0424606276832, 0.160892203008, 0.0445352965619, -0.0879636030809, 0.154146037853, 0.0288780423027, -0.317282313525, 0.385922416222, -0.0664872317641, 0.146757511446, 0.133876976675, -0.116256543538, -0.0213680210249, 0.20015749221, -0.103457503483, -0.0425956948505, 0.121667463986, 0.204968623833, -0.159914333708, 0.149243857206, 0.0456901480449, -0.136975914168, -0.090899493726, 0.103816779262, 0.157303539757, -0.0813411072267, -0.186218530515, -0.18296662406, -0.010627334433 5, -0.15268221771, -0.103478313334, 0.199984721908, 0.140135789025, 0.143966006971, -0.0970102681415, 0.0261086412469, -0.0330555393771, 0.0995076031524, 0.00130931383947, 0.0778756667869, -0.0527086470232, -0.600133053147, 0.0859133658584, 0.24113974288, 0.170332657512, 0.0676027349796, -0.21148456397, 0.248468295235, -0.436384189411, 0.175863820799, 0.123778486759, -0.282581458533, 0.0191701197018, 0.280887575877, 0.21926286976, -0.24666316945, -0.0262117477515, 0.268417193439, -0.0591642733853, 0.168335924156, 0.567262570173, -0.0244295268223, -0.161746761037, 0.239464436128, 0.0927106954455, 0.108399901337, -0.00684509495191, 0.0942018933997, 0.27681768162, -0.0200546475364, -0.0178865707443, 0.0590165895996, 0.178019822826, 0.270901895834, -0.180136838054, 0.201151867097, -0.269229552029, 0.162901821652, 0.176989876451, -0.0562545583224, 0.302891594425, -0.107643708081, 0.124875778612, -0.0717199493045, -0.256875036503, -0.104361508745, 0.0223166114387, -0.072220047336, 0.04806 6243185, 0.0917036420782, -0.0401521607798, -0.166103131, -0.0598667085507, 0.226655306977, 0.196020672052, -0.0019893331732, 0.298047303143, 0.0610271888562, -0.123756854539, -0.138425394609, -0.0813844369306, -0.211061219263, 0.280266006616, -0.456958897687, 0.0957847778583, 0.10211765487, -0.280257589355, 0.31403698235, -0.0988827746221, -0.315188453188, 0.0323558432208, 0.30938173242, -0.337991301037, -0.0746337727327, -0.128019139853, -0.107135213031, 0.401895277966, -0.0856407552996, 0.338660876519, 0.182709918475, 0.00468663392494, 0.296164404708, -0.149328993889, -0.455192142697, 0.123695517097, -0.101042506058, -0.299448111606, -0.121418328349, -0.015253490999, 0.105945563874, 0.206274179032, 0.0110203190195, -0.0198463256646, -0.122433280535, 0.159641322926, 0.0111190972641, 0.146024395139, -0.3292631924, 0.383620965193, -0.407475374511, 0.124656819426, 0.00814682875215, -0.196043525121, 0.259447828442, -0.0110058837001, -0.0911647683094, -0.495353629738, 0.158324486446, 0 .284649054294, 0.159015240494, 0.127990452507, -0.0450611236751, 0.100307630204, -0.178426360702, 0.0795949564344, -0.0935834137435, 0.0121807964771, 0.13280000068, -0.0780024566724, 0.268183360375, -0.143863232988, 0.229637060781, 0.143828509532, 0.0467453366612, -0.101445117614, 0.110812332279, 0.274695568803, 0.0886704115438, 0.0920752574129, 0.230045146137, 0.189557642894, 0.175650286388, 0.612810225963, 0.445301543545, -0.115078741068, 0.157970457258, -0.0479338968094, 0.212561173909, 0.180880194985, -0.0870642554799, 0.142339596596, -0.151003784829, 0.0688697053283, -0.128056785862, 0.0793206934645, -0.118269324067, -0.0332473245304, -0.311866015062, -0.117922286012, 0.138892683591, -0.13204744327, 0.0179514948766, -0.019965660049, 0.132229635306, -0.0313989569283, -0.130315567213, 0.0845857500725, -0.159680095558, -0.0112748210112, -0.00992759022784, 0.164452997275, 0.0322838283196, -0.270242562524, 0.162183925006, -0.0574346808757, -0.22116233726, -0.047400722579, 0.00863094 243228, 0.211311350784, -0.163996331401, -0.0477683759558, -0.124856946878, -0.00714217581038, -0.266411668298, 0.313991233853, -0.225971650736, 0.161591101292, -0.188041216058, -0.214157824743, 0.118347597707, 0.00614700038529, -0.16123423931, -0.102840533896, -0.102239903301, -0.000257003830486, 0.0511787840226, 0.0912252762164, -0.0751847257127, 0.0846024895256, -0.157788057275, 0.0988942818557, 0.0264955057078, 0.0417793769197, -0.0540072711509, -0.00632082438876, -0.180617287286, -0.0526100222044, 0.118036208602, -0.000661796862675, -0.254004752387, -0.153693011234, -0.171988783244, 0.164053555891, -0.0169405183209, 0.142598359067, -0.085909030833, -0.00680028505823, -0.0510574641112, -0.10346364174, -0.25172618999, 0.457311106971, -0.873025286844, 0.183853172995, -0.124620606266, -0.123278406783, -0.191666220417, 0.159077818728, 0.284815904677, -0.735489991194, 0.351899702528, -0.235562757335, -0.0727504017353, 0.040328813679, -0.125064567205, 0.314741875925, 0.0331754869107, 0.0884781096995, -0.178033350198, 0.116105414459, -0.0304993356093, -0.27746854463, 0.264004909184, 0.0563022180518, 0.201355642027, -0.0517272948481, 0.0748041423242, -0.0262047282685, -0.092904739373, 0.115647917976, 0.0175972187488, 0.0193950400413, 0.140792231397, -0.00159044852904, -0.00590168585226, -0.139502824903, 0.0410346350032, 0.0413657977828, -0.0633203110196, 0.111349814682, 0.00747228750879, 0.140373814921, 0.102221513338, -0.116981487287, -0.0693222177589, 0.0413462607849, -0.105398620292, -0.143032508267, -0.0178459843317, -0.0731008179726, 0.103677863645, 0.245558735393, 0.0864260843503, 0.180994293221, 0.0917188124414, -0.0121479315063, -0.0382354365149, 0.07547910454, -0.165737544578, 0.0807115289326, -0.134231077856, -0.137895944349, -0.23171311458, -0.163296061142, 0.116367060077, -0.275327082108, 0.000654881341552, -0.101458765282, -0.230568790277, 0.00328713579536, 0.121407554293, 0.270474969032, -0.0332705761859, 0.195253595642, 0.202199456653, 0.02402412614 06, -0.117138257744, 0.262847956454, 0.167488405552, 0.0791663226224, -0.0850169677771, 0.127015057297, 0.0881663075291, 0.157579936939, 0.0790931902296, -0.0350584878084, 0.0687973505113, 0.127800875594, 0.0990169958988, 0.261759911012, -0.0209958907533, 0.178141387084, 0.0802477036446, 0.0375667573926, -0.156527532652, 0.09595744853, 0.144071902168, 0.0272982692838, -0.107756789849, 0.191712610983, 0.131812031368, -0.150147659255, -0.0193064462346, 0.248057351446, 0.132636431219, -0.0656579288641, -0.0102056610199, 0.166573513737, -0.089580189987, -0.13267377489, 0.00765881272357, -0.0294600407492, 0.109945889046, 0.127286865292, -0.103663816189, -0.13701947915, -0.140194114889, -0.00273323080925, 0.193293649137, 0.111701285793, -0.00816508369269, -0.128614542133, 0.0578130660316, 0.0262319837289, 0.00944211414361, 0.0349982062763, -0.0497391325644, -0.0684296167909, 0.133638380272, -0.00513533046615, -0.113005170272, 0.0371435244753, 0.00685550220379, 0.0345504421305, 0.029131487 8, -0.0963381594242, 0.113578331339, -0.148312320923, -0.139277275139, 0.145797861416, -0.115402530254, -0.0141028671065, -0.0964248545817, 0.109089742884, -0.0746598414841, -0.0469864710904, 0.130259197266, -0.0537714811321, 0.110238509759, 0.0570749324965, -0.159819021215, -0.0835962095075, 0.0959104137424, 0.0437050342893, 0.116577502313, -0.0661250670803, -0.0961631710138, -0.103486980341, 0.0392920895058, -0.134266639255, 0.0601833600592, 0.0995456197992, 0.0780624339557, -0.103283278004, -0.0337175173846, -0.0631874494788, -0.021864548369, -0.0910979985263, -0.0393805762934, 0.0124725373216, -0.118269891962, 0.0190946263827, -0.167315972761, -0.041287897759, -0.0402292813632, -0.124074804749, 0.0743012733141, -0.0889323559481, 0.0291352210856, 0.0152283683786, 0.137832612062, 0.0299226271789, 0.0973969148803, 0.0455620046357, -0.0690846277006, 0.0463525595754, -0.170969427729, -0.00530098180334, 0.0860765294519, 0.0200025350096, 0.0408217087641, 0.000674906042282, -0.048452280 5343, -0.063904387141, 0.0680186364286, 0.0956646181429, 0.0211126410967, -0.135956703385, -0.173184730333, -0.153592609448, 0.0515420758792, -0.106504865816, 0.0427090308662, 0.0588571345969, 0.038628020038, -0.167495451323, -0.171995990523, -0.0746036469661, -0.0893353944191, 0.0769671432324, 0.129290424004, -0.115018423904, -0.0608217250046, 0.0576405168275, 0.138811008142, 0.0484048069161, -0.0648970775095, 0.12408920527, -0.116670120942, -0.11461248868, -0.000556289750656, 0.113501171408, 0.073011667822, 0.157085701474, -0.103935971103, -0.0285974926781, 0.0668750198788, 0.495352461847, 0.102622017207, 0.343493376056, 0.356233832875, -0.0474178679926, 0.143445828453, 0.276335059268, 0.396289343536, 0.131457852813, 0.0343149415343, 0.343782464698, 0.327289025174, 0.176853423997, 0.24349898742, 0.149069604566, 0.418153967013, 0.235895507356, 0.033694810931, 0.267228687328, 0.0814095665435, 0.120164220897, 0.462571446373, 0.222535013115, 0.11424010583, -0.0906197141098, 0.12945860 3213, 0.108931721147, -0.0890778578148, 0.191923971562, 0.17901983388, 0.151099630479, 0.0566764331123, 0.0489030212844, 0.073160549313, 0.085809192495, 0.0485886118323, 0.0215963694964, -0.0834544297183, 0.0327265042333, -0.0931091459474, 0.112069055149, 0.043039352435, 0.130235543317, 0.0299101055543, -0.149377273183, 0.00646975988585, -0.000554735025032, -0.0831731375459, 0.13563300492, -0.0165566474276, 0.0485383902634, 0.00888025767832, -0.139154446785, -0.0353843799165, -0.0871916612905, -0.04686546779, -0.166337955736, 0.0750002808334, 0.0168899088843, 0.104283523831, 0.0187089188179, -0.0251945516132, -0.0131352016429, 0.170830441379, -0.154809267796, -0.155763940099, 0.0443056586881, 0.0274827557274, -0.0977496234606, -0.0970180773028, 0.0350993020096, 0.146745061397, 0.078815475605, -0.0389907438712, -0.0335619958692, 0.118211074784, 0.152649314105, 0.147007119673, 0.103915506114, 0.0890045904999, -0.0362673918145, 0.0113676661504, -0.0975364965898, -0.010383109145, 0.1794 94474761, -0.00826024925394, 0.139510966287, 0.132164679775, -0.133918793256, -0.00635655788897, 0.0767882832999, 0.10726612898, 0.142674161454, -0.0883857724845, 0.0855403411171, 0.165220673799, -0.112104374468, -0.0380395831084, 0.00501932226928, 0.00263269004584, 0.0356991274653, -0.113333317511, -0.0124000028152, -0.0515325939146, 0.373876728674, 0.207596737771, 0.0813373842185, 0.183371106337, 0.165925307432, -0.135548407527, 0.12660809146, 0.0536288727356, 0.147716796788, 0.0568205245395, 0.197828755587, 0.176202897769, -0.101477342881, 0.159869918913, 0.0979205147731, 0.0746576092262, -0.103825525431, 0.00928188760788, -0.0504819433487, 0.191625701797, 0.0852383221501, 0.358258019876, 0.227735982591, 0.176158268625, -0.111316520456, 0.0455282389086, -0.0614696960162, 0.0843890679754, 0.0463263728094, 0.119850225612, -0.11927794611, -0.0713883054891, 0.0924825570733, -0.0624807055815, 0.115190061413, 0.0844764278027, -0.0132342347732, -0.0849093011178, 0.152923763906, -0.14832 5245058, 0.105231772707, 0.0283105691055, 0.0720970891556, 0.132436487291, -0.10922576116, -0.0815515390022, -0.114722256636, -0.170915126812, 0.0175996345112, -0.124412352668, -0.13304515302, -0.106477616928, 0.120199855771, -0.174671087798, -0.122672321399, -0.144846427759, -0.0154300950821, 0.0507292688075, 0.126551952929, -0.143439054826, 0.0299006158402, -0.15005716624, -0.0401343045517, -0.00252760290587, -0.0237907281236, -0.0797753025531, -0.0991644050106, 0.134308019727, 0.0200117589182, -0.112084247031, -0.167995452395, -0.0295671374508, 0.107351765084, 0.0632551569846, -0.0582422444596, 0.16366750329, -0.162479674609, 0.00214073446262, -0.103278921502, 0.0558777777097, 0.0723732610543, -0.0613630222551, 0.135128811096, 0.050568989534, -0.0546392828733, -0.153971534091, 0.0924751592423, 0.101706791108, 0.0607050649968, 0.0596870445781, 0.136883549105, -0.0637011610473, 0.0871060796234, -0.0681990890046, 0.104202328594, -0.108179302086, 0.0223725516522, -0.155548226486, -0. 141998314601, -0.122743124803, -0.0875520919781, -0.125286855988, 0.0307176179015, -0.153042075499, 0.23696229498, 0.177058796447, -0.082108841359, 0.20329677785, 0.00175169627491, -0.0846871909768, 0.099835299916, -0.0481316144897, 0.0661170188435, -0.00346217036623, 0.161053868398, 0.207674939326, 0.0281098351705, 0.0860739426955, 0.105192953584, -0.0285058349374, -0.0580393494223, 0.0714749978916, 0.0512850741134, -0.152503380977, 0.168472501535, 0.0485058104422, -0.0672229715629, 0.0117263745766, -0.0365743846743, -0.0412353731283, 0.0291119464053, 0.0441411422804, -0.0316612818875, -0.087567406123, 0.063411170627, -0.0589250368343, -0.107599666557, 0.081908821462, -0.0935361527796, 0.0577467210631, -0.168351581892, 0.11573638204, 0.115117344323, 0.0244735425375, -0.070952849692, 0.0485741149425, 0.00733481525825, 0.116528669326, -0.090380421955, -0.0854344432891, -0.148897840337, -0.00584468517666, 0.125584338188, 0.0303923394942, 0.00866505023759, 0.107045771132, 0.10450766779 1, 0.0759993673344, -0.0560582027212, -0.0601905601019, -0.162052677836, -0.165142098544, -0.171051382284, -0.00747175762698, -0.0725351668202, 0.0574896106535, 0.0883795833561, -0.00115905040131, -0.0208072413433, -0.146708687264, -0.106512542328, -0.0148684889684, 0.135203007488, -0.159366976488, -0.185727272921, -0.0986049216812, 0.0663781391027, 0.00337193126349, -0.143224040776, 0.151377444068, 0.122900039992, -0.123876153218, -0.0616559523895, 0.07474867887, 0.0759827169541, -0.196835030676, 0.0240963945802, 0.0815042565879, 0.0813777417882, 0.122173760283, -0.149634521852, 0.0950858936023, 0.147034831794, -0.143746034738, -0.0783443886749, -0.100895898859, 0.090389376971, -0.163667864143, 0.0799367976539, -0.080443647121, -0.112871360843, 0.147539108095, 0.0899145483935, -0.161477148966, 0.138928943053, 0.110902779082, -0.0666307526333, 0.0729760796573, 0.0290997225314, -0.0156054718084, 0.111971399191, -0.0448196144982, 0.0547285642038, -0.12753315933, -0.144334282657, -0.01 06939193609, -0.0986907128641, 0.0600266162345, -0.120230639837, 0.16669020756, 0.0439133843093, -0.0724689439423, 0.0629139032688, 0.0504403844796, -0.068271656679, 0.156640645958, 0.0919725761945, -0.127266025876, -0.159289153952, 0.00546200555327, -0.0223831026984, -0.0931915620749, -0.0163207716589, -0.12419155542, 0.116520477065, 0.0153427362007, -0.129060892518, 0.157096091442, -0.129280966926, -0.0137527695116, -0.0171380709271, 0.195707287676, 0.142599751827, 0.072152904103, 0.252011620265, 0.207324038903, -0.125405121498, -0.100013512174, 0.0583275733665, -0.0146841391113, 0.196379045132, -0.0875501697175, 0.078958660421, 0.0740925533203, 0.0902143324496, 0.29527471964, 0.073563228215, -0.0180613309394, 0.115906200451, 0.0786093503948, 1.306422046, 0.137688222641, 0.756774881139, 0.496403150097, 0.0972268426637, -0.1337106213, 0.739166706063, 0.550671589035, 0.441561444758, 0.231081592497, 0.622412508822, 0.644482827381, 0.282956642749, 0.587122887501, 0.276576675538, 0.654 81113446, 0.229873153345, 0.271528882808, 0.506621646929, 0.00565366103103, 0.282388297824, 1.02057546753, 0.217085174554, -0.0379775424244, -0.10514999404, 0.332566472108, 0.332449111037, 0.124367878135, 0.0141927018472, 0.15855771012, -0.038603288579, -0.135492381545, 0.254667963224, -0.0234027088103, 0.0593833730872, -0.0644732316305, -0.00845219720435, 0.028657566833, 0.0936179669519, 0.254740895255, -0.0461609993025, 0.120413358916, 0.138933025039, 0.196582134527, -0.0400133811542, -0.040323701499, 0.0581573904481, 0.23554382053, 0.192417275699, -0.0776195707746, 0.0592937520315, 0.089886264069, 0.0539065543803, 0.09886462098, 0.0667640588722, -0.18084306698, 0.0222659539242, 0.0403979776829, 0.0775484353974, 0.0791672960253, -0.0899907599431, -0.150064269247, 0.124833764386, -0.120770820436, -0.114178974003, -0.0639780687205, -0.0336269431702, -0.0608694034386, 0.0467895651684, -0.141267333139, -0.064375264314, 0.121018439709, 0.0366265122228, -0.154133636165, 0.0859110581911, 0.0756296178076, -0.0424145003869, -0.0456028602463, -0.0407671725434, -0.0791181207568, 0.0288097129653, -0.151513578519, 0.138406753718, -0.0233697746503, 0.127112027349, 0.0588020231249, -0.015333084669, -0.0097756075345, -0.0715451307032, -0.0949891422502, -0.0875247946453, 0.0222796529221, -0.0696183219033, -0.128883685169, 0.0827929784601, -0.00779955596619, -0.00487182386305, -0.106664395522, 0.123694800488, -0.129353816579, -0.071810234853, -0.110089746206, -0.171366525529, 0.136882534904, -0.111206506502, 0.0982967459967, 0.0298035370415, 0.0446501099772, -0.117237913284, -0.156925357628, 0.13528309057, 0.0847468661361, -0.150103993267, 0.117264719501, -0.114274591469, 0.157554316643, 0.00546812137529, 0.038074635687, -0.0263940535714, 0.101334013433, 0.128183877645, 0.0670206279241, -0.0442752598958, -0.109288889221, 0.111823141465, -0.148855288965, 0.0475096756177, -0.0695066497011, -0.0537726255674, 0.000632083516048, 1.57413626503, 0.55559495427, 0.314639199295, 0.0363 237197077, 0.0298504461036, -0.224219935066, -0.157329470601, 1.21592745382, 0.234224608374, -0.0740354122049, 0.165430831517, 0.66674980751, -0.12077633358, -0.0337960848896, 0.00708649021609, 0.293365487636, 0.000863544472662, -0.199669889395, -0.210671570015, 0.180015272246, 0.467349281518, 2.58471060843, 0.516690045711, -0.0632937912186, -0.15288034917, -0.0418765986663], Decimal('1.48283239226'), 500)]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT * FROM mlp_regress;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Prediction for classification\n", + "\n", + "In the following examples we will use the training data set for prediction as well, which is not usual but serves to show the syntax. First we will test the classification example. The prediction is in the the estimated_class_text column with the actual value in the class_text column." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>estimated_class_text</th>\n", + " <th>attributes</th>\n", + " <th>class_text</th>\n", + " <th>class</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>Iris-setosa</td>\n", + " <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n", + " <td>Iris-setosa</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n", + " <td>Iris-versicolor</td>\n", + " <td>2</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, u'Iris-setosa', [Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (2, u'Iris-setosa', [Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (3, u'Iris-setosa', [Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (4, u'Iris-setosa', [Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (5, u'Iris-setosa', [Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (6, u'Iris-setosa', [Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')], u'Iris-setosa', 1),\n", + " (7, u'Iris-setosa', [Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')], u'Iris-setosa', 1),\n", + " (8, u'Iris-setosa', [Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (9, u'Iris-setosa', [Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')], u'Iris-setosa', 1),\n", + " (10, u'Iris-setosa', [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris-setosa', 1),\n", + " (11, u'Iris-versicolor', [Decimal('7.0'), Decimal('3.2'), Decimal('4.7'), Decimal('1.4')], u'Iris-versicolor', 2),\n", + " (12, u'Iris-versicolor', [Decimal('6.4'), Decimal('3.2'), Decimal('4.5'), Decimal('1.5')], u'Iris-versicolor', 2),\n", + " (13, u'Iris-versicolor', [Decimal('6.9'), Decimal('3.1'), Decimal('4.9'), Decimal('1.5')], u'Iris-versicolor', 2),\n", + " (14, u'Iris-versicolor', [Decimal('5.5'), Decimal('2.3'), Decimal('4.0'), Decimal('1.3')], u'Iris-versicolor', 2),\n", + " (15, u'Iris-versicolor', [Decimal('6.5'), Decimal('2.8'), Decimal('4.6'), Decimal('1.5')], u'Iris-versicolor', 2),\n", + " (16, u'Iris-versicolor', [Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')], u'Iris-versicolor', 2),\n", + " (17, u'Iris-versicolor', [Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')], u'Iris-versicolor', 2),\n", + " (18, u'Iris-versicolor', [Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')], u'Iris-versicolor', 2),\n", + " (19, u'Iris-versicolor', [Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')], u'Iris-versicolor', 2),\n", + " (20, u'Iris-versicolor', [Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')], u'Iris-versicolor', 2)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS mlp_prediction;\n", + "\n", + "SELECT madlib.mlp_predict(\n", + " 'mlp_model', -- Model table\n", + " 'iris_data', -- Test data table\n", + " 'id', -- Id column in test table\n", + " 'mlp_prediction', -- Output table for predictions\n", + " 'response' -- Output classes, not probabilities\n", + " );\n", + "\n", + "SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count the misclassifications" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>count</th>\n", + " </tr>\n", + " <tr>\n", + " <td>0</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(0L,)]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM mlp_prediction JOIN iris_data USING (id) \n", + "WHERE mlp_prediction.estimated_class_text != iris_data.class_text;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 8. Prediction for regression" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>x</th>\n", + " <th>grp_by_col</th>\n", + " <th>y</th>\n", + " <th>estimated_y</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>[0.00632, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98]</td>\n", + " <td>1</td>\n", + " <td>24.0</td>\n", + " <td>23.997693578</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>[0.02731, 0.0, 7.07, 0.0, 0.469, 6.421, 78.9, 4.9671, 2.0, 242.0, 17.8, 396.9, 9.14]</td>\n", + " <td>1</td>\n", + " <td>21.6</td>\n", + " <td>22.0225551504</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>[0.02729, 0.0, 7.07, 0.0, 0.469, 7.185, 61.1, 4.9671, 2.0, 242.0, 17.8, 392.83, 4.03]</td>\n", + " <td>1</td>\n", + " <td>34.7</td>\n", + " <td>34.3269436787</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>[0.03237, 0.0, 2.18, 0.0, 0.458, 6.998, 45.8, 6.0622, 3.0, 222.0, 18.7, 394.63, 2.94]</td>\n", + " <td>1</td>\n", + " <td>33.4</td>\n", + " <td>34.7421700033</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>[0.06905, 0.0, 2.18, 0.0, 0.458, 7.147, 54.2, 6.0622, 3.0, 222.0, 18.7, 396.9, 5.33]</td>\n", + " <td>1</td>\n", + " <td>36.2</td>\n", + " <td>35.1914922401</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>[0.02985, 0.0, 2.18, 0.0, 0.458, 6.43, 58.7, 6.0622, 3.0, 222.0, 18.7, 394.12, 5.21]</td>\n", + " <td>1</td>\n", + " <td>28.7</td>\n", + " <td>29.5286073544</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>[0.08829, 12.5, 7.87, 0.0, 0.524, 6.012, 66.6, 5.5605, 5.0, 311.0, 15.2, 395.6, 12.43]</td>\n", + " <td>1</td>\n", + " <td>22.9</td>\n", + " <td>23.2022360304</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>[0.14455, 12.5, 7.87, 0.0, 0.524, 6.172, 96.1, 5.9505, 5.0, 311.0, 15.2, 396.9, 19.15]</td>\n", + " <td>1</td>\n", + " <td>27.1</td>\n", + " <td>23.364906529</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>[0.21124, 12.5, 7.87, 0.0, 0.524, 5.631, 100.0, 6.0821, 5.0, 311.0, 15.2, 386.63, 29.93]</td>\n", + " <td>1</td>\n", + " <td>16.5</td>\n", + " <td>17.7779926867</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>[0.17004, 12.5, 7.87, 0.0, 0.524, 6.004, 85.9, 6.5921, 5.0, 311.0, 15.2, 386.71, 17.1]</td>\n", + " <td>1</td>\n", + " <td>18.9</td>\n", + " <td>13.9266690258</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>[0.22489, 12.5, 7.87, 0.0, 0.524, 6.377, 94.3, 6.3467, 5.0, 311.0, 15.2, 392.52, 20.45]</td>\n", + " <td>1</td>\n", + " <td>15.0</td>\n", + " <td>18.5049155839</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>[0.11747, 12.5, 7.87, 0.0, 0.524, 6.009, 82.9, 6.2267, 5.0, 311.0, 15.2, 396.9, 13.27]</td>\n", + " <td>1</td>\n", + " <td>18.9</td>\n", + " <td>18.4287114359</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>[0.09378, 12.5, 7.87, 0.0, 0.524, 5.889, 39.0, 5.4509, 5.0, 311.0, 15.2, 390.5, 15.71]</td>\n", + " <td>1</td>\n", + " <td>21.7</td>\n", + " <td>22.6228336115</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>[0.62976, 0.0, 8.14, 0.0, 0.538, 5.949, 61.8, 4.7075, 4.0, 307.0, 21.0, 396.9, 8.26]</td>\n", + " <td>1</td>\n", + " <td>20.4</td>\n", + " <td>20.1083536059</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>[0.63796, 0.0, 8.14, 0.0, 0.538, 6.096, 84.5, 4.4619, 4.0, 307.0, 21.0, 380.02, 10.26]</td>\n", + " <td>1</td>\n", + " <td>18.2</td>\n", + " <td>18.8935467873</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>[0.62739, 0.0, 8.14, 0.0, 0.538, 5.834, 56.5, 4.4986, 4.0, 307.0, 21.0, 395.62, 8.47]</td>\n", + " <td>1</td>\n", + " <td>19.9</td>\n", + " <td>19.8383202293</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>[1.05393, 0.0, 8.14, 0.0, 0.538, 5.935, 29.3, 4.4986, 4.0, 307.0, 21.0, 386.85, 6.58]</td>\n", + " <td>1</td>\n", + " <td>23.1</td>\n", + " <td>23.1604635402</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>[0.7842, 0.0, 8.14, 0.0, 0.538, 5.99, 81.7, 4.2579, 4.0, 307.0, 21.0, 386.75, 14.67]</td>\n", + " <td>1</td>\n", + " <td>17.5</td>\n", + " <td>16.8540384346</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>[0.80271, 0.0, 8.14, 0.0, 0.538, 5.456, 36.6, 3.7965, 4.0, 307.0, 21.0, 288.99, 11.69]</td>\n", + " <td>1</td>\n", + " <td>20.2</td>\n", + " <td>20.3628760581</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>[0.7258, 0.0, 8.14, 0.0, 0.538, 5.727, 69.5, 3.7965, 4.0, 307.0, 21.0, 390.95, 11.28]</td>\n", + " <td>1</td>\n", + " <td>18.2</td>\n", + " <td>18.1198369917</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, [0.00632, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98], 1, 24.0, 23.9976935779896),\n", + " (2, [0.02731, 0.0, 7.07, 0.0, 0.469, 6.421, 78.9, 4.9671, 2.0, 242.0, 17.8, 396.9, 9.14], 1, 21.6, 22.0225551503712),\n", + " (3, [0.02729, 0.0, 7.07, 0.0, 0.469, 7.185, 61.1, 4.9671, 2.0, 242.0, 17.8, 392.83, 4.03], 1, 34.7, 34.3269436787012),\n", + " (4, [0.03237, 0.0, 2.18, 0.0, 0.458, 6.998, 45.8, 6.0622, 3.0, 222.0, 18.7, 394.63, 2.94], 1, 33.4, 34.7421700032985),\n", + " (5, [0.06905, 0.0, 2.18, 0.0, 0.458, 7.147, 54.2, 6.0622, 3.0, 222.0, 18.7, 396.9, 5.33], 1, 36.2, 35.1914922401243),\n", + " (6, [0.02985, 0.0, 2.18, 0.0, 0.458, 6.43, 58.7, 6.0622, 3.0, 222.0, 18.7, 394.12, 5.21], 1, 28.7, 29.5286073543722),\n", + " (7, [0.08829, 12.5, 7.87, 0.0, 0.524, 6.012, 66.6, 5.5605, 5.0, 311.0, 15.2, 395.6, 12.43], 1, 22.9, 23.2022360304219),\n", + " (8, [0.14455, 12.5, 7.87, 0.0, 0.524, 6.172, 96.1, 5.9505, 5.0, 311.0, 15.2, 396.9, 19.15], 1, 27.1, 23.3649065290002),\n", + " (9, [0.21124, 12.5, 7.87, 0.0, 0.524, 5.631, 100.0, 6.0821, 5.0, 311.0, 15.2, 386.63, 29.93], 1, 16.5, 17.7779926866502),\n", + " (10, [0.17004, 12.5, 7.87, 0.0, 0.524, 6.004, 85.9, 6.5921, 5.0, 311.0, 15.2, 386.71, 17.1], 1, 18.9, 13.9266690257803),\n", + " (11, [0.22489, 12.5, 7.87, 0.0, 0.524, 6.377, 94.3, 6.3467, 5.0, 311.0, 15.2, 392.52, 20.45], 1, 15.0, 18.5049155838719),\n", + " (12, [0.11747, 12.5, 7.87, 0.0, 0.524, 6.009, 82.9, 6.2267, 5.0, 311.0, 15.2, 396.9, 13.27], 1, 18.9, 18.4287114359317),\n", + " (13, [0.09378, 12.5, 7.87, 0.0, 0.524, 5.889, 39.0, 5.4509, 5.0, 311.0, 15.2, 390.5, 15.71], 1, 21.7, 22.6228336114696),\n", + " (14, [0.62976, 0.0, 8.14, 0.0, 0.538, 5.949, 61.8, 4.7075, 4.0, 307.0, 21.0, 396.9, 8.26], 1, 20.4, 20.1083536059151),\n", + " (15, [0.63796, 0.0, 8.14, 0.0, 0.538, 6.096, 84.5, 4.4619, 4.0, 307.0, 21.0, 380.02, 10.26], 1, 18.2, 18.8935467873061),\n", + " (16, [0.62739, 0.0, 8.14, 0.0, 0.538, 5.834, 56.5, 4.4986, 4.0, 307.0, 21.0, 395.62, 8.47], 1, 19.9, 19.8383202293121),\n", + " (17, [1.05393, 0.0, 8.14, 0.0, 0.538, 5.935, 29.3, 4.4986, 4.0, 307.0, 21.0, 386.85, 6.58], 1, 23.1, 23.160463540176),\n", + " (18, [0.7842, 0.0, 8.14, 0.0, 0.538, 5.99, 81.7, 4.2579, 4.0, 307.0, 21.0, 386.75, 14.67], 1, 17.5, 16.8540384345856),\n", + " (19, [0.80271, 0.0, 8.14, 0.0, 0.538, 5.456, 36.6, 3.7965, 4.0, 307.0, 21.0, 288.99, 11.69], 1, 20.2, 20.3628760580577),\n", + " (20, [0.7258, 0.0, 8.14, 0.0, 0.538, 5.727, 69.5, 3.7965, 4.0, 307.0, 21.0, 390.95, 11.28], 1, 18.2, 18.1198369917265)]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS mlp_regress_prediction;\n", + "\n", + "SELECT madlib.mlp_predict(\n", + " 'mlp_regress', -- Model table\n", + " 'lin_housing', -- Test data table\n", + " 'id', -- Id column in test table\n", + " 'mlp_regress_prediction', -- Output table for predictions\n", + " 'response' -- Output values, not probabilities\n", + " );\n", + "\n", + "SELECT * FROM lin_housing JOIN mlp_regress_prediction USING (id) ORDER BY id;" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>x</th>\n", + " <th>grp_by_col</th>\n", + " <th>y</th>\n", + " <th>estimated_y</th>\n", + " <th>abs_diff</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>[0.00632, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98]</td>\n", + " <td>1</td>\n", + " <td>24.0</td>\n", + " <td>23.997693578</td>\n", + " <td>0.00230642201043</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>[0.02731, 0.0, 7.07, 0.0, 0.469, 6.421, 78.9, 4.9671, 2.0, 242.0, 17.8, 396.9, 9.14]</td>\n", + " <td>1</td>\n", + " <td>21.6</td>\n", + " <td>22.0225551504</td>\n", + " <td>0.422555150371</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>[0.02729, 0.0, 7.07, 0.0, 0.469, 7.185, 61.1, 4.9671, 2.0, 242.0, 17.8, 392.83, 4.03]</td>\n", + " <td>1</td>\n", + " <td>34.7</td>\n", + " <td>34.3269436787</td>\n", + " <td>0.373056321299</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>[0.03237, 0.0, 2.18, 0.0, 0.458, 6.998, 45.8, 6.0622, 3.0, 222.0, 18.7, 394.63, 2.94]</td>\n", + " <td>1</td>\n", + " <td>33.4</td>\n", + " <td>34.7421700033</td>\n", + " <td>1.3421700033</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>[0.06905, 0.0, 2.18, 0.0, 0.458, 7.147, 54.2, 6.0622, 3.0, 222.0, 18.7, 396.9, 5.33]</td>\n", + " <td>1</td>\n", + " <td>36.2</td>\n", + " <td>35.1914922401</td>\n", + " <td>1.00850775988</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>[0.02985, 0.0, 2.18, 0.0, 0.458, 6.43, 58.7, 6.0622, 3.0, 222.0, 18.7, 394.12, 5.21]</td>\n", + " <td>1</td>\n", + " <td>28.7</td>\n", + " <td>29.5286073544</td>\n", + " <td>0.828607354372</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>[0.08829, 12.5, 7.87, 0.0, 0.524, 6.012, 66.6, 5.5605, 5.0, 311.0, 15.2, 395.6, 12.43]</td>\n", + " <td>1</td>\n", + " <td>22.9</td>\n", + " <td>23.2022360304</td>\n", + " <td>0.302236030422</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>[0.144 <TRUNCATED>