This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 07661ae decouple record/train and add state readers (#7356) 07661ae is described below commit 07661ae9a627d2a90b15c04b665fdb0773920285 Author: Sheng Zha <s...@users.noreply.github.com> AuthorDate: Tue Aug 8 15:13:29 2017 -0700 decouple record/train and add state readers (#7356) * decouple record/train and add state readers * update per comments * update per concensus * add API doc * fix --- docs/api/python/autograd.md | 21 +++-- include/mxnet/c_api.h | 12 +++ python/mxnet/autograd.py | 136 ++++++++++++++++++++++++--------- python/mxnet/ndarray.py | 7 +- src/c_api/c_api_ndarray.cc | 12 +++ tests/python/unittest/test_autograd.py | 37 ++++++++- 6 files changed, 174 insertions(+), 51 deletions(-) diff --git a/docs/api/python/autograd.md b/docs/api/python/autograd.md index 440a1e4..d204a2c 100644 --- a/docs/api/python/autograd.md +++ b/docs/api/python/autograd.md @@ -14,19 +14,28 @@ ## Autograd ```eval_rst -.. currentmodule:: mxnet.autograd -``` - - -```eval_rst .. autosummary:: :nosignatures: record pause - mark_variables + train_mode + predict_mode backward set_training + is_training set_recording + is_recording + mark_variables +``` + +## API Reference + +<script type="text/javascript" src='../../_static/js/auto_module_index.js'></script> + +```eval_rst +.. automodule:: mxnet.autograd + :members: ``` +<script>auto_index("api-reference");</script> diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d9a5315..3b8d54c 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -566,6 +566,18 @@ MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int* prev); */ MXNET_DLL int MXAutogradSetIsTraining(int is_training, int* prev); /*! + * \brief get whether autograd recording is on + * \param curr returns the current status. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXAutogradIsRecording(bool* curr); +/*! + * \brief get whether training mode is on + * \param curr returns the current status. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXAutogradIsTraining(bool* curr); +/*! * \brief mark NDArrays as variables to compute gradient for autograd * \param num_var number of variable NDArrays * \param var_handles variable NDArrays diff --git a/python/mxnet/autograd.py b/python/mxnet/autograd.py index 2f33052..2c3feab 100644 --- a/python/mxnet/autograd.py +++ b/python/mxnet/autograd.py @@ -10,7 +10,7 @@ from .ndarray import NDArray from .symbol import _GRAD_REQ_MAP -def set_recording(is_recording): +def set_recording(is_recording): #pylint: disable=redefined-outer-name """Set status to recording/not recording. When recording, graph will be constructed for gradient computation. @@ -27,14 +27,14 @@ def set_recording(is_recording): ctypes.c_int(is_recording), ctypes.byref(prev))) return bool(prev.value) -def set_training(is_train): - """Set status to training/not training. This affects ctx.is_train in operator +def set_training(train_mode): #pylint: disable=redefined-outer-name + """Set status to training/predicting. This affects ctx.is_train in operator running context. For example, Dropout will drop inputs randomly when - is_train=True while simply passing through if is_train=False. + train_mode=True while simply passing through if train_mode=False. Parameters ---------- - is_train: bool + train_mode: bool Returns ------- @@ -42,43 +42,70 @@ def set_training(is_train): """ prev = ctypes.c_int() check_call(_LIB.MXAutogradSetIsTraining( - ctypes.c_int(is_train), ctypes.byref(prev))) + ctypes.c_int(train_mode), ctypes.byref(prev))) return bool(prev.value) +def is_recording(): + """Get status on recording/not recording. -class RecordingStateScope(object): + Returns + ------- + Current state of recording. + """ + curr = ctypes.c_bool() + check_call(_LIB.MXAutogradIsRecording(ctypes.byref(curr))) + return curr.value + +def is_training(): + """Get status on training/predicting. + + Returns + ------- + Current state of training/predicting. + """ + curr = ctypes.c_bool() + check_call(_LIB.MXAutogradIsTraining(ctypes.byref(curr))) + return curr.value + + +class _RecordingStateScope(object): """Scope for managing training state. Example:: - with RecordingStateScope(True, True): + + with _RecordingStateScope(True, True): y = model(x) backward([y]) + """ - def __init__(self, enter_state, is_train): - self._enter_state = enter_state - self._enter_is_train = is_train - self._prev = None - self._prev_is_train = None + def __init__(self, is_record, train_mode): #pylint: disable=redefined-outer-name + self._enter_is_record = is_record + self._enter_train_mode = train_mode + self._prev_is_record = None + self._prev_train_mode = None def __enter__(self): - self._prev = set_recording(self._enter_state) - self._prev_is_train = set_training(self._enter_is_train) + if self._enter_is_record is not None: + self._prev_is_record = set_recording(self._enter_is_record) + if self._enter_train_mode is not None: + self._prev_train_mode = set_training(self._enter_train_mode) def __exit__(self, ptype, value, trace): - if self._prev != self._enter_state: - set_recording(self._prev) - if self._prev_is_train != self._enter_is_train: - set_training(self._prev_is_train) + if self._enter_is_record is not None and self._prev_is_record != self._enter_is_record: + set_recording(self._prev_is_record) + if self._enter_train_mode is not None and self._prev_train_mode != self._enter_train_mode: + set_training(self._prev_train_mode) -def record(is_train=True): - """Returns a training scope context to be used in 'with' statement - and captures training code. +def record(train_mode=True): #pylint: disable=redefined-outer-name + """Returns an autograd recording scope context to be used in 'with' statement + and captures code that needs gradients to be calculated. - .. note:: When forwarding with is_train=False, the corresponding backward - should also use is_train=False, otherwise gradient is undefined. + .. note:: When forwarding with train_mode=False, the corresponding backward + should also use train_mode=False, otherwise gradient is undefined. Example:: + with autograd.record(): y = model(x) backward([y]) @@ -87,17 +114,19 @@ def record(is_train=True): Parameters ---------- - is_train: bool, default True - Whether to do forward for training or inference. + train_mode: bool, default True + Whether the forward pass is in training or predicting mode. This controls the behavior + of some layers such as Dropout, BatchNorm. """ - return RecordingStateScope(True, is_train) + return _RecordingStateScope(True, train_mode) -def pause(is_train=False): - """Returns a testing scope context to be used in 'with' statement - and captures testing code. +def pause(train_mode=False): #pylint: disable=redefined-outer-name + """Returns a scope context to be used in 'with' statement for codes that do not need + gradients to be calculated. Example:: + with autograd.record(): y = model(x) backward([y]) @@ -106,10 +135,41 @@ def pause(is_train=False): Parameters ---------- - is_train: bool, default False - Whether to do forward for training or inference. + train_mode: bool, default False + Whether to do forward for training or predicting. + """ + return _RecordingStateScope(False, train_mode) + + +def train_mode(): + """Returns a scope context to be used in 'with' statement + in which forward pass behavior is set to training mode, + without changing the recording states. + + Example:: + + y = model(x) + with autograd.train_mode(): + y = dropout(y) + + """ + return _RecordingStateScope(None, True) + + +def predict_mode(): + """Returns a scope context to be used in 'with' statement + in which forward pass behavior is set to inference mode, + without changing the recording states. + + Example:: + + with autograd.record(): + y = model(x) + with autograd.predict_mode(): + y = sampling(y) + backward([y]) """ - return RecordingStateScope(False, is_train) + return _RecordingStateScope(None, False) def mark_variables(variables, gradients, grad_reqs='write'): @@ -143,7 +203,7 @@ def mark_variables(variables, gradients, grad_reqs='write'): c_array(NDArrayHandle, gradient_handles))) -def backward(heads, head_grads=None, retain_graph=False, is_train=True): +def backward(heads, head_grads=None, retain_graph=False, train_mode=True): #pylint: disable=redefined-outer-name """Compute the gradients of heads w.r.t previously marked variables. Parameters @@ -152,8 +212,8 @@ def backward(heads, head_grads=None, retain_graph=False, is_train=True): Output NDArray(s) head_grads: NDArray or list of NDArray or None Gradients with respect to heads. - is_train: bool, optional - Whether to do backward for training or inference. + train_mode: bool, optional + Whether to do backward for training or predicting. """ if isinstance(heads, NDArray): assert head_grads is None or isinstance(head_grads, NDArray) @@ -170,7 +230,7 @@ def backward(heads, head_grads=None, retain_graph=False, is_train=True): c_array(NDArrayHandle, output_handles), ctypes.c_void_p(0), ctypes.c_int(retain_graph), - ctypes.c_int(is_train))) + ctypes.c_int(train_mode))) return ograd_handles = [] @@ -187,4 +247,4 @@ def backward(heads, head_grads=None, retain_graph=False, is_train=True): c_array(NDArrayHandle, output_handles), c_array(NDArrayHandle, ograd_handles), ctypes.c_int(retain_graph), - ctypes.c_int(is_train))) + ctypes.c_int(train_mode))) diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index b2178a9..d4a0cdb 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -1059,7 +1059,7 @@ fixed-size items. check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl))) return NDArray(hdl) - def backward(self, out_grad=None, retain_graph=False, is_train=True): + def backward(self, out_grad=None, retain_graph=False, train_mode=True): """Compute the gradients of this NDArray w.r.t variables. Parameters @@ -1070,7 +1070,7 @@ fixed-size items. Whether to retain the computaion graph for another backward pass on the same graph. By default the computaion history is cleared. - is_train : bool, optional + train_mode : bool, optional Whether to compute gradient for training or inference. """ if out_grad is None: @@ -1082,7 +1082,7 @@ fixed-size items. 1, c_array(NDArrayHandle, [self.handle]), c_array(NDArrayHandle, ograd_handles), ctypes.c_int(retain_graph), - ctypes.c_int(is_train))) + ctypes.c_int(train_mode))) def onehot_encode(indices, out): @@ -2538,7 +2538,6 @@ def _make_ndarray_function(handle, name): else: signature.append('%s=_Null'%name) kwarg_names.append(name) - #signature.append('is_train=False') signature.append('out=None') signature.append('name=None') signature.append('**kwargs') diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index f401394..a37e314 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -522,12 +522,24 @@ int MXInvokeCachedOp(CachedOpHandle handle, API_END(); } +int MXAutogradIsTraining(bool* curr) { + API_BEGIN(); + *curr = AutogradRuntime::Get()->IsTraining(); + API_END(); +} + int MXAutogradSetIsTraining(int is_training, int* prev) { API_BEGIN(); *prev = AutogradRuntime::Get()->SetIsTraining(static_cast<bool>(is_training)); API_END(); } +int MXAutogradIsRecording(bool* curr) { + API_BEGIN(); + *curr = AutogradRuntime::Get()->IsRecording(); + API_END(); +} + int MXAutogradSetIsRecording(int is_recording, int* prev) { API_BEGIN(); *prev = AutogradRuntime::Get()->SetIsRecording(static_cast<bool>(is_recording)); diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 172075d..7ee3500 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -251,18 +251,49 @@ def test_attach_grad(): def test_is_train(): x = mx.nd.ones((10, 10)) x.attach_grad() - with record(True): + with record(train_mode=True): + assert is_recording() + assert is_training() y = mx.nd.Dropout(x, p=0.5) assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 y.backward() assert (x.grad.asnumpy() == y.asnumpy()).all() - with record(False): + with predict_mode(): + assert is_recording() + assert not is_training() + y = mx.nd.Dropout(x, p=0.5) + assert (y.asnumpy() == x.asnumpy()).all() + y.backward(train_mode=False) + assert (x.grad.asnumpy() == x.asnumpy()).all() + + with record(train_mode=False): + assert is_recording() + assert not is_training() y = mx.nd.Dropout(x, p=0.5) assert (y.asnumpy() == x.asnumpy()).all() - y.backward(is_train=False) + y.backward(train_mode=False) assert (x.grad.asnumpy() == x.asnumpy()).all() + with train_mode(): + assert is_recording() + assert is_training() + y = mx.nd.Dropout(x, p=0.5) + assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 + y.backward() + assert (x.grad.asnumpy() == y.asnumpy()).all() + + assert not is_recording() + assert not is_training() + y = mx.nd.Dropout(x, p=0.5) + assert (y.asnumpy() == x.asnumpy()).all() + + with train_mode(): + assert not is_recording() + assert is_training() + y = mx.nd.Dropout(x, p=0.5) + assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 + if __name__ == "__main__": import nose -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].