sandeep-krishnamurthy commented on a change in pull request #12376: [MXNET-854] 
SVRG Optimization in Python Module API
URL: https://github.com/apache/incubator-mxnet/pull/12376#discussion_r213412718
 
 

 ##########
 File path: contrib/svrg_optimization_python/src/svrg_module.py
 ##########
 @@ -0,0 +1,581 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A `SVRGModule` implements the `Module` API by wrapping an auxiliary module 
to perform
+SVRG optimization logic.
+"""
+
+import mxnet as mx
+import time
+import logging
+from svrg_optimizer import SVRGOptimizer
+from mxnet.module import Module
+
+
+class SVRGModule(Module):
+    """SVRGModule is a module that encapsulates two Modules to accommodate the 
SVRG optimization technique.
+    It is functionally the same as Module API, except it is implemented using 
SVRG optimization logic.
+
+    Parameters
+    ----------
+    symbol : Symbol
+    data_names : list of str
+        Defaults to `('data')` for a typical model used in image 
classification.
+    label_names : list of str
+        Defaults to `('softmax_label')` for a typical model used in image
+        classification.
+    logger : Logger
+        Defaults to `logging`.
+    context : Context or list of Context
+        Defaults to ``mx.cpu()``.
+    work_load_list : list of number
+        Default ``None``, indicating uniform workload.
+    fixed_param_names: list of str
+        Default ``None``, indicating no network parameters are fixed.
+    state_names : list of str
+        states are similar to data and label, but not provided by data 
iterator.
+        Instead they are initialized to 0 and can be set by `set_states()`.
+    group2ctxs : dict of str to context or list of context,
+                 or list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context 
assignment.
+    compression_params : dict
+        Specifies type of gradient compression and additional arguments 
depending
+        on the type of compression being used. For example, 2bit compression 
requires a threshold.
+        Arguments would then be {'type':'2bit', 'threshold':0.5}
+        See mxnet.KVStore.set_gradient_compression method for more details on 
gradient compression.
+    update_freq: int
+        Specifies the number of times to update the full gradients to be used 
in the SVRG optimization. For instance,
+        update_freq = 2 will calculates the gradients over all data every two 
epochs
+    Examples
+    --------
+    >>> # An example of declaring and using SVRGModule.
+    >>> mod = mod = SVRGModule(symbol=lro, data_names=['data'], 
label_names=['lin_reg_label'], update_freq=2)
+    >>> mod.fit(di, eval_metric='mse', optimizer='sgd', 
optimizer_params=(('learning_rate', 0.025),),
+    >>>         num_epoch=num_epoch, kvstore='local')
+    """
+
+    def __init__(self, symbol, data_names=('data',), 
label_names=('softmax_label',),
+                 logger=logging, context=mx.cpu(), work_load_list=None,
+                 fixed_param_names=None, state_names=None, group2ctxs=None,
+                 compression_params=None, update_freq=None):
+        super(SVRGModule, self).__init__(symbol, data_names=data_names, 
label_names=label_names, logger=logger,
+                                         context=context, 
work_load_list=work_load_list,
+                                         fixed_param_names=fixed_param_names, 
state_names=state_names,
+                                         group2ctxs=group2ctxs, 
compression_params=compression_params)
+
+        # Type check update_frequency
+        if isinstance(update_freq, int):
+            self.update_freq = update_freq
+        else:
+            raise TypeError("update_freq must be an integer")
+
+        self._mod_aux = mx.mod.Module(symbol, data_names, label_names, logger, 
context, work_load_list,
+                                      fixed_param_names, state_names, 
group2ctxs, compression_params)
+
+        self._param_dict = [{} for ctx in self._context]
+
+    def _reset_bind(self):
+        """Internal function to reset binded state."""
+        super(SVRGModule, self)._reset_bind()
+        self._mod_aux._reset_bind()
+
+
+    def reshape(self, data_shapes, label_shapes=None):
+        super(SVRGModule, self).reshape(data_shapes, label_shapes=label_shapes)
+        self._mod_aux.reshape(data_shapes, label_shapes=label_shapes)
+
+    def init_optimizer(self, kvstore='local', optimizer='sgd',
+                       optimizer_params=(('learning_rate', 0.01),), 
force_init=False):
+        """Installs and initializes SVRG optimizers. The SVRGOptimizer is a 
wrapper for normal SGD optimizer
+        and special AssignmentOptimizer to accumulate gradients. If KVStore 
exists, additional keys will be
+        pushed to the kvstore for accumulating full grads.
+
+        Parameters
+        ----------
+        kvstore : str or KVStore
+            Default `'local'`.
+        optimizer : str or Optimizer
+            Default `'sgd'`
+        optimizer_params : dict
+            Default `(('learning_rate', 0.01),)`. The default value is not a 
dictionary,
+            just to avoid pylint warning of dangerous default values.
+        force_init : bool
+            Default ``False``, indicating whether we should force 
re-initializing the
+            optimizer in the case an optimizer is already installed.
+        """
+        # Init dict for storing average of full gradients for each device
+        for i in range(len(self._context)):
+            self._param_dict[i] = {key: mx.nd.zeros(shape=value.shape, 
ctx=self._context[i])
+                                   for key, value in 
self.get_params()[0].items()}
+
+        svrg_optimizer = self._create_optimizer(SVRGOptimizer.__name__, 
default_opt=optimizer,
+                                                kvstore=kvstore, 
optimizer_params=optimizer_params)
+
+        super(SVRGModule, self).init_optimizer(kvstore=kvstore, 
optimizer=svrg_optimizer,
+                                optimizer_params=optimizer_params, 
force_init=force_init)
+
+        # Init additional keys for accumulating full grads in KVStore
+        if self._kvstore:
+            for idx, param_on_devs in enumerate(self._exec_group.param_arrays):
+                name = self._exec_group.param_names[idx]
+                self._kvstore.init(name + "_full", 
mx.nd.zeros(shape=self._arg_params[name].shape))
+                if self._update_on_kvstore:
+                    self._kvstore.pull(name + "_full", param_on_devs, 
priority=-idx)
+
+    def _create_optimizer(self, optimizer, default_opt, kvstore, 
optimizer_params):
+        """Helper function to create a svrg optimizer. SVRG optimizer 
encapsulates two optimizers and
+        will redirect update() to the correct optimizer.
+
+        Parameters
+        ----------
+        kvstore : str or KVStore
+            Default `'local'`.
+        optimizer: str
+            Name for SVRGOptimizer
+        default_opt : str or Optimizer that was passed in.
+        optimizer_params : dict
+           optimizer params that was passed in.
+        """
+
+        # code partially copied from mxnet module.init_optimizer() to 
accomodate svrg_optimizer
+        batch_size = self._exec_group.batch_size
+
+        (kv_store, update_on_kvstore) = mx.model._create_kvstore(kvstore, 
len(self._context), self._arg_params)
+        if kv_store and 'dist' in kv_store.type and '_sync' in kv_store.type:
+            batch_size *= kv_store.num_workers
+        rescale_grad = 1.0 / batch_size
+
+        idx2name = {}
+        if update_on_kvstore:
+            idx2name.update(enumerate(self._exec_group.param_names))
+        else:
+            for k in range(len(self._context)):
+                idx2name.update({i * len(self._context) + k: n
+                                 for i, n in 
enumerate(self._exec_group.param_names)})
+
+        # update idx2name to include new keys
+        for key in self._param_dict[0].keys():
+            max_key = max(list(idx2name.keys()))
+            max_key += 1
+            idx2name[max_key] = key + "_full"
+
+        optimizer_params = dict(optimizer_params)
+        if 'rescale_grad' not in optimizer_params:
+            optimizer_params['rescale_grad'] = rescale_grad
+        optimizer_params["default_optimizer"] = default_opt
+        optimizer_params["param_idx2name"] = idx2name
+        optimizer = mx.optimizer.create(optimizer, **optimizer_params)
+
+        return optimizer
+
+    def bind(self, data_shapes, label_shapes=None, for_training=True,
+             inputs_need_grad=False, force_rebind=False, shared_module=None,
+             grad_req='write'):
+        """Binds the symbols to construct executors for both two modules. This 
is necessary before one
+        can perform computation with the SVRGModule.
+
+        Parameters
+        ----------
+        data_shapes : list of (str, tuple)
+            Typically is ``data_iter.provide_data``.
+        label_shapes : list of (str, tuple)
+            Typically is ``data_iter.provide_label``.
+        for_training : bool
+            Default is ``True``. Whether the executors should be bound for 
training.
+        inputs_need_grad : bool
+            Default is ``False``. Whether the gradients to the input data need 
to be computed.
+            Typically this is not needed. But this might be needed when 
implementing composition
+            of modules.
+        force_rebind : bool
+            Default is ``False``. This function does nothing if the executors 
are already
+            bound. But with this ``True``, the executors will be forced to 
rebind.
+        shared_module : Module
+            Default is ``None``. This is used in bucketing. When not ``None``, 
the shared module
+            essentially corresponds to a different bucket -- a module with 
different symbol
+            but with the same sets of parameters (e.g. unrolled RNNs with 
different lengths).
+        """
+        # force rebinding is typically used when one want to switch from
+        # training to prediction phase.
+
+        super(SVRGModule, self).bind(data_shapes, label_shapes, for_training, 
inputs_need_grad, force_rebind,
+                                     shared_module, grad_req)
+
+        if for_training:
+            self._mod_aux.bind(data_shapes, label_shapes, for_training, 
inputs_need_grad, force_rebind, shared_module,
+                               grad_req)
+
+    def forward(self, data_batch, is_train=None):
+        """Forward computation for both two modules. It supports data batches 
with different shapes, such as
+        different batch sizes or different image sizes.
+        If reshaping of data batch relates to modification of symbol or 
module, such as
+        changing image layout ordering or switching from training to 
predicting, module
+        rebinding is required.
+
+        See Also
+        ----------
+        :meth:`BaseModule.forward`.
+
+        Parameters
+        ----------
+        data_batch : DataBatch
+            Could be anything with similar API implemented.
+        is_train : bool
+            Default is ``None``, which means ``is_train`` takes the value of 
``self.for_training``.
+        """
+
+        super(SVRGModule, self).forward(data_batch, is_train)
+
+        if is_train:
+            self._mod_aux.forward(data_batch, is_train)
+
+    def backward(self, out_grads=None):
+        """Backward computation.
+
+        See Also
+        ----------
+        :meth:`BaseModule.backward`.
+
+        Parameters
+        ----------
+        out_grads : NDArray or list of NDArray, optional
+            Gradient on the outputs to be propagated back.
+            This parameter is only needed when bind is called
+            on outputs that are not a loss function.
+        """
+
+        super(SVRGModule, self).backward(out_grads)
+
+        if self._mod_aux.binded:
+            self._mod_aux.backward()
+
+    def update(self):
+        """Updates parameters according to the installed optimizer and the 
gradients computed
+        in the previous forward-backward batch. The gradients in the 
_exec_group will be overwritten
+        using the gradients calculated by the SVRG update rule.
+
+        When KVStore is used to update parameters for multi-device or 
multi-machine training,
+        a copy of the parameters is stored in KVStore. Note that for 
`row_sparse` parameters,
+        this function does update the copy of parameters in KVStore, but 
doesn't broadcast the
+        updated parameters to all devices / machines. Please call `prepare` to 
broadcast
+        `row_sparse` parameters with the next batch of data.
+
+        See Also
+        ----------
+        :meth:`BaseModule.update`.
+        """
+
+        self.update_svrg_gradients()
+        super(SVRGModule, self).update()
+
+    def update_full_grads(self, train_data):
+        """Computes the gradients over all data w.r.t weights of past
+        m epochs. For distributed env, it will accumulate full grads in the 
kvstore.
+
+        Parameters
+        ----------
+        train_data: DataIter
+            Train data iterator
+
+        """
+        param_names = self._exec_group.param_names
+        arg, aux = self.get_params()
+        self._mod_aux.set_params(arg_params=arg, aux_params=aux)
+        train_data.reset()
+        nbatch = 0
+        padding = 0
+        for batch in train_data:
+            self._mod_aux.forward(batch, is_train=True)
+            self._mod_aux.backward()
+            nbatch += 1
+
+            for i in range(len(self._context)):
+                for j in range(len(param_names)):
+                    grads = self._mod_aux._exec_group.grad_arrays[j][i]
+                    self._param_dict[i][param_names[j]] = 
mx.nd.broadcast_add(self._param_dict[i][param_names[j]],
+                                                                              
grads, axis=0)
+            padding = batch.pad
+
+        # Average full gradients over number of batches, accumulate in the 
kvstore if kvstore is set
+        for i in range(len(self._context)):
+            for key in self._param_dict[i].keys():
+                self._param_dict[i][key] /= (nbatch - padding / 
train_data.batch_size)
+
+        if self._kvstore:
+            # Push a list of gradients from each device in the KVStore
+            for key in self._param_dict[0].keys():
+                grad_list = []
+                for i in range(len(self._param_dict)):
+                    grad_list.append(self._param_dict[i][key])
+
+                self._accumulate_kvstore(key, grad_list)
+
+    def _accumulate_kvstore(self, key, value):
+        """Accumulate gradients over all data in the KVstore. In distributed 
setting, each worker sees a portion of
+        data. The full gradients will be aggregated from each worker in the 
KVStore.
+
+        Parameters
+        ----------
+
+        key: int or str
+        Key in the kvstore.
+        value: NDArray, RowSparseNDArray
+        Average of the full gradients.
+        """
+
+        # Accumulate full gradients for current epochs
+        self._kvstore.push(key + "_full", value)
+
+        self._kvstore._barrier()
+        self._kvstore.pull(key + "_full", value)
+        self._allocate_gradients(key, value)
+
+    def _allocate_gradients(self, key, value):
+        """Allocate averge of full gradients accumualted in the KVStore to 
each device.
+
+        Parameters
+        ----------
+
+        key: int or str
+        Key in the kvstore.
+        value: List of NDArray, List of RowSparseNDArray
+        A list of average of the full gradients in the KVStore.
+        """
+
+        num_device = len(self._context)
+        for i in range(len(self._param_dict)):
+            self._param_dict[i][key] = value[i] / num_device
+
+    def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, 
g_curr_batch_special_weight, g_special_weight_all_batch):
+        """Calculates the gradient based on the SVRG update rule.
+        Parameters
+        ----------
+        g_curr_batch_curr_weight : NDArray
+            gradients of current weight of self.mod w.r.t current batch of data
+        g_curr_batch_special_weight: NDArray
+            gradients of the weight of past m epochs of self._mod_special 
w.r.t current batch of data
+        g_special_weight_all_batch: NDArray
+            average of full gradients over full pass of data
+
+        Returns
+        ----------
+        Gradients calculated using SVRG update rule:
+        grads = g_curr_batch_curr_weight - g_curr_batch_special_weight + 
g_special_weight_all_batch
+        """
+        for i in range(len(g_curr_batch_curr_weight)):
+            g_curr_batch_curr_weight[i] -= g_curr_batch_special_weight[i]
+            g_curr_batch_curr_weight[i] += g_special_weight_all_batch[i]
+        return g_curr_batch_curr_weight
+
+    def update_svrg_gradients(self):
+        """Calculates gradients based on the SVRG update rule.
+        """
+        param_names = self._exec_group.param_names
+        for i in range(len(self._context)):
+            for j in range(len(param_names)):
+                g_curr_batch_reg = self._exec_group.grad_arrays[j][i]
+                g_curr_batch_special = 
self._mod_aux._exec_group.grad_arrays[j][i]
+                g_special_weight_all_batch = 
self._param_dict[i][param_names[j]]
+                g_svrg = self._svrg_grads_update_rule(g_curr_batch_reg, 
g_curr_batch_special, g_special_weight_all_batch)
+                self._exec_group.grad_arrays[j][i] = g_svrg
+
+    def fit(self, train_data, eval_data=None, eval_metric='acc',
+            epoch_end_callback=None, batch_end_callback=None, kvstore='local',
+            optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+            eval_end_callback=None,
+            eval_batch_end_callback=None, initializer=mx.init.Uniform(0.01),
+            arg_params=None, aux_params=None, allow_missing=False,
+            force_rebind=False, force_init=False, begin_epoch=0, 
num_epoch=None,
+            validation_metric=None, monitor=None, sparse_row_id_fn=None):
+        """Trains the module parameters.
+        Parameters
+        ----------
+        train_data : DataIter
+            Train DataIter.
+        eval_data : DataIter
+            If not ``None``, will be used as validation set and the performance
+            after each epoch will be evaluated.
+        eval_metric : str or EvalMetric
+            Defaults to 'accuracy'. The performance measure used to display 
during training.
+            Other possible predefined metrics are:
+            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
+        epoch_end_callback : function or list of functions
+            Each callback will be called with the current `epoch`, `symbol`, 
`arg_params`
+            and `aux_params`.
+        batch_end_callback : function or list of function
+            Each callback will be called with a `BatchEndParam`.
+        kvstore : str or KVStore
+            Defaults to 'local'.
+        optimizer : str or Optimizer
+            Defaults to 'sgd'.
+        optimizer_params : dict
+            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
+            the optimizer constructor.
+            The default value is not a dict, just to avoid pylint warning on 
dangerous
+            default values.
+        eval_end_callback : function or list of function
+            These will be called at the end of each full evaluation, with the 
metrics over
+            the entire evaluation set.
+        eval_batch_end_callback : function or list of function
+            These will be called at the end of each mini-batch during 
evaluation.
+        initializer : Initializer
+            The initializer is called to initialize the module parameters when 
they are
+            not already initialized.
+        arg_params : dict
+            Defaults to ``None``, if not ``None``, should be existing 
parameters from a trained
+            model or loaded from a checkpoint (previously saved model). In 
this case,
+            the value here will be used to initialize the module parameters, 
unless they
+            are already initialized by the user via a call to `init_params` or 
`fit`.
+            `arg_params` has a higher priority than `initializer`.
+        aux_params : dict
+            Defaults to ``None``. Similar to `arg_params`, except for 
auxiliary states.
+        allow_missing : bool
+            Defaults to ``False``. Indicates whether to allow missing 
parameters when `arg_params`
+            and `aux_params` are not ``None``. If this is ``True``, then the 
missing parameters
+            will be initialized via the `initializer`.
+        force_rebind : bool
+            Defaults to ``False``. Whether to force rebinding the executors if 
already bound.
+        force_init : bool
+            Defaults to ``False``. Indicates whether to force initialization 
even if the
+            parameters are already initialized.
+        begin_epoch : int
+            Defaults to 0. Indicates the starting epoch. Usually, if resumed 
from a
+            checkpoint saved at a previous training phase at epoch N, then 
this value should be
+            N+1.
+        num_epoch : int
+            Number of epochs for training.
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the 
param,
+            and the value is the row id of the param to pull.
+        """
+
+        assert num_epoch is not None, 'please specify number of epochs'
+
+        self.bind(data_shapes=train_data.provide_data, 
label_shapes=train_data.provide_label,
+                  for_training=True, force_rebind=force_rebind)
+        if monitor is not None:
+            self.install_monitor(monitor)
+        self.init_params(initializer=initializer, arg_params=arg_params, 
aux_params=aux_params,
+                         allow_missing=allow_missing, force_init=force_init)
+        self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
+                            optimizer_params=optimizer_params)
+
+        if validation_metric is None:
+            validation_metric = eval_metric
+        if not isinstance(eval_metric, mx.metric.EvalMetric):
+            eval_metric = mx.metric.create(eval_metric)
+
+        
################################################################################
+        # training loop
+        
################################################################################
+        for epoch in range(begin_epoch, num_epoch):
+            eval_metric.reset()
+            tic = time.time()
+            if epoch % (self.update_freq) == 0:
+                self.update_full_grads(train_data)
+
+            train_data.reset()
+            data_iter = iter(train_data)
+            end_of_batch = False
+            nbatch = 0
+            next_data_batch = next(data_iter)
+
+            while not end_of_batch:
+                data_batch = next_data_batch
+                if monitor is not None:
+                    monitor.tic()
+
+                self.forward_backward(data_batch)
+                self.update()
+
+                if isinstance(data_batch, list):
+                    self.update_metric(eval_metric,
+                                       [db.label for db in data_batch],
+                                       pre_sliced=True)
+                else:
+                    self.update_metric(eval_metric, data_batch.label)
+
+                try:
+                    # pre fetch next batch
+                    next_data_batch = next(data_iter)
+                    self.prepare(next_data_batch, 
sparse_row_id_fn=sparse_row_id_fn)
+                except StopIteration:
+                    end_of_batch = True
+
+                if monitor is not None:
+                    monitor.toc_print()
+
+                if end_of_batch:
+                    eval_name_vals = eval_metric.get_name_value()
+
+                if batch_end_callback is not None:
+                    batch_end_params = mx.model.BatchEndParam(epoch=epoch, 
nbatch=nbatch,
+                                                              
eval_metric=eval_metric, locals=locals())
+                    for callback in mx.base._as_list(batch_end_callback):
+                        callback(batch_end_params)
+
+                nbatch += 1
+            for name, val in eval_name_vals:
+                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
+            toc = time.time()
+            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
+            print('Epoch[%d] Time cost=%.3f', epoch, eval_metric.get())
 
 Review comment:
   please use logger instead of print every where. example self.logger.info

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to