This is an automated email from the ASF dual-hosted git repository. indhub pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 68bc9b7 Updated capsnet example (#12934) 68bc9b7 is described below commit 68bc9b7f444e76e42c02adfa97ec12149ba0d996 Author: Thomas Delteil <thomas.delte...@gmail.com> AuthorDate: Thu Nov 8 08:38:42 2018 -0800 Updated capsnet example (#12934) * Updated capsnet * trigger CI * Update README.md --- example/capsnet/README.md | 132 ++++---- example/capsnet/capsulenet.py | 695 +++++++++++++++++++++--------------------- 2 files changed, 413 insertions(+), 414 deletions(-) diff --git a/example/capsnet/README.md b/example/capsnet/README.md index 49a6dd1..500c7df 100644 --- a/example/capsnet/README.md +++ b/example/capsnet/README.md @@ -1,66 +1,66 @@ -**CapsNet-MXNet** -========================================= - -This example is MXNet implementation of [CapsNet](https://arxiv.org/abs/1710.09829): -Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017 -- The current `best test error is 0.29%` and `average test error is 0.303%` -- The `average test error on paper is 0.25%` - -Log files for the error rate are uploaded in [repository](https://github.com/samsungsds-rnd/capsnet.mxnet). -* * * -## **Usage** -Install scipy with pip -``` -pip install scipy -``` -Install tensorboard with pip -``` -pip install tensorboard -``` - -On Single gpu -``` -python capsulenet.py --devices gpu0 -``` -On Multi gpus -``` -python capsulenet.py --devices gpu0,gpu1 -``` -Full arguments -``` -python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr 0.001 --num_routing 3 --model_prefix capsnet -``` - -* * * -## **Prerequisities** - -MXNet version above (0.11.0) -scipy version above (0.19.0) - -*** -## **Results** -Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080 gpus) - -CapsNet classification test error on MNIST - -``` -python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99 --model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3 --num_epoch 200 -``` - -![](result.PNG) - -| Trial | Epoch | train err(%) | test err(%) | train loss | test loss | -| :---: | :---: | :---: | :---: | :---: | :---: | -| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 | -| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 | -| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 | -| average | - | 0.043 | 0.303 | 0.005 | 0.006 | - -We achieved `the best test error rate=0.29%` and `average test error=0.303%`. It is the best accuracy and fastest training time result among other implementations(Keras, Tensorflow at 2017-11-23). -The result on paper is `0.25% (average test error rate)`. - -| Implementation| test err(%) | ※train time/epoch | GPU Used| -| :---: | :---: | :---: |:---: | -| MXNet | 0.29 | 36 sec | 2 GTX 1080 | -| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) | -| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti | +**CapsNet-MXNet** +========================================= + +This example is MXNet implementation of [CapsNet](https://arxiv.org/abs/1710.09829): +Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017 +- The current `best test error is 0.29%` and `average test error is 0.303%` +- The `average test error on paper is 0.25%` + +Log files for the error rate are uploaded in [repository](https://github.com/samsungsds-rnd/capsnet.mxnet). +* * * +## **Usage** +Install scipy with pip +``` +pip install scipy +``` +Install tensorboard and mxboard with pip +``` +pip install mxboard tensorflow +``` + +On Single gpu +``` +python capsulenet.py --devices gpu0 +``` +On Multi gpus +``` +python capsulenet.py --devices gpu0,gpu1 +``` +Full arguments +``` +python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr 0.001 --num_routing 3 --model_prefix capsnet +``` + +* * * +## **Prerequisities** + +MXNet version above (1.2.0) +scipy version above (0.19.0) + +*** +## **Results** +Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080 gpus) + +CapsNet classification test error on MNIST: + +``` +python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99 --model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3 --num_epoch 200 +``` + +![](result.PNG) + +| Trial | Epoch | train err(%) | test err(%) | train loss | test loss | +| :---: | :---: | :---: | :---: | :---: | :---: | +| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 | +| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 | +| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 | +| average | - | 0.043 | 0.303 | 0.005 | 0.006 | + +We achieved `the best test error rate=0.29%` and `average test error=0.303%`. It is the best accuracy and fastest training time result among other implementations(Keras, Tensorflow at 2017-11-23). +The result on paper is `0.25% (average test error rate)`. + +| Implementation| test err(%) | ※train time/epoch | GPU Used| +| :---: | :---: | :---: |:---: | +| MXNet | 0.29 | 36 sec | 2 GTX 1080 | +| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) | +| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti | diff --git a/example/capsnet/capsulenet.py b/example/capsnet/capsulenet.py index 6b44c3d..6710875 100644 --- a/example/capsnet/capsulenet.py +++ b/example/capsnet/capsulenet.py @@ -1,348 +1,347 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import mxnet as mx -import numpy as np -import os -import re -import urllib -import gzip -import struct -import scipy.ndimage as ndi -from capsulelayers import primary_caps, CapsuleLayer - -from tensorboard import SummaryWriter - -def margin_loss(y_true, y_pred): - loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\ - 0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1)) - return mx.sym.mean(data=mx.sym.sum(loss, 1)) - - -def capsnet(batch_size, n_class, num_routing,recon_loss_weight): - # data.shape = [batch_size, 1, 28, 28] - data = mx.sym.Variable('data') - - input_shape = (1, 28, 28) - # Conv2D layer - # net.shape = [batch_size, 256, 20, 20] - conv1 = mx.sym.Convolution(data=data, - num_filter=256, - kernel=(9, 9), - layout='NCHW', - name='conv1') - conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act') - # net.shape = [batch_size, 256, 6, 6] - - primarycaps = primary_caps(data=conv1, - dim_vector=8, - n_channels=32, - kernel=(9, 9), - strides=[2, 2], - name='primarycaps') - primarycaps.infer_shape(data=(batch_size, 1, 28, 28)) - # CapsuleLayer - kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=3) - bias_initializer = mx.init.Zero() - digitcaps = CapsuleLayer(num_capsule=10, - dim_vector=16, - batch_size=batch_size, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - num_routing=num_routing)(primarycaps) - - # out_caps : (batch_size, 10) - out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2)) - out_caps.infer_shape(data=(batch_size, 1, 28, 28)) - - y = mx.sym.Variable('softmax_label', shape=(batch_size,)) - y_onehot = mx.sym.one_hot(y, n_class) - y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class, -1)) - y_reshaped.infer_shape(softmax_label=(batch_size,)) - - # inputs_masked : (batch_size, 16) - inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps, transpose_a=True) - inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0)) - x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512, name='x_recon') - x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act') - x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024, name='x_recon2') - x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act2') - x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=np.prod(input_shape), name='x_recon3') - x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid', name='x_recon_act3') - - data_flatten = mx.sym.flatten(data=data) - squared_error = mx.sym.square(x_recon-data_flatten) - recon_error = mx.sym.mean(squared_error) - recon_error_stopped = recon_error - recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped) - loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot, out_caps)+recon_loss_weight*recon_error) - - out_caps_blocked = out_caps - out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked) - return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped]) - - -def download_data(url, force_download=False): - fname = url.split("/")[-1] - if force_download or not os.path.exists(fname): - urllib.urlretrieve(url, fname) - return fname - - -def read_data(label_url, image_url): - with gzip.open(download_data(label_url)) as flbl: - magic, num = struct.unpack(">II", flbl.read(8)) - label = np.fromstring(flbl.read(), dtype=np.int8) - with gzip.open(download_data(image_url), 'rb') as fimg: - magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) - image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols) - return label, image - - -def to4d(img): - return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255 - - -class LossMetric(mx.metric.EvalMetric): - def __init__(self, batch_size, num_gpu): - super(LossMetric, self).__init__('LossMetric') - self.batch_size = batch_size - self.num_gpu = num_gpu - self.sum_metric = 0 - self.num_inst = 0 - self.loss = 0.0 - self.batch_sum_metric = 0 - self.batch_num_inst = 0 - self.batch_loss = 0.0 - self.recon_loss = 0.0 - self.n_batch = 0 - - def update(self, labels, preds): - batch_sum_metric = 0 - batch_num_inst = 0 - for label, pred_outcaps in zip(labels[0], preds[0]): - label_np = int(label.asnumpy()) - pred_label = int(np.argmax(pred_outcaps.asnumpy())) - batch_sum_metric += int(label_np == pred_label) - batch_num_inst += 1 - batch_loss = preds[1].asnumpy() - recon_loss = preds[2].asnumpy() - self.sum_metric += batch_sum_metric - self.num_inst += batch_num_inst - self.loss += batch_loss - self.recon_loss += recon_loss - self.batch_sum_metric = batch_sum_metric - self.batch_num_inst = batch_num_inst - self.batch_loss = batch_loss - self.n_batch += 1 - - def get_name_value(self): - acc = float(self.sum_metric)/float(self.num_inst) - mean_loss = self.loss / float(self.n_batch) - mean_recon_loss = self.recon_loss / float(self.n_batch) - return acc, mean_loss, mean_recon_loss - - def get_batch_log(self, n_batch): - print("n_batch :"+str(n_batch)+" batch_acc:" + - str(float(self.batch_sum_metric) / float(self.batch_num_inst)) + - ' batch_loss:' + str(float(self.batch_loss)/float(self.batch_num_inst))) - self.batch_sum_metric = 0 - self.batch_num_inst = 0 - self.batch_loss = 0.0 - - def reset(self): - self.sum_metric = 0 - self.num_inst = 0 - self.loss = 0.0 - self.recon_loss = 0.0 - self.n_batch = 0 - - -class SimpleLRScheduler(mx.lr_scheduler.LRScheduler): - """A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr` - dynamically based on performance on the validation set. - """ - - def __init__(self, learning_rate=0.001): - super(SimpleLRScheduler, self).__init__() - self.learning_rate = learning_rate - - def __call__(self, num_update): - return self.learning_rate - - -def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, decay): - summary_writer = SummaryWriter(args.tblog_dir) - lr_scheduler = SimpleLRScheduler(learning_rate) - optimizer_params = {'lr_scheduler': lr_scheduler} - module.init_params() - module.init_optimizer(kvstore=kvstore, - optimizer=optimizer, - optimizer_params=optimizer_params) - n_epoch = 0 - while True: - if n_epoch >= num_epoch: - break - train_iter.reset() - val_iter.reset() - loss_metric.reset() - for n_batch, data_batch in enumerate(train_iter): - module.forward_backward(data_batch) - module.update() - module.update_metric(loss_metric, data_batch.label) - loss_metric.get_batch_log(n_batch) - train_acc, train_loss, train_recon_err = loss_metric.get_name_value() - loss_metric.reset() - for n_batch, data_batch in enumerate(val_iter): - module.forward(data_batch) - module.update_metric(loss_metric, data_batch.label) - loss_metric.get_batch_log(n_batch) - val_acc, val_loss, val_recon_err = loss_metric.get_name_value() - - summary_writer.add_scalar('train_acc', train_acc, n_epoch) - summary_writer.add_scalar('train_loss', train_loss, n_epoch) - summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch) - summary_writer.add_scalar('val_acc', val_acc, n_epoch) - summary_writer.add_scalar('val_loss', val_loss, n_epoch) - summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch) - - print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, train_acc, train_loss, train_recon_err)) - print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, val_acc, val_loss, val_recon_err)) - print('SAVE CHECKPOINT') - - module.save_checkpoint(prefix=model_prefix, epoch=n_epoch) - n_epoch += 1 - lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch) - - -def apply_transform(x, - transform_matrix, - fill_mode='nearest', - cval=0.): - x = np.rollaxis(x, 0, 0) - final_affine_matrix = transform_matrix[:2, :2] - final_offset = transform_matrix[:2, 2] - channel_images = [ndi.interpolation.affine_transform( - x_channel, - final_affine_matrix, - final_offset, - order=0, - mode=fill_mode, - cval=cval) for x_channel in x] - x = np.stack(channel_images, axis=0) - x = np.rollaxis(x, 0, 0 + 1) - return x - - -def random_shift(x, width_shift_fraction, height_shift_fraction): - tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) * x.shape[2] - ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) * x.shape[1] - shift_matrix = np.array([[1, 0, tx], - [0, 1, ty], - [0, 0, 1]]) - x = apply_transform(x, shift_matrix, 'nearest') - return x - -def _shuffle(data, idx): - """Shuffle the data.""" - shuffle_data = [] - - for k, v in data: - shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context))) - - return shuffle_data - -class MNISTCustomIter(mx.io.NDArrayIter): - - def reset(self): - # shuffle data - if self.is_train: - np.random.shuffle(self.idx) - self.data = _shuffle(self.data, self.idx) - self.label = _shuffle(self.label, self.idx) - if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: - self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size - else: - self.cursor = -self.batch_size - def set_is_train(self, is_train): - self.is_train = is_train - def next(self): - if self.iter_next(): - if self.is_train: - data_raw_list = self.getdata() - data_shifted = [] - for data_raw in data_raw_list[0]: - data_shifted.append(random_shift(data_raw.asnumpy(), 0.1, 0.1)) - return mx.io.DataBatch(data=[mx.nd.array(data_shifted)], label=self.getlabel(), - pad=self.getpad(), index=None) - else: - return mx.io.DataBatch(data=self.getdata(), label=self.getlabel(), \ - pad=self.getpad(), index=None) - - else: - raise StopIteration - - -if __name__ == "__main__": - # Read mnist data set - path = 'http://yann.lecun.com/exdb/mnist/' - (train_lbl, train_img) = read_data( - path + 'train-labels-idx1-ubyte.gz', path + 'train-images-idx3-ubyte.gz') - (val_lbl, val_img) = read_data( - path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz') - # set batch size - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--batch_size', default=100, type=int) - parser.add_argument('--devices', default='gpu0', type=str) - parser.add_argument('--num_epoch', default=100, type=int) - parser.add_argument('--lr', default=0.001, type=float) - parser.add_argument('--num_routing', default=3, type=int) - parser.add_argument('--model_prefix', default='capsnet', type=str) - parser.add_argument('--decay', default=0.9, type=float) - parser.add_argument('--tblog_dir', default='tblog', type=str) - parser.add_argument('--recon_loss_weight', default=0.392, type=float) - args = parser.parse_args() - for k, v in sorted(vars(args).items()): - print("{0}: {1}".format(k, v)) - contexts = re.split(r'\W+', args.devices) - for i, ctx in enumerate(contexts): - if ctx[:3] == 'gpu': - contexts[i] = mx.context.gpu(int(ctx[3:])) - else: - contexts[i] = mx.context.cpu() - num_gpu = len(contexts) - - if args.batch_size % num_gpu != 0: - raise Exception('num_gpu should be positive divisor of batch_size') - - # generate train_iter, val_iter - train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl, batch_size=args.batch_size, shuffle=True) - train_iter.set_is_train(True) - val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl, batch_size=args.batch_size,) - val_iter.set_is_train(False) - # define capsnet - final_net = capsnet(batch_size=args.batch_size/num_gpu, n_class=10, num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight) - # set metric - loss_metric = LossMetric(args.batch_size/num_gpu, 1) - - # run model - module = mx.mod.Module(symbol=final_net, context=contexts, data_names=('data',), label_names=('softmax_label',)) - module.bind(data_shapes=train_iter.provide_data, - label_shapes=val_iter.provide_label, - for_training=True) - do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device', learning_rate=args.lr, - model_prefix=args.model_prefix, decay=args.decay) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import mxnet as mx +import numpy as np +import os +import re +import gzip +import struct +import scipy.ndimage as ndi +from capsulelayers import primary_caps, CapsuleLayer + +from mxboard import SummaryWriter + +def margin_loss(y_true, y_pred): + loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\ + 0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1)) + return mx.sym.mean(data=mx.sym.sum(loss, 1)) + + +def capsnet(batch_size, n_class, num_routing,recon_loss_weight): + # data.shape = [batch_size, 1, 28, 28] + data = mx.sym.Variable('data') + + input_shape = (1, 28, 28) + # Conv2D layer + # net.shape = [batch_size, 256, 20, 20] + conv1 = mx.sym.Convolution(data=data, + num_filter=256, + kernel=(9, 9), + layout='NCHW', + name='conv1') + conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act') + # net.shape = [batch_size, 256, 6, 6] + + primarycaps = primary_caps(data=conv1, + dim_vector=8, + n_channels=32, + kernel=(9, 9), + strides=[2, 2], + name='primarycaps') + primarycaps.infer_shape(data=(batch_size, 1, 28, 28)) + # CapsuleLayer + kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=3) + bias_initializer = mx.init.Zero() + digitcaps = CapsuleLayer(num_capsule=10, + dim_vector=16, + batch_size=batch_size, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + num_routing=num_routing)(primarycaps) + + # out_caps : (batch_size, 10) + out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2)) + out_caps.infer_shape(data=(batch_size, 1, 28, 28)) + + y = mx.sym.Variable('softmax_label', shape=(batch_size,)) + y_onehot = mx.sym.one_hot(y, n_class) + y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class, -1)) + y_reshaped.infer_shape(softmax_label=(batch_size,)) + + # inputs_masked : (batch_size, 16) + inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps, transpose_a=True) + inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0)) + x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512, name='x_recon') + x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act') + x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024, name='x_recon2') + x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act2') + x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=np.prod(input_shape), name='x_recon3') + x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid', name='x_recon_act3') + + data_flatten = mx.sym.flatten(data=data) + squared_error = mx.sym.square(x_recon-data_flatten) + recon_error = mx.sym.mean(squared_error) + recon_error_stopped = recon_error + recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped) + loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot, out_caps)+recon_loss_weight*recon_error) + + out_caps_blocked = out_caps + out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked) + return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped]) + + +def download_data(url, force_download=False): + fname = url.split("/")[-1] + if force_download or not os.path.exists(fname): + mx.test_utils.download(url, fname) + return fname + + +def read_data(label_url, image_url): + with gzip.open(download_data(label_url)) as flbl: + magic, num = struct.unpack(">II", flbl.read(8)) + label = np.fromstring(flbl.read(), dtype=np.int8) + with gzip.open(download_data(image_url), 'rb') as fimg: + magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) + image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols) + return label, image + + +def to4d(img): + return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255 + + +class LossMetric(mx.metric.EvalMetric): + def __init__(self, batch_size, num_gpu): + super(LossMetric, self).__init__('LossMetric') + self.batch_size = batch_size + self.num_gpu = num_gpu + self.sum_metric = 0 + self.num_inst = 0 + self.loss = 0.0 + self.batch_sum_metric = 0 + self.batch_num_inst = 0 + self.batch_loss = 0.0 + self.recon_loss = 0.0 + self.n_batch = 0 + + def update(self, labels, preds): + batch_sum_metric = 0 + batch_num_inst = 0 + for label, pred_outcaps in zip(labels[0], preds[0]): + label_np = int(label.asnumpy()) + pred_label = int(np.argmax(pred_outcaps.asnumpy())) + batch_sum_metric += int(label_np == pred_label) + batch_num_inst += 1 + batch_loss = preds[1].asnumpy() + recon_loss = preds[2].asnumpy() + self.sum_metric += batch_sum_metric + self.num_inst += batch_num_inst + self.loss += batch_loss + self.recon_loss += recon_loss + self.batch_sum_metric = batch_sum_metric + self.batch_num_inst = batch_num_inst + self.batch_loss = batch_loss + self.n_batch += 1 + + def get_name_value(self): + acc = float(self.sum_metric)/float(self.num_inst) + mean_loss = self.loss / float(self.n_batch) + mean_recon_loss = self.recon_loss / float(self.n_batch) + return acc, mean_loss, mean_recon_loss + + def get_batch_log(self, n_batch): + print("n_batch :"+str(n_batch)+" batch_acc:" + + str(float(self.batch_sum_metric) / float(self.batch_num_inst)) + + ' batch_loss:' + str(float(self.batch_loss)/float(self.batch_num_inst))) + self.batch_sum_metric = 0 + self.batch_num_inst = 0 + self.batch_loss = 0.0 + + def reset(self): + self.sum_metric = 0 + self.num_inst = 0 + self.loss = 0.0 + self.recon_loss = 0.0 + self.n_batch = 0 + + +class SimpleLRScheduler(mx.lr_scheduler.LRScheduler): + """A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr` + dynamically based on performance on the validation set. + """ + + def __init__(self, learning_rate=0.001): + super(SimpleLRScheduler, self).__init__() + self.learning_rate = learning_rate + + def __call__(self, num_update): + return self.learning_rate + + +def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, decay): + summary_writer = SummaryWriter(args.tblog_dir) + lr_scheduler = SimpleLRScheduler(learning_rate) + optimizer_params = {'lr_scheduler': lr_scheduler} + module.init_params() + module.init_optimizer(kvstore=kvstore, + optimizer=optimizer, + optimizer_params=optimizer_params) + n_epoch = 0 + while True: + if n_epoch >= num_epoch: + break + train_iter.reset() + val_iter.reset() + loss_metric.reset() + for n_batch, data_batch in enumerate(train_iter): + module.forward_backward(data_batch) + module.update() + module.update_metric(loss_metric, data_batch.label) + loss_metric.get_batch_log(n_batch) + train_acc, train_loss, train_recon_err = loss_metric.get_name_value() + loss_metric.reset() + for n_batch, data_batch in enumerate(val_iter): + module.forward(data_batch) + module.update_metric(loss_metric, data_batch.label) + loss_metric.get_batch_log(n_batch) + val_acc, val_loss, val_recon_err = loss_metric.get_name_value() + + summary_writer.add_scalar('train_acc', train_acc, n_epoch) + summary_writer.add_scalar('train_loss', train_loss, n_epoch) + summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch) + summary_writer.add_scalar('val_acc', val_acc, n_epoch) + summary_writer.add_scalar('val_loss', val_loss, n_epoch) + summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch) + + print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, train_acc, train_loss, train_recon_err)) + print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, val_acc, val_loss, val_recon_err)) + print('SAVE CHECKPOINT') + + module.save_checkpoint(prefix=model_prefix, epoch=n_epoch) + n_epoch += 1 + lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch) + + +def apply_transform(x, + transform_matrix, + fill_mode='nearest', + cval=0.): + x = np.rollaxis(x, 0, 0) + final_affine_matrix = transform_matrix[:2, :2] + final_offset = transform_matrix[:2, 2] + channel_images = [ndi.interpolation.affine_transform( + x_channel, + final_affine_matrix, + final_offset, + order=0, + mode=fill_mode, + cval=cval) for x_channel in x] + x = np.stack(channel_images, axis=0) + x = np.rollaxis(x, 0, 0 + 1) + return x + + +def random_shift(x, width_shift_fraction, height_shift_fraction): + tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) * x.shape[2] + ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) * x.shape[1] + shift_matrix = np.array([[1, 0, tx], + [0, 1, ty], + [0, 0, 1]]) + x = apply_transform(x, shift_matrix, 'nearest') + return x + +def _shuffle(data, idx): + """Shuffle the data.""" + shuffle_data = [] + + for k, v in data: + shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context))) + + return shuffle_data + +class MNISTCustomIter(mx.io.NDArrayIter): + + def reset(self): + # shuffle data + if self.is_train: + np.random.shuffle(self.idx) + self.data = _shuffle(self.data, self.idx) + self.label = _shuffle(self.label, self.idx) + if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: + self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size + else: + self.cursor = -self.batch_size + def set_is_train(self, is_train): + self.is_train = is_train + def next(self): + if self.iter_next(): + if self.is_train: + data_raw_list = self.getdata() + data_shifted = [] + for data_raw in data_raw_list[0]: + data_shifted.append(random_shift(data_raw.asnumpy(), 0.1, 0.1)) + return mx.io.DataBatch(data=[mx.nd.array(data_shifted)], label=self.getlabel(), + pad=self.getpad(), index=None) + else: + return mx.io.DataBatch(data=self.getdata(), label=self.getlabel(), \ + pad=self.getpad(), index=None) + + else: + raise StopIteration + + +if __name__ == "__main__": + # Read mnist data set + path = 'http://yann.lecun.com/exdb/mnist/' + (train_lbl, train_img) = read_data( + path + 'train-labels-idx1-ubyte.gz', path + 'train-images-idx3-ubyte.gz') + (val_lbl, val_img) = read_data( + path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz') + # set batch size + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=100, type=int) + parser.add_argument('--devices', default='gpu0', type=str) + parser.add_argument('--num_epoch', default=100, type=int) + parser.add_argument('--lr', default=0.001, type=float) + parser.add_argument('--num_routing', default=3, type=int) + parser.add_argument('--model_prefix', default='capsnet', type=str) + parser.add_argument('--decay', default=0.9, type=float) + parser.add_argument('--tblog_dir', default='tblog', type=str) + parser.add_argument('--recon_loss_weight', default=0.392, type=float) + args = parser.parse_args() + for k, v in sorted(vars(args).items()): + print("{0}: {1}".format(k, v)) + contexts = re.split(r'\W+', args.devices) + for i, ctx in enumerate(contexts): + if ctx[:3] == 'gpu': + contexts[i] = mx.context.gpu(int(ctx[3:])) + else: + contexts[i] = mx.context.cpu() + num_gpu = len(contexts) + + if args.batch_size % num_gpu != 0: + raise Exception('num_gpu should be positive divisor of batch_size') + + # generate train_iter, val_iter + train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl, batch_size=int(args.batch_size), shuffle=True) + train_iter.set_is_train(True) + val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl, batch_size=int(args.batch_size),) + val_iter.set_is_train(False) + # define capsnet + final_net = capsnet(batch_size=int(args.batch_size/num_gpu), n_class=10, num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight) + # set metric + loss_metric = LossMetric(args.batch_size/num_gpu, 1) + + # run model + module = mx.mod.Module(symbol=final_net, context=contexts, data_names=('data',), label_names=('softmax_label',)) + module.bind(data_shapes=train_iter.provide_data, + label_shapes=val_iter.provide_label, + for_training=True) + do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device', learning_rate=args.lr, + model_prefix=args.model_prefix, decay=args.decay)