This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b235cf7  Update doc for sparse related APIs (#7688)
b235cf7 is described below

commit b235cf7778dba503008d76e223dfffe638f95654
Author: Haibin Lin <linhaibin.e...@gmail.com>
AuthorDate: Wed Sep 6 11:59:55 2017 -0700

    Update doc for sparse related APIs (#7688)
    
    * update doc for NDArrayIter
    
    * update kvstore doc. remove int64 restriction for row_ids in rowsparse pull
    
    * add exception test for sparse op
    
    * update doc for optimizers
    
    * update doc for sgd/adam operators
    
    * CR comment and fix lint
---
 docs/mxdoc.py                                |  3 ++-
 python/mxnet/io.py                           | 21 ++++++++++++++-
 python/mxnet/kvstore.py                      | 39 ++++++++++++++++++++--------
 python/mxnet/optimizer.py                    | 34 +++++++++++++++++++++---
 src/kvstore/kvstore_local.h                  |  3 ++-
 src/operator/optimizer_op.cc                 | 23 +++++++++++++---
 tests/python/unittest/test_io.py             |  2 +-
 tests/python/unittest/test_kvstore.py        |  4 +--
 tests/python/unittest/test_module.py         |  8 +++---
 tests/python/unittest/test_sparse_ndarray.py |  8 ++++++
 10 files changed, 117 insertions(+), 28 deletions(-)

diff --git a/docs/mxdoc.py b/docs/mxdoc.py
index 2726a1c..010118c 100644
--- a/docs/mxdoc.py
+++ b/docs/mxdoc.py
@@ -23,7 +23,8 @@ import json
 import sys
 from recommonmark import transform
 import pypandoc
-import StringIO
+# import StringIO from io for python3 compatibility
+from io import StringIO
 import contextlib
 
 # white list to evaluate the code block output, such as ['tutorials/gluon']
diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index b169681..0665101 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -34,6 +34,7 @@ from .base import DataIterHandle, NDArrayHandle
 from .base import mx_real_t
 from .base import check_call, build_param_doc as _build_param_doc
 from .ndarray import NDArray
+from .ndarray.sparse import CSRNDArray
 from .ndarray import _ndarray_cls
 from .ndarray import array
 from .ndarray import concatenate
@@ -513,7 +514,8 @@ def _init_data(data, allow_empty, default_name):
     return list(data.items())
 
 class NDArrayIter(DataIter):
-    """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray`` or 
``h5py.Dataset``.
+    """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, 
``h5py.Dataset``
+    or ``mx.nd.sparse.CSRNDArray``.
 
     Example usage:
     ----------
@@ -576,6 +578,18 @@ class NDArrayIter(DataIter):
     >>> label = {'label1':np.zeros(shape=(10,1)), 
'label2':np.zeros(shape=(20,1))}
     >>> dataiter = mx.io.NDArrayIter(data, label, 3, True, 
last_batch_handle='discard')
 
+    `NDArrayIter` also supports ``mx.nd.sparse.CSRNDArray`` with `shuffle` set 
to `False`
+    and `last_batch_handle` set to `discard`.
+
+    >>> csr_data = mx.nd.array(np.arange(40).reshape((10,4))).tostype('csr')
+    >>> labels = np.ones([10, 1])
+    >>> dataiter = mx.io.NDArrayIter(csr_data, labels, 3, 
last_batch_handle='discard')
+    >>> [batch.data[0] for batch in dataiter]
+    [
+    <CSRNDArray 3x4 @cpu(0)>,
+    <CSRNDArray 3x4 @cpu(0)>,
+    <CSRNDArray 3x4 @cpu(0)>]
+
     Parameters
     ----------
     data: array or list of array or dict of string to array
@@ -603,6 +617,11 @@ class NDArrayIter(DataIter):
 
         self.data = _init_data(data, allow_empty=False, default_name=data_name)
         self.label = _init_data(label, allow_empty=True, 
default_name=label_name)
+        if isinstance(data, CSRNDArray) or isinstance(label, CSRNDArray):
+            assert(shuffle is False), \
+                  "`NDArrayIter` only supports ``CSRNDArray`` with `shuffle` 
set to `False`"
+            assert(last_batch_handle == 'discard'), "`NDArrayIter` only 
supports ``CSRNDArray``" \
+                                                    " with `last_batch_handle` 
set to `discard`."
 
         self.idx = np.arange(self.data[0][1].shape[0])
         # shuffle data
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 0da9e39..cf5c159 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -104,7 +104,7 @@ class KVStore(object):
         ----------
         key : str, int, or sequence of str or int
             The keys.
-        value : NDArray or sequence of NDArray
+        value : NDArray, RowSparseNDArray or sequence of NDArray or 
RowSparseNDArray
             Values corresponding to the keys.
 
         Examples
@@ -120,8 +120,15 @@ class KVStore(object):
         [ 2.  2.  2.]]
 
         >>> # init a list of key-value pairs
-        >>> keys = [5, 7, 9]
+        >>> keys = ['5', '7', '9']
         >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
+
+        >>> # init a row_sparse value
+        >>> kv.init('4', mx.nd.ones(shape).tostype('row_sparse'))
+        >>> b = mx.nd.sparse.zeros('row_sparse', shape)
+        >>> kv.row_sparse_pull('4', row_ids=mx.nd.array([0, 1]), out=b)
+        >>> print b
+        <RowSparseNDArray 2x3 @cpu(0)>
         """
         ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
         if use_str_keys:
@@ -143,7 +150,8 @@ class KVStore(object):
         key : str, int, or sequence of str or int
             Keys.
 
-        value : NDArray or list of NDArray or list of list of NDArray
+        value : NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray,
+                or list of list of NDArray or RowSparseNDArray
             Values corresponding to the keys.
 
         priority : int, optional
@@ -171,7 +179,7 @@ class KVStore(object):
 
         >>> # push a list of keys.
         >>> # single device
-        >>> keys = [4, 5, 6]
+        >>> keys = ['4', '5', '6']
         >>> kv.push(keys, [mx.nd.ones(shape)]*len(keys))
         >>> b = [mx.nd.zeros(shape)]*len(keys)
         >>> kv.pull(keys, out=b)
@@ -187,6 +195,15 @@ class KVStore(object):
         >>> print b[1][1].asnumpy()
         [[ 4.  4.  4.]
         [ 4.  4.  4.]]
+
+        >>> # push a row_sparse value
+        >>> b = mx.nd.sparse.zeros('row_sparse', shape)
+        >>> kv.init('10', mx.nd.sparse.zeros('row_sparse', shape))
+        >>> kv.push('10', mx.nd.ones(shape).tostype('row_sparse'))
+        >>> # pull out the value
+        >>> kv.row_sparse_pull('10', row_ids=mx.nd.array([0, 1]), out=b)
+        >>> print b
+        <RowSparseNDArray 2x3 @cpu(0)>
         """
         ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
         if use_str_keys:
@@ -209,7 +226,7 @@ class KVStore(object):
 
         The returned values are gauranteed to be the latest values in the 
store.
 
-        For row_sparse values, please use `row_sparse_pull` instead.
+        For `RowSparseNDArray` values, please use ``row_sparse_pull`` instead.
 
         Parameters
         ----------
@@ -242,7 +259,7 @@ class KVStore(object):
 
         >>> # pull a list of key-value pairs.
         >>> # On single device
-        >>> keys = [5, 7, 9]
+        >>> keys = ['5', '7', '9']
         >>> b = [mx.nd.zeros(shape)]*len(keys)
         >>> kv.pull(keys, out=b)
         >>> print b[1].asnumpy()
@@ -266,8 +283,8 @@ class KVStore(object):
                 self.handle, mx_uint(len(ckeys)), ckeys, cvals, 
ctypes.c_int(priority)))
 
     def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
