http://git-wip-us.apache.org/repos/asf/madlib-site/blob/acd339f6/community-artifacts/Elastic-net-v3.ipynb
----------------------------------------------------------------------
diff --git a/community-artifacts/Elastic-net-v3.ipynb
b/community-artifacts/Elastic-net-v3.ipynb
new file mode 100644
index 0000000..7592fe6
--- /dev/null
+++ b/community-artifacts/Elastic-net-v3.ipynb
@@ -0,0 +1,2049 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Elastic net (MADlib v1.10+)\n",
+ "Demonstrates elastic net, including these updates:\n",
+ "- in MADlib 1.10: grouping and cross validation introduced \n",
+ "- in MADlib 1.13: report negative root mean squared error instead of the
negative mean squared error"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The sql extension is already loaded. To reload it, use:\n",
+ " %reload_ext sql\n"
+ ]
+ }
+ ],
+ "source": [
+ "%load_ext sql"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "u'Connected: gpadmin@madlib'"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Greenplum Database 5.4.0 on GCP (demo machine)\n",
+ "%sql postgresql://gpadmin@35.184.253.255:5432/madlib\n",
+ " \n",
+ "# PostgreSQL local\n",
+ "#%sql postgresql://fmcquillan@localhost:5432/madlib"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>version</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>MADlib version: 1.15-dev, git revision:
rc/1.14-rc1-23-gabafa66, cmake configuration time: Wed Jul 11 00:36:05 UTC
2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C
compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(u'MADlib version: 1.15-dev, git revision: rc/1.14-rc1-23-gabafa66,
cmake configuration time: Wed Jul 11 00:36:05 UTC 2018, build type: release,
build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++
compiler: g++ 4.4.7',)]"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%sql select madlib.version();\n",
+ "#%sql select version();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Create data set\n",
+ "House prices and characteristics."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Done.\n",
+ "Done.\n",
+ "27 rows affected.\n",
+ "27 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>id</th>\n",
+ " <th>tax</th>\n",
+ " <th>bedroom</th>\n",
+ " <th>bath</th>\n",
+ " <th>price</th>\n",
+ " <th>size</th>\n",
+ " <th>lot</th>\n",
+ " <th>zipcode</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1</td>\n",
+ " <td>590</td>\n",
+ " <td>2</td>\n",
+ " <td>1.0</td>\n",
+ " <td>50000</td>\n",
+ " <td>770</td>\n",
+ " <td>22100</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>2</td>\n",
+ " <td>1050</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>85000</td>\n",
+ " <td>1410</td>\n",
+ " <td>12000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>3</td>\n",
+ " <td>20</td>\n",
+ " <td>3</td>\n",
+ " <td>1.0</td>\n",
+ " <td>22500</td>\n",
+ " <td>1060</td>\n",
+ " <td>3500</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>4</td>\n",
+ " <td>870</td>\n",
+ " <td>2</td>\n",
+ " <td>2.0</td>\n",
+ " <td>90000</td>\n",
+ " <td>1300</td>\n",
+ " <td>17500</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>5</td>\n",
+ " <td>1320</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>133000</td>\n",
+ " <td>1500</td>\n",
+ " <td>30000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>6</td>\n",
+ " <td>1350</td>\n",
+ " <td>2</td>\n",
+ " <td>1.0</td>\n",
+ " <td>90500</td>\n",
+ " <td>820</td>\n",
+ " <td>25700</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>7</td>\n",
+ " <td>2790</td>\n",
+ " <td>3</td>\n",
+ " <td>2.5</td>\n",
+ " <td>260000</td>\n",
+ " <td>2130</td>\n",
+ " <td>25000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>8</td>\n",
+ " <td>680</td>\n",
+ " <td>2</td>\n",
+ " <td>1.0</td>\n",
+ " <td>142500</td>\n",
+ " <td>1170</td>\n",
+ " <td>22000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>9</td>\n",
+ " <td>1840</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>160000</td>\n",
+ " <td>1500</td>\n",
+ " <td>19000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>10</td>\n",
+ " <td>3680</td>\n",
+ " <td>4</td>\n",
+ " <td>2.0</td>\n",
+ " <td>240000</td>\n",
+ " <td>2790</td>\n",
+ " <td>20000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>11</td>\n",
+ " <td>1660</td>\n",
+ " <td>3</td>\n",
+ " <td>1.0</td>\n",
+ " <td>87000</td>\n",
+ " <td>1030</td>\n",
+ " <td>17500</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>12</td>\n",
+ " <td>1620</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>118600</td>\n",
+ " <td>1250</td>\n",
+ " <td>20000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>13</td>\n",
+ " <td>3100</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>140000</td>\n",
+ " <td>1760</td>\n",
+ " <td>38000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>14</td>\n",
+ " <td>2070</td>\n",
+ " <td>2</td>\n",
+ " <td>3.0</td>\n",
+ " <td>148000</td>\n",
+ " <td>1550</td>\n",
+ " <td>14000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>15</td>\n",
+ " <td>650</td>\n",
+ " <td>3</td>\n",
+ " <td>1.5</td>\n",
+ " <td>65000</td>\n",
+ " <td>1450</td>\n",
+ " <td>12000</td>\n",
+ " <td>94301</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>16</td>\n",
+ " <td>770</td>\n",
+ " <td>2</td>\n",
+ " <td>2.0</td>\n",
+ " <td>91000</td>\n",
+ " <td>1300</td>\n",
+ " <td>17500</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>17</td>\n",
+ " <td>1220</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>132300</td>\n",
+ " <td>1500</td>\n",
+ " <td>30000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>18</td>\n",
+ " <td>1150</td>\n",
+ " <td>2</td>\n",
+ " <td>1.0</td>\n",
+ " <td>91100</td>\n",
+ " <td>820</td>\n",
+ " <td>25700</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>19</td>\n",
+ " <td>2690</td>\n",
+ " <td>3</td>\n",
+ " <td>2.5</td>\n",
+ " <td>260011</td>\n",
+ " <td>2130</td>\n",
+ " <td>25000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>20</td>\n",
+ " <td>780</td>\n",
+ " <td>2</td>\n",
+ " <td>1.0</td>\n",
+ " <td>141800</td>\n",
+ " <td>1170</td>\n",
+ " <td>22000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>21</td>\n",
+ " <td>1910</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>160900</td>\n",
+ " <td>1500</td>\n",
+ " <td>19000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>22</td>\n",
+ " <td>3600</td>\n",
+ " <td>4</td>\n",
+ " <td>2.0</td>\n",
+ " <td>239000</td>\n",
+ " <td>2790</td>\n",
+ " <td>20000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>23</td>\n",
+ " <td>1600</td>\n",
+ " <td>3</td>\n",
+ " <td>1.0</td>\n",
+ " <td>81010</td>\n",
+ " <td>1030</td>\n",
+ " <td>17500</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>24</td>\n",
+ " <td>1590</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>117910</td>\n",
+ " <td>1250</td>\n",
+ " <td>20000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>25</td>\n",
+ " <td>3200</td>\n",
+ " <td>3</td>\n",
+ " <td>2.0</td>\n",
+ " <td>141100</td>\n",
+ " <td>1760</td>\n",
+ " <td>38000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>26</td>\n",
+ " <td>2270</td>\n",
+ " <td>2</td>\n",
+ " <td>3.0</td>\n",
+ " <td>148011</td>\n",
+ " <td>1550</td>\n",
+ " <td>14000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>27</td>\n",
+ " <td>750</td>\n",
+ " <td>3</td>\n",
+ " <td>1.5</td>\n",
+ " <td>66000</td>\n",
+ " <td>1450</td>\n",
+ " <td>12000</td>\n",
+ " <td>76010</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(1, 590, 2, 1.0, 50000, 770, 22100, 94301),\n",
+ " (2, 1050, 3, 2.0, 85000, 1410, 12000, 94301),\n",
+ " (3, 20, 3, 1.0, 22500, 1060, 3500, 94301),\n",
+ " (4, 870, 2, 2.0, 90000, 1300, 17500, 94301),\n",
+ " (5, 1320, 3, 2.0, 133000, 1500, 30000, 94301),\n",
+ " (6, 1350, 2, 1.0, 90500, 820, 25700, 94301),\n",
+ " (7, 2790, 3, 2.5, 260000, 2130, 25000, 94301),\n",
+ " (8, 680, 2, 1.0, 142500, 1170, 22000, 94301),\n",
+ " (9, 1840, 3, 2.0, 160000, 1500, 19000, 94301),\n",
+ " (10, 3680, 4, 2.0, 240000, 2790, 20000, 94301),\n",
+ " (11, 1660, 3, 1.0, 87000, 1030, 17500, 94301),\n",
+ " (12, 1620, 3, 2.0, 118600, 1250, 20000, 94301),\n",
+ " (13, 3100, 3, 2.0, 140000, 1760, 38000, 94301),\n",
+ " (14, 2070, 2, 3.0, 148000, 1550, 14000, 94301),\n",
+ " (15, 650, 3, 1.5, 65000, 1450, 12000, 94301),\n",
+ " (16, 770, 2, 2.0, 91000, 1300, 17500, 76010),\n",
+ " (17, 1220, 3, 2.0, 132300, 1500, 30000, 76010),\n",
+ " (18, 1150, 2, 1.0, 91100, 820, 25700, 76010),\n",
+ " (19, 2690, 3, 2.5, 260011, 2130, 25000, 76010),\n",
+ " (20, 780, 2, 1.0, 141800, 1170, 22000, 76010),\n",
+ " (21, 1910, 3, 2.0, 160900, 1500, 19000, 76010),\n",
+ " (22, 3600, 4, 2.0, 239000, 2790, 20000, 76010),\n",
+ " (23, 1600, 3, 1.0, 81010, 1030, 17500, 76010),\n",
+ " (24, 1590, 3, 2.0, 117910, 1250, 20000, 76010),\n",
+ " (25, 3200, 3, 2.0, 141100, 1760, 38000, 76010),\n",
+ " (26, 2270, 2, 3.0, 148011, 1550, 14000, 76010),\n",
+ " (27, 750, 3, 1.5, 66000, 1450, 12000, 76010)]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql \n",
+ "DROP TABLE IF EXISTS houses;\n",
+ "\n",
+ "CREATE TABLE houses ( id INT,\n",
+ " tax INT,\n",
+ " bedroom INT,\n",
+ " bath FLOAT,\n",
+ " price INT,\n",
+ " size INT,\n",
+ " lot INT,\n",
+ " zipcode INT);\n",
+ "\n",
+ "INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode)
VALUES\n",
+ "(1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301),\n",
+ "(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301),\n",
+ "(3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301),\n",
+ "(4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301),\n",
+ "(5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301),\n",
+ "(6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301),\n",
+ "(7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301),\n",
+ "(8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301),\n",
+ "(9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301),\n",
+ "(10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301),\n",
+ "(11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301),\n",
+ "(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301),\n",
+ "(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301),\n",
+ "(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301),\n",
+ "(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301),\n",
+ "(16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010),\n",
+ "(17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010),\n",
+ "(18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010),\n",
+ "(19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010),\n",
+ "(20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010),\n",
+ "(21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010),\n",
+ "(22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010),\n",
+ "(23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010),\n",
+ "(24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010),\n",
+ "(25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010),\n",
+ "(26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010),\n",
+ "(27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010);\n",
+ "\n",
+ "SELECT * FROM houses ORDER BY id;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Train the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Done.\n",
+ "1 rows affected.\n",
+ "1 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>family</th>\n",
+ " <th>features</th>\n",
+ " <th>features_selected</th>\n",
+ " <th>coef_nonzero</th>\n",
+ " <th>coef_all</th>\n",
+ " <th>intercept</th>\n",
+ " <th>log_likelihood</th>\n",
+ " <th>standardize</th>\n",
+ " <th>iteration_run</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>gaussian</td>\n",
+ " <td>[u'tax', u'bath', u'size']</td>\n",
+ " <td>[u'tax', u'bath', u'size']</td>\n",
+ " <td>[22.7851318679, 10707.9553682, 54.7961166559]</td>\n",
+ " <td>[22.7851318679, 10707.9553682, 54.7961166559]</td>\n",
+ " <td>-7798.78310728</td>\n",
+ " <td>-512248641.97</td>\n",
+ " <td>True</td>\n",
+ " <td>10000</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath', u'size'],
[22.7851318679, 10707.9553682, 54.7961166559], [22.7851318679, 10707.9553682,
54.7961166559], -7798.78310728, -512248641.97, True, 10000)]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "DROP TABLE IF EXISTS houses_en, houses_en_summary;\n",
+ "SELECT madlib.elastic_net_train( 'houses', -- Source
table\n",
+ " 'houses_en', -- Result
table\n",
+ " 'price', -- Dependent
variable\n",
+ " 'array[tax, bath, size]', --
Independent variable\n",
+ " 'gaussian', -- Regression
family\n",
+ " 0.5, -- Alpha
value\n",
+ " 0.1, -- Lambda
value\n",
+ " TRUE, --
Standardize\n",
+ " NULL, -- Grouping
column(s)\n",
+ " 'fista', --
Optimizer\n",
+ " '', -- Optimizer
parameters\n",
+ " NULL, -- Excluded
columns\n",
+ " 10000, -- Maximum
iterations\n",
+ " 1e-6 -- Tolerance
value\n",
+ " );\n",
+ "SELECT * FROM houses_en;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 3. Prediction\n",
+ "Evaluate residuals."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "27 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>id</th>\n",
+ " <th>price</th>\n",
+ " <th>predict</th>\n",
+ " <th>residual</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1</td>\n",
+ " <td>50000</td>\n",
+ " <td>58545.409888</td>\n",
+ " <td>-8545.40988802</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>2</td>\n",
+ " <td>85000</td>\n",
+ " <td>114804.040575</td>\n",
+ " <td>-29804.0405752</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>3</td>\n",
+ " <td>22500</td>\n",
+ " <td>61448.7585535</td>\n",
+ " <td>-38948.7585535</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>4</td>\n",
+ " <td>90000</td>\n",
+ " <td>104675.144007</td>\n",
+ " <td>-14675.1440069</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>5</td>\n",
+ " <td>133000</td>\n",
+ " <td>125887.676679</td>\n",
+ " <td>7112.3233214</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>6</td>\n",
+ " <td>90500</td>\n",
+ " <td>78601.9159404</td>\n",
+ " <td>11898.0840596</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>7</td>\n",
+ " <td>260000</td>\n",
+ " <td>199257.351702</td>\n",
+ " <td>60742.6482983</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>8</td>\n",
+ " <td>142500</td>\n",
+ " <td>82514.5184185</td>\n",
+ " <td>59985.4815815</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>9</td>\n",
+ " <td>160000</td>\n",
+ " <td>137735.94525</td>\n",
+ " <td>22264.0547501</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>10</td>\n",
+ " <td>240000</td>\n",
+ " <td>250347.578373</td>\n",
+ " <td>-10347.578373</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>11</td>\n",
+ " <td>87000</td>\n",
+ " <td>97172.4913172</td>\n",
+ " <td>-10172.4913172</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>12</td>\n",
+ " <td>118600</td>\n",
+ " <td>119024.187075</td>\n",
+ " <td>-424.187074993</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>13</td>\n",
+ " <td>140000</td>\n",
+ " <td>180692.201734</td>\n",
+ " <td>-40692.201734</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>14</td>\n",
+ " <td>148000</td>\n",
+ " <td>156424.286781</td>\n",
+ " <td>-8424.28678052</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>15</td>\n",
+ " <td>65000</td>\n",
+ " <td>102527.85481</td>\n",
+ " <td>-37527.8548102</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>16</td>\n",
+ " <td>91000</td>\n",
+ " <td>102396.63082</td>\n",
+ " <td>-11396.6308201</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>17</td>\n",
+ " <td>132300</td>\n",
+ " <td>123609.163492</td>\n",
+ " <td>8690.83650819</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>18</td>\n",
+ " <td>91100</td>\n",
+ " <td>74044.8895668</td>\n",
+ " <td>17055.1104332</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>19</td>\n",
+ " <td>260011</td>\n",
+ " <td>196978.838515</td>\n",
+ " <td>63032.1614851</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>20</td>\n",
+ " <td>141800</td>\n",
+ " <td>84793.0316053</td>\n",
+ " <td>57006.9683947</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>21</td>\n",
+ " <td>160900</td>\n",
+ " <td>139330.904481</td>\n",
+ " <td>21569.0955193</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>22</td>\n",
+ " <td>239000</td>\n",
+ " <td>248524.767824</td>\n",
+ " <td>-9524.76782352</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>23</td>\n",
+ " <td>81010</td>\n",
+ " <td>95805.3834051</td>\n",
+ " <td>-14795.3834051</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>24</td>\n",
+ " <td>117910</td>\n",
+ " <td>118340.633119</td>\n",
+ " <td>-430.633118956</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>25</td>\n",
+ " <td>141100</td>\n",
+ " <td>182970.714921</td>\n",
+ " <td>-41870.7149208</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>26</td>\n",
+ " <td>148011</td>\n",
+ " <td>160981.313154</td>\n",
+ " <td>-12970.3131541</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>27</td>\n",
+ " <td>66000</td>\n",
+ " <td>104806.367997</td>\n",
+ " <td>-38806.367997</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(1, 50000, 58545.409888024, -8545.409888024),\n",
+ " (2, 85000, 114804.040575234, -29804.040575234),\n",
+ " (3, 22500, 61448.758553532, -38948.758553532),\n",
+ " (4, 90000, 104675.144006863, -14675.144006863),\n",
+ " (5, 133000, 125887.676678598, 7112.323321402),\n",
+ " (6, 90500, 78601.915940423, 11898.084059577),\n",
+ " (7, 260000, 199257.351701728, 60742.648298272),\n",
+ " (8, 142500, 82514.518418495, 59985.481581505),\n",
+ " (9, 160000, 137735.945249906, 22264.054750094),\n",
+ " (10, 240000, 250347.578372953, -10347.578372953),\n",
+ " (11, 87000, 97172.491317211, -10172.491317211),\n",
+ " (12, 118600, 119024.187074993, -424.187074992995),\n",
+ " (13, 140000, 180692.201733994, -40692.201733994),\n",
+ " (14, 148000, 156424.286780518, -8424.28678051801),\n",
+ " (15, 65000, 102527.85481021, -37527.85481021),\n",
+ " (16, 91000, 102396.630820073, -11396.630820073),\n",
+ " (17, 132300, 123609.163491808, 8690.83650819201),\n",
+ " (18, 91100, 74044.889566843, 17055.110433157),\n",
+ " (19, 260011, 196978.838514938, 63032.161485062),\n",
+ " (20, 141800, 84793.031605285, 57006.968394715),\n",
+ " (21, 160900, 139330.904480659, 21569.095519341),\n",
+ " (22, 239000, 248524.767823521, -9524.76782352099),\n",
+ " (23, 81010, 95805.383405137, -14795.383405137),\n",
+ " (24, 117910, 118340.633118956, -430.633118956001),\n",
+ " (25, 141100, 182970.714920784, -41870.714920784),\n",
+ " (26, 148011, 160981.313154098, -12970.313154098),\n",
+ " (27, 66000, 104806.367997, -38806.367997)]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "SELECT id, price, predict, price - predict AS residual\n",
+ "FROM (\n",
+ " SELECT\n",
+ " houses.*,\n",
+ " madlib.elastic_net_gaussian_predict(\n",
+ " m.coef_all, -- Coefficients\n",
+ " m.intercept, -- Intercept\n",
+ " ARRAY[tax,bath,size] -- Features (corresponding to
coefficients)\n",
+ " ) AS predict\n",
+ " FROM houses, houses_en m) s\n",
+ "ORDER BY id;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 4. Grouping \n",
+ "Group on zip code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Done.\n",
+ "1 rows affected.\n",
+ "2 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>zipcode</th>\n",
+ " <th>family</th>\n",
+ " <th>features</th>\n",
+ " <th>features_selected</th>\n",
+ " <th>coef_nonzero</th>\n",
+ " <th>coef_all</th>\n",
+ " <th>intercept</th>\n",
+ " <th>log_likelihood</th>\n",
+ " <th>standardize</th>\n",
+ " <th>iteration_run</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>76010</td>\n",
+ " <td>gaussian</td>\n",
+ " <td>[u'tax', u'bath', u'size']</td>\n",
+ " <td>[u'tax', u'bath', u'size']</td>\n",
+ " <td>[14.9802020928, 9133.17041265, 62.8225614522]</td>\n",
+ " <td>[14.9802020928, 9133.17041265, 62.8225614522]</td>\n",
+ " <td>14.7294468096</td>\n",
+ " <td>-525667117.987</td>\n",
+ " <td>True</td>\n",
+ " <td>10000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>94301</td>\n",
+ " <td>gaussian</td>\n",
+ " <td>[u'tax', u'bath', u'size']</td>\n",
+ " <td>[u'tax', u'bath', u'size']</td>\n",
+ " <td>[27.6945649037, 11509.010807, 49.0945476263]</td>\n",
+ " <td>[27.6945649037, 11509.010807, 49.0945476263]</td>\n",
+ " <td>-11145.5017384</td>\n",
+ " <td>-520358795.785</td>\n",
+ " <td>True</td>\n",
+ " <td>10000</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(76010, u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath',
u'size'], [14.9802020928, 9133.17041265, 62.8225614522], [14.9802020928,
9133.17041265, 62.8225614522], 14.7294468096, -525667117.987, True, 10000),\n",
+ " (94301, u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath',
u'size'], [27.6945649037, 11509.010807, 49.0945476263], [27.6945649037,
11509.010807, 49.0945476263], -11145.5017384, -520358795.785, True, 10000)]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "DROP TABLE IF EXISTS houses_en1, houses_en1_summary;\n",
+ "SELECT madlib.elastic_net_train( 'houses', -- Source
table\n",
+ " 'houses_en1', -- Result
table\n",
+ " 'price', -- Dependent
variable\n",
+ " 'array[tax, bath, size]', --
Independent variable\n",
+ " 'gaussian', -- Regression
family\n",
+ " 0.5, -- Alpha
value\n",
+ " 0.1, -- Lambda
value\n",
+ " TRUE, --
Standardize\n",
+ " 'zipcode', -- Grouping
column(s)\n",
+ " 'fista', --
Optimizer\n",
+ " '', -- Optimizer
parameters\n",
+ " NULL, -- Excluded
columns\n",
+ " 10000, -- Maximum
iterations\n",
+ " 1e-6 -- Tolerance
value\n",
+ " );\n",
+ "SELECT * FROM houses_en1;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Prediction function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1 rows affected.\n",
+ "27 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>id</th>\n",
+ " <th>price</th>\n",
+ " <th>prediction</th>\n",
+ " <th>residual</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1</td>\n",
+ " <td>50000</td>\n",
+ " <td>54506.104034</td>\n",
+ " <td>-4506.10403403</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>2</td>\n",
+ " <td>85000</td>\n",
+ " <td>110175.125178</td>\n",
+ " <td>-25175.1251776</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>3</td>\n",
+ " <td>22500</td>\n",
+ " <td>52957.6208506</td>\n",
+ " <td>-30457.6208506</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>4</td>\n",
+ " <td>90000</td>\n",
+ " <td>99789.703256</td>\n",
+ " <td>-9789.70325601</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>5</td>\n",
+ " <td>133000</td>\n",
+ " <td>122071.166988</td>\n",
+ " <td>10928.8330121</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>6</td>\n",
+ " <td>90500</td>\n",
+ " <td>78008.7007422</td>\n",
+ " <td>12491.2992578</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>7</td>\n",
+ " <td>260000</td>\n",
+ " <td>199466.247804</td>\n",
+ " <td>60533.7521956</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>8</td>\n",
+ " <td>142500</td>\n",
+ " <td>76636.4339259</td>\n",
+ " <td>65863.5660741</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>9</td>\n",
+ " <td>160000</td>\n",
+ " <td>136472.340738</td>\n",
+ " <td>23527.6592621</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>10</td>\n",
+ " <td>240000</td>\n",
+ " <td>250762.306599</td>\n",
+ " <td>-10762.3065986</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>11</td>\n",
+ " <td>87000</td>\n",
+ " <td>96903.8708638</td>\n",
+ " <td>-9903.87086383</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>12</td>\n",
+ " <td>118600</td>\n",
+ " <td>118105.899552</td>\n",
+ " <td>494.100447531</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>13</td>\n",
+ " <td>140000</td>\n",
+ " <td>184132.074899</td>\n",
+ " <td>-44132.0748994</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>14</td>\n",
+ " <td>148000</td>\n",
+ " <td>156805.828854</td>\n",
+ " <td>-8805.82885402</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>15</td>\n",
+ " <td>65000</td>\n",
+ " <td>95306.5757176</td>\n",
+ " <td>-30306.5757176</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>16</td>\n",
+ " <td>91000</td>\n",
+ " <td>111485.155771</td>\n",
+ " <td>-20485.1557714</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>17</td>\n",
+ " <td>132300</td>\n",
+ " <td>130790.759004</td>\n",
+ " <td>1509.24099637</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>18</td>\n",
+ " <td>91100</td>\n",
+ " <td>77889.632657</td>\n",
+ " <td>13210.367343</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>19</td>\n",
+ " <td>260011</td>\n",
+ " <td>196956.455001</td>\n",
+ " <td>63054.5449987</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>20</td>\n",
+ " <td>141800</td>\n",
+ " <td>94334.8543909</td>\n",
+ " <td>47465.1456091</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>21</td>\n",
+ " <td>160900</td>\n",
+ " <td>141127.098448</td>\n",
+ " <td>19772.9015523</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>22</td>\n",
+ " <td>239000</td>\n",
+ " <td>247484.744258</td>\n",
+ " <td>-8484.74425783</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>23</td>\n",
+ " <td>81010</td>\n",
+ " <td>97823.4615037</td>\n",
+ " <td>-16813.4615037</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>24</td>\n",
+ " <td>117910</td>\n",
+ " <td>120627.793415</td>\n",
+ " <td>-2717.79341491</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>25</td>\n",
+ " <td>141100</td>\n",
+ " <td>176785.425125</td>\n",
+ " <td>-35685.4251249</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>26</td>\n",
+ " <td>148011</td>\n",
+ " <td>158794.269686</td>\n",
+ " <td>-10783.2696863</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>27</td>\n",
+ " <td>66000</td>\n",
+ " <td>116042.350741</td>\n",
+ " <td>-50042.3507411</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(1, 50000, 54506.104034034, -4506.104034034),\n",
+ " (2, 85000, 110175.125177568, -25175.125177568),\n",
+ " (3, 22500, 52957.620850552, -30457.620850552),\n",
+ " (4, 90000, 99789.703256009, -9789.703256009),\n",
+ " (5, 133000, 122071.166987934, 10928.833012066),\n",
+ " (6, 90500, 78008.700742161, 12491.299257839),\n",
+ " (7, 260000, 199466.247804442, 60533.752195558),\n",
+ " (8, 142500, 76636.433925887, 65863.566074113),\n",
+ " (9, 160000, 136472.340737858, 23527.659262142),\n",
+ " (10, 240000, 250762.306598593, -10762.306598593),\n",
+ " (11, 87000, 96903.870863831, -9903.87086383101),\n",
+ " (12, 118600, 118105.899552469, 494.100447531004),\n",
+ " (13, 140000, 184132.074899358, -44132.074899358),\n",
+ " (14, 148000, 156805.828854024, -8805.828854024),\n",
+ " (15, 65000, 95306.57571764, -30306.57571764),\n",
+ " (16, 91000, 111485.155771426, -20485.1557714256),\n",
+ " (17, 132300, 130790.759003626, 1509.2409963744),\n",
+ " (18, 91100, 77889.6326569836, 13210.3673430164),\n",
+ " (19, 260011, 196956.455001253, 63054.5449987474),\n",
+ " (20, 141800, 94334.8543909176, 47465.1456090824),\n",
+ " (21, 160900, 141127.098447658, 19772.9015523424),\n",
+ " (22, 239000, 247484.744257828, -8484.74425782761),\n",
+ " (23, 81010, 97823.4615037056, -16813.4615037056),\n",
+ " (24, 117910, 120627.793414912, -2717.7934149116),\n",
+ " (25, 141100, 176785.425124942, -35685.4251249416),\n",
+ " (26, 148011, 158794.269686326, -10783.2696863256),\n",
+ " (27, 66000, 116042.350741075, -50042.3507410746)]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "SELECT madlib.elastic_net_predict(\n",
+ " 'houses_en1', -- Model table\n",
+ " 'houses', -- New source data table\n",
+ " 'id', -- Unique ID associated with
each row\n",
+ " 'houses_en1_prediction' -- Table to store prediction
result\n",
+ " );\n",
+ "\n",
+ "SELECT houses.id,\n",
+ " houses.price,\n",
+ " houses_en1_prediction.prediction,\n",
+ " houses.price - houses_en1_prediction.prediction AS residual\n",
+ "FROM houses_en1_prediction, houses\n",
+ "WHERE houses.id = houses_en1_prediction.id ORDER BY id;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. When coef_nonzero is different from coef_all\n",
+ "Train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Done.\n",
+ "1 rows affected.\n",
+ "1 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>family</th>\n",
+ " <th>features</th>\n",
+ " <th>features_selected</th>\n",
+ " <th>coef_nonzero</th>\n",
+ " <th>coef_all</th>\n",
+ " <th>intercept</th>\n",
+ " <th>log_likelihood</th>\n",
+ " <th>standardize</th>\n",
+ " <th>iteration_run</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>gaussian</td>\n",
+ " <td>[u'tax', u'bath', u'size']</td>\n",
+ " <td>[u'tax', u'size']</td>\n",
+ " <td>[6.94383308191, 29.7206857861]</td>\n",
+ " <td>[6.94383308191, 0.0, 29.7206857861]</td>\n",
+ " <td>74441.4573381</td>\n",
+ " <td>-1635348584.1</td>\n",
+ " <td>True</td>\n",
+ " <td>173</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'size'],
[6.94383308191, 29.7206857861], [6.94383308191, 0.0, 29.7206857861],
74441.4573381, -1635348584.1, True, 173)]"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "DROP TABLE IF EXISTS houses_en2, houses_en2_summary;\n",
+ "SELECT madlib.elastic_net_train( 'houses', -- Source
table\n",
+ " 'houses_en2', -- Result
table\n",
+ " 'price', -- Dependent
variable\n",
+ " 'array[tax, bath, size]', --
Independent variable\n",
+ " 'gaussian', -- Regression
family\n",
+ " 1, -- Alpha
value\n",
+ " 30000, -- Lambda
value\n",
+ " TRUE, --
Standardize\n",
+ " NULL, -- Grouping
column(s)\n",
+ " 'fista', --
Optimizer\n",
+ " '', -- Optimizer
parameters\n",
+ " NULL, -- Excluded
columns\n",
+ " 10000, -- Maximum
iterations\n",
+ " 1e-6 -- Tolerance
value\n",
+ " );\n",
+ "SELECT * FROM houses_en2;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Prediction function with coef_all to evaluate residuals."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "27 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>id</th>\n",
+ " <th>price</th>\n",
+ " <th>predict</th>\n",
+ " <th>residual</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1</td>\n",
+ " <td>50000</td>\n",
+ " <td>101423.246912</td>\n",
+ " <td>-51423.2469117</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>2</td>\n",
+ " <td>85000</td>\n",
+ " <td>123638.649033</td>\n",
+ " <td>-38638.6490325</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>3</td>\n",
+ " <td>22500</td>\n",
+ " <td>106084.260933</td>\n",
+ " <td>-83584.260933</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>4</td>\n",
+ " <td>90000</td>\n",
+ " <td>119119.483641</td>\n",
+ " <td>-29119.4836413</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>5</td>\n",
+ " <td>133000</td>\n",
+ " <td>128188.345685</td>\n",
+ " <td>4811.65431463</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>6</td>\n",
+ " <td>90500</td>\n",
+ " <td>108186.594343</td>\n",
+ " <td>-17686.5943433</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>7</td>\n",
+ " <td>260000</td>\n",
+ " <td>157119.812361</td>\n",
+ " <td>102880.187639</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>8</td>\n",
+ " <td>142500</td>\n",
+ " <td>113936.466204</td>\n",
+ " <td>28563.5337965</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>9</td>\n",
+ " <td>160000</td>\n",
+ " <td>131799.138888</td>\n",
+ " <td>28200.861112</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>10</td>\n",
+ " <td>240000</td>\n",
+ " <td>182915.476423</td>\n",
+ " <td>57084.5235773</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>11</td>\n",
+ " <td>87000</td>\n",
+ " <td>116580.526614</td>\n",
+ " <td>-29580.5266138</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>12</td>\n",
+ " <td>118600</td>\n",
+ " <td>122841.324163</td>\n",
+ " <td>-4241.32416342</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>13</td>\n",
+ " <td>140000</td>\n",
+ " <td>148275.746876</td>\n",
+ " <td>-8275.74687556</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>14</td>\n",
+ " <td>148000</td>\n",
+ " <td>134882.254786</td>\n",
+ " <td>13117.7452139</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>15</td>\n",
+ " <td>65000</td>\n",
+ " <td>122049.943231</td>\n",
+ " <td>-57049.9432312</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>16</td>\n",
+ " <td>91000</td>\n",
+ " <td>118425.100333</td>\n",
+ " <td>-27425.1003331</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>17</td>\n",
+ " <td>132300</td>\n",
+ " <td>127493.962377</td>\n",
+ " <td>4806.03762282</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>18</td>\n",
+ " <td>91100</td>\n",
+ " <td>106797.827727</td>\n",
+ " <td>-15697.8277269</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>19</td>\n",
+ " <td>260011</td>\n",
+ " <td>156425.429053</td>\n",
+ " <td>103585.570947</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>20</td>\n",
+ " <td>141800</td>\n",
+ " <td>114630.849512</td>\n",
+ " <td>27169.1504883</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>21</td>\n",
+ " <td>160900</td>\n",
+ " <td>132285.207204</td>\n",
+ " <td>28614.7927963</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>22</td>\n",
+ " <td>239000</td>\n",
+ " <td>182359.969776</td>\n",
+ " <td>56640.0302238</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>23</td>\n",
+ " <td>81010</td>\n",
+ " <td>116163.896629</td>\n",
+ " <td>-35153.8966288</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>24</td>\n",
+ " <td>117910</td>\n",
+ " <td>122633.009171</td>\n",
+ " <td>-4723.00917096</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>25</td>\n",
+ " <td>141100</td>\n",
+ " <td>148970.130184</td>\n",
+ " <td>-7870.13018375</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>26</td>\n",
+ " <td>148011</td>\n",
+ " <td>136271.021402</td>\n",
+ " <td>11739.9785975</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>27</td>\n",
+ " <td>66000</td>\n",
+ " <td>122744.326539</td>\n",
+ " <td>-56744.3265394</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(1, 50000, 101423.246911724, -51423.2469117239),\n",
+ " (2, 85000, 123638.649032506, -38638.6490325065),\n",
+ " (3, 22500, 106084.260933004, -83584.2609330042),\n",
+ " (4, 90000, 119119.483641292, -29119.4836412917),\n",
+ " (5, 133000, 128188.345685371, 4811.6543146288),\n",
+ " (6, 90500, 108186.59434328, -17686.5943432805),\n",
+ " (7, 260000, 157119.812361022, 102880.187638978),\n",
+ " (8, 142500, 113936.466203536, 28563.5337964642),\n",
+ " (9, 160000, 131799.138887964, 28200.8611120356),\n",
+ " (10, 240000, 182915.476422748, 57084.5235772522),\n",
+ " (11, 87000, 116580.526613754, -29580.5266137536),\n",
+ " (12, 118600, 122841.324163419, -4241.32416341919),\n",
+ " (13, 140000, 148275.746875557, -8275.746875557),\n",
+ " (14, 148000, 134882.254786109, 13117.7452138913),\n",
+ " (15, 65000, 122049.943231186, -57049.9432311865),\n",
+ " (16, 91000, 118425.100333101, -27425.1003331007),\n",
+ " (17, 132300, 127493.96237718, 4806.03762281981),\n",
+ " (18, 91100, 106797.827726898, -15697.8277268985),\n",
+ " (19, 260011, 156425.429052831, 103585.570947169),\n",
+ " (20, 141800, 114630.849511727, 27169.1504882732),\n",
+ " (21, 160900, 132285.207203698, 28614.7927963019),\n",
+ " (22, 239000, 182359.969776195, 56640.030223805),\n",
+ " (23, 81010, 116163.896628839, -35153.896628839),\n",
+ " (24, 117910, 122633.009170962, -4723.00917096189),\n",
+ " (25, 141100, 148970.130183748, -7870.130183748),\n",
+ " (26, 148011, 136271.021402491, 11739.9785975093),\n",
+ " (27, 66000, 122744.326539377, -56744.3265393775)]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "SELECT id, price, predict, price - predict AS residual\n",
+ "FROM (\n",
+ " SELECT\n",
+ " houses.*,\n",
+ " madlib.elastic_net_gaussian_predict(\n",
+ " m.coef_all, -- All coefficients\n",
+ " m.intercept, -- Intercept\n",
+ " ARRAY[tax,bath,size] -- All features\n",
+ " ) AS predict\n",
+ " FROM houses, houses_en2 m) s\n",
+ "ORDER BY id;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can speed up the prediction function with coef_nonzero to evaluate
residuals. This requires the user to examine the feature_selected column in the
result table to construct the correct set of independent variables to provide
to the prediction function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "27 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>id</th>\n",
+ " <th>price</th>\n",
+ " <th>predict</th>\n",
+ " <th>residual</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1</td>\n",
+ " <td>50000</td>\n",
+ " <td>101423.246912</td>\n",
+ " <td>-51423.2469117</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>2</td>\n",
+ " <td>85000</td>\n",
+ " <td>123638.649033</td>\n",
+ " <td>-38638.6490325</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>3</td>\n",
+ " <td>22500</td>\n",
+ " <td>106084.260933</td>\n",
+ " <td>-83584.260933</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>4</td>\n",
+ " <td>90000</td>\n",
+ " <td>119119.483641</td>\n",
+ " <td>-29119.4836413</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>5</td>\n",
+ " <td>133000</td>\n",
+ " <td>128188.345685</td>\n",
+ " <td>4811.65431463</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>6</td>\n",
+ " <td>90500</td>\n",
+ " <td>108186.594343</td>\n",
+ " <td>-17686.5943433</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>7</td>\n",
+ " <td>260000</td>\n",
+ " <td>157119.812361</td>\n",
+ " <td>102880.187639</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>8</td>\n",
+ " <td>142500</td>\n",
+ " <td>113936.466204</td>\n",
+ " <td>28563.5337965</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>9</td>\n",
+ " <td>160000</td>\n",
+ " <td>131799.138888</td>\n",
+ " <td>28200.861112</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>10</td>\n",
+ " <td>240000</td>\n",
+ " <td>182915.476423</td>\n",
+ " <td>57084.5235773</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>11</td>\n",
+ " <td>87000</td>\n",
+ " <td>116580.526614</td>\n",
+ " <td>-29580.5266138</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>12</td>\n",
+ " <td>118600</td>\n",
+ " <td>122841.324163</td>\n",
+ " <td>-4241.32416342</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>13</td>\n",
+ " <td>140000</td>\n",
+ " <td>148275.746876</td>\n",
+ " <td>-8275.74687556</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>14</td>\n",
+ " <td>148000</td>\n",
+ " <td>134882.254786</td>\n",
+ " <td>13117.7452139</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>15</td>\n",
+ " <td>65000</td>\n",
+ " <td>122049.943231</td>\n",
+ " <td>-57049.9432312</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>16</td>\n",
+ " <td>91000</td>\n",
+ " <td>118425.100333</td>\n",
+ " <td>-27425.1003331</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>17</td>\n",
+ " <td>132300</td>\n",
+ " <td>127493.962377</td>\n",
+ " <td>4806.03762282</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>18</td>\n",
+ " <td>91100</td>\n",
+ " <td>106797.827727</td>\n",
+ " <td>-15697.8277269</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>19</td>\n",
+ " <td>260011</td>\n",
+ " <td>156425.429053</td>\n",
+ " <td>103585.570947</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>20</td>\n",
+ " <td>141800</td>\n",
+ " <td>114630.849512</td>\n",
+ " <td>27169.1504883</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>21</td>\n",
+ " <td>160900</td>\n",
+ " <td>132285.207204</td>\n",
+ " <td>28614.7927963</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>22</td>\n",
+ " <td>239000</td>\n",
+ " <td>182359.969776</td>\n",
+ " <td>56640.0302238</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>23</td>\n",
+ " <td>81010</td>\n",
+ " <td>116163.896629</td>\n",
+ " <td>-35153.8966288</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>24</td>\n",
+ " <td>117910</td>\n",
+ " <td>122633.009171</td>\n",
+ " <td>-4723.00917096</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>25</td>\n",
+ " <td>141100</td>\n",
+ " <td>148970.130184</td>\n",
+ " <td>-7870.13018375</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>26</td>\n",
+ " <td>148011</td>\n",
+ " <td>136271.021402</td>\n",
+ " <td>11739.9785975</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>27</td>\n",
+ " <td>66000</td>\n",
+ " <td>122744.326539</td>\n",
+ " <td>-56744.3265394</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(1, 50000, 101423.246911724, -51423.2469117239),\n",
+ " (2, 85000, 123638.649032506, -38638.6490325065),\n",
+ " (3, 22500, 106084.260933004, -83584.2609330042),\n",
+ " (4, 90000, 119119.483641292, -29119.4836412917),\n",
+ " (5, 133000, 128188.345685371, 4811.6543146288),\n",
+ " (6, 90500, 108186.59434328, -17686.5943432805),\n",
+ " (7, 260000, 157119.812361022, 102880.187638978),\n",
+ " (8, 142500, 113936.466203536, 28563.5337964642),\n",
+ " (9, 160000, 131799.138887964, 28200.8611120356),\n",
+ " (10, 240000, 182915.476422748, 57084.5235772522),\n",
+ " (11, 87000, 116580.526613754, -29580.5266137536),\n",
+ " (12, 118600, 122841.324163419, -4241.32416341919),\n",
+ " (13, 140000, 148275.746875557, -8275.746875557),\n",
+ " (14, 148000, 134882.254786109, 13117.7452138913),\n",
+ " (15, 65000, 122049.943231186, -57049.9432311865),\n",
+ " (16, 91000, 118425.100333101, -27425.1003331007),\n",
+ " (17, 132300, 127493.96237718, 4806.03762281981),\n",
+ " (18, 91100, 106797.827726898, -15697.8277268985),\n",
+ " (19, 260011, 156425.429052831, 103585.570947169),\n",
+ " (20, 141800, 114630.849511727, 27169.1504882732),\n",
+ " (21, 160900, 132285.207203698, 28614.7927963019),\n",
+ " (22, 239000, 182359.969776195, 56640.030223805),\n",
+ " (23, 81010, 116163.896628839, -35153.896628839),\n",
+ " (24, 117910, 122633.009170962, -4723.00917096189),\n",
+ " (25, 141100, 148970.130183748, -7870.130183748),\n",
+ " (26, 148011, 136271.021402491, 11739.9785975093),\n",
+ " (27, 66000, 122744.326539377, -56744.3265393775)]"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "SELECT id, price, predict, price - predict AS residual\n",
+ "FROM (\n",
+ " SELECT\n",
+ " houses.*,\n",
+ " madlib.elastic_net_gaussian_predict(\n",
+ " m.coef_nonzero, -- Non-zero coefficients\n",
+ " m.intercept, -- Intercept\n",
+ " ARRAY[tax,size] -- Features corresponding to
non-zero coefficients\n",
+ " ) AS predict\n",
+ " FROM houses, houses_en2 m) s\n",
+ "ORDER BY id;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 6. Cross validation\n",
+ "Reuse the houses table above. Here we use 3-fold cross validation with 3
automatically generated lambda values and 3 specified alpha values. (This can
take some time to run since elastic net is effectively being called 27 times
for these combinations, then a 28th time for the whole dataset.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%sql\n",
+ "DROP TABLE IF EXISTS houses_en3, houses_en3_summary, houses_en3_cv;\n",
+ "SELECT madlib.elastic_net_train( 'houses', -- Source
table\n",
+ " 'houses_en3', -- Result
table\n",
+ " 'price', -- Dependent
variable\n",
+ " 'array[tax, bath, size]', --
Independent variable\n",
+ " 'gaussian', -- Regression
family\n",
+ " 0.5, -- Alpha
value\n",
+ " 0.1, -- Lambda
value\n",
+ " TRUE, --
Standardize\n",
+ " NULL, -- Grouping
column(s)\n",
+ " 'fista', --
Optimizer\n",
+ " $$ n_folds = 3, -- Optimizer
parameters\n",
+ " validation_result=houses_en3_cv,\n",
+ " n_lambdas = 3, \n",
+ " alpha = {0, 0.1, 1}\n",
+ " $$, \n",
+ " NULL, -- Excluded
columns\n",
+ " 10000, -- Maximum
iterations\n",
+ " 1e-6 -- Tolerance
value\n",
+ " );\n",
+ "SELECT * FROM houses_en3;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Details of the cross validation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>alpha</th>\n",
+ " <th>lambda_value</th>\n",
+ " <th>mean_neg_loss</th>\n",
+ " <th>std_neg_loss</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>0.0</td>\n",
+ " <td>0.1</td>\n",
+ " <td>-36094.4685768</td>\n",
+ " <td>10524.4473253</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>0.1</td>\n",
+ " <td>0.1</td>\n",
+ " <td>-36136.2448004</td>\n",
+ " <td>10682.4136993</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1.0</td>\n",
+ " <td>100.0</td>\n",
+ " <td>-37007.9496501</td>\n",
+ " <td>12679.3781975</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1.0</td>\n",
+ " <td>0.1</td>\n",
+ " <td>-37018.1019927</td>\n",
+ " <td>12716.7438015</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>0.1</td>\n",
+ " <td>100.0</td>\n",
+ " <td>-59275.6940173</td>\n",
+ " <td>9764.50064237</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>0.0</td>\n",
+ " <td>100.0</td>\n",
+ " <td>-59380.252681</td>\n",
+ " <td>9763.26373034</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>1.0</td>\n",
+ " <td>100000.0</td>\n",
+ " <td>-60353.0220769</td>\n",
+ " <td>9748.10305107</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>0.1</td>\n",
+ " <td>100000.0</td>\n",
+ "
<td>-143513752113000000000000000000000000000000000000000000</td>\n",
+ "
<td>157073834312000000000000000000000000000000000000000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>0.0</td>\n",
+ " <td>100000.0</td>\n",
+ "
<td>-11248884473800000000000000000000000000000000000000000000</td>\n",
+ "
<td>9490568229990000000000000000000000000000000000000000000</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(Decimal('0.0'), Decimal('0.1'), Decimal('-36094.4685768'),
Decimal('10524.4473253')),\n",
+ " (Decimal('0.1'), Decimal('0.1'), Decimal('-36136.2448004'),
Decimal('10682.4136993')),\n",
+ " (Decimal('1.0'), Decimal('100.0'), Decimal('-37007.9496501'),
Decimal('12679.3781975')),\n",
+ " (Decimal('1.0'), Decimal('0.1'), Decimal('-37018.1019927'),
Decimal('12716.7438015')),\n",
+ " (Decimal('0.1'), Decimal('100.0'), Decimal('-59275.6940173'),
Decimal('9764.50064237')),\n",
+ " (Decimal('0.0'), Decimal('100.0'), Decimal('-59380.252681'),
Decimal('9763.26373034')),\n",
+ " (Decimal('1.0'), Decimal('100000.0'), Decimal('-60353.0220769'),
Decimal('9748.10305107')),\n",
+ " (Decimal('0.1'), Decimal('100000.0'),
Decimal('-143513752113000000000000000000000000000000000000000000'),
Decimal('157073834312000000000000000000000000000000000000000000')),\n",
+ " (Decimal('0.0'), Decimal('100000.0'),
Decimal('-11248884473800000000000000000000000000000000000000000000'),
Decimal('9490568229990000000000000000000000000000000000000000000'))]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "SELECT * FROM houses_en3_cv ORDER BY mean_neg_loss DESC;"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>method</th>\n",
+ " <th>source_table</th>\n",
+ " <th>out_table</th>\n",
+ " <th>dependent_varname</th>\n",
+ " <th>independent_varname</th>\n",
+ " <th>family</th>\n",
+ " <th>alpha</th>\n",
+ " <th>lambda_value</th>\n",
+ " <th>grouping_col</th>\n",
+ " <th>num_all_groups</th>\n",
+ " <th>num_failed_groups</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>elastic_net</td>\n",
+ " <td>houses</td>\n",
+ " <td>houses_en3</td>\n",
+ " <td>price</td>\n",
+ " <td>array[tax, bath, size]</td>\n",
+ " <td>gaussian</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.1</td>\n",
+ " <td>NULL</td>\n",
+ " <td>1</td>\n",
+ " <td>0</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(u'elastic_net', u'houses', u'houses_en3', u'price', u'array[tax,
bath, size]', u'gaussian', 0.0, 0.1, u'NULL', 1, 0)]"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "SELECT * FROM houses_en3_summary;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": true
+ },
+ "source": [
+ "# 6a. Cross validation\n",
+ "Here we use 3-fold cross validation with 3 automatically generated lambda
values and 1 alpha value (i.e., 9 times)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%sql\n",
+ "DROP TABLE IF EXISTS houses_en3, houses_en3_summary, houses_en3_cv;\n",
+ "SELECT madlib.elastic_net_train( 'houses', -- Source
table\n",
+ " 'houses_en3', -- Result
table\n",
+ " 'price', -- Dependent
variable\n",
+ " 'array[tax, bath, size]', --
Independent variable\n",
+ " 'gaussian', -- Regression
family\n",
+ " 0.5, -- Alpha
value\n",
+ " 0.1, -- Lambda
value\n",
+ " TRUE, --
Standardize\n",
+ " NULL, -- Grouping
column(s)\n",
+ " 'fista', --
Optimizer\n",
+ " $$ n_folds = 3, -- Optimizer
parameters\n",
+ " validation_result=houses_en3_cv,\n",
+ " n_lambdas = 3\n",
+ " $$, \n",
+ " NULL, -- Excluded
columns\n",
+ " 10000, -- Maximum
iterations\n",
+ " 1e-6 -- Tolerance
value\n",
+ " );\n",
+ "SELECT * FROM houses_en3;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Details of the cross validation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3 rows affected.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<table>\n",
+ " <tr>\n",
+ " <th>lambda_value</th>\n",
+ " <th>mean_neg_loss</th>\n",
+ " <th>std_neg_loss</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>100000.0</td>\n",
+ " <td>-255543791799000000000000000000000000000000000</td>\n",
+ " <td>442158712729000000000000000000000000000000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>100.0</td>\n",
+ " <td>-59332.2198813</td>\n",
+ " <td>8220.8755071</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <td>0.1</td>\n",
+ " <td>-51938.9613421</td>\n",
+ " <td>28946.523247</td>\n",
+ " </tr>\n",
+ "</table>"
+ ],
+ "text/plain": [
+ "[(Decimal('100000.0'),
Decimal('-255543791799000000000000000000000000000000000'),
Decimal('442158712729000000000000000000000000000000000')),\n",
+ " (Decimal('100.0'), Decimal('-59332.2198813'),
Decimal('8220.8755071')),\n",
+ " (Decimal('0.1'), Decimal('-51938.9613421'), Decimal('28946.523247'))]"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%sql\n",
+ "SELECT * FROM houses_en3_cv ORDER BY mean_neg_loss DESC;"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "%%sql\n",
+ "SELECT * FROM houses_en3_summary;"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "language": "python",
+ "name": "python2"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}