This is an automated email from the ASF dual-hosted git repository. kellen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 402e985 Update autoencoder example (#12933) 402e985 is described below commit 402e985b1482bac6e8b89896bab18024e3a9fcfc Author: Thomas Delteil <thomas.delte...@gmail.com> AuthorDate: Wed Jan 23 13:29:05 2019 -0800 Update autoencoder example (#12933) * Fixing the autoencoder example * adding pointer to VAE * fix typos * Update README.md * Updating notebook * Update after comments * Update README.md * Update README.md * Retrigger build * Updates after review --- example/autoencoder/README.md | 24 +- example/autoencoder/autoencoder.py | 206 -------- .../autoencoder/convolutional_autoencoder.ipynb | 543 +++++++++++++++++++++ example/autoencoder/data.py | 34 -- example/autoencoder/mnist_sae.py | 100 ---- example/autoencoder/model.py | 78 --- example/autoencoder/solver.py | 151 ------ 7 files changed, 557 insertions(+), 579 deletions(-) diff --git a/example/autoencoder/README.md b/example/autoencoder/README.md index 7efa30a..960636c 100644 --- a/example/autoencoder/README.md +++ b/example/autoencoder/README.md @@ -1,16 +1,20 @@ -# Example of Autencoder +# Example of a Convolutional Autoencoder -Autoencoder architecture is often used for unsupervised feature learning. This [link](http://ufldl.stanford.edu/tutorial/unsupervised/Autoencoders/) contains an introduction tutorial to autoencoders. This example illustrates a simple autoencoder using stack of fully-connected layers for both encoder and decoder. The number of hidden layers and size of each hidden layer can be customized using command line arguments. +Autoencoder architectures are often used for unsupervised feature learning. This [link](http://ufldl.stanford.edu/tutorial/unsupervised/Autoencoders/) contains an introduction tutorial to autoencoders. This example illustrates a simple autoencoder using a stack of convolutional layers for both the encoder and the decoder. -## Training Stages -This example uses a two-stage training. In the first stage, each layer of encoder and its corresponding decoder are trained separately in a layer-wise training loop. In the second stage the entire autoencoder network is fine-tuned end to end. + +![](https://cdn-images-1.medium.com/max/800/1*LSYNW5m3TN7xRX61BZhoZA.png) + +([Diagram source](https://towardsdatascience.com/autoencoders-introduction-and-implementation-3f40483b0a85)) + + +The idea of an autoencoder is to learn to use bottleneck architecture to encode the input and then try to decode it to reproduce the original. By doing so, the network learns to effectively compress the information of the input, the resulting embedding representation can then be used in several domains. For example as featurized representation for visual search, or in anomaly detection. ## Dataset -The dataset used in this example is [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. This example uses scikit-learn module to download this dataset. -## Simple autoencoder example -mnist_sae.py: this example uses a simple auto-encoder architecture to encode and decode MNIST images with size of 28x28 pixels. It contains several command line arguments. Pass -h (or --help) to view all available options. To start the training on CPU (use --gpu option for training on GPU) using default options: +The dataset used in this example is [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. + +## Variational Autoencoder + +You can check an example of variational autoencoder [here](https://gluon.mxnet.io/chapter13_unsupervised-learning/vae-gluon.html) -``` -python mnist_sae.py -``` diff --git a/example/autoencoder/autoencoder.py b/example/autoencoder/autoencoder.py deleted file mode 100644 index 47931e5..0000000 --- a/example/autoencoder/autoencoder.py +++ /dev/null @@ -1,206 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=missing-docstring, arguments-differ -from __future__ import print_function - -import logging - -import mxnet as mx -import numpy as np -import model -from solver import Solver, Monitor - - -class AutoEncoderModel(model.MXModel): - def setup(self, dims, sparseness_penalty=None, pt_dropout=None, - ft_dropout=None, input_act=None, internal_act='relu', output_act=None): - self.N = len(dims) - 1 - self.dims = dims - self.stacks = [] - self.pt_dropout = pt_dropout - self.ft_dropout = ft_dropout - self.input_act = input_act - self.internal_act = internal_act - self.output_act = output_act - - self.data = mx.symbol.Variable('data') - for i in range(self.N): - if i == 0: - decoder_act = input_act - idropout = None - else: - decoder_act = internal_act - idropout = pt_dropout - if i == self.N-1: - encoder_act = output_act - odropout = None - else: - encoder_act = internal_act - odropout = pt_dropout - istack, iargs, iargs_grad, iargs_mult, iauxs = self.make_stack( - i, self.data, dims[i], dims[i+1], sparseness_penalty, - idropout, odropout, encoder_act, decoder_act - ) - self.stacks.append(istack) - self.args.update(iargs) - self.args_grad.update(iargs_grad) - self.args_mult.update(iargs_mult) - self.auxs.update(iauxs) - self.encoder, self.internals = self.make_encoder( - self.data, dims, sparseness_penalty, ft_dropout, internal_act, output_act) - self.decoder = self.make_decoder( - self.encoder, dims, sparseness_penalty, ft_dropout, internal_act, input_act) - if input_act == 'softmax': - self.loss = self.decoder - else: - self.loss = mx.symbol.LinearRegressionOutput(data=self.decoder, label=self.data) - - def make_stack(self, istack, data, num_input, num_hidden, sparseness_penalty=None, - idropout=None, odropout=None, encoder_act='relu', decoder_act='relu'): - x = data - if idropout: - x = mx.symbol.Dropout(data=x, p=idropout) - x = mx.symbol.FullyConnected(name='encoder_%d'%istack, data=x, num_hidden=num_hidden) - if encoder_act: - x = mx.symbol.Activation(data=x, act_type=encoder_act) - if encoder_act == 'sigmoid' and sparseness_penalty: - x = mx.symbol.IdentityAttachKLSparseReg( - data=x, name='sparse_encoder_%d' % istack, penalty=sparseness_penalty) - if odropout: - x = mx.symbol.Dropout(data=x, p=odropout) - x = mx.symbol.FullyConnected(name='decoder_%d'%istack, data=x, num_hidden=num_input) - if decoder_act == 'softmax': - x = mx.symbol.Softmax(data=x, label=data, prob_label=True, act_type=decoder_act) - elif decoder_act: - x = mx.symbol.Activation(data=x, act_type=decoder_act) - if decoder_act == 'sigmoid' and sparseness_penalty: - x = mx.symbol.IdentityAttachKLSparseReg( - data=x, name='sparse_decoder_%d' % istack, penalty=sparseness_penalty) - x = mx.symbol.LinearRegressionOutput(data=x, label=data) - else: - x = mx.symbol.LinearRegressionOutput(data=x, label=data) - - args = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), - 'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), - 'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), - 'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} - args_grad = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), - 'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), - 'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), - 'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} - args_mult = {'encoder_%d_weight'%istack: 1.0, - 'encoder_%d_bias'%istack: 2.0, - 'decoder_%d_weight'%istack: 1.0, - 'decoder_%d_bias'%istack: 2.0,} - auxs = {} - if encoder_act == 'sigmoid' and sparseness_penalty: - auxs['sparse_encoder_%d_moving_avg' % istack] = mx.nd.ones(num_hidden, self.xpu) * 0.5 - if decoder_act == 'sigmoid' and sparseness_penalty: - auxs['sparse_decoder_%d_moving_avg' % istack] = mx.nd.ones(num_input, self.xpu) * 0.5 - init = mx.initializer.Uniform(0.07) - for k, v in args.items(): - init(mx.initializer.InitDesc(k), v) - - return x, args, args_grad, args_mult, auxs - - def make_encoder(self, data, dims, sparseness_penalty=None, dropout=None, internal_act='relu', - output_act=None): - x = data - internals = [] - N = len(dims) - 1 - for i in range(N): - x = mx.symbol.FullyConnected(name='encoder_%d'%i, data=x, num_hidden=dims[i+1]) - if internal_act and i < N-1: - x = mx.symbol.Activation(data=x, act_type=internal_act) - if internal_act == 'sigmoid' and sparseness_penalty: - x = mx.symbol.IdentityAttachKLSparseReg( - data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty) - elif output_act and i == N-1: - x = mx.symbol.Activation(data=x, act_type=output_act) - if output_act == 'sigmoid' and sparseness_penalty: - x = mx.symbol.IdentityAttachKLSparseReg( - data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty) - if dropout: - x = mx.symbol.Dropout(data=x, p=dropout) - internals.append(x) - return x, internals - - def make_decoder(self, feature, dims, sparseness_penalty=None, dropout=None, - internal_act='relu', input_act=None): - x = feature - N = len(dims) - 1 - for i in reversed(range(N)): - x = mx.symbol.FullyConnected(name='decoder_%d'%i, data=x, num_hidden=dims[i]) - if internal_act and i > 0: - x = mx.symbol.Activation(data=x, act_type=internal_act) - if internal_act == 'sigmoid' and sparseness_penalty: - x = mx.symbol.IdentityAttachKLSparseReg( - data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty) - elif input_act and i == 0: - x = mx.symbol.Activation(data=x, act_type=input_act) - if input_act == 'sigmoid' and sparseness_penalty: - x = mx.symbol.IdentityAttachKLSparseReg( - data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty) - if dropout and i > 0: - x = mx.symbol.Dropout(data=x, p=dropout) - return x - - def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, - lr_scheduler=None, print_every=1000): - def l2_norm(label, pred): - return np.mean(np.square(label-pred))/2.0 - solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, - lr_scheduler=lr_scheduler) - solver.set_metric(mx.metric.CustomMetric(l2_norm)) - solver.set_monitor(Monitor(print_every)) - data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, - last_batch_handle='roll_over') - for i in range(self.N): - if i == 0: - data_iter_i = data_iter - else: - X_i = list(model.extract_feature( - self.internals[i-1], self.args, self.auxs, data_iter, X.shape[0], - self.xpu).values())[0] - data_iter_i = mx.io.NDArrayIter({'data': X_i}, batch_size=batch_size, - last_batch_handle='roll_over') - logging.info('Pre-training layer %d...', i) - solver.solve(self.xpu, self.stacks[i], self.args, self.args_grad, self.auxs, - data_iter_i, 0, n_iter, {}, False) - - def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None, - print_every=1000): - def l2_norm(label, pred): - return np.mean(np.square(label-pred))/2.0 - solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, - lr_scheduler=lr_scheduler) - solver.set_metric(mx.metric.CustomMetric(l2_norm)) - solver.set_monitor(Monitor(print_every)) - data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, - last_batch_handle='roll_over') - logging.info('Fine tuning...') - solver.solve(self.xpu, self.loss, self.args, self.args_grad, self.auxs, data_iter, - 0, n_iter, {}, False) - - def eval(self, X): - batch_size = 100 - data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False, - last_batch_handle='pad') - Y = list(model.extract_feature( - self.loss, self.args, self.auxs, data_iter, X.shape[0], self.xpu).values())[0] - return np.mean(np.square(Y-X))/2.0 diff --git a/example/autoencoder/convolutional_autoencoder.ipynb b/example/autoencoder/convolutional_autoencoder.ipynb new file mode 100644 index 0000000..c42ad90 --- /dev/null +++ b/example/autoencoder/convolutional_autoencoder.ipynb @@ -0,0 +1,543 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Convolutional Autoencoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://cdn-images-1.medium.com/max/800/1*LSYNW5m3TN7xRX61BZhoZA.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we will demonstrate how you can create a convolutional autoencoder in Gluon" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import mxnet as mx\n", + "from mxnet import autograd, gluon" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data\n", + "\n", + "We will use the FashionMNIST dataset, which is of a similar format to MNIST but is richer and has more variance" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 512\n", + "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) > 0 else mx.cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "transform = lambda x,y: (x.transpose((2,0,1)).astype('float32')/255., y)\n", + "\n", + "train_dataset = gluon.data.vision.FashionMNIST(train=True)\n", + "test_dataset = gluon.data.vision.FashionMNIST(train=False)\n", + "\n", + "train_dataset_t = train_dataset.transform(transform)\n", + "test_dataset_t = test_dataset.transform(transform)\n", + "\n", + "train_data = gluon.data.DataLoader(train_dataset_t, batch_size=batch_size, last_batch='rollover', shuffle=True, num_workers=5)\n", + "test_data = gluon.data.DataLoader(test_dataset_t, batch_size=batch_size, last_batch='rollover', shuffle=True, num_workers=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABIEAAACBCAYAAABXearSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJztnXm4VmW5/2+q0+QQ5iwITkwOCKKghWLOSuZsNqk5HI+WiXoqT9ox09LqKr2wKK/UIjNLvRrMIU3AMENESHECkUkEHBFTGk51PH/8fjx9n297Pb1uNnu/77s+n7/utZ9nr7Xe9YxrXff3vnu9/vrrAQAAAAAAAAAA7c2bevoGAAAAAAAAAABg7cNHIAAAAAAAAACAGsBHIAAAAAAAAACAGsBHIAAAAAAAAACAGsBHIAAAAAAAAACAGsBHIAAAAAAAAACAGsBHIAAAAAAA [...] + "text/plain": [ + "<Figure size 1440x720 with 10 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(20,10))\n", + "for i in range(10):\n", + " ax = plt.subplot(1, 10, i+1)\n", + " ax.imshow(train_dataset[i][0].squeeze().asnumpy(), cmap='gray')\n", + " ax.axis('off')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "net = gluon.nn.HybridSequential(prefix='autoencoder_')\n", + "with net.name_scope():\n", + " # Encoder 1x28x28 -> 32x1x1\n", + " encoder = gluon.nn.HybridSequential(prefix='encoder_')\n", + " with encoder.name_scope():\n", + " encoder.add(\n", + " gluon.nn.Conv2D(channels=4, kernel_size=3, padding=1, strides=(2,2), activation='relu'),\n", + " gluon.nn.BatchNorm(),\n", + " gluon.nn.Conv2D(channels=8, kernel_size=3, padding=1, strides=(2,2), activation='relu'),\n", + " gluon.nn.BatchNorm(),\n", + " gluon.nn.Conv2D(channels=16, kernel_size=3, padding=1, strides=(2,2), activation='relu'),\n", + " gluon.nn.BatchNorm(),\n", + " gluon.nn.Conv2D(channels=32, kernel_size=3, padding=0, strides=(2,2),activation='relu'),\n", + " gluon.nn.BatchNorm()\n", + " )\n", + " decoder = gluon.nn.HybridSequential(prefix='decoder_')\n", + " # Decoder 32x1x1 -> 1x28x28\n", + " with decoder.name_scope():\n", + " decoder.add(\n", + " gluon.nn.Conv2D(channels=32, kernel_size=3, padding=2, activation='relu'),\n", + " gluon.nn.HybridLambda(lambda F, x: F.UpSampling(x, scale=2, sample_type='nearest')),\n", + " gluon.nn.BatchNorm(),\n", + " gluon.nn.Conv2D(channels=16, kernel_size=3, padding=1, activation='relu'),\n", + " gluon.nn.HybridLambda(lambda F, x: F.UpSampling(x, scale=2, sample_type='nearest')),\n", + " gluon.nn.BatchNorm(),\n", + " gluon.nn.Conv2D(channels=8, kernel_size=3, padding=2, activation='relu'),\n", + " gluon.nn.HybridLambda(lambda F, x: F.UpSampling(x, scale=2, sample_type='nearest')),\n", + " gluon.nn.BatchNorm(),\n", + " gluon.nn.Conv2D(channels=4, kernel_size=3, padding=1, activation='relu'),\n", + " gluon.nn.Conv2D(channels=1, kernel_size=3, padding=1, activation='sigmoid')\n", + " )\n", + " net.add(\n", + " encoder,\n", + " decoder\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "net.initialize(ctx=ctx)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================================\n", + " Input (1, 1, 28, 28) 0\n", + " Activation-1 <Symbol autoencoder_encoder_conv0_relu_fwd> 0\n", + " Activation-2 (1, 4, 14, 14) 0\n", + " Conv2D-3 (1, 4, 14, 14) 40\n", + " BatchNorm-4 (1, 4, 14, 14) 16\n", + " Activation-5 <Symbol autoencoder_encoder_conv1_relu_fwd> 0\n", + " Activation-6 (1, 8, 7, 7) 0\n", + " Conv2D-7 (1, 8, 7, 7) 296\n", + " BatchNorm-8 (1, 8, 7, 7) 32\n", + " Activation-9 <Symbol autoencoder_encoder_conv2_relu_fwd> 0\n", + " Activation-10 (1, 16, 4, 4) 0\n", + " Conv2D-11 (1, 16, 4, 4) 1168\n", + " BatchNorm-12 (1, 16, 4, 4) 64\n", + " Activation-13 <Symbol autoencoder_encoder_conv3_relu_fwd> 0\n", + " Activation-14 (1, 32, 1, 1) 0\n", + " Conv2D-15 (1, 32, 1, 1) 4640\n", + " BatchNorm-16 (1, 32, 1, 1) 128\n", + " Activation-17 <Symbol autoencoder_decoder_conv0_relu_fwd> 0\n", + " Activation-18 (1, 32, 3, 3) 0\n", + " Conv2D-19 (1, 32, 3, 3) 9248\n", + " HybridLambda-20 (1, 32, 6, 6) 0\n", + " BatchNorm-21 (1, 32, 6, 6) 128\n", + " Activation-22 <Symbol autoencoder_decoder_conv1_relu_fwd> 0\n", + " Activation-23 (1, 16, 6, 6) 0\n", + " Conv2D-24 (1, 16, 6, 6) 4624\n", + " HybridLambda-25 (1, 16, 12, 12) 0\n", + " BatchNorm-26 (1, 16, 12, 12) 64\n", + " Activation-27 <Symbol autoencoder_decoder_conv2_relu_fwd> 0\n", + " Activation-28 (1, 8, 14, 14) 0\n", + " Conv2D-29 (1, 8, 14, 14) 1160\n", + " HybridLambda-30 (1, 8, 28, 28) 0\n", + " BatchNorm-31 (1, 8, 28, 28) 32\n", + " Activation-32 <Symbol autoencoder_decoder_conv3_relu_fwd> 0\n", + " Activation-33 (1, 4, 28, 28) 0\n", + " Conv2D-34 (1, 4, 28, 28) 292\n", + " Activation-35 <Symbol autoencoder_decoder_conv4_sigmoid_fwd> 0\n", + " Activation-36 (1, 1, 28, 28) 0\n", + " Conv2D-37 (1, 1, 28, 28) 37\n", + "================================================================================\n", + "Parameters in forward computation graph, duplicate included\n", + " Total params: 21969\n", + " Trainable params: 21737\n", + " Non-trainable params: 232\n", + "Shared params in forward computation graph: 0\n", + "Unique parameters in model: 21969\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], + "source": [ + "net.summary(test_dataset_t[0][0].expand_dims(axis=0).as_in_context(ctx))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the original image goes from 28x28 = 784 pixels to a vector of length 32. That is a ~25x information compression rate.\n", + "Then the decoder brings back this compressed information to the original shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "l2_loss = gluon.loss.L2Loss()\n", + "l1_loss = gluon.loss.L1Loss()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 0.001, 'wd':0.001})\n", + "net.hybridize(static_shape=True, static_alloc=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training loop" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [0], Loss 0.2246280246310764\n", + "Epoch [1], Loss 0.14493223337026742\n", + "Epoch [2], Loss 0.13147933666522688\n", + "Epoch [3], Loss 0.12138325943906084\n", + "Epoch [4], Loss 0.11291297684367906\n", + "Epoch [5], Loss 0.10611823453741559\n", + "Epoch [6], Loss 0.09942417470817892\n", + "Epoch [7], Loss 0.09408332955124032\n", + "Epoch [8], Loss 0.08883619716024807\n", + "Epoch [9], Loss 0.08491455795418502\n", + "Epoch [10], Loss 0.0809355994402352\n", + "Epoch [11], Loss 0.07784551636785524\n", + "Epoch [12], Loss 0.07570812029716296\n", + "Epoch [13], Loss 0.07417513366438384\n", + "Epoch [14], Loss 0.07218785571236895\n", + "Epoch [15], Loss 0.07093704352944584\n", + "Epoch [16], Loss 0.0700181406787318\n", + "Epoch [17], Loss 0.0689836893326197\n", + "Epoch [18], Loss 0.06782063459738708\n", + "Epoch [19], Loss 0.06713279088338216\n" + ] + } + ], + "source": [ + "epochs = 20\n", + "for e in range(epochs):\n", + " curr_loss = 0.\n", + " for i, (data, _) in enumerate(train_data):\n", + " data = data.as_in_context(ctx)\n", + " with autograd.record():\n", + " output = net(data)\n", + " # Compute the L2 and L1 losses between the original and the generated image\n", + " l2 = l2_loss(output.flatten(), data.flatten())\n", + " l1 = l1_loss(output.flatten(), data.flatten())\n", + " l = l2 + l1 \n", + " l.backward()\n", + " trainer.step(data.shape[0])\n", + " \n", + " curr_loss += l.mean()\n", + "\n", + " print(\"Epoch [{}], Loss {}\".format(e, curr_loss.asscalar()/(i+1)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing reconstruction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We plot 10 images and their reconstruction by the autoencoder. The results are pretty good for a ~25x compression rate!" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABIEAAAD4CAYAAAB7VPbbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzsvWe8XVXVvj1QsQKBkJCQQiqBhEQ6gRBaIIAoIE1BimAB8UcVBfUvIGIBFSygPHZ/IihFRVSU3psGQpASEtJIJw0CKPqIvB98mc8975w13Qmn7LPXdX0aO3Oetdeefa2Me4y1Xn311QAAAAAAAAAAgNbmDV19AwAAAAAAAAAA0PHwEggAAAAAAAAAoAbwEggAAAAAAAAAoAbwEggAAAAAAAAAoAbwEggAAAAAAAAAoAbwEggAAAAAAAAAoAbwEggAAAAAAAAAoAbwEggA [...] + "text/plain": [ + "<Figure size 1440x288 with 20 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(20,4))\n", + "for i in range(10):\n", + " idx = random.randint(0, len(test_dataset))\n", + " img, _ = test_dataset[idx]\n", + " x, _ = test_dataset_t[idx]\n", + "\n", + " data = x.as_in_context(ctx).expand_dims(axis=0)\n", + " output = net(data)\n", + " \n", + " ax = plt.subplot(2, 10, i+1)\n", + " ax.imshow(img.squeeze().asnumpy(), cmap='gray')\n", + " ax.axis('off')\n", + " ax = plt.subplot(2, 10, 10+i+1)\n", + " ax.imshow((output[0].asnumpy() * 255.).transpose((1,2,0)).squeeze(), cmap='gray')\n", + " _ = ax.axis('off')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Manipulating latent space" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now use separately the **encoder** that takes an image to a latent vector and the **decoder** that transform a latent vector into images" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We get two images from the testing set" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACPCAYAAAARM4LLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAACsxJREFUeJztnduLFdkVxr9le7/ftdXWUdFRCUJkCMYEEaOo8zIP4hWCoOBLAgkEzEzyByiCeRCDIEYnD9EYiKAEYYjaAwbjoNHBqENPa7z1qPF+v7buPHR5sven59Q5fbbn1LG/HzRdX+06Vbu7V++9au1Vq8w5ByHKpVO1OyDeD2RIIgoyJBEFGZKIggxJREGGJKIgQxJRkCGJKJRlSGY238yazOysmX0aq1Oi9rD2RrbNrA7AtwDmAmgBcBTAMufcmQKfURi99rjpnBuSdlA5I9IP [...] + "text/plain": [ + "<Figure size 144x144 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "<matplotlib.image.AxesImage at 0x7f04995adc50>" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACPCAYAAAARM4LLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAACZRJREFUeJztnUtsVdcVhv+Feb8JD2Nsg4OwKjFAqhRVoFYC0SJoJmFUBUHEIBKTVmqlSCRph0zKpLNOkEDpoHJVqZWSQSSrRNSoUIE9iKgJAkwRD2Owzdvmadgd3Bv37D/xvde+y/eew/k/yeL851zfsxP93nudvddex0IIEKJaZtS7AeLNQEYSLshIwgUZSbggIwkXZCThgowkXJCRhAtVGcnMdpnZRTPrM7NPvBolsodNdWbbzBoAXAKwA8BNAN0A9oQQvinxO6mdRm9tbS15fWxs [...] + "text/plain": [ + "<Figure size 144x144 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "idx = random.randint(0, len(test_dataset))\n", + "img1, _ = test_dataset[idx]\n", + "x, _ = test_dataset_t[idx]\n", + "data1 = x.as_in_context(ctx).expand_dims(axis=0)\n", + "\n", + "idx = random.randint(0, len(test_dataset))\n", + "img2, _ = test_dataset[idx]\n", + "x, _ = test_dataset_t[idx]\n", + "data2 = x.as_in_context(ctx).expand_dims(axis=0)\n", + "\n", + "plt.figure(figsize=(2,2))\n", + "plt.imshow(img1.squeeze().asnumpy(), cmap='gray')\n", + "plt.show()\n", + "plt.figure(figsize=(2,2))\n", + "plt.imshow(img2.squeeze().asnumpy(), cmap='gray')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We get the latent representations of the images by passing them through the network" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "latent1 = encoder(data1)\n", + "latent2 = encoder(data2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the latent vector is made of 32 components" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 32, 1, 1)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "latent1.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We interpolate the two latent representations, vectors of 32 values, to get a new intermediate latent representation, pass it through the decoder and plot the resulting decoded image" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABIEAAACBCAYAAABXearSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3WmMXlUdx/GDC/vSZVraaaEtpS1d6AIFoYSqBRUqpCAISCIkShBBUZREjb4QQnjhQmKMJmCihqAoCCIKmoJhL1vL0tZSoHTfpmVaWmhR2XyBHH7nN3MOT6fzzDzz3O/n1X9679zn9p57zr3z5Pz/Z4933nknAAAAAAAAoLl9qLdPAAAAAAAAAPXHl0AAAAAAAAAVwJdAAAAAAAAAFcCXQAAAAAAAABXAl0AAAAAAAAAVwJdAAAAAAAAAFcCXQAAAAAAAABXAl0AAAAAA [...] + "text/plain": [ + "<Figure size 1440x360 with 10 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "num = 10\n", + "plt.figure(figsize=(20, 5))\n", + "\n", + "for i in range(int(num)):\n", + " \n", + " new_latent = latent2*(i+1)/num + latent1*(num-i)/num\n", + " output = decoder(new_latent)\n", + " \n", + " #plot result\n", + " ax = plt.subplot(1, num, i+1)\n", + " ax.imshow((output[0].asnumpy() * 255.).transpose((1,2,0)).squeeze(), cmap='gray')\n", + " _ = ax.axis('off')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the latent space learnt by the autoencoder is fairly smooth, there is no sudden jump from one shape to another" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/example/autoencoder/data.py b/example/autoencoder/data.py deleted file mode 100644 index 99dd4eb..0000000 --- a/example/autoencoder/data.py +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=missing-docstring -from __future__ import print_function - -import os -import numpy as np -from sklearn.datasets import fetch_mldata - - -def get_mnist(): - np.random.seed(1234) # set seed for deterministic ordering - data_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - data_path = os.path.join(data_path, '../../data') - mnist = fetch_mldata('MNIST original', data_home=data_path) - p = np.random.permutation(mnist.data.shape[0]) - X = mnist.data[p].astype(np.float32)*0.02 - Y = mnist.target[p] - return X, Y diff --git a/example/autoencoder/mnist_sae.py b/example/autoencoder/mnist_sae.py deleted file mode 100644 index 886f2a1..0000000 --- a/example/autoencoder/mnist_sae.py +++ /dev/null @@ -1,100 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=missing-docstring -from __future__ import print_function - -import argparse -import logging - -import mxnet as mx -import numpy as np -import data -from autoencoder import AutoEncoderModel - -parser = argparse.ArgumentParser(description='Train an auto-encoder model for mnist dataset.', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument('--print-every', type=int, default=1000, - help='interval of printing during training.') -parser.add_argument('--batch-size', type=int, default=256, - help='batch size used for training.') -parser.add_argument('--pretrain-num-iter', type=int, default=50000, - help='number of iterations for pretraining.') -parser.add_argument('--finetune-num-iter', type=int, default=100000, - help='number of iterations for fine-tuning.') -parser.add_argument('--visualize', action='store_true', - help='whether to visualize the original image and the reconstructed one.') -parser.add_argument('--num-units', type=str, default="784,500,500,2000,10", - help='number of hidden units for the layers of the encoder.' - 'The decoder layers are created in the reverse order. First dimension ' - 'must be 784 (28x28) to match mnist image dimension.') -parser.add_argument('--gpu', action='store_true', - help='whether to start training on GPU.') - -# set to INFO to see less information during training -logging.basicConfig(level=logging.INFO) -opt = parser.parse_args() -logging.info(opt) -print_every = opt.print_every -batch_size = opt.batch_size -pretrain_num_iter = opt.pretrain_num_iter -finetune_num_iter = opt.finetune_num_iter -visualize = opt.visualize -gpu = opt.gpu -layers = [int(i) for i in opt.num_units.split(',')] - - -if __name__ == '__main__': - xpu = mx.gpu() if gpu else mx.cpu() - print("Training on {}".format("GPU" if gpu else "CPU")) - - ae_model = AutoEncoderModel(xpu, layers, pt_dropout=0.2, internal_act='relu', - output_act='relu') - - X, _ = data.get_mnist() - train_X = X[:60000] - val_X = X[60000:] - - ae_model.layerwise_pretrain(train_X, batch_size, pretrain_num_iter, 'sgd', l_rate=0.1, - decay=0.0, lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1), - print_every=print_every) - ae_model.finetune(train_X, batch_size, finetune_num_iter, 'sgd', l_rate=0.1, decay=0.0, - lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1), print_every=print_every) - ae_model.save('mnist_pt.arg') - ae_model.load('mnist_pt.arg') - print("Training error:", ae_model.eval(train_X)) - print("Validation error:", ae_model.eval(val_X)) - if visualize: - try: - from matplotlib import pyplot as plt - from model import extract_feature - # sample a random image - original_image = X[np.random.choice(X.shape[0]), :].reshape(1, 784) - data_iter = mx.io.NDArrayIter({'data': original_image}, batch_size=1, shuffle=False, - last_batch_handle='pad') - # reconstruct the image - reconstructed_image = extract_feature(ae_model.decoder, ae_model.args, - ae_model.auxs, data_iter, 1, - ae_model.xpu).values()[0] - print("original image") - plt.imshow(original_image.reshape((28, 28))) - plt.show() - print("reconstructed image") - plt.imshow(reconstructed_image.reshape((28, 28))) - plt.show() - except ImportError: - logging.info("matplotlib is required for visualization") diff --git a/example/autoencoder/model.py b/example/autoencoder/model.py deleted file mode 100644 index 9b6185c..0000000 --- a/example/autoencoder/model.py +++ /dev/null @@ -1,78 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=missing-docstring -from __future__ import print_function - -import mxnet as mx -import numpy as np -try: - import cPickle as pickle -except ImportError: - import pickle - - -def extract_feature(sym, args, auxs, data_iter, N, xpu=mx.cpu()): - input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in data_iter.provide_data] - input_names = [k for k, shape in data_iter.provide_data] - args = dict(args, **dict(zip(input_names, input_buffs))) - exe = sym.bind(xpu, args=args, aux_states=auxs) - outputs = [[] for _ in exe.outputs] - output_buffs = None - - data_iter.hard_reset() - for batch in data_iter: - for data, buff in zip(batch.data, input_buffs): - data.copyto(buff) - exe.forward(is_train=False) - if output_buffs is None: - output_buffs = [mx.nd.empty(i.shape, ctx=mx.cpu()) for i in exe.outputs] - else: - for out, buff in zip(outputs, output_buffs): - out.append(buff.asnumpy()) - for out, buff in zip(exe.outputs, output_buffs): - out.copyto(buff) - for out, buff in zip(outputs, output_buffs): - out.append(buff.asnumpy()) - outputs = [np.concatenate(i, axis=0)[:N] for i in outputs] - return dict(zip(sym.list_outputs(), outputs)) - - -class MXModel(object): - def __init__(self, xpu=mx.cpu(), *args, **kwargs): - self.xpu = xpu - self.loss = None - self.args = {} - self.args_grad = {} - self.args_mult = {} - self.auxs = {} - self.setup(*args, **kwargs) - - def save(self, fname): - args_save = {key: v.asnumpy() for key, v in self.args.items()} - with open(fname, 'wb') as fout: - pickle.dump(args_save, fout) - - def load(self, fname): - with open(fname, 'rb') as fin: - args_save = pickle.load(fin) - for key, v in args_save.items(): - if key in self.args: - self.args[key][:] = v - - def setup(self, *args, **kwargs): - raise NotImplementedError("must override this") diff --git a/example/autoencoder/solver.py b/example/autoencoder/solver.py deleted file mode 100644 index 0c990ce..0000000 --- a/example/autoencoder/solver.py +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=missing-docstring -from __future__ import print_function - -import logging - -import mxnet as mx -import numpy as np - - -class Monitor(object): - def __init__(self, interval, level=logging.DEBUG, stat=None): - self.interval = interval - self.level = level - if stat is None: - def mean_abs(x): - return np.fabs(x).mean() - self.stat = mean_abs - else: - self.stat = stat - - def forward_end(self, i, internals): - if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level): - for key in sorted(internals.keys()): - arr = internals[key] - logging.log(self.level, 'Iter:%d param:%s\t\tstat(%s):%s', - i, key, self.stat.__name__, str(self.stat(arr.asnumpy()))) - - def backward_end(self, i, weights, grads, metric=None): - if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level): - for key in sorted(grads.keys()): - arr = grads[key] - logging.log(self.level, 'Iter:%d param:%s\t\tstat(%s):%s\t\tgrad_stat:%s', - i, key, self.stat.__name__, - str(self.stat(weights[key].asnumpy())), str(self.stat(arr.asnumpy()))) - if i % self.interval == 0 and metric is not None: - logging.log(logging.INFO, 'Iter:%d metric:%f', i, metric.get()[1]) - metric.reset() - - -class Solver(object): - def __init__(self, optimizer, **kwargs): - if isinstance(optimizer, str): - self.optimizer = mx.optimizer.create(optimizer, **kwargs) - else: - self.optimizer = optimizer - self.updater = mx.optimizer.get_updater(self.optimizer) - self.monitor = None - self.metric = None - self.iter_end_callback = None - self.iter_start_callback = None - - def set_metric(self, metric): - self.metric = metric - - def set_monitor(self, monitor): - self.monitor = monitor - - def set_iter_end_callback(self, callback): - self.iter_end_callback = callback - - def set_iter_start_callback(self, callback): - self.iter_start_callback = callback - - def solve(self, xpu, sym, args, args_grad, auxs, - data_iter, begin_iter, end_iter, args_lrmult=None, debug=False): - if args_lrmult is None: - args_lrmult = dict() - input_desc = data_iter.provide_data + data_iter.provide_label - input_names = [k for k, shape in input_desc] - input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in input_desc] - args = dict(args, **dict(zip(input_names, input_buffs))) - - output_names = sym.list_outputs() - if debug: - sym_group = [] - for x in sym.get_internals(): - if x.name not in args: - if x.name not in output_names: - x = mx.symbol.BlockGrad(x, name=x.name) - sym_group.append(x) - sym = mx.symbol.Group(sym_group) - exe = sym.bind(xpu, args=args, args_grad=args_grad, aux_states=auxs) - - assert len(sym.list_arguments()) == len(exe.grad_arrays) - update_dict = { - name: nd for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd is not None - } - batch_size = input_buffs[0].shape[0] - self.optimizer.rescale_grad = 1.0/batch_size - self.optimizer.set_lr_mult(args_lrmult) - - output_dict = {} - output_buff = {} - internal_dict = dict(zip(input_names, input_buffs)) - for key, arr in zip(sym.list_outputs(), exe.outputs): - if key in output_names: - output_dict[key] = arr - output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) - else: - internal_dict[key] = arr - - data_iter.reset() - for i in range(begin_iter, end_iter): - if self.iter_start_callback is not None: - if self.iter_start_callback(i): - return - try: - batch = data_iter.next() - except StopIteration: - data_iter.reset() - batch = data_iter.next() - for data, buff in zip(batch.data+batch.label, input_buffs): - data.copyto(buff) - exe.forward(is_train=True) - if self.monitor is not None: - self.monitor.forward_end(i, internal_dict) - for key in output_dict: - output_dict[key].copyto(output_buff[key]) - - exe.backward() - for key, arr in update_dict.items(): - self.updater(key, arr, args[key]) - - if self.metric is not None: - self.metric.update([input_buffs[-1]], - [output_buff[output_names[0]]]) - - if self.monitor is not None: - self.monitor.backward_end(i, args, update_dict, self.metric) - - if self.iter_end_callback is not None: - if self.iter_end_callback(i): - return - exe.outputs[0].wait_to_read()