-        """ Pulls a single row_sparse value or a sequence of row_sparse values 
from the store
-         with specified row_ids.
+        """ Pulls a single RowSparseNDArray value or a sequence of 
RowSparseNDArray values \
+        from the store with specified row_ids.
 
         `row_sparse_pull` is executed asynchronously after all previous
         `push`/`pull`/`row_sparse_pull` calls for the same input key(s) are 
finished.
@@ -279,7 +296,7 @@ class KVStore(object):
         key : str, int, or sequence of str or int
             Keys.
 
-        out: NDArray or list of NDArray or list of list of NDArray
+        out: RowSparseNDArray or list of RowSparseNDArray or list of list of 
RowSparseNDArray
             Values corresponding to the keys. The stype is expected to be 
row_sparse
 
         priority : int, optional
@@ -288,14 +305,14 @@ class KVStore(object):
             other pull actions.
 
         row_ids : NDArray or list of NDArray
-            The row_ids for which to pull for each value. Each row_id is an 
1D-NDArray \
+            The row_ids for which to pull for each value. Each row_id is an 1D 
NDArray \
             whose values don't have to be unique nor sorted.
 
         Examples
         --------
         >>> shape = (3, 3)
         >>> kv.init('3', mx.nd.ones(shape).tostype('row_sparse'))
