http://git-wip-us.apache.org/repos/asf/madlib-site/blob/acd339f6/community-artifacts/MLP-v4.ipynb ---------------------------------------------------------------------- diff --git a/community-artifacts/MLP-v4.ipynb b/community-artifacts/MLP-v4.ipynb new file mode 100644 index 0000000..a6b62d6 --- /dev/null +++ b/community-artifacts/MLP-v4.ipynb @@ -0,0 +1,4588 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multilayer Perceptron\n", + "\n", + "Multilayer Perceptron (MLP) is a type of neural network that can be used for regression and classification.\n", + "\n", + "This version of the workbook includes mini-batching added in 1.14 and momentum added in 1.15" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated. You should import from traitlets.config instead.\n", + " \"You should import from traitlets.config instead.\", ShimWarning)\n", + "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n", + " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n" + ] + } + ], + "source": [ + "%load_ext sql" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "u'Connected: gpadmin@madlib'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Greenplum Database 5.4.0 on GCP (demo machine)\n", + "%sql postgresql://gpadmin@35.184.253.255:5432/madlib\n", + " \n", + "# PostgreSQL local\n", + "#%sql postgresql://fmcquillan@localhost:5432/madlib\n", + "\n", + "# Greenplum Database 4.3.10.0\n", + "#%sql postgresql://gpdbchina@10.194.10.68:61000/madlib" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>version</th>\n", + " </tr>\n", + " <tr>\n", + " <td>MADlib version: 1.15-dev, git revision: rc/1.14-rc1-23-g5c4331d, cmake configuration time: Thu Jul 5 17:46:06 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(u'MADlib version: 1.15-dev, git revision: rc/1.14-rc1-23-g5c4331d, cmake configuration time: Thu Jul 5 17:46:06 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7',)]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql select madlib.version();\n", + "#%sql select version();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Classification without Mini-Batching\n", + "\n", + "# 1. Create input table for classification" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Done.\n", + "52 rows affected.\n", + "52 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>attributes</th>\n", + " <th>class_text</th>\n", + " <th>class</th>\n", + " <th>state</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>[Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>[Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>[Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>[Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>[Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>[Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>[Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>[Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>[Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>21</td>\n", + " <td>[Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>22</td>\n", + " <td>[Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>23</td>\n", + " <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>24</td>\n", + " <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>25</td>\n", + " <td>[Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>26</td>\n", + " <td>[Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>27</td>\n", + " <td>[Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>28</td>\n", + " <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>29</td>\n", + " <td>[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>30</td>\n", + " <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>31</td>\n", + " <td>[Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>32</td>\n", + " <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>33</td>\n", + " <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>34</td>\n", + " <td>[Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>35</td>\n", + " <td>[Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>36</td>\n", + " <td>[Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>37</td>\n", + " <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>38</td>\n", + " <td>[Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>39</td>\n", + " <td>[Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>40</td>\n", + " <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>41</td>\n", + " <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>42</td>\n", + " <td>[Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>43</td>\n", + " <td>[Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>44</td>\n", + " <td>[Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>45</td>\n", + " <td>[Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>46</td>\n", + " <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>47</td>\n", + " <td>[Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>48</td>\n", + " <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>49</td>\n", + " <td>[Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>50</td>\n", + " <td>[Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>51</td>\n", + " <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>52</td>\n", + " <td>[Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, [Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (2, [Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (3, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris_setosa', 1, u'Alaska'),\n", + " (4, [Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (5, [Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (6, [Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')], u'Iris_setosa', 1, u'Alaska'),\n", + " (7, [Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')], u'Iris_setosa', 1, u'Alaska'),\n", + " (8, [Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (9, [Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')], u'Iris_setosa', 1, u'Alaska'),\n", + " (10, [Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')], u'Iris_setosa', 1, u'Alaska'),\n", + " (11, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')], u'Iris_setosa', 1, u'Alaska'),\n", + " (12, [Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (13, [Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (14, [Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (15, [Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (16, [Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (17, [Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (18, [Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (19, [Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (20, [Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (21, [Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (22, [Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (23, [Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (24, [Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (25, [Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (26, [Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (27, [Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (28, [Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (29, [Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (30, [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (31, [Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (32, [Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (33, [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (34, [Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (35, [Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (36, [Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (37, [Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (38, [Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (39, [Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (40, [Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (41, [Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (42, [Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (43, [Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (44, [Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (45, [Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (46, [Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (47, [Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (48, [Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (49, [Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (50, [Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (51, [Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (52, [Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')], u'Iris_versicolor', 2, u'Tennessee')]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql \n", + "DROP TABLE IF EXISTS iris_data;\n", + "\n", + "CREATE TABLE iris_data(\n", + " id serial,\n", + " attributes numeric[],\n", + " class_text varchar,\n", + " class integer,\n", + " state varchar\n", + ");\n", + "\n", + "INSERT INTO iris_data(id, attributes, class_text, class, state) VALUES\n", + "(1,ARRAY[5.0,3.2,1.2,0.2],'Iris_setosa',1,'Alaska'),\n", + "(2,ARRAY[5.5,3.5,1.3,0.2],'Iris_setosa',1,'Alaska'),\n", + "(3,ARRAY[4.9,3.1,1.5,0.1],'Iris_setosa',1,'Alaska'),\n", + "(4,ARRAY[4.4,3.0,1.3,0.2],'Iris_setosa',1,'Alaska'),\n", + "(5,ARRAY[5.1,3.4,1.5,0.2],'Iris_setosa',1,'Alaska'),\n", + "(6,ARRAY[5.0,3.5,1.3,0.3],'Iris_setosa',1,'Alaska'),\n", + "(7,ARRAY[4.5,2.3,1.3,0.3],'Iris_setosa',1,'Alaska'),\n", + "(8,ARRAY[4.4,3.2,1.3,0.2],'Iris_setosa',1,'Alaska'),\n", + "(9,ARRAY[5.0,3.5,1.6,0.6],'Iris_setosa',1,'Alaska'),\n", + "(10,ARRAY[5.1,3.8,1.9,0.4],'Iris_setosa',1,'Alaska'),\n", + "(11,ARRAY[4.8,3.0,1.4,0.3],'Iris_setosa',1,'Alaska'),\n", + "(12,ARRAY[5.1,3.8,1.6,0.2],'Iris_setosa',1,'Alaska'),\n", + "(13,ARRAY[5.7,2.8,4.5,1.3],'Iris_versicolor',2,'Alaska'),\n", + "(14,ARRAY[6.3,3.3,4.7,1.6],'Iris_versicolor',2,'Alaska'),\n", + "(15,ARRAY[4.9,2.4,3.3,1.0],'Iris_versicolor',2,'Alaska'),\n", + "(16,ARRAY[6.6,2.9,4.6,1.3],'Iris_versicolor',2,'Alaska'),\n", + "(17,ARRAY[5.2,2.7,3.9,1.4],'Iris_versicolor',2,'Alaska'),\n", + "(18,ARRAY[5.0,2.0,3.5,1.0],'Iris_versicolor',2,'Alaska'),\n", + "(19,ARRAY[5.9,3.0,4.2,1.5],'Iris_versicolor',2,'Alaska'),\n", + "(20,ARRAY[6.0,2.2,4.0,1.0],'Iris_versicolor',2,'Alaska'),\n", + "(21,ARRAY[6.1,2.9,4.7,1.4],'Iris_versicolor',2,'Alaska'),\n", + "(22,ARRAY[5.6,2.9,3.6,1.3],'Iris_versicolor',2,'Alaska'),\n", + "(23,ARRAY[6.7,3.1,4.4,1.4],'Iris_versicolor',2,'Alaska'),\n", + "(24,ARRAY[5.6,3.0,4.5,1.5],'Iris_versicolor',2,'Alaska'),\n", + "(25,ARRAY[5.8,2.7,4.1,1.0],'Iris_versicolor',2,'Alaska'),\n", + "(26,ARRAY[6.2,2.2,4.5,1.5],'Iris_versicolor',2,'Alaska'),\n", + "(27,ARRAY[5.6,2.5,3.9,1.1],'Iris_versicolor',2,'Alaska'),\n", + "(28,ARRAY[5.0,3.4,1.5,0.2],'Iris_setosa',1,'Tennessee'),\n", + "(29,ARRAY[4.4,2.9,1.4,0.2],'Iris_setosa',1,'Tennessee'),\n", + "(30,ARRAY[4.9,3.1,1.5,0.1],'Iris_setosa',1,'Tennessee'),\n", + "(31,ARRAY[5.4,3.7,1.5,0.2],'Iris_setosa',1,'Tennessee'),\n", + "(32,ARRAY[4.8,3.4,1.6,0.2],'Iris_setosa',1,'Tennessee'),\n", + "(33,ARRAY[4.8,3.0,1.4,0.1],'Iris_setosa',1,'Tennessee'),\n", + "(34,ARRAY[4.3,3.0,1.1,0.1],'Iris_setosa',1,'Tennessee'),\n", + "(35,ARRAY[5.8,4.0,1.2,0.2],'Iris_setosa',1,'Tennessee'),\n", + "(36,ARRAY[5.7,4.4,1.5,0.4],'Iris_setosa',1,'Tennessee'),\n", + "(37,ARRAY[5.4,3.9,1.3,0.4],'Iris_setosa',1,'Tennessee'),\n", + "(38,ARRAY[6.0,2.9,4.5,1.5],'Iris_versicolor',2,'Tennessee'),\n", + "(39,ARRAY[5.7,2.6,3.5,1.0],'Iris_versicolor',2,'Tennessee'),\n", + "(40,ARRAY[5.5,2.4,3.8,1.1],'Iris_versicolor',2,'Tennessee'),\n", + "(41,ARRAY[5.5,2.4,3.7,1.0],'Iris_versicolor',2,'Tennessee'),\n", + "(42,ARRAY[5.8,2.7,3.9,1.2],'Iris_versicolor',2,'Tennessee'),\n", + "(43,ARRAY[6.0,2.7,5.1,1.6],'Iris_versicolor',2,'Tennessee'),\n", + "(44,ARRAY[5.4,3.0,4.5,1.5],'Iris_versicolor',2,'Tennessee'),\n", + "(45,ARRAY[6.0,3.4,4.5,1.6],'Iris_versicolor',2,'Tennessee'),\n", + "(46,ARRAY[6.7,3.1,4.7,1.5],'Iris_versicolor',2,'Tennessee'),\n", + "(47,ARRAY[6.3,2.3,4.4,1.3],'Iris_versicolor',2,'Tennessee'),\n", + "(48,ARRAY[5.6,3.0,4.1,1.3],'Iris_versicolor',2,'Tennessee'),\n", + "(49,ARRAY[5.5,2.5,4.0,1.3],'Iris_versicolor',2,'Tennessee'),\n", + "(50,ARRAY[5.5,2.6,4.4,1.2],'Iris_versicolor',2,'Tennessee'),\n", + "(51,ARRAY[6.1,3.0,4.6,1.4],'Iris_versicolor',2,'Tennessee'),\n", + "(52,ARRAY[5.8,2.6,4.0,1.2],'Iris_versicolor',2,'Tennessee');\n", + "\n", + "SELECT * FROM iris_data ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Classification model\n", + "\n", + "Generate a multilayer perceptron with a single hidden layer of 5 units. Use the attributes column as the independent variables, and use the class column as the classification. Set the tolerance to 0 so that 500 iterations will be run. Use a hyperbolic tangent activation function. The model will be written to mlp_model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>mlp_classification</th>\n", + " </tr>\n", + " <tr>\n", + " <td></td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[('',)]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization;\n", + "\n", + "-- Set seed so results are reproducible\n", + "SELECT setseed(0);\n", + "\n", + "SELECT madlib.mlp_classification(\n", + " 'iris_data', -- Source table\n", + " 'mlp_model', -- Destination table\n", + " 'attributes', -- Input features\n", + " 'class_text', -- Label\n", + " ARRAY[5], -- Number of units per layer\n", + " 'learning_rate_init=0.003,\n", + " n_iterations=500,\n", + " tolerance=0', -- Optimizer params\n", + " 'tanh', -- Activation function\n", + " NULL, -- Default weight (1)\n", + " FALSE, -- No warm start\n", + " FALSE -- Not verbose\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "View the classification model:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coeff</th>\n", + " <th>loss</th>\n", + " <th>num_iterations</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[-0.241222466342704, 0.249761165436226, 0.0303614415534361, 0.175507685036152, 0.0815734316744983, -0.139212167184273, -0.523351688827021, 0.682480785215137, -0.621908347655988, -0.469472551411455, 0.0147756731570967, -0.147278260748286, 0.305555593978066, -0.56090735505522, -0.161234544160645, 0.271024873678476, 0.353270530669186, -0.584992635094265, 0.805209136021446, 0.680671510984583, 0.263358536565038, 0.416395782728773, -0.536749491420233, 0.401066730627837, 0.851835204311541, -0.118513505811856, -0.0923782150443473, 1.08817076327241, 0.581060997166745, -1.65259364782558, -1.05618327045111, 0.319893429079889, 0.227676950437848, -1.23886857988594, -0.403615235762293, 1.2788824239611, 1.35644018931145]</td>\n", + " <td>0.000759051397192</td>\n", + " <td>500</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([-0.241222466342704, 0.249761165436226, 0.0303614415534361, 0.175507685036152, 0.0815734316744983, -0.139212167184273, -0.523351688827021, 0.682480785215137, -0.621908347655988, -0.469472551411455, 0.0147756731570967, -0.147278260748286, 0.305555593978066, -0.56090735505522, -0.161234544160645, 0.271024873678476, 0.353270530669186, -0.584992635094265, 0.805209136021446, 0.680671510984583, 0.263358536565038, 0.416395782728773, -0.536749491420233, 0.401066730627837, 0.851835204311541, -0.118513505811856, -0.0923782150443473, 1.08817076327241, 0.581060997166745, -1.65259364782558, -1.05618327045111, 0.319893429079889, 0.227676950437848, -1.23886857988594, -0.403615235762293, 1.2788824239611, 1.35644018931145], 0.000759051397191927, 500)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT * FROM mlp_model;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "View the model summary table:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>source_table</th>\n", + " <th>independent_varname</th>\n", + " <th>dependent_varname</th>\n", + " <th>dependent_vartype</th>\n", + " <th>tolerance</th>\n", + " <th>learning_rate_init</th>\n", + " <th>learning_rate_policy</th>\n", + " <th>momentum</th>\n", + " <th>nesterov</th>\n", + " <th>n_iterations</th>\n", + " <th>n_tries</th>\n", + " <th>layer_sizes</th>\n", + " <th>activation</th>\n", + " <th>is_classification</th>\n", + " <th>classes</th>\n", + " <th>weights</th>\n", + " <th>grouping_col</th>\n", + " </tr>\n", + " <tr>\n", + " <td>iris_data</td>\n", + " <td>attributes</td>\n", + " <td>class_text</td>\n", + " <td>character varying</td>\n", + " <td>0.0</td>\n", + " <td>0.003</td>\n", + " <td>constant</td>\n", + " <td>0.9</td>\n", + " <td>True</td>\n", + " <td>500</td>\n", + " <td>1</td>\n", + " <td>[4, 5, 2]</td>\n", + " <td>tanh</td>\n", + " <td>True</td>\n", + " <td>[u'Iris_setosa', u'Iris_versicolor']</td>\n", + " <td>1</td>\n", + " <td>NULL</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(u'iris_data', u'attributes', u'class_text', u'character varying', 0.0, 0.003, u'constant', 0.9, True, 500, 1, [4, 5, 2], u'tanh', True, [u'Iris_setosa', u'Iris_versicolor'], u'1', u'NULL')]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT * FROM mlp_model_summary;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "View the standardization table:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>mean</th>\n", + " <th>std</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[5.45961538461539, 2.99807692307692, 3.025, 0.851923076923077]</td>\n", + " <td>[0.598799958694505, 0.498262513685689, 1.41840579525043, 0.550346179381454]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([5.45961538461539, 2.99807692307692, 3.025, 0.851923076923077], [0.598799958694505, 0.498262513685689, 1.41840579525043, 0.550346179381454])]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT * FROM mlp_model_standardization;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Prediction for classification\n", + "\n", + "Now let's use the model to predict. In the following example we will use the training data set for prediction as well, which is not usual but serves to show the syntax. The prediction is in the estimated_class_text column with the actual value in the class_text column." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "52 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>estimated_class_text</th>\n", + " <th>attributes</th>\n", + " <th>class_text</th>\n", + " <th>class</th>\n", + " <th>state</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>21</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>22</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>23</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>24</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>25</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>26</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>27</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>28</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>29</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>30</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>31</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>32</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>33</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>34</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>35</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>36</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>37</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>38</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>39</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>40</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>41</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>42</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>43</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>44</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>45</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>46</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>47</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>48</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>49</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>50</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>51</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + " <tr>\n", + " <td>52</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Tennessee</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(1, u'Iris_setosa', [Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (2, u'Iris_setosa', [Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (3, u'Iris_setosa', [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris_setosa', 1, u'Alaska'),\n", + " (4, u'Iris_setosa', [Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (5, u'Iris_setosa', [Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (6, u'Iris_setosa', [Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')], u'Iris_setosa', 1, u'Alaska'),\n", + " (7, u'Iris_setosa', [Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')], u'Iris_setosa', 1, u'Alaska'),\n", + " (8, u'Iris_setosa', [Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (9, u'Iris_setosa', [Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')], u'Iris_setosa', 1, u'Alaska'),\n", + " (10, u'Iris_setosa', [Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')], u'Iris_setosa', 1, u'Alaska'),\n", + " (11, u'Iris_setosa', [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')], u'Iris_setosa', 1, u'Alaska'),\n", + " (12, u'Iris_setosa', [Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')], u'Iris_setosa', 1, u'Alaska'),\n", + " (13, u'Iris_versicolor', [Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (14, u'Iris_versicolor', [Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (15, u'Iris_versicolor', [Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (16, u'Iris_versicolor', [Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (17, u'Iris_versicolor', [Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (18, u'Iris_versicolor', [Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (19, u'Iris_versicolor', [Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (20, u'Iris_versicolor', [Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (21, u'Iris_versicolor', [Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (22, u'Iris_versicolor', [Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (23, u'Iris_versicolor', [Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (24, u'Iris_versicolor', [Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (25, u'Iris_versicolor', [Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (26, u'Iris_versicolor', [Decimal('6.2'), Decimal('2.2'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (27, u'Iris_versicolor', [Decimal('5.6'), Decimal('2.5'), Decimal('3.9'), Decimal('1.1')], u'Iris_versicolor', 2, u'Alaska'),\n", + " (28, u'Iris_setosa', [Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (29, u'Iris_setosa', [Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (30, u'Iris_setosa', [Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (31, u'Iris_setosa', [Decimal('5.4'), Decimal('3.7'), Decimal('1.5'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (32, u'Iris_setosa', [Decimal('4.8'), Decimal('3.4'), Decimal('1.6'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (33, u'Iris_setosa', [Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.1')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (34, u'Iris_setosa', [Decimal('4.3'), Decimal('3.0'), Decimal('1.1'), Decimal('0.1')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (35, u'Iris_setosa', [Decimal('5.8'), Decimal('4.0'), Decimal('1.2'), Decimal('0.2')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (36, u'Iris_setosa', [Decimal('5.7'), Decimal('4.4'), Decimal('1.5'), Decimal('0.4')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (37, u'Iris_setosa', [Decimal('5.4'), Decimal('3.9'), Decimal('1.3'), Decimal('0.4')], u'Iris_setosa', 1, u'Tennessee'),\n", + " (38, u'Iris_versicolor', [Decimal('6.0'), Decimal('2.9'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (39, u'Iris_versicolor', [Decimal('5.7'), Decimal('2.6'), Decimal('3.5'), Decimal('1.0')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (40, u'Iris_versicolor', [Decimal('5.5'), Decimal('2.4'), Decimal('3.8'), Decimal('1.1')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (41, u'Iris_versicolor', [Decimal('5.5'), Decimal('2.4'), Decimal('3.7'), Decimal('1.0')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (42, u'Iris_versicolor', [Decimal('5.8'), Decimal('2.7'), Decimal('3.9'), Decimal('1.2')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (43, u'Iris_versicolor', [Decimal('6.0'), Decimal('2.7'), Decimal('5.1'), Decimal('1.6')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (44, u'Iris_versicolor', [Decimal('5.4'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (45, u'Iris_versicolor', [Decimal('6.0'), Decimal('3.4'), Decimal('4.5'), Decimal('1.6')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (46, u'Iris_versicolor', [Decimal('6.7'), Decimal('3.1'), Decimal('4.7'), Decimal('1.5')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (47, u'Iris_versicolor', [Decimal('6.3'), Decimal('2.3'), Decimal('4.4'), Decimal('1.3')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (48, u'Iris_versicolor', [Decimal('5.6'), Decimal('3.0'), Decimal('4.1'), Decimal('1.3')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (49, u'Iris_versicolor', [Decimal('5.5'), Decimal('2.5'), Decimal('4.0'), Decimal('1.3')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (50, u'Iris_versicolor', [Decimal('5.5'), Decimal('2.6'), Decimal('4.4'), Decimal('1.2')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (51, u'Iris_versicolor', [Decimal('6.1'), Decimal('3.0'), Decimal('4.6'), Decimal('1.4')], u'Iris_versicolor', 2, u'Tennessee'),\n", + " (52, u'Iris_versicolor', [Decimal('5.8'), Decimal('2.6'), Decimal('4.0'), Decimal('1.2')], u'Iris_versicolor', 2, u'Tennessee')]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS mlp_prediction;\n", + "\n", + "SELECT madlib.mlp_predict(\n", + " 'mlp_model', -- Model table\n", + " 'iris_data', -- Test data table\n", + " 'id', -- Id column in test table\n", + " 'mlp_prediction', -- Output table for predictions\n", + " 'response' -- Output classes, not probabilities\n", + " );\n", + "\n", + "SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count the misclassifications:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>count</th>\n", + " </tr>\n", + " <tr>\n", + " <td>0</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(0L,)]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM mlp_prediction JOIN iris_data USING (id) \n", + "WHERE mlp_prediction.estimated_class_text != iris_data.class_text;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Classification with Mini-Batching\n", + "\n", + "# 1. Call mini-batch preprocessor\n", + "\n", + "Use the same data set as above." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "2 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>__id__</th>\n", + " <th>dependent_varname</th>\n", + " <th>independent_varname</th>\n", + " </tr>\n", + " <tr>\n", + " <td>0</td>\n", + " <td>[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]</td>\n", + " <td>[[0.0674425821070736, -1.20032494247982, 0.475886380512869, 0.269061417385597], [-0.767560815504508, -2.00311461461654, 0.334883008509056, 0.269061417385597], [-0.433559456459875, -0.598232688377286, 0.616889752516682, 0.995876674738521], [-0.934561495026824, 0.20455698375943, -1.07515071152907, -1.36627291165848], [0.568444620674023, -0.798930106411465, 0.687391438518589, 0.632469046062059], [-0.767560815504508, 0.806649237861967, -1.07515071152907, -1.18456909732025], [-0.934561495026824, -1.20032494247982, 0.193879636505243, 0.269061417385597], [0.902445979718656, -0.196837852308928, 1.03989986852812, 1.17758048907675], [0.0674425821070736, 1.00734665589615, -1.21615408353289, -1.18456909732025], [1.23644733876329, -1.60171977854818, 1.03989986852812, 1.17758048907675], [1.06944665924097, 0.00385956572525086, 1.11040155453003, 0.995876674738521], [-1.76956489263841, 0.00385956572525086, -1.21615408353289, -1.18456909732025], [1.4034480182856, 0.605951819827788 , 1.18090324053193, 1.35928430341498], [1.4034480182856, -1.401022360514, 0.969398182526215, 0.81417286040029], [-0.934561495026824, 0.20455698375943, -1.07515071152907, -1.36627291165848], [-0.0995580974152422, 0.00385956572525086, 1.03989986852812, 1.17758048907675], [-1.76956489263841, 0.405254401793609, -1.21615408353289, -1.18456909732025], [-1.93656557216072, 0.00385956572525086, -1.3571574555367, -1.36627291165848], [0.902445979718656, 0.806649237861967, 1.03989986852812, 1.35928430341498], [-0.767560815504508, 1.00734665589615, -1.00464902552717, -0.457753839967327], [-1.10156217454914, 0.00385956572525086, -1.14565239753098, -1.36627291165848], [-1.60256421311609, -1.401022360514, -1.21615408353289, -1.00286528298202], [0.401443941151707, 2.81362341820376, -1.07515071152907, -0.821161468643789], [-0.767560815504508, 1.00734665589615, -1.21615408353289, -1.00286528298202], [0.568444620674023, -0.598232688377286, 0.757893124520495, 0.269061417385597], [0.568444620674023, 2.01 083374606704, -1.28665576953479, -1.18456909732025]]</td>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>[[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]</td>\n", + " <td>[[-1.76956489263841, -0.196837852308928, -1.14565239753098, -1.18456909732025], [0.401443941151707, -0.397535270343108, 1.03989986852812, 0.81417286040029], [0.902445979718656, -1.60171977854818, 0.687391438518589, 0.269061417385597], [0.234443261629389, 0.00385956572525086, 1.03989986852812, 1.17758048907675], [-0.600560135982193, 1.60943890999868, -0.793143967521448, -0.821161468643789], [-0.600560135982193, 0.806649237861967, -1.07515071152907, -1.18456909732025], [0.401443941151707, -0.798930106411465, 0.334883008509056, 0.269061417385597], [0.73544530019634, 0.00385956572525086, 0.828394810522402, 1.17758048907675], [-0.600560135982193, 1.60943890999868, -1.00464902552717, -1.18456909732025], [1.06944665924097, -0.196837852308928, 1.18090324053193, 0.995876674738521], [0.234443261629389, -0.196837852308928, 0.405384694510963, 0.81417286040029], [-0.767560815504508, 0.405254401793609, -1.28665576953479, -1.18456909732025], [0.0674425821070736, -0.999627524445 644, 0.687391438518589, 0.81417286040029], [0.902445979718656, -0.598232688377286, 1.46290998453956, 1.35928430341498], [0.568444620674023, -0.598232688377286, 0.616889752516682, 0.632469046062059], [-0.0995580974152422, 1.81013632803286, -1.21615408353289, -0.821161468643789], [1.90445005685255, -0.196837852308928, 1.11040155453003, 0.81417286040029], [-1.10156217454914, 0.806649237861967, -1.00464902552717, -1.18456909732025], [-1.10156217454914, 0.00385956572525086, -1.14565239753098, -1.00286528298202], [0.234443261629389, -0.999627524445644, 0.616889752516682, 0.450765231723828], [0.0674425821070736, -1.20032494247982, 0.546388066514775, 0.450765231723828], [2.07145073637487, 0.20455698375943, 0.969398182526215, 0.995876674738521], [0.0674425821070736, -0.798930106411465, 0.969398182526215, 0.632469046062059], [2.07145073637487, 0.20455698375943, 1.18090324053193, 1.17758048907675], [-0.0995580974152422, 1.4087414919645, -1.07515071152907, -1.18456909732025], [0.234443261629389 , 0.00385956572525086, 0.757893124520495, 0.81417286040029]]</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[(0L, [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], [[0.0674425821070736, -1.20032494247982, 0.475886380512869, 0.269061417385597], [-0.767560815504508, -2.00311461461654, 0.334883008509056, 0.269061417385597], [-0.433559456459875, -0.598232688377286, 0.616889752516682, 0.995876674738521], [-0.934561495026824, 0.20455698375943, -1.07515071152907, -1.36627291165848], [0.568444620674023, -0.798930106411465, 0.687391438518589, 0.632469046062059], [-0.767560815504508, 0.806649237861967, -1.07515071152907, -1.18456909732025], [-0.934561495026824, -1.20032494247982, 0.193879636505243, 0.269061417385597], [0.902445979718656, -0.196837852308928, 1.03989986852812, 1.17758048907675], [0.0674425821070736, 1.00734665589615, -1.21615408353289, -1.18456909732025], [1.23644733876329, -1.60171977854818, 1.03989986852812, 1.17758048907675], [1.06944665924097, 0.00385956572525086, 1.11040155453003, 0.995876674738521], [-1.76956489263841, 0.00385956572525086, -1.21615408353289, -1.18456909732025], [1.4034480182856, 0.605951819827788, 1.18090324053193, 1.35928430341498], [1.4034480182856, -1.401022360514, 0.969398182526215, 0.81417286040029], [-0.934561495026824, 0.20455698375943, -1.07515071152907, -1.36627291165848], [-0.0995580974152422, 0.00385956572525086, 1.03989986852812, 1.17758048907675], [-1.76956489263841, 0.405254401793609, -1.21615408353289, -1.18456909732025], [-1.93656557216072, 0.00385956572525086, -1.3571574555367, -1.36627291165848], [0.902445979718656, 0.806649237861967, 1.03989986852812, 1.35928430341498], [-0.767560815504508, 1.00734665589615, -1.00464902552717, -0.457753839967327], [-1.10156217454914, 0.00385956572525086, -1.14565239753098, -1.36627291165848], [-1.60256421311609, -1.4010 22360514, -1.21615408353289, -1.00286528298202], [0.401443941151707, 2.81362341820376, -1.07515071152907, -0.821161468643789], [-0.767560815504508, 1.00734665589615, -1.21615408353289, -1.00286528298202], [0.568444620674023, -0.598232688377286, 0.757893124520495, 0.269061417385597], [0.568444620674023, 2.01083374606704, -1.28665576953479, -1.18456909732025]]),\n", + " (1L, [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]], [[-1.76956489263841, -0.196837852308928, -1.14565239753098, -1.18456909732025], [0.401443941151707, -0.397535270343108, 1.03989986852812, 0.81417286040029], [0.902445979718656, -1.60171977854818, 0.687391438518589, 0.269061417385597], [0.234443261629389, 0.00385956572525086, 1.03989986852812, 1.17758048907675], [-0.600560135982193, 1.60943890999868, -0.793143967521448, -0.821161468643789], [-0.600560135982193, 0.806649237861967, -1.07515071152907, -1.18456909732025], [0.401443941151707, -0.798930106411465, 0.334883008509056, 0.269061417385597], [0.73544530019634, 0.00385956572525086, 0.828394810522402, 1.17758048907675], [-0.600560135982193, 1.60943890999868, - 1.00464902552717, -1.18456909732025], [1.06944665924097, -0.196837852308928, 1.18090324053193, 0.995876674738521], [0.234443261629389, -0.196837852308928, 0.405384694510963, 0.81417286040029], [-0.767560815504508, 0.405254401793609, -1.28665576953479, -1.18456909732025], [0.0674425821070736, -0.999627524445644, 0.687391438518589, 0.81417286040029], [0.902445979718656, -0.598232688377286, 1.46290998453956, 1.35928430341498], [0.568444620674023, -0.598232688377286, 0.616889752516682, 0.632469046062059], [-0.0995580974152422, 1.81013632803286, -1.21615408353289, -0.821161468643789], [1.90445005685255, -0.196837852308928, 1.11040155453003, 0.81417286040029], [-1.10156217454914, 0.806649237861967, -1.00464902552717, -1.18456909732025], [-1.10156217454914, 0.00385956572525086, -1.14565239753098, -1.00286528298202], [0.234443261629389, -0.999627524445644, 0.616889752516682, 0.450765231723828], [0.0674425821070736, -1.20032494247982, 0.546388066514775, 0.450765231723828], [2.07145073637487, 0.20455698375943, 0.969398182526215, 0.995876674738521], [0.0674425821070736, -0.798930106411465, 0.969398182526215, 0.632469046062059], [2.07145073637487, 0.20455698375943, 1.18090324053193, 1.17758048907675], [-0.0995580974152422, 1.4087414919645, -1.07515071152907, -1.18456909732025], [0.234443261629389, 0.00385956572525086, 0.757893124520495, 0.81417286040029]])]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_summary, iris_data_packed_standardization;\n", + "\n", + "SELECT madlib.minibatch_preprocessor('iris_data', -- Source table\n", + " 'iris_data_packed', -- Output table\n", + " 'class_text', -- Dependent variable\n", + " 'attributes' -- Independent variables\n", + " );\n", + "\n", + "SELECT * FROM iris_data_packed ORDER BY __id__;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Classification model with mini-batching\n", + "Use similar parameters as before:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>mlp_classification</th>\n", + " </tr>\n", + " <tr>\n", + " <td></td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[('',)]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization;\n", + "-- Set seed so results are reproducible\n", + "SELECT setseed(0);\n", + "\n", + "SELECT madlib.mlp_classification(\n", + " 'iris_data_packed', -- Ouput table from mini-batch preprocessor\n", + " 'mlp_model', -- Destination table\n", + " 'independent_varname', -- Hardcode to this, from table iris_data_packed\n", + " 'dependent_varname', -- Hardcode to this, from table iris_data_packed\n", + " ARRAY[5], -- Number of units per layer\n", + " 'learning_rate_init=0.1,\n", + " n_iterations=500,\n", + " tolerance=0', -- Optimizer params\n", + " 'tanh', -- Activation function\n", + " NULL, -- Default weight (1)\n", + " FALSE, -- No warm start\n", + " FALSE -- Not verbose\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "View the classification model:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>coeff</th>\n", + " <th>loss</th>\n", + " <th>num_iterations</th>\n", + " </tr>\n", + " <tr>\n", + " <td>[-0.13252534498649, -0.37896100676424, 0.503403120152183, -1.08679816249876, -0.561814480613823, 0.402548045485862, 0.532705053534283, -0.552630848296953, 0.67855516747366, 0.413920498590441, -0.256766228396609, -0.310783793721915, 0.473025380015013, -0.764847565888928, -0.633685437636607, 0.331118371614658, 0.68785202110586, -0.730905401730095, 0.812276423458694, 1.0028080447692, -0.247150080195493, -0.326479660450041, 0.272027165497577, -0.783484532590993, -0.3240005914012, -0.292052260754316, 0.942848112860559, -0.611370901320803, 0.727726656356208, -1.80532276028114, 0.859042273992032, 0.245507590486603, -1.36472202886508, 0.750248279831834, -0.892308360996602, 1.34236183516907, -0.528363560505027]</td>\n", + " <td>0.000217511008854</td>\n", + " <td>500</td>\n", + " </tr>\n", + "</table>" + ], + "text/plain": [ + "[([-0.13252534498649, -0.37896100676424, 0.503403120152183, -1.08679816249876, -0.561814480613823, 0.402548045485862, 0.532705053534283, -0.552630848296953, 0.67855516747366, 0.413920498590441, -0.256766228396609, -0.310783793721915, 0.473025380015013, -0.764847565888928, -0.633685437636607, 0.331118371614658, 0.68785202110586, -0.730905401730095, 0.812276423458694, 1.0028080447692, -0.247150080195493, -0.326479660450041, 0.272027165497577, -0.783484532590993, -0.3240005914012, -0.292052260754316, 0.942848112860559, -0.611370901320803, 0.727726656356208, -1.80532276028114, 0.859042273992032, 0.245507590486603, -1.36472202886508, 0.750248279831834, -0.892308360996602, 1.34236183516907, -0.528363560505027], 0.000217511008854446, 500)]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT * FROM mlp_model;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Prediction for classification\n", + "\n", + "Now predict using model we built using mini-batching. As before we will use the training data set for prediction as well, which is not usual but serves to show the syntax. The prediction is in the estimated_class_text column with the actual value in the class_text column. Note that the prediction function is exactly the same whether you use mini batching or not:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "1 rows affected.\n", + "52 rows affected.\n" + ] + }, + { + "data": { + "text/html": [ + "<table>\n", + " <tr>\n", + " <th>id</th>\n", + " <th>estimated_class_text</th>\n", + " <th>attributes</th>\n", + " <th>class_text</th>\n", + " <th>class</th>\n", + " <th>state</th>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.2'), Decimal('1.2'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.5'), Decimal('3.5'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.4'), Decimal('3.0'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.1'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.3'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.5'), Decimal('2.3'), Decimal('1.3'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.4'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.0'), Decimal('3.5'), Decimal('1.6'), Decimal('0.6')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.9'), Decimal('0.4')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('4.8'), Decimal('3.0'), Decimal('1.4'), Decimal('0.3')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>Iris_setosa</td>\n", + " <td>[Decimal('5.1'), Decimal('3.8'), Decimal('1.6'), Decimal('0.2')]</td>\n", + " <td>Iris_setosa</td>\n", + " <td>1</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.7'), Decimal('2.8'), Decimal('4.5'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.3'), Decimal('3.3'), Decimal('4.7'), Decimal('1.6')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('4.9'), Decimal('2.4'), Decimal('3.3'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.6'), Decimal('2.9'), Decimal('4.6'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.2'), Decimal('2.7'), Decimal('3.9'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.0'), Decimal('2.0'), Decimal('3.5'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.9'), Decimal('3.0'), Decimal('4.2'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.0'), Decimal('2.2'), Decimal('4.0'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>21</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.1'), Decimal('2.9'), Decimal('4.7'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>22</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.6'), Decimal('2.9'), Decimal('3.6'), Decimal('1.3')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>23</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('6.7'), Decimal('3.1'), Decimal('4.4'), Decimal('1.4')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>24</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.6'), Decimal('3.0'), Decimal('4.5'), Decimal('1.5')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n", + " </tr>\n", + " <tr>\n", + " <td>25</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>[Decimal('5.8'), Decimal('2.7'), Decimal('4.1'), Decimal('1.0')]</td>\n", + " <td>Iris_versicolor</td>\n", + " <td>2</td>\n", + " <td>Alaska</td>\n
<TRUNCATED>