This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 568b5a2 bce loss (#7304) 568b5a2 is described below commit 568b5a2d3e701768ff6f270238e5edccc2f35ff1 Author: Sheng Zha <s...@users.noreply.github.com> AuthorDate: Sun Aug 13 17:01:13 2017 -0700 bce loss (#7304) --- python/mxnet/gluon/loss.py | 69 +++++++++++++++++++++++++++++--------- tests/python/unittest/test_loss.py | 30 +++++++++++++++++ 2 files changed, 84 insertions(+), 15 deletions(-) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 2b31840..5839105 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -20,7 +20,7 @@ """ losses for training neural networks """ from __future__ import absolute_import -from .. import symbol, ndarray +from .. import ndarray from ..base import numeric_types from .block import HybridBlock @@ -54,6 +54,11 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None): return loss +def _reshape_label_as_output(F, output, label): + # for symbolic output.shape is not available so we reshape + # to empty shape and let it be inferred from output's shape + # via the '-' operator later. + return label.reshape(output.shape) if F is ndarray else label.reshape(()) class Loss(HybridBlock): """Base class for loss. @@ -113,13 +118,8 @@ class L2Loss(Loss): super(L2Loss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, output, label, sample_weight=None): - if F is ndarray: - loss = ndarray.square(output - label.reshape(output.shape)) - else: - # for symbolic output.shape is not available so we reshape - # to empty shape and let it be inferred from output's shape - # via the '-' operator later. - loss = symbol.square(output - label.reshape(())) + label = _reshape_label_as_output(F, output, label) + loss = F.square(output - label) loss = _apply_weighting(F, loss, self._weight/2, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) @@ -148,19 +148,56 @@ class L1Loss(Loss): super(L1Loss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, output, label, sample_weight=None): - if F is ndarray: - loss = ndarray.abs(output - label.reshape(output.shape)) + label = _reshape_label_as_output(F, output, label) + loss = F.abs(output - label) + loss = _apply_weighting(F, loss, self._weight, sample_weight) + return F.mean(loss, axis=self._batch_axis, exclude=True) + + +class SigmoidBinaryCrossEntropyLoss(Loss): + r"""The cross-entropy loss for binary classification. (alias: SigmoidBCELoss) + + BCE loss is useful when training logistic regression. + + .. math:: + loss(o, t) = - 1/n \sum_i (t[i] * log(o[i]) + (1 - t[i]) * log(1 - o[i])) + + + Parameters + ---------- + from_sigmoid : bool, default is `False` + Whether the input is from the output of sigmoid. Set this to false will make + the loss calculate sigmoid and then BCE, which is more numerically stable through + log-sum-exp trick. + weight : float or None + Global scalar weight for loss. + sample_weight : Symbol or None + Per sample weighting. Must be broadcastable to + the same shape as loss. For example, if loss has + shape (64, 10) and you want to weight each sample + in the batch, `sample_weight` should have shape (64, 1). + batch_axis : int, default 0 + The axis that represents mini-batch. + """ + def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs): + super(SigmoidBinaryCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs) + self._from_sigmoid = from_sigmoid + + def hybrid_forward(self, F, output, label, sample_weight=None): + label = _reshape_label_as_output(F, output, label) + if not self._from_sigmoid: + max_val = F.maximum(-output, 0) + loss = output - output*label + max_val + F.log(F.exp(-max_val)+F.exp(-output-max_val)) else: - # for symbolic output.shape is not available so we reshape - # to empty shape and let it be inferred from output's shape - # via the '-' operator later. - loss = symbol.abs(output - label.reshape(())) + loss = -(F.log(output+1e-8)*label + F.log(1.-output+1e-8)*(1.-label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) +SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss + class SoftmaxCrossEntropyLoss(Loss): - """Computes the softmax cross entropy loss. + """Computes the softmax cross entropy loss. (alias: SoftmaxCELoss) If `sparse_label` is `True`, label should contain integer category indicators: @@ -216,6 +253,8 @@ class SoftmaxCrossEntropyLoss(Loss): loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) +SoftmaxCELoss = SoftmaxCrossEntropyLoss + class KLDivLoss(Loss): """The Kullback-Leibler divergence loss. diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index 8eced7b..714ea75 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -18,6 +18,7 @@ import mxnet as mx import numpy as np from mxnet import gluon +from mxnet.test_utils import assert_almost_equal def test_loss_ndarray(): @@ -81,6 +82,34 @@ def test_ce_loss(): assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01 +def test_bce_loss(): + mx.random.seed(1234) + np.random.seed(1234) + N = 20 + data = mx.random.uniform(-1, 1, shape=(N, 20)) + label = mx.nd.array(np.random.randint(2, size=(N,)), dtype='float32') + data_iter = mx.io.NDArrayIter(data, label, batch_size=10, label_name='label') + output = get_net(1) + fc2 = output.get_internals()['fc2_output'] + l = mx.symbol.Variable('label') + Loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() + loss = Loss(output, l) + loss = mx.sym.make_loss(loss) + mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',)) + mod.fit(data_iter, num_epoch=200, optimizer_params={'learning_rate': 1.}, + eval_metric=mx.metric.Loss()) + assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01 + +def test_bce_equal_ce2(): + N = 100 + loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True) + loss2 = gluon.loss.SoftmaxCELoss(from_logits=True) + out1 = mx.random.uniform(0, 1, shape=(N, 1)) + out2 = mx.nd.log(mx.nd.concat(1-out1, out1, dim=1) + 1e-8) + label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1))) + assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy()) + + def test_kl_loss(): mx.random.seed(1234) np.random.seed(1234) @@ -117,6 +146,7 @@ def test_l2_loss(): eval_metric=mx.metric.Loss()) assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05 + def test_l1_loss(): mx.random.seed(1234) np.random.seed(1234) -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].