-        >>> a = mx.nd.zeros(shape, stype='row_sparse')
+        >>> a = mx.nd.sparse.zeros('row_sparse', shape)
         >>> row_ids = mx.nd.array([0, 2], dtype='int64')
         >>> kv.row_sparse_pull('3', out=a, row_ids=row_ids)
         >>> print a.asnumpy()
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index cc34c47..967ba24 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -18,6 +18,7 @@
 # coding: utf-8
 # pylint: disable=too-many-lines
 """Weight updating functions."""
+# pylint: disable=too-many-lines
 import math
 import pickle
 import logging
@@ -333,16 +334,26 @@ class Optimizer(object):
 # convenience wrapper for Optimizer.Register
 register = Optimizer.register   # pylint: disable=invalid-name
 
+# pylint: disable=line-too-long
 @register
 class SGD(Optimizer):
     """The SGD optimizer with momentum and weight decay.
 
     The optimizer updates the weight by::
 
-        state = momentum * state + lr * rescale_grad * clip(grad, 
clip_gradient) + wd * weight
+        rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * 
weight
+        state = momentum * state + rescaled_grad
         weight = weight - state
 
-    Sparse updating is supported. For details of the update algorithm see
+    If the storage types of weight, state and grad are all ``row_sparse``, \
+    sparse updates are applied by::
+
+        for row in grad.indices:
+            rescaled_grad[row] = lr * rescale_grad * clip(grad[row], 
clip_gradient) + wd * weight[row]
+            state[row] = momentum[row] * state[row] + rescaled_grad[row]
+            weight[row] = weight[row] - state[row]
+
+    For details of the update algorithm see
     :class:`~mxnet.ndarray.sgd_update` and 
:class:`~mxnet.ndarray.sgd_mom_update`.
 
     This optimizer accepts the following parameters in addition to those 
accepted
@@ -411,6 +422,7 @@ class SGD(Optimizer):
                 mp_sgd_update(weight, grad, state[1], out=weight,
                               lr=lr, wd=wd, **kwargs)
 
+# pylint: enable=line-too-long
 @register
 class DCASGD(Optimizer):
     """The DCASGD optimizer.
@@ -545,10 +557,26 @@ class Adam(Optimizer):
     This class implements the optimizer described in *Adam: A Method for
     Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
 
+    The optimizer updates the weight by::
+
+        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
+        m = beta1 * m + (1 - beta1) * rescaled_grad
+        v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
+        w = w - learning_rate * m / (sqrt(v) + epsilon)
+
+    If the storage types of weight, state and grad are all ``row_sparse``, \
+    sparse updates are applied by::
+
+        for row in grad.indices:
+            rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * 
weight[row], clip_gradient)
+            m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row]
+            v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
+            w[row] = w[row] - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
+
     This optimizer accepts the following parameters in addition to those 
accepted
     by :class:`.Optimizer`.
 
-    For details of the update algorithm, see :class:`ndarray.adam_update`.
+    For details of the update algorithm, see 
:class:`~mxnet.ndarray.adam_update`.
 
     Parameters
     ----------
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index e05819b..db1d04a 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -216,7 +216,8 @@ class KVStoreLocal : public KVStore {
       const size_t num_vals = target_val_rowids.size();
       for (size_t i = 0; i < num_vals; i++) {
         auto &row_id = target_val_rowids[i].second;
-        NDArray indices = row_id.Copy(pinned_ctx_);
+        NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
+        CopyFromTo(row_id, &indices, 0);
         Unique(&indices, priority);
         target_val_rowids[i].second = indices;
       }
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 9b2b088..eace28a 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -40,8 +40,11 @@ It updates the weights using::
 
  weight = weight - learning_rate * gradient
 
-If weight is stored with `row_sparse` storage type,
-only the row slices whose indices appear in grad.indices are updated.
+If weight is of ``row_sparse`` storage type,
+only the row slices whose indices appear in grad.indices are updated::
+
+ for row in gradient.indices:
+     weight[row] = weight[row] - learning_rate * gradient[row]
 
 )code" ADD_FILELINE)
 .set_num_inputs(2)
@@ -74,8 +77,12 @@ It updates the weights using::
 
 Where the parameter ``momentum`` is the decay rate of momentum estimates at 
each epoch.
 
-If weights are stored with `row_sparse` storage type,
-only the row slices whose indices appear in grad.indices are updated (for both 
weight and momentum).
+If weight and momentum are both of ``row_sparse`` storage type,
+only the row slices whose indices appear in grad.indices are updated (for both 
weight and momentum)::
+
+  for row in gradient.indices:
+      v[row] = momentum[row] * v[row] - learning_rate * gradient[row]
+      weight[row] += v[row]
 
 )code" ADD_FILELINE)
 .set_num_inputs(3)
@@ -149,6 +156,14 @@ It updates the weights using::
  v = beta2*v + (1-beta2)*(grad**2)
  w += - learning_rate * m / (sqrt(v) + epsilon)
 
+If w, m and v are all of ``row_sparse`` storage type,
+only the row slices whose indices appear in grad.indices are updated (for w, m 
and v)::
+
+ for row in grad.indices:
+     m[row] = beta1*m[row] + (1-beta1)*grad[row]
+     v[row] = beta2*v[row] + (1-beta2)*(grad[row]**2)
+     w[row] += - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
+
 )code" ADD_FILELINE)
 .set_num_inputs(4)
 .set_num_outputs(1)
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index 6ec462e..a1f14ef 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -163,7 +163,7 @@ def test_NDArrayIter_csr():
     dns = csr.asnumpy()
 
     # make iterators
-    csr_iter = iter(mx.io.NDArrayIter(csr, csr, batch_size))
+    csr_iter = iter(mx.io.NDArrayIter(csr, csr, batch_size, 
last_batch_handle='discard'))
     begin = 0
     for batch in csr_iter:
         expected = np.zeros((batch_size, num_cols))
diff --git a/tests/python/unittest/test_kvstore.py 
b/tests/python/unittest/test_kvstore.py
index 20ad2cd..12feb7e 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -78,7 +78,7 @@ def test_row_sparse_pull():
         for i in range(count):
             vals.append(mx.nd.zeros(shape).tostype('row_sparse'))
             row_id = np.random.randint(num_rows, size=num_rows)
-            row_ids.append(mx.nd.array(row_id, dtype='int64'))
+            row_ids.append(mx.nd.array(row_id))
         row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids
         vals_to_pull = vals[0] if len(vals) == 1 else vals
 
@@ -165,7 +165,7 @@ def test_sparse_aggregator():
         expected_sum += v.asnumpy()
 
     # prepare row_ids
-    all_rows = mx.nd.array(np.arange(shape[0]), dtype='int64')
+    all_rows = mx.nd.array(np.arange(shape[0]))
     kv.push('a', vals)
     kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals))
     result_sum = np.zeros(shape)
diff --git a/tests/python/unittest/test_module.py 
b/tests/python/unittest/test_module.py
index da02e8b..6813c48 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -506,9 +506,8 @@ def test_factorization_machine_module():
     csr_nd = rand_ndarray((num_samples, feature_dim), 'csr', 0.1)
     label = mx.nd.ones((num_samples,1))
     # the alternative is to use LibSVMIter
-    train_iter = mx.io.NDArrayIter(data=csr_nd,
-                                   label={'label':label},
-                                   batch_size=batch_size)
+    train_iter = mx.io.NDArrayIter(data=csr_nd, label={'label':label},
+                                   batch_size=batch_size, 
last_batch_handle='discard')
     # create module
     mod = mx.mod.Module(symbol=model, data_names=['data'], 
label_names=['label'])
     # allocate memory by given the input data and lable shapes
@@ -548,7 +547,8 @@ def test_module_initializer():
 
     data = mx.nd.zeros(shape=(n, m), stype='csr')
     label = mx.nd.zeros((n, 1))
-    iterator = mx.io.NDArrayIter(data=data, label={'label':label}, 
batch_size=n)
+    iterator = mx.io.NDArrayIter(data=data, label={'label':label},
+                                 batch_size=n, last_batch_handle='discard')
 
     # create module
     mod = mx.mod.Module(symbol=model, data_names=['data'], 
label_names=['label'])
diff --git a/tests/python/unittest/test_sparse_ndarray.py 
b/tests/python/unittest/test_sparse_ndarray.py
index f96c94c..94ea228 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -548,6 +548,14 @@ def test_synthetic_dataset_generator():
     test_powerlaw_generator(csr_arr_big, final_row=4)
     test_powerlaw_generator(csr_arr_square, final_row=6)
 
+def test_sparse_nd_exception():
+    """ test invalid sparse operator will throw a exception """
+    a = mx.nd.zeros((2,2))
+    try:
+        b = mx.nd.sparse.retain(a, invalid_arg="garbage_value")
+        assert(False)
+    except:
+        return
 
 if __name__ == '__main__':
     import nose

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to