This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 568b5a2  bce loss (#7304)
568b5a2 is described below

commit 568b5a2d3e701768ff6f270238e5edccc2f35ff1
Author: Sheng Zha <s...@users.noreply.github.com>
AuthorDate: Sun Aug 13 17:01:13 2017 -0700

    bce loss (#7304)
---
 python/mxnet/gluon/loss.py         | 69 +++++++++++++++++++++++++++++---------
 tests/python/unittest/test_loss.py | 30 +++++++++++++++++
 2 files changed, 84 insertions(+), 15 deletions(-)

diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 2b31840..5839105 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -20,7 +20,7 @@
 """ losses for training neural networks """
 from __future__ import absolute_import
 
-from .. import symbol, ndarray
+from .. import ndarray
 from ..base import numeric_types
 from .block import HybridBlock
 
@@ -54,6 +54,11 @@ def _apply_weighting(F, loss, weight=None, 
sample_weight=None):
 
     return loss
 
+def _reshape_label_as_output(F, output, label):
+    # for symbolic output.shape is not available so we reshape
+    # to empty shape and let it be inferred from output's shape
+    # via the '-' operator later.
+    return label.reshape(output.shape) if F is ndarray else label.reshape(())
 
 class Loss(HybridBlock):
     """Base class for loss.
@@ -113,13 +118,8 @@ class L2Loss(Loss):
         super(L2Loss, self).__init__(weight, batch_axis, **kwargs)
 
     def hybrid_forward(self, F, output, label, sample_weight=None):
-        if F is ndarray:
-            loss = ndarray.square(output - label.reshape(output.shape))
-        else:
-            # for symbolic output.shape is not available so we reshape
-            # to empty shape and let it be inferred from output's shape
-            # via the '-' operator later.
-            loss = symbol.square(output - label.reshape(()))
+        label = _reshape_label_as_output(F, output, label)
+        loss = F.square(output - label)
         loss = _apply_weighting(F, loss, self._weight/2, sample_weight)
         return F.mean(loss, axis=self._batch_axis, exclude=True)
 
@@ -148,19 +148,56 @@ class L1Loss(Loss):
         super(L1Loss, self).__init__(weight, batch_axis, **kwargs)
 
     def hybrid_forward(self, F, output, label, sample_weight=None):
-        if F is ndarray:
-            loss = ndarray.abs(output - label.reshape(output.shape))
+        label = _reshape_label_as_output(F, output, label)
+        loss = F.abs(output - label)
+        loss = _apply_weighting(F, loss, self._weight, sample_weight)
+        return F.mean(loss, axis=self._batch_axis, exclude=True)
+
+
+class SigmoidBinaryCrossEntropyLoss(Loss):
+    r"""The cross-entropy loss for binary classification. (alias: 
SigmoidBCELoss)
+
+    BCE loss is useful when training logistic regression.
+
+    .. math::
+        loss(o, t) = - 1/n \sum_i (t[i] * log(o[i]) + (1 - t[i]) * log(1 - 
o[i]))
+
+
+    Parameters
+    ----------
+    from_sigmoid : bool, default is `False`
+        Whether the input is from the output of sigmoid. Set this to false 
will make
+        the loss calculate sigmoid and then BCE, which is more numerically 
stable through
+        log-sum-exp trick.
+    weight : float or None
+        Global scalar weight for loss.
+    sample_weight : Symbol or None
+        Per sample weighting. Must be broadcastable to
+        the same shape as loss. For example, if loss has
+        shape (64, 10) and you want to weight each sample
+        in the batch, `sample_weight` should have shape (64, 1).
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    """
+    def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, 
**kwargs):
+        super(SigmoidBinaryCrossEntropyLoss, self).__init__(weight, 
batch_axis, **kwargs)
+        self._from_sigmoid = from_sigmoid
+
+    def hybrid_forward(self, F, output, label, sample_weight=None):
+        label = _reshape_label_as_output(F, output, label)
+        if not self._from_sigmoid:
+            max_val = F.maximum(-output, 0)
+            loss = output - output*label + max_val + 
F.log(F.exp(-max_val)+F.exp(-output-max_val))
         else:
-            # for symbolic output.shape is not available so we reshape
-            # to empty shape and let it be inferred from output's shape
-            # via the '-' operator later.
-            loss = symbol.abs(output - label.reshape(()))
+            loss = -(F.log(output+1e-8)*label + 
F.log(1.-output+1e-8)*(1.-label))
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
         return F.mean(loss, axis=self._batch_axis, exclude=True)
 
+SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss
+
 
 class SoftmaxCrossEntropyLoss(Loss):
-    """Computes the softmax cross entropy loss.
+    """Computes the softmax cross entropy loss. (alias: SoftmaxCELoss)
 
     If `sparse_label` is `True`, label should contain integer category 
indicators:
 
@@ -216,6 +253,8 @@ class SoftmaxCrossEntropyLoss(Loss):
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
         return F.mean(loss, axis=self._batch_axis, exclude=True)
 
+SoftmaxCELoss = SoftmaxCrossEntropyLoss
+
 
 class KLDivLoss(Loss):
     """The Kullback-Leibler divergence loss.
diff --git a/tests/python/unittest/test_loss.py 
b/tests/python/unittest/test_loss.py
index 8eced7b..714ea75 100644
--- a/tests/python/unittest/test_loss.py
+++ b/tests/python/unittest/test_loss.py
@@ -18,6 +18,7 @@
 import mxnet as mx
 import numpy as np
 from mxnet import gluon
+from mxnet.test_utils import assert_almost_equal
 
 
 def test_loss_ndarray():
@@ -81,6 +82,34 @@ def test_ce_loss():
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01
 
 
+def test_bce_loss():
+    mx.random.seed(1234)
+    np.random.seed(1234)
+    N = 20
+    data = mx.random.uniform(-1, 1, shape=(N, 20))
+    label = mx.nd.array(np.random.randint(2, size=(N,)), dtype='float32')
+    data_iter = mx.io.NDArrayIter(data, label, batch_size=10, 
label_name='label')
+    output = get_net(1)
+    fc2 = output.get_internals()['fc2_output']
+    l = mx.symbol.Variable('label')
+    Loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
+    loss = Loss(output, l)
+    loss = mx.sym.make_loss(loss)
+    mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
+    mod.fit(data_iter, num_epoch=200, optimizer_params={'learning_rate': 1.},
+            eval_metric=mx.metric.Loss())
+    assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01
+
+def test_bce_equal_ce2():
+    N = 100
+    loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True)
+    loss2 = gluon.loss.SoftmaxCELoss(from_logits=True)
+    out1 = mx.random.uniform(0, 1, shape=(N, 1))
+    out2 = mx.nd.log(mx.nd.concat(1-out1, out1, dim=1) + 1e-8)
+    label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1)))
+    assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, 
label).asnumpy())
+
+
 def test_kl_loss():
     mx.random.seed(1234)
     np.random.seed(1234)
@@ -117,6 +146,7 @@ def test_l2_loss():
             eval_metric=mx.metric.Loss())
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05
 
+
 def test_l1_loss():
     mx.random.seed(1234)
     np.random.seed(1234)

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to