This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new b235cf7 Update doc for sparse related APIs (#7688) b235cf7 is described below commit b235cf7778dba503008d76e223dfffe638f95654 Author: Haibin Lin <linhaibin.e...@gmail.com> AuthorDate: Wed Sep 6 11:59:55 2017 -0700 Update doc for sparse related APIs (#7688) * update doc for NDArrayIter * update kvstore doc. remove int64 restriction for row_ids in rowsparse pull * add exception test for sparse op * update doc for optimizers * update doc for sgd/adam operators * CR comment and fix lint --- docs/mxdoc.py | 3 ++- python/mxnet/io.py | 21 ++++++++++++++- python/mxnet/kvstore.py | 39 ++++++++++++++++++++-------- python/mxnet/optimizer.py | 34 +++++++++++++++++++++--- src/kvstore/kvstore_local.h | 3 ++- src/operator/optimizer_op.cc | 23 +++++++++++++--- tests/python/unittest/test_io.py | 2 +- tests/python/unittest/test_kvstore.py | 4 +-- tests/python/unittest/test_module.py | 8 +++--- tests/python/unittest/test_sparse_ndarray.py | 8 ++++++ 10 files changed, 117 insertions(+), 28 deletions(-) diff --git a/docs/mxdoc.py b/docs/mxdoc.py index 2726a1c..010118c 100644 --- a/docs/mxdoc.py +++ b/docs/mxdoc.py @@ -23,7 +23,8 @@ import json import sys from recommonmark import transform import pypandoc -import StringIO +# import StringIO from io for python3 compatibility +from io import StringIO import contextlib # white list to evaluate the code block output, such as ['tutorials/gluon'] diff --git a/python/mxnet/io.py b/python/mxnet/io.py index b169681..0665101 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -34,6 +34,7 @@ from .base import DataIterHandle, NDArrayHandle from .base import mx_real_t from .base import check_call, build_param_doc as _build_param_doc from .ndarray import NDArray +from .ndarray.sparse import CSRNDArray from .ndarray import _ndarray_cls from .ndarray import array from .ndarray import concatenate @@ -513,7 +514,8 @@ def _init_data(data, allow_empty, default_name): return list(data.items()) class NDArrayIter(DataIter): - """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray`` or ``h5py.Dataset``. + """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset`` + or ``mx.nd.sparse.CSRNDArray``. Example usage: ---------- @@ -576,6 +578,18 @@ class NDArrayIter(DataIter): >>> label = {'label1':np.zeros(shape=(10,1)), 'label2':np.zeros(shape=(20,1))} >>> dataiter = mx.io.NDArrayIter(data, label, 3, True, last_batch_handle='discard') + `NDArrayIter` also supports ``mx.nd.sparse.CSRNDArray`` with `shuffle` set to `False` + and `last_batch_handle` set to `discard`. + + >>> csr_data = mx.nd.array(np.arange(40).reshape((10,4))).tostype('csr') + >>> labels = np.ones([10, 1]) + >>> dataiter = mx.io.NDArrayIter(csr_data, labels, 3, last_batch_handle='discard') + >>> [batch.data[0] for batch in dataiter] + [ + <CSRNDArray 3x4 @cpu(0)>, + <CSRNDArray 3x4 @cpu(0)>, + <CSRNDArray 3x4 @cpu(0)>] + Parameters ---------- data: array or list of array or dict of string to array @@ -603,6 +617,11 @@ class NDArrayIter(DataIter): self.data = _init_data(data, allow_empty=False, default_name=data_name) self.label = _init_data(label, allow_empty=True, default_name=label_name) + if isinstance(data, CSRNDArray) or isinstance(label, CSRNDArray): + assert(shuffle is False), \ + "`NDArrayIter` only supports ``CSRNDArray`` with `shuffle` set to `False`" + assert(last_batch_handle == 'discard'), "`NDArrayIter` only supports ``CSRNDArray``" \ + " with `last_batch_handle` set to `discard`." self.idx = np.arange(self.data[0][1].shape[0]) # shuffle data diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 0da9e39..cf5c159 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -104,7 +104,7 @@ class KVStore(object): ---------- key : str, int, or sequence of str or int The keys. - value : NDArray or sequence of NDArray + value : NDArray, RowSparseNDArray or sequence of NDArray or RowSparseNDArray Values corresponding to the keys. Examples @@ -120,8 +120,15 @@ class KVStore(object): [ 2. 2. 2.]] >>> # init a list of key-value pairs - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) + + >>> # init a row_sparse value + >>> kv.init('4', mx.nd.ones(shape).tostype('row_sparse')) + >>> b = mx.nd.sparse.zeros('row_sparse', shape) + >>> kv.row_sparse_pull('4', row_ids=mx.nd.array([0, 1]), out=b) + >>> print b + <RowSparseNDArray 2x3 @cpu(0)> """ ckeys, cvals, use_str_keys = _ctype_key_value(key, value) if use_str_keys: @@ -143,7 +150,8 @@ class KVStore(object): key : str, int, or sequence of str or int Keys. - value : NDArray or list of NDArray or list of list of NDArray + value : NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray, + or list of list of NDArray or RowSparseNDArray Values corresponding to the keys. priority : int, optional @@ -171,7 +179,7 @@ class KVStore(object): >>> # push a list of keys. >>> # single device - >>> keys = [4, 5, 6] + >>> keys = ['4', '5', '6'] >>> kv.push(keys, [mx.nd.ones(shape)]*len(keys)) >>> b = [mx.nd.zeros(shape)]*len(keys) >>> kv.pull(keys, out=b) @@ -187,6 +195,15 @@ class KVStore(object): >>> print b[1][1].asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] + + >>> # push a row_sparse value + >>> b = mx.nd.sparse.zeros('row_sparse', shape) + >>> kv.init('10', mx.nd.sparse.zeros('row_sparse', shape)) + >>> kv.push('10', mx.nd.ones(shape).tostype('row_sparse')) + >>> # pull out the value + >>> kv.row_sparse_pull('10', row_ids=mx.nd.array([0, 1]), out=b) + >>> print b + <RowSparseNDArray 2x3 @cpu(0)> """ ckeys, cvals, use_str_keys = _ctype_key_value(key, value) if use_str_keys: @@ -209,7 +226,7 @@ class KVStore(object): The returned values are gauranteed to be the latest values in the store. - For row_sparse values, please use `row_sparse_pull` instead. + For `RowSparseNDArray` values, please use ``row_sparse_pull`` instead. Parameters ---------- @@ -242,7 +259,7 @@ class KVStore(object): >>> # pull a list of key-value pairs. >>> # On single device - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> b = [mx.nd.zeros(shape)]*len(keys) >>> kv.pull(keys, out=b) >>> print b[1].asnumpy() @@ -266,8 +283,8 @@ class KVStore(object): self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): - """ Pulls a single row_sparse value or a sequence of row_sparse values from the store - with specified row_ids. + """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ + from the store with specified row_ids. `row_sparse_pull` is executed asynchronously after all previous `push`/`pull`/`row_sparse_pull` calls for the same input key(s) are finished. @@ -279,7 +296,7 @@ class KVStore(object): key : str, int, or sequence of str or int Keys. - out: NDArray or list of NDArray or list of list of NDArray + out: RowSparseNDArray or list of RowSparseNDArray or list of list of RowSparseNDArray Values corresponding to the keys. The stype is expected to be row_sparse priority : int, optional @@ -288,14 +305,14 @@ class KVStore(object): other pull actions. row_ids : NDArray or list of NDArray - The row_ids for which to pull for each value. Each row_id is an 1D-NDArray \ + The row_ids for which to pull for each value. Each row_id is an 1D NDArray \ whose values don't have to be unique nor sorted. Examples -------- >>> shape = (3, 3) >>> kv.init('3', mx.nd.ones(shape).tostype('row_sparse')) - >>> a = mx.nd.zeros(shape, stype='row_sparse') + >>> a = mx.nd.sparse.zeros('row_sparse', shape) >>> row_ids = mx.nd.array([0, 2], dtype='int64') >>> kv.row_sparse_pull('3', out=a, row_ids=row_ids) >>> print a.asnumpy() diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index cc34c47..967ba24 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -18,6 +18,7 @@ # coding: utf-8 # pylint: disable=too-many-lines """Weight updating functions.""" +# pylint: disable=too-many-lines import math import pickle import logging @@ -333,16 +334,26 @@ class Optimizer(object): # convenience wrapper for Optimizer.Register register = Optimizer.register # pylint: disable=invalid-name +# pylint: disable=line-too-long @register class SGD(Optimizer): """The SGD optimizer with momentum and weight decay. The optimizer updates the weight by:: - state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight + rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight + state = momentum * state + rescaled_grad weight = weight - state - Sparse updating is supported. For details of the update algorithm see + If the storage types of weight, state and grad are all ``row_sparse``, \ + sparse updates are applied by:: + + for row in grad.indices: + rescaled_grad[row] = lr * rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row] + state[row] = momentum[row] * state[row] + rescaled_grad[row] + weight[row] = weight[row] - state[row] + + For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`. This optimizer accepts the following parameters in addition to those accepted @@ -411,6 +422,7 @@ class SGD(Optimizer): mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs) +# pylint: enable=line-too-long @register class DCASGD(Optimizer): """The DCASGD optimizer. @@ -545,10 +557,26 @@ class Adam(Optimizer): This class implements the optimizer described in *Adam: A Method for Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980. + The optimizer updates the weight by:: + + rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient) + m = beta1 * m + (1 - beta1) * rescaled_grad + v = beta2 * v + (1 - beta2) * (rescaled_grad**2) + w = w - learning_rate * m / (sqrt(v) + epsilon) + + If the storage types of weight, state and grad are all ``row_sparse``, \ + sparse updates are applied by:: + + for row in grad.indices: + rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * weight[row], clip_gradient) + m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row] + v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2) + w[row] = w[row] - learning_rate * m[row] / (sqrt(v[row]) + epsilon) + This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`. - For details of the update algorithm, see :class:`ndarray.adam_update`. + For details of the update algorithm, see :class:`~mxnet.ndarray.adam_update`. Parameters ---------- diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index e05819b..db1d04a 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -216,7 +216,8 @@ class KVStoreLocal : public KVStore { const size_t num_vals = target_val_rowids.size(); for (size_t i = 0; i < num_vals; i++) { auto &row_id = target_val_rowids[i].second; - NDArray indices = row_id.Copy(pinned_ctx_); + NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64); + CopyFromTo(row_id, &indices, 0); Unique(&indices, priority); target_val_rowids[i].second = indices; } diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 9b2b088..eace28a 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -40,8 +40,11 @@ It updates the weights using:: weight = weight - learning_rate * gradient -If weight is stored with `row_sparse` storage type, -only the row slices whose indices appear in grad.indices are updated. +If weight is of ``row_sparse`` storage type, +only the row slices whose indices appear in grad.indices are updated:: + + for row in gradient.indices: + weight[row] = weight[row] - learning_rate * gradient[row] )code" ADD_FILELINE) .set_num_inputs(2) @@ -74,8 +77,12 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -If weights are stored with `row_sparse` storage type, -only the row slices whose indices appear in grad.indices are updated (for both weight and momentum). +If weight and momentum are both of ``row_sparse`` storage type, +only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: + + for row in gradient.indices: + v[row] = momentum[row] * v[row] - learning_rate * gradient[row] + weight[row] += v[row] )code" ADD_FILELINE) .set_num_inputs(3) @@ -149,6 +156,14 @@ It updates the weights using:: v = beta2*v + (1-beta2)*(grad**2) w += - learning_rate * m / (sqrt(v) + epsilon) +If w, m and v are all of ``row_sparse`` storage type, +only the row slices whose indices appear in grad.indices are updated (for w, m and v):: + + for row in grad.indices: + m[row] = beta1*m[row] + (1-beta1)*grad[row] + v[row] = beta2*v[row] + (1-beta2)*(grad[row]**2) + w[row] += - learning_rate * m[row] / (sqrt(v[row]) + epsilon) + )code" ADD_FILELINE) .set_num_inputs(4) .set_num_outputs(1) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 6ec462e..a1f14ef 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -163,7 +163,7 @@ def test_NDArrayIter_csr(): dns = csr.asnumpy() # make iterators - csr_iter = iter(mx.io.NDArrayIter(csr, csr, batch_size)) + csr_iter = iter(mx.io.NDArrayIter(csr, csr, batch_size, last_batch_handle='discard')) begin = 0 for batch in csr_iter: expected = np.zeros((batch_size, num_cols)) diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 20ad2cd..12feb7e 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -78,7 +78,7 @@ def test_row_sparse_pull(): for i in range(count): vals.append(mx.nd.zeros(shape).tostype('row_sparse')) row_id = np.random.randint(num_rows, size=num_rows) - row_ids.append(mx.nd.array(row_id, dtype='int64')) + row_ids.append(mx.nd.array(row_id)) row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids vals_to_pull = vals[0] if len(vals) == 1 else vals @@ -165,7 +165,7 @@ def test_sparse_aggregator(): expected_sum += v.asnumpy() # prepare row_ids - all_rows = mx.nd.array(np.arange(shape[0]), dtype='int64') + all_rows = mx.nd.array(np.arange(shape[0])) kv.push('a', vals) kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals)) result_sum = np.zeros(shape) diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index da02e8b..6813c48 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -506,9 +506,8 @@ def test_factorization_machine_module(): csr_nd = rand_ndarray((num_samples, feature_dim), 'csr', 0.1) label = mx.nd.ones((num_samples,1)) # the alternative is to use LibSVMIter - train_iter = mx.io.NDArrayIter(data=csr_nd, - label={'label':label}, - batch_size=batch_size) + train_iter = mx.io.NDArrayIter(data=csr_nd, label={'label':label}, + batch_size=batch_size, last_batch_handle='discard') # create module mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['label']) # allocate memory by given the input data and lable shapes @@ -548,7 +547,8 @@ def test_module_initializer(): data = mx.nd.zeros(shape=(n, m), stype='csr') label = mx.nd.zeros((n, 1)) - iterator = mx.io.NDArrayIter(data=data, label={'label':label}, batch_size=n) + iterator = mx.io.NDArrayIter(data=data, label={'label':label}, + batch_size=n, last_batch_handle='discard') # create module mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['label']) diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index f96c94c..94ea228 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -548,6 +548,14 @@ def test_synthetic_dataset_generator(): test_powerlaw_generator(csr_arr_big, final_row=4) test_powerlaw_generator(csr_arr_square, final_row=6) +def test_sparse_nd_exception(): + """ test invalid sparse operator will throw a exception """ + a = mx.nd.zeros((2,2)) + try: + b = mx.nd.sparse.retain(a, invalid_arg="garbage_value") + assert(False) + except: + return if __name__ == '__main__': import nose -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].