This is an automated email from the ASF dual-hosted git repository.

indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 68bc9b7  Updated capsnet example (#12934)
68bc9b7 is described below

commit 68bc9b7f444e76e42c02adfa97ec12149ba0d996
Author: Thomas Delteil <thomas.delte...@gmail.com>
AuthorDate: Thu Nov 8 08:38:42 2018 -0800

    Updated capsnet example (#12934)
    
    * Updated capsnet
    
    * trigger CI
    
    * Update README.md
---
 example/capsnet/README.md     | 132 ++++----
 example/capsnet/capsulenet.py | 695 +++++++++++++++++++++---------------------
 2 files changed, 413 insertions(+), 414 deletions(-)

diff --git a/example/capsnet/README.md b/example/capsnet/README.md
index 49a6dd1..500c7df 100644
--- a/example/capsnet/README.md
+++ b/example/capsnet/README.md
@@ -1,66 +1,66 @@
-**CapsNet-MXNet**
-=========================================
-
-This example is MXNet implementation of 
[CapsNet](https://arxiv.org/abs/1710.09829):  
-Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between 
Capsules. NIPS 2017
-- The current `best test error is 0.29%` and `average test error is 0.303%`
-- The `average test error on paper is 0.25%`  
-
-Log files for the error rate are uploaded in 
[repository](https://github.com/samsungsds-rnd/capsnet.mxnet).  
-* * *
-## **Usage**
-Install scipy with pip  
-```
-pip install scipy
-```
-Install tensorboard with pip
-```
-pip install tensorboard
-```
-
-On Single gpu
-```
-python capsulenet.py --devices gpu0
-```
-On Multi gpus
-```
-python capsulenet.py --devices gpu0,gpu1
-```
-Full arguments  
-```
-python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr 
0.001 --num_routing 3 --model_prefix capsnet
-```  
-
-* * *
-## **Prerequisities**
-
-MXNet version above (0.11.0)  
-scipy version above (0.19.0)
-
-***
-## **Results**  
-Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080 
gpus)  
-
-CapsNet classification test error on MNIST  
-
-```
-python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99 
--model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3 
--num_epoch 200
-```
-
-![](result.PNG)
-
-| Trial | Epoch | train err(%) | test err(%) | train loss | test loss |
-| :---: | :---: | :---: | :---: | :---: | :---: |
-| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 |
-| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 |
-| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 |
-| average | - | 0.043 | 0.303 | 0.005 | 0.006 |
-
-We achieved `the best test error rate=0.29%` and `average test error=0.303%`. 
It is the best accuracy and fastest training time result among other 
implementations(Keras, Tensorflow at 2017-11-23).
-The result on paper is `0.25% (average test error rate)`.
-
-| Implementation| test err(%) | ※train time/epoch | GPU  Used|
-| :---: | :---: | :---: |:---: |
-| MXNet | 0.29 | 36 sec | 2 GTX 1080 |
-| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) |
-| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti |
+**CapsNet-MXNet**
+=========================================
+
+This example is MXNet implementation of 
[CapsNet](https://arxiv.org/abs/1710.09829):  
+Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between 
Capsules. NIPS 2017
+- The current `best test error is 0.29%` and `average test error is 0.303%`
+- The `average test error on paper is 0.25%`  
+
+Log files for the error rate are uploaded in 
[repository](https://github.com/samsungsds-rnd/capsnet.mxnet).  
+* * *
+## **Usage**
+Install scipy with pip  
+```
+pip install scipy
+```
+Install tensorboard and mxboard with pip
+```
+pip install mxboard tensorflow
+```
+
+On Single gpu
+```
+python capsulenet.py --devices gpu0
+```
+On Multi gpus
+```
+python capsulenet.py --devices gpu0,gpu1
+```
+Full arguments  
+```
+python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr 
0.001 --num_routing 3 --model_prefix capsnet
+```  
+
+* * *
+## **Prerequisities**
+
+MXNet version above (1.2.0)  
+scipy version above (0.19.0)
+
+***
+## **Results**  
+Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080 
gpus)  
+
+CapsNet classification test error on MNIST:
+
+```
+python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99 
--model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3 
--num_epoch 200
+```
+
+![](result.PNG)
+
+| Trial | Epoch | train err(%) | test err(%) | train loss | test loss |
+| :---: | :---: | :---: | :---: | :---: | :---: |
+| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 |
+| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 |
+| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 |
+| average | - | 0.043 | 0.303 | 0.005 | 0.006 |
+
+We achieved `the best test error rate=0.29%` and `average test error=0.303%`. 
It is the best accuracy and fastest training time result among other 
implementations(Keras, Tensorflow at 2017-11-23).
+The result on paper is `0.25% (average test error rate)`.
+
+| Implementation| test err(%) | ※train time/epoch | GPU  Used|
+| :---: | :---: | :---: |:---: |
+| MXNet | 0.29 | 36 sec | 2 GTX 1080 |
+| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) |
+| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti |
diff --git a/example/capsnet/capsulenet.py b/example/capsnet/capsulenet.py
index 6b44c3d..6710875 100644
--- a/example/capsnet/capsulenet.py
+++ b/example/capsnet/capsulenet.py
@@ -1,348 +1,347 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import mxnet as mx
-import numpy as np
-import os
-import re
-import urllib
-import gzip
-import struct
-import scipy.ndimage as ndi
-from capsulelayers import primary_caps, CapsuleLayer
-
-from tensorboard import SummaryWriter
-
-def margin_loss(y_true, y_pred):
-    loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\
-        0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1))
-    return mx.sym.mean(data=mx.sym.sum(loss, 1))
-
-
-def capsnet(batch_size, n_class, num_routing,recon_loss_weight):
-    # data.shape = [batch_size, 1, 28, 28]
-    data = mx.sym.Variable('data')
-
-    input_shape = (1, 28, 28)
-    # Conv2D layer
-    # net.shape = [batch_size, 256, 20, 20]
-    conv1 = mx.sym.Convolution(data=data,
-                               num_filter=256,
-                               kernel=(9, 9),
-                               layout='NCHW',
-                               name='conv1')
-    conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act')
-    # net.shape = [batch_size, 256, 6, 6]
-
-    primarycaps = primary_caps(data=conv1,
-                               dim_vector=8,
-                               n_channels=32,
-                               kernel=(9, 9),
-                               strides=[2, 2],
-                               name='primarycaps')
-    primarycaps.infer_shape(data=(batch_size, 1, 28, 28))
-    # CapsuleLayer
-    kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', 
magnitude=3)
-    bias_initializer = mx.init.Zero()
-    digitcaps = CapsuleLayer(num_capsule=10,
-                             dim_vector=16,
-                             batch_size=batch_size,
-                             kernel_initializer=kernel_initializer,
-                             bias_initializer=bias_initializer,
-                             num_routing=num_routing)(primarycaps)
-
-    # out_caps : (batch_size, 10)
-    out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2))
-    out_caps.infer_shape(data=(batch_size, 1, 28, 28))
-
-    y = mx.sym.Variable('softmax_label', shape=(batch_size,))
-    y_onehot = mx.sym.one_hot(y, n_class)
-    y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class, 
-1))
-    y_reshaped.infer_shape(softmax_label=(batch_size,))
-
-    # inputs_masked : (batch_size, 16)
-    inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps, 
transpose_a=True)
-    inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0))
-    x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512, 
name='x_recon')
-    x_recon = mx.sym.Activation(data=x_recon, act_type='relu', 
name='x_recon_act')
-    x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024, 
name='x_recon2')
-    x_recon = mx.sym.Activation(data=x_recon, act_type='relu', 
name='x_recon_act2')
-    x_recon = mx.sym.FullyConnected(data=x_recon, 
num_hidden=np.prod(input_shape), name='x_recon3')
-    x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid', 
name='x_recon_act3')
-
-    data_flatten = mx.sym.flatten(data=data)
-    squared_error = mx.sym.square(x_recon-data_flatten)
-    recon_error = mx.sym.mean(squared_error)
-    recon_error_stopped = recon_error
-    recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped)
-    loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot, 
out_caps)+recon_loss_weight*recon_error)
-
-    out_caps_blocked = out_caps
-    out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked)
-    return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped])
-
-
-def download_data(url, force_download=False):
-    fname = url.split("/")[-1]
-    if force_download or not os.path.exists(fname):
-        urllib.urlretrieve(url, fname)
-    return fname
-
-
-def read_data(label_url, image_url):
-    with gzip.open(download_data(label_url)) as flbl:
-        magic, num = struct.unpack(">II", flbl.read(8))
-        label = np.fromstring(flbl.read(), dtype=np.int8)
-    with gzip.open(download_data(image_url), 'rb') as fimg:
-        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
-        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), 
rows, cols)
-    return label, image
-
-
-def to4d(img):
-    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
-
-
-class LossMetric(mx.metric.EvalMetric):
-    def __init__(self, batch_size, num_gpu):
-        super(LossMetric, self).__init__('LossMetric')
-        self.batch_size = batch_size
-        self.num_gpu = num_gpu
-        self.sum_metric = 0
-        self.num_inst = 0
-        self.loss = 0.0
-        self.batch_sum_metric = 0
-        self.batch_num_inst = 0
-        self.batch_loss = 0.0
-        self.recon_loss = 0.0
-        self.n_batch = 0
-
-    def update(self, labels, preds):
-        batch_sum_metric = 0
-        batch_num_inst = 0
-        for label, pred_outcaps in zip(labels[0], preds[0]):
-            label_np = int(label.asnumpy())
-            pred_label = int(np.argmax(pred_outcaps.asnumpy()))
-            batch_sum_metric += int(label_np == pred_label)
-            batch_num_inst += 1
-        batch_loss = preds[1].asnumpy()
-        recon_loss = preds[2].asnumpy()
-        self.sum_metric += batch_sum_metric
-        self.num_inst += batch_num_inst
-        self.loss += batch_loss
-        self.recon_loss += recon_loss
-        self.batch_sum_metric = batch_sum_metric
-        self.batch_num_inst = batch_num_inst
-        self.batch_loss = batch_loss
-        self.n_batch += 1 
-
-    def get_name_value(self):
-        acc = float(self.sum_metric)/float(self.num_inst)
-        mean_loss = self.loss / float(self.n_batch)
-        mean_recon_loss = self.recon_loss / float(self.n_batch)
-        return acc, mean_loss, mean_recon_loss
-
-    def get_batch_log(self, n_batch):
-        print("n_batch :"+str(n_batch)+" batch_acc:" +
-              str(float(self.batch_sum_metric) / float(self.batch_num_inst)) +
-              ' batch_loss:' + 
str(float(self.batch_loss)/float(self.batch_num_inst)))
-        self.batch_sum_metric = 0
-        self.batch_num_inst = 0
-        self.batch_loss = 0.0
-
-    def reset(self):
-        self.sum_metric = 0
-        self.num_inst = 0
-        self.loss = 0.0
-        self.recon_loss = 0.0
-        self.n_batch = 0
-
-
-class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
-    """A simple lr schedule that simply return `dynamic_lr`. We will set 
`dynamic_lr`
-    dynamically based on performance on the validation set.
-    """
-
-    def __init__(self, learning_rate=0.001):
-        super(SimpleLRScheduler, self).__init__()
-        self.learning_rate = learning_rate
-
-    def __call__(self, num_update):
-        return self.learning_rate
-
-
-def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, 
decay):
-    summary_writer = SummaryWriter(args.tblog_dir)
-    lr_scheduler = SimpleLRScheduler(learning_rate)
-    optimizer_params = {'lr_scheduler': lr_scheduler}
-    module.init_params()
-    module.init_optimizer(kvstore=kvstore,
-                          optimizer=optimizer,
-                          optimizer_params=optimizer_params)
-    n_epoch = 0
-    while True:
-        if n_epoch >= num_epoch:
-            break
-        train_iter.reset()
-        val_iter.reset()
-        loss_metric.reset()
-        for n_batch, data_batch in enumerate(train_iter):
-            module.forward_backward(data_batch)
-            module.update()
-            module.update_metric(loss_metric, data_batch.label)
-            loss_metric.get_batch_log(n_batch)
-        train_acc, train_loss, train_recon_err = loss_metric.get_name_value()
-        loss_metric.reset()
-        for n_batch, data_batch in enumerate(val_iter):
-            module.forward(data_batch)
-            module.update_metric(loss_metric, data_batch.label)
-            loss_metric.get_batch_log(n_batch)
-        val_acc, val_loss, val_recon_err = loss_metric.get_name_value()
-
-        summary_writer.add_scalar('train_acc', train_acc, n_epoch)
-        summary_writer.add_scalar('train_loss', train_loss, n_epoch)
-        summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch)
-        summary_writer.add_scalar('val_acc', val_acc, n_epoch)
-        summary_writer.add_scalar('val_loss', val_loss, n_epoch)
-        summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch)
-
-        print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % 
(n_epoch, train_acc, train_loss, train_recon_err))
-        print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, 
val_acc, val_loss, val_recon_err))
-        print('SAVE CHECKPOINT')
-
-        module.save_checkpoint(prefix=model_prefix, epoch=n_epoch)
-        n_epoch += 1
-        lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch)
-
-
-def apply_transform(x,
-                    transform_matrix,
-                    fill_mode='nearest',
-                    cval=0.):
-    x = np.rollaxis(x, 0, 0)
-    final_affine_matrix = transform_matrix[:2, :2]
-    final_offset = transform_matrix[:2, 2]
-    channel_images = [ndi.interpolation.affine_transform(
-        x_channel,
-        final_affine_matrix,
-        final_offset,
-        order=0,
-        mode=fill_mode,
-        cval=cval) for x_channel in x]
-    x = np.stack(channel_images, axis=0)
-    x = np.rollaxis(x, 0, 0 + 1)
-    return x
-
-
-def random_shift(x, width_shift_fraction, height_shift_fraction):
-    tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) * 
x.shape[2]
-    ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) * 
x.shape[1]
-    shift_matrix = np.array([[1, 0, tx],
-                             [0, 1, ty],
-                             [0, 0, 1]])
-    x = apply_transform(x, shift_matrix, 'nearest')
-    return x
-
-def _shuffle(data, idx):
-    """Shuffle the data."""
-    shuffle_data = []
-
-    for k, v in data:
-        shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context)))
-
-    return shuffle_data
-
-class MNISTCustomIter(mx.io.NDArrayIter):
-    
-    def reset(self):
-        # shuffle data
-        if self.is_train:
-            np.random.shuffle(self.idx)
-            self.data = _shuffle(self.data, self.idx)
-            self.label = _shuffle(self.label, self.idx)
-        if self.last_batch_handle == 'roll_over' and self.cursor > 
self.num_data:
-            self.cursor = -self.batch_size + 
(self.cursor%self.num_data)%self.batch_size
-        else:
-            self.cursor = -self.batch_size
-    def set_is_train(self, is_train):
-        self.is_train = is_train
-    def next(self):
-        if self.iter_next():
-            if self.is_train:
-                data_raw_list = self.getdata()
-                data_shifted = []
-                for data_raw in data_raw_list[0]:
-                    data_shifted.append(random_shift(data_raw.asnumpy(), 0.1, 
0.1))
-                return mx.io.DataBatch(data=[mx.nd.array(data_shifted)], 
label=self.getlabel(),
-                                       pad=self.getpad(), index=None)
-            else:
-                 return mx.io.DataBatch(data=self.getdata(), 
label=self.getlabel(), \
-                                  pad=self.getpad(), index=None)
-
-        else:
-            raise StopIteration
-
-
-if __name__ == "__main__":
-    # Read mnist data set
-    path = 'http://yann.lecun.com/exdb/mnist/'
-    (train_lbl, train_img) = read_data(
-        path + 'train-labels-idx1-ubyte.gz', path + 
'train-images-idx3-ubyte.gz')
-    (val_lbl, val_img) = read_data(
-        path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz')
-    # set batch size
-    import argparse
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--batch_size', default=100, type=int)
-    parser.add_argument('--devices', default='gpu0', type=str)
-    parser.add_argument('--num_epoch', default=100, type=int)
-    parser.add_argument('--lr', default=0.001, type=float)
-    parser.add_argument('--num_routing', default=3, type=int)
-    parser.add_argument('--model_prefix', default='capsnet', type=str)
-    parser.add_argument('--decay', default=0.9, type=float)
-    parser.add_argument('--tblog_dir', default='tblog', type=str)
-    parser.add_argument('--recon_loss_weight', default=0.392, type=float)
-    args = parser.parse_args()
-    for k, v in sorted(vars(args).items()):
-        print("{0}: {1}".format(k, v))
-    contexts = re.split(r'\W+', args.devices)
-    for i, ctx in enumerate(contexts):
-        if ctx[:3] == 'gpu':
-            contexts[i] = mx.context.gpu(int(ctx[3:]))
-        else:
-            contexts[i] = mx.context.cpu()
-    num_gpu = len(contexts)
-
-    if args.batch_size % num_gpu != 0:
-        raise Exception('num_gpu should be positive divisor of batch_size')
-
-    # generate train_iter, val_iter
-    train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl, 
batch_size=args.batch_size, shuffle=True)
-    train_iter.set_is_train(True)
-    val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl, 
batch_size=args.batch_size,)
-    val_iter.set_is_train(False)
-    # define capsnet
-    final_net = capsnet(batch_size=args.batch_size/num_gpu, n_class=10, 
num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight)
-    # set metric
-    loss_metric = LossMetric(args.batch_size/num_gpu, 1)
-
-    # run model
-    module = mx.mod.Module(symbol=final_net, context=contexts, 
data_names=('data',), label_names=('softmax_label',))
-    module.bind(data_shapes=train_iter.provide_data,
-                label_shapes=val_iter.provide_label,
-                for_training=True)
-    do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device', 
learning_rate=args.lr,
-                model_prefix=args.model_prefix, decay=args.decay)
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import mxnet as mx
+import numpy as np
+import os
+import re
+import gzip
+import struct
+import scipy.ndimage as ndi
+from capsulelayers import primary_caps, CapsuleLayer
+
+from mxboard import SummaryWriter
+
+def margin_loss(y_true, y_pred):
+    loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\
+        0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1))
+    return mx.sym.mean(data=mx.sym.sum(loss, 1))
+
+
+def capsnet(batch_size, n_class, num_routing,recon_loss_weight):
+    # data.shape = [batch_size, 1, 28, 28]
+    data = mx.sym.Variable('data')
+
+    input_shape = (1, 28, 28)
+    # Conv2D layer
+    # net.shape = [batch_size, 256, 20, 20]
+    conv1 = mx.sym.Convolution(data=data,
+                               num_filter=256,
+                               kernel=(9, 9),
+                               layout='NCHW',
+                               name='conv1')
+    conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act')
+    # net.shape = [batch_size, 256, 6, 6]
+
+    primarycaps = primary_caps(data=conv1,
+                               dim_vector=8,
+                               n_channels=32,
+                               kernel=(9, 9),
+                               strides=[2, 2],
+                               name='primarycaps')
+    primarycaps.infer_shape(data=(batch_size, 1, 28, 28))
+    # CapsuleLayer
+    kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', 
magnitude=3)
+    bias_initializer = mx.init.Zero()
+    digitcaps = CapsuleLayer(num_capsule=10,
+                             dim_vector=16,
+                             batch_size=batch_size,
+                             kernel_initializer=kernel_initializer,
+                             bias_initializer=bias_initializer,
+                             num_routing=num_routing)(primarycaps)
+
+    # out_caps : (batch_size, 10)
+    out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2))
+    out_caps.infer_shape(data=(batch_size, 1, 28, 28))
+
+    y = mx.sym.Variable('softmax_label', shape=(batch_size,))
+    y_onehot = mx.sym.one_hot(y, n_class)
+    y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class, 
-1))
+    y_reshaped.infer_shape(softmax_label=(batch_size,))
+
+    # inputs_masked : (batch_size, 16)
+    inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps, 
transpose_a=True)
+    inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0))
+    x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512, 
name='x_recon')
+    x_recon = mx.sym.Activation(data=x_recon, act_type='relu', 
name='x_recon_act')
+    x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024, 
name='x_recon2')
+    x_recon = mx.sym.Activation(data=x_recon, act_type='relu', 
name='x_recon_act2')
+    x_recon = mx.sym.FullyConnected(data=x_recon, 
num_hidden=np.prod(input_shape), name='x_recon3')
+    x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid', 
name='x_recon_act3')
+
+    data_flatten = mx.sym.flatten(data=data)
+    squared_error = mx.sym.square(x_recon-data_flatten)
+    recon_error = mx.sym.mean(squared_error)
+    recon_error_stopped = recon_error
+    recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped)
+    loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot, 
out_caps)+recon_loss_weight*recon_error)
+
+    out_caps_blocked = out_caps
+    out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked)
+    return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped])
+
+
+def download_data(url, force_download=False):
+    fname = url.split("/")[-1]
+    if force_download or not os.path.exists(fname):
+        mx.test_utils.download(url, fname)
+    return fname
+
+
+def read_data(label_url, image_url):
+    with gzip.open(download_data(label_url)) as flbl:
+        magic, num = struct.unpack(">II", flbl.read(8))
+        label = np.fromstring(flbl.read(), dtype=np.int8)
+    with gzip.open(download_data(image_url), 'rb') as fimg:
+        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
+        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), 
rows, cols)
+    return label, image
+
+
+def to4d(img):
+    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
+
+
+class LossMetric(mx.metric.EvalMetric):
+    def __init__(self, batch_size, num_gpu):
+        super(LossMetric, self).__init__('LossMetric')
+        self.batch_size = batch_size
+        self.num_gpu = num_gpu
+        self.sum_metric = 0
+        self.num_inst = 0
+        self.loss = 0.0
+        self.batch_sum_metric = 0
+        self.batch_num_inst = 0
+        self.batch_loss = 0.0
+        self.recon_loss = 0.0
+        self.n_batch = 0
+
+    def update(self, labels, preds):
+        batch_sum_metric = 0
+        batch_num_inst = 0
+        for label, pred_outcaps in zip(labels[0], preds[0]):
+            label_np = int(label.asnumpy())
+            pred_label = int(np.argmax(pred_outcaps.asnumpy()))
+            batch_sum_metric += int(label_np == pred_label)
+            batch_num_inst += 1
+        batch_loss = preds[1].asnumpy()
+        recon_loss = preds[2].asnumpy()
+        self.sum_metric += batch_sum_metric
+        self.num_inst += batch_num_inst
+        self.loss += batch_loss
+        self.recon_loss += recon_loss
+        self.batch_sum_metric = batch_sum_metric
+        self.batch_num_inst = batch_num_inst
+        self.batch_loss = batch_loss
+        self.n_batch += 1 
+
+    def get_name_value(self):
+        acc = float(self.sum_metric)/float(self.num_inst)
+        mean_loss = self.loss / float(self.n_batch)
+        mean_recon_loss = self.recon_loss / float(self.n_batch)
+        return acc, mean_loss, mean_recon_loss
+
+    def get_batch_log(self, n_batch):
+        print("n_batch :"+str(n_batch)+" batch_acc:" +
+              str(float(self.batch_sum_metric) / float(self.batch_num_inst)) +
+              ' batch_loss:' + 
str(float(self.batch_loss)/float(self.batch_num_inst)))
+        self.batch_sum_metric = 0
+        self.batch_num_inst = 0
+        self.batch_loss = 0.0
+
+    def reset(self):
+        self.sum_metric = 0
+        self.num_inst = 0
+        self.loss = 0.0
+        self.recon_loss = 0.0
+        self.n_batch = 0
+
+
+class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
+    """A simple lr schedule that simply return `dynamic_lr`. We will set 
`dynamic_lr`
+    dynamically based on performance on the validation set.
+    """
+
+    def __init__(self, learning_rate=0.001):
+        super(SimpleLRScheduler, self).__init__()
+        self.learning_rate = learning_rate
+
+    def __call__(self, num_update):
+        return self.learning_rate
+
+
+def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, 
decay):
+    summary_writer = SummaryWriter(args.tblog_dir)
+    lr_scheduler = SimpleLRScheduler(learning_rate)
+    optimizer_params = {'lr_scheduler': lr_scheduler}
+    module.init_params()
+    module.init_optimizer(kvstore=kvstore,
+                          optimizer=optimizer,
+                          optimizer_params=optimizer_params)
+    n_epoch = 0
+    while True:
+        if n_epoch >= num_epoch:
+            break
+        train_iter.reset()
+        val_iter.reset()
+        loss_metric.reset()
+        for n_batch, data_batch in enumerate(train_iter):
+            module.forward_backward(data_batch)
+            module.update()
+            module.update_metric(loss_metric, data_batch.label)
+            loss_metric.get_batch_log(n_batch)
+        train_acc, train_loss, train_recon_err = loss_metric.get_name_value()
+        loss_metric.reset()
+        for n_batch, data_batch in enumerate(val_iter):
+            module.forward(data_batch)
+            module.update_metric(loss_metric, data_batch.label)
+            loss_metric.get_batch_log(n_batch)
+        val_acc, val_loss, val_recon_err = loss_metric.get_name_value()
+
+        summary_writer.add_scalar('train_acc', train_acc, n_epoch)
+        summary_writer.add_scalar('train_loss', train_loss, n_epoch)
+        summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch)
+        summary_writer.add_scalar('val_acc', val_acc, n_epoch)
+        summary_writer.add_scalar('val_loss', val_loss, n_epoch)
+        summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch)
+
+        print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % 
(n_epoch, train_acc, train_loss, train_recon_err))
+        print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, 
val_acc, val_loss, val_recon_err))
+        print('SAVE CHECKPOINT')
+
+        module.save_checkpoint(prefix=model_prefix, epoch=n_epoch)
+        n_epoch += 1
+        lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch)
+
+
+def apply_transform(x,
+                    transform_matrix,
+                    fill_mode='nearest',
+                    cval=0.):
+    x = np.rollaxis(x, 0, 0)
+    final_affine_matrix = transform_matrix[:2, :2]
+    final_offset = transform_matrix[:2, 2]
+    channel_images = [ndi.interpolation.affine_transform(
+        x_channel,
+        final_affine_matrix,
+        final_offset,
+        order=0,
+        mode=fill_mode,
+        cval=cval) for x_channel in x]
+    x = np.stack(channel_images, axis=0)
+    x = np.rollaxis(x, 0, 0 + 1)
+    return x
+
+
+def random_shift(x, width_shift_fraction, height_shift_fraction):
+    tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) * 
x.shape[2]
+    ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) * 
x.shape[1]
+    shift_matrix = np.array([[1, 0, tx],
+                             [0, 1, ty],
+                             [0, 0, 1]])
+    x = apply_transform(x, shift_matrix, 'nearest')
+    return x
+
+def _shuffle(data, idx):
+    """Shuffle the data."""
+    shuffle_data = []
+
+    for k, v in data:
+        shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context)))
+
+    return shuffle_data
+
+class MNISTCustomIter(mx.io.NDArrayIter):
+    
+    def reset(self):
+        # shuffle data
+        if self.is_train:
+            np.random.shuffle(self.idx)
+            self.data = _shuffle(self.data, self.idx)
+            self.label = _shuffle(self.label, self.idx)
+        if self.last_batch_handle == 'roll_over' and self.cursor > 
self.num_data:
+            self.cursor = -self.batch_size + 
(self.cursor%self.num_data)%self.batch_size
+        else:
+            self.cursor = -self.batch_size
+    def set_is_train(self, is_train):
+        self.is_train = is_train
+    def next(self):
+        if self.iter_next():
+            if self.is_train:
+                data_raw_list = self.getdata()
+                data_shifted = []
+                for data_raw in data_raw_list[0]:
+                    data_shifted.append(random_shift(data_raw.asnumpy(), 0.1, 
0.1))
+                return mx.io.DataBatch(data=[mx.nd.array(data_shifted)], 
label=self.getlabel(),
+                                       pad=self.getpad(), index=None)
+            else:
+                 return mx.io.DataBatch(data=self.getdata(), 
label=self.getlabel(), \
+                                  pad=self.getpad(), index=None)
+
+        else:
+            raise StopIteration
+
+
+if __name__ == "__main__":
+    # Read mnist data set
+    path = 'http://yann.lecun.com/exdb/mnist/'
+    (train_lbl, train_img) = read_data(
+        path + 'train-labels-idx1-ubyte.gz', path + 
'train-images-idx3-ubyte.gz')
+    (val_lbl, val_img) = read_data(
+        path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz')
+    # set batch size
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=100, type=int)
+    parser.add_argument('--devices', default='gpu0', type=str)
+    parser.add_argument('--num_epoch', default=100, type=int)
+    parser.add_argument('--lr', default=0.001, type=float)
+    parser.add_argument('--num_routing', default=3, type=int)
+    parser.add_argument('--model_prefix', default='capsnet', type=str)
+    parser.add_argument('--decay', default=0.9, type=float)
+    parser.add_argument('--tblog_dir', default='tblog', type=str)
+    parser.add_argument('--recon_loss_weight', default=0.392, type=float)
+    args = parser.parse_args()
+    for k, v in sorted(vars(args).items()):
+        print("{0}: {1}".format(k, v))
+    contexts = re.split(r'\W+', args.devices)
+    for i, ctx in enumerate(contexts):
+        if ctx[:3] == 'gpu':
+            contexts[i] = mx.context.gpu(int(ctx[3:]))
+        else:
+            contexts[i] = mx.context.cpu()
+    num_gpu = len(contexts)
+
+    if args.batch_size % num_gpu != 0:
+        raise Exception('num_gpu should be positive divisor of batch_size')
+
+    # generate train_iter, val_iter
+    train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl, 
batch_size=int(args.batch_size), shuffle=True)
+    train_iter.set_is_train(True)
+    val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl, 
batch_size=int(args.batch_size),)
+    val_iter.set_is_train(False)
+    # define capsnet
+    final_net = capsnet(batch_size=int(args.batch_size/num_gpu), n_class=10, 
num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight)
+    # set metric
+    loss_metric = LossMetric(args.batch_size/num_gpu, 1)
+
+    # run model
+    module = mx.mod.Module(symbol=final_net, context=contexts, 
data_names=('data',), label_names=('softmax_label',))
+    module.bind(data_shapes=train_iter.provide_data,
+                label_shapes=val_iter.provide_label,
+                for_training=True)
+    do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device', 
learning_rate=args.lr,
+                model_prefix=args.model_prefix, decay=args.decay)

Reply via email to