This is an automated email from the ASF dual-hosted git repository. fmcquillan pushed a commit to branch automl in repository https://gitbox.apache.org/repos/asf/madlib-site.git
The following commit(s) were added to refs/heads/automl by this push: new 94a7f7e hyperband diagonal E2E update 94a7f7e is described below commit 94a7f7e81077ccd67710648850b696e2344e39d9 Author: Frank McQuillan <fmcquil...@pivotal.io> AuthorDate: Fri Nov 22 16:29:51 2019 -0800 hyperband diagonal E2E update --- .../hyperband_diag_v2_mnist-checkpoint.ipynb | 157 +++++++++------------ .../automl/hyperband_diag_v2_mnist.ipynb | 130 ++++++++--------- 2 files changed, 135 insertions(+), 152 deletions(-) diff --git a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb index 091e6fd..b62f8d5 100644 --- a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb +++ b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb @@ -30,19 +30,17 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 16, "metadata": { "scrolled": true }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n", - " \"You should import from traitlets.config instead.\", ShimWarning)\n", - "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n", - " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n" + "The sql extension is already loaded. To reload it, use:\n", + " %reload_ext sql\n" ] } ], @@ -52,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -74,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -100,7 +98,7 @@ "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]" ] }, - "execution_count": 3, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -121,24 +119,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 20, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using TensorFlow backend.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Couldn't import dot_parser, loading of dot files will not be possible.\n" - ] - } - ], + "outputs": [], "source": [ "from __future__ import print_function\n", "\n", @@ -180,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -794,7 +777,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -821,7 +804,7 @@ "[]" ] }, - "execution_count": 17, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -896,7 +879,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -924,7 +907,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -953,12 +936,13 @@ " self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n", " self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n", " sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n", + " \n", + " print (\" \")\n", + " print (\"Hyperband brackets\")\n", "\n", " #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n", " for s in reversed(range(self.s_max+1)):\n", - "\n", - " print (\" \")\n", - " print (\"Hyperband brackets\")\n", + " \n", " print (\" \")\n", " print (\"s=\" + str(s))\n", " print (\"n_i r_i\")\n", @@ -1040,40 +1024,44 @@ " # filter out early stops, if any\n", " \n", " # loop on brackets s desc to prune model selection table\n", - " print (\"loop on s desc to prune mst table:\")\n", - " for s in range(self.s_max, self.s_max-i-1, -1):\n", - " \n", - " k = int( self.n_vals[s][i] / self.eta)\n", + " # don't need to prune if finished last diagonal\n", + " if i < self.s_max:\n", + " print (\"loop on s desc to prune mst table:\")\n", + " for s in range(self.s_max, self.s_max-i-1, -1):\n", " \n", - " # temporarily re-run table names again due to weird scope issues\n", - " results_table = 'results_mnist'\n", + " # compute number of configs to keep\n", + " # remember i value is different for each bracket s on the diagonal\n", + " k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n", + " print (\"pruning s = {} with k = {}\".format(s, k))\n", "\n", - " output_table = 'mnist_multi_model'\n", - " output_table_info = '_'.join([output_table, 'info'])\n", - " output_table_summary = '_'.join([output_table, 'summary'])\n", + " # temporarily re-define table names due to weird Python scope issues\n", + " results_table = 'results_mnist'\n", "\n", - " mst_table = 'mst_table_hb_mnist'\n", - " mst_table_summary = '_'.join([mst_table, 'summary'])\n", + " output_table = 'mnist_multi_model'\n", + " output_table_info = '_'.join([output_table, 'info'])\n", + " output_table_summary = '_'.join([output_table, 'summary'])\n", "\n", - " mst_diag_table = 'mst_diag_table_hb_mnist'\n", - " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n", + " mst_table = 'mst_table_hb_mnist'\n", + " mst_table_summary = '_'.join([mst_table, 'summary'])\n", "\n", - " model_arch_table = 'model_arch_library_mnist'\n", - " \n", - " query = \"\"\"\n", - " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n", - " \"\"\".format(**locals())\n", - " cur.execute(query)\n", - " conn.commit()\n", - " #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n", - "# %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n", - " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n", + " mst_diag_table = 'mst_diag_table_hb_mnist'\n", + " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n", + "\n", + " model_arch_table = 'model_arch_library_mnist'\n", + " \n", + " query = \"\"\"\n", + " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n", + " \"\"\".format(**locals())\n", + " cur.execute(query)\n", + " conn.commit()\n", + " #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n", + " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n", " return" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1129,7 +1117,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1163,7 +1151,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": { "scrolled": false }, @@ -1182,16 +1170,12 @@ "3.0 3.0\n", "1.0 9.0\n", " \n", - "Hyperband brackets\n", - " \n", "s=1\n", "n_i r_i\n", "------------\n", "3 3.0\n", "1.0 9.0\n", " \n", - "Hyperband brackets\n", - " \n", "s=0\n", "n_i r_i\n", "------------\n", @@ -1248,6 +1232,7 @@ "9 rows affected.\n", "9 rows affected.\n", "loop on s desc to prune mst table:\n", + "pruning s = 2 with k = 3\n", " \n", "i=1\n", "Done.\n", @@ -1264,29 +1249,35 @@ "6 rows affected.\n", "6 rows affected.\n", "loop on s desc to prune mst table:\n", + "pruning s = 2 with k = 1\n", + "pruning s = 1 with k = 1\n", " \n", "i=2\n", "Done.\n", "loop on s desc to create diagonal table:\n", "1 rows affected.\n", - "0 rows affected.\n", + "1 rows affected.\n", "3 rows affected.\n", "try params for i = 2\n", "Done.\n", "1 rows affected.\n", "Done.\n", - "4 rows affected.\n", + "5 rows affected.\n", "Done.\n", - "4 rows affected.\n", - "4 rows affected.\n", - "4 rows affected.\n", - "loop on s desc to prune mst table:\n" + "5 rows affected.\n", + "5 rows affected.\n", + "5 rows affected.\n", + "loop on s desc to prune mst table:\n", + "pruning s = 2 with k = 0\n", + "pruning s = 1 with k = 0\n", + "pruning s = 0 with k = 1\n" ] } ], "source": [ "hp = Hyperband_diagonal(get_params, try_params )\n", - "results = hp.run()" + "results = hp.run()\n", + "#hp.n_vals[1]" ] }, { @@ -1299,7 +1290,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -1330,7 +1321,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "15 rows affected.\n" + "10 rows affected.\n" ] }, { @@ -2116,7 +2107,7 @@ { "data": { "text/html": [ - "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3xb5b3+H8myJK/sOE7iDDJIwt5lUyhlFFpKGR1QVktv6O2ggz+9rBtuWrjpbbmFrlAohQuUUaCUtrQFWvZoGGUnhGw7y3GWtyRL+n+eVz62LEv2kXRkH9nPy0fIsd7znvf9vkfWT8/5DU88Ho9DTQREQAREQAREQAREQAREQAREQAREQAREQAREoKAEPBLiCspXg4uACIiACIiACIiACIiACIiACIiACIiACIiAISAhTheCCBQ5gbPPPhsPP/ww/v3f/x0/+9nPHF3NIYccgtdffx3/8z//g+9+97uOjq3BnCPw7rvvYt999zUDbtu2DRMmTHBucI0kAiIgAiIgAiIwaARk1/VF3dLSgqqqKvPCq6++CtqnaiIgAiJQzAQkxBXz7mn [...] + "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb1f3+X0se8ooTJ3GcxM5wEhJ2gSTMJJBCGYVCCJRSaKGD1VL6b/kBLSNAwygd0AE0ZbS0BcoKq0CBssPMYI8MMp3h7HhLsiX9n/fI15GH7CtZlq7s9zyPHyf2uWd8zpX11Xu/IyMUCoWgJgIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIi0KsEMiTE9SpfDS4CIiACIiACIiACIiACIiACIiACIiACIiAChoCEON0IIpDmBE477TTMnz8fP/7xj3H77bcndDeTJ0/GkiVL8Nvf/hb/93//l9CxNVjiCHz66afYd999zYBbt27FkCFDEje4RhIBERABERABEUgaAdl1HVHX1dWhsLDQ/GLRokWgfaomAiIgAulMQEJcOp+e1t5rBDI [...] ], "text/plain": [ "<IPython.core.display.HTML object>" @@ -2138,18 +2129,13 @@ "1 rows affected.\n", "1 rows affected.\n", "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", "1 rows affected.\n" ] } ], "source": [ "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n", - "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n", + "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 10;\n", "df_results = df_results.DataFrame()\n", "\n", "#set up plots\n", @@ -2199,7 +2185,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "15 rows affected.\n" + "10 rows affected.\n" ] }, { @@ -2985,7 +2971,7 @@ { "data": { "text/html": [ - "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhbV5n+X0m2Nttx7CRuFmdplibpknRnui8U2lKgLaQMAx06MCxh5j8z0MJ0oG0IEyiTKVNgYKDDNnRo2UppYaAb0L2lpE2XJG3SZmkWJ3GcxbusxZL+z3vtq8iyZF1ZV/aV9Z5Wj2zr3HPP+Z1zcz+99zvf50omk0moiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIlJSAS0JcSfmqcREQAREQAREQAREQAREQAREQAREQAREQAREwCEiI00IQgTEk8NBDD+Hyyy+Hz+dDOBwecuaRPrPSxWKPt3KOkerccccd+OQnP4nFixdjy5YtxTan40tEQPNUIrBqVgREQAREoOIIyK5z3pSPtz3sPCLqkQiIgBMJSIhz4qyoT7YT+NjHPob [...] + "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCZiT1fX/v8ns+zDMhgyLyOaKsoqyuAGCu2LVYkVFLdat1bb8q4IoLqVVa12x1V/FuiMq1YpAQRRQZFNEBUGQZZBhFph9yUyS//O9wztmMsnkTfImeSdz7vPMM0Duve+9n3tDTr733HMsTqfTCSlCQAgIASEgBISAEBACQkAICAEhIASEgBAQAkJACISUgEWEuJDylc6FgBAQAkJACAgBISAEhIAQEAJCQAgIASEgBISAIiBCnGwEIRBGAh999BEmTpyIhIQE1NfXt3pye6/pGWKw7fU8o7068+bNw80334wBAwZg27ZtwXYn7UNEQNYpRGClWyEgBISAEOh0BMSuM9+SR9oeNh8RGZEQEAJmJCBCnBlXRcZkOIEbb7wRL7zwArKysvD [...] ], "text/plain": [ "<IPython.core.display.HTML object>" @@ -3007,18 +2993,13 @@ "1 rows affected.\n", "1 rows affected.\n", "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", "1 rows affected.\n" ] } ], "source": [ "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n", - "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n", + "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 10;\n", "df_results = df_results.DataFrame()\n", "\n", "#set up plots\n", diff --git a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb index 091e6fd..171c9cd 100644 --- a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb +++ b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -100,7 +100,7 @@ "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -121,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -180,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -821,7 +821,7 @@ "[]" ] }, - "execution_count": 17, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -896,7 +896,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -924,7 +924,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -953,12 +953,13 @@ " self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n", " self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n", " sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n", + " \n", + " print (\" \")\n", + " print (\"Hyperband brackets\")\n", "\n", " #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n", " for s in reversed(range(self.s_max+1)):\n", - "\n", - " print (\" \")\n", - " print (\"Hyperband brackets\")\n", + " \n", " print (\" \")\n", " print (\"s=\" + str(s))\n", " print (\"n_i r_i\")\n", @@ -1040,40 +1041,44 @@ " # filter out early stops, if any\n", " \n", " # loop on brackets s desc to prune model selection table\n", - " print (\"loop on s desc to prune mst table:\")\n", - " for s in range(self.s_max, self.s_max-i-1, -1):\n", - " \n", - " k = int( self.n_vals[s][i] / self.eta)\n", + " # don't need to prune if finished last diagonal\n", + " if i < self.s_max:\n", + " print (\"loop on s desc to prune mst table:\")\n", + " for s in range(self.s_max, self.s_max-i-1, -1):\n", " \n", - " # temporarily re-run table names again due to weird scope issues\n", - " results_table = 'results_mnist'\n", + " # compute number of configs to keep\n", + " # remember i value is different for each bracket s on the diagonal\n", + " k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n", + " print (\"pruning s = {} with k = {}\".format(s, k))\n", "\n", - " output_table = 'mnist_multi_model'\n", - " output_table_info = '_'.join([output_table, 'info'])\n", - " output_table_summary = '_'.join([output_table, 'summary'])\n", + " # temporarily re-define table names due to weird Python scope issues\n", + " results_table = 'results_mnist'\n", "\n", - " mst_table = 'mst_table_hb_mnist'\n", - " mst_table_summary = '_'.join([mst_table, 'summary'])\n", + " output_table = 'mnist_multi_model'\n", + " output_table_info = '_'.join([output_table, 'info'])\n", + " output_table_summary = '_'.join([output_table, 'summary'])\n", "\n", - " mst_diag_table = 'mst_diag_table_hb_mnist'\n", - " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n", + " mst_table = 'mst_table_hb_mnist'\n", + " mst_table_summary = '_'.join([mst_table, 'summary'])\n", "\n", - " model_arch_table = 'model_arch_library_mnist'\n", - " \n", - " query = \"\"\"\n", - " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n", - " \"\"\".format(**locals())\n", - " cur.execute(query)\n", - " conn.commit()\n", - " #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n", - "# %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n", - " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n", + " mst_diag_table = 'mst_diag_table_hb_mnist'\n", + " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n", + "\n", + " model_arch_table = 'model_arch_library_mnist'\n", + " \n", + " query = \"\"\"\n", + " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n", + " \"\"\".format(**locals())\n", + " cur.execute(query)\n", + " conn.commit()\n", + " #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n", + " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n", " return" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1129,7 +1134,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1163,7 +1168,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": { "scrolled": false }, @@ -1182,16 +1187,12 @@ "3.0 3.0\n", "1.0 9.0\n", " \n", - "Hyperband brackets\n", - " \n", "s=1\n", "n_i r_i\n", "------------\n", "3 3.0\n", "1.0 9.0\n", " \n", - "Hyperband brackets\n", - " \n", "s=0\n", "n_i r_i\n", "------------\n", @@ -1248,6 +1249,7 @@ "9 rows affected.\n", "9 rows affected.\n", "loop on s desc to prune mst table:\n", + "pruning s = 2 with k = 3\n", " \n", "i=1\n", "Done.\n", @@ -1264,29 +1266,35 @@ "6 rows affected.\n", "6 rows affected.\n", "loop on s desc to prune mst table:\n", + "pruning s = 2 with k = 1\n", + "pruning s = 1 with k = 1\n", " \n", "i=2\n", "Done.\n", "loop on s desc to create diagonal table:\n", "1 rows affected.\n", - "0 rows affected.\n", + "1 rows affected.\n", "3 rows affected.\n", "try params for i = 2\n", "Done.\n", "1 rows affected.\n", "Done.\n", - "4 rows affected.\n", + "5 rows affected.\n", "Done.\n", - "4 rows affected.\n", - "4 rows affected.\n", - "4 rows affected.\n", - "loop on s desc to prune mst table:\n" + "5 rows affected.\n", + "5 rows affected.\n", + "5 rows affected.\n", + "loop on s desc to prune mst table:\n", + "pruning s = 2 with k = 0\n", + "pruning s = 1 with k = 0\n", + "pruning s = 0 with k = 1\n" ] } ], "source": [ "hp = Hyperband_diagonal(get_params, try_params )\n", - "results = hp.run()" + "results = hp.run()\n", + "#hp.n_vals[1]" ] }, { @@ -1299,7 +1307,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -1323,14 +1331,14 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "15 rows affected.\n" + "12 rows affected.\n" ] }, { @@ -2116,7 +2124,7 @@ { "data": { "text/html": [ - "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3xb5b3+H8myJK/sOE7iDDJIwt5lUyhlFFpKGR1QVktv6O2ggz+9rBtuWrjpbbmFrlAohQuUUaCUtrQFWvZoGGUnhGw7y3GWtyRL+n+eVz62LEv2kXRkH9nPy0fIsd7znvf9vkfWT8/5DU88Ho9DTQREQAREQAREQAREQAREQAREQAREQAREQAREoKAEPBLiCspXg4uACIiACIiACIiACIiACIiACIiACIiACIiAISAhTheCCBQ5gbPPPhsPP/ww/v3f/x0/+9nPHF3NIYccgtdffx3/8z//g+9+97uOjq3BnCPw7rvvYt999zUDbtu2DRMmTHBucI0kAiIgAiIgAiIwaARk1/VF3dLSgqqqKvPCq6++CtqnaiIgAiJQzAQkxBXz7mn [...] + "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb5dn+b0m2vDOc2Fl2diBAWVnsXSBQRinQSQu00EJ3oS39GIE2FMr3tdD2T1tmW7poWWXPllF2BrthZMcZjp3YTjwlW9L/ul/52LItyUfS0bB9v9fly4n1nnf83mPr0X2e4QqFQiGoiYAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIpJWAS0JcWvlqcBEQAREQAREQAREQAREQAREQAREQAREQAREwBCTE6UYQgSFO4KyzzsL999+Pb3zjG7j55psd3c2CBQuwcuVK/N///R++//3vOzq2BnOOwHvvvYd9993XDFhfX4/x48c7N7hGEgEREAEREAERyBgB2XUDUbe0tKCsrMy8sHz5ctA+VRMBERCBoUxAQtxQPj2tPW0EXC5X0mP [...] ], "text/plain": [ "<IPython.core.display.HTML object>" @@ -2140,16 +2148,13 @@ "1 rows affected.\n", "1 rows affected.\n", "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", "1 rows affected.\n" ] } ], "source": [ "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n", - "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n", + "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 12;\n", "df_results = df_results.DataFrame()\n", "\n", "#set up plots\n", @@ -2192,14 +2197,14 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "15 rows affected.\n" + "12 rows affected.\n" ] }, { @@ -2985,7 +2990,7 @@ { "data": { "text/html": [ - "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhbV5n+X0m2Nttx7CRuFmdplibpknRnui8U2lKgLaQMAx06MCxh5j8z0MJ0oG0IEyiTKVNgYKDDNnRo2UppYaAb0L2lpE2XJG3SZmkWJ3GcxbusxZL+z3vtq8iyZF1ZV/aV9Z5Wj2zr3HPP+Z1zcz+99zvf50omk0moiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIlJSAS0JcSfmqcREQAREQAREQAREQAREQAREQAREQAREQAREwCEiI00IQgTEk8NBDD+Hyyy+Hz+dDOBwecuaRPrPSxWKPt3KOkerccccd+OQnP4nFixdjy5YtxTan40tEQPNUIrBqVgREQAREoOIIyK5z3pSPtz3sPCLqkQiIgBMJSIhz4qyoT7YT+NjHPob [...] + "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhU1fnG31myrwQStgQQVMAN2aKAgKBW3Dds3dFqFWtrq/6rbVWkxaV0cWmtxbVq1boUq3VHZUchgAiiICoCCRAC2deZJJP/857kJpPJLPfO3FmSfOd58iQk5557zu+cYb5577dYWlpaWiBNCAgBISAEhIAQEAJCQAgIASEgBISAEBACQkAICIGwErCIEBdWvjK4EBACQkAICAEhIASEgBAQAkJACAgBISAEhIAQUAREiJODIAQiSOD999/HGWecgYSEBDQ0NHS6s7+/6ZliqNfruYe/PosWLcKNN96IkSNHYvv27aEOJ9eHiYDsU5jAyrBCQAgIASHQ6wiIXRd7Wx5tezj2iMiMhIAQiEUCIsTF4q7InEwn8JOf/ARPPfUUsrKysG/ [...] ], "text/plain": [ "<IPython.core.display.HTML object>" @@ -3009,16 +3014,13 @@ "1 rows affected.\n", "1 rows affected.\n", "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", - "1 rows affected.\n", "1 rows affected.\n" ] } ], "source": [ "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n", - "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n", + "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 12;\n", "df_results = df_results.DataFrame()\n", "\n", "#set up plots\n",