This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ba8a9d1  support softmin operator with unit test (#12306)
ba8a9d1 is described below

commit ba8a9d13e1b549d061f1933c463cfad5e7bdd7aa
Author: Hao Jin <haoj...@users.noreply.github.com>
AuthorDate: Wed Aug 29 10:35:44 2018 -0700

    support softmin operator with unit test (#12306)
---
 src/operator/contrib/ctc_loss-inl.h    |  7 +--
 src/operator/nn/softmax-inl.h          | 88 ++++++++++++++++++++--------------
 src/operator/nn/softmax.cc             | 39 +++++++++++++++
 src/operator/nn/softmax.cu             |  7 +++
 tests/python/unittest/test_operator.py | 24 ++++++++--
 5 files changed, 123 insertions(+), 42 deletions(-)

diff --git a/src/operator/contrib/ctc_loss-inl.h 
b/src/operator/contrib/ctc_loss-inl.h
index 72209ae..9380be4 100644
--- a/src/operator/contrib/ctc_loss-inl.h
+++ b/src/operator/contrib/ctc_loss-inl.h
@@ -409,7 +409,8 @@ class CTCLossOp : public Operator {
 
     // since the input is activation before softmax and cudnn ctc takes softmax
     // apply softmax to inputs first.
-    mxnet_op::Softmax<mxnet_op::softmax_fwd>(s, data.dptr_, prob.dptr_, 
data.shape_, 2, 1.0);
+    mxnet_op::Softmax<mxnet_op::softmax_fwd, false>(
+      s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0);
 
     CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_,
                             prob_desc_,
@@ -426,8 +427,8 @@ class CTCLossOp : public Operator {
                             workspace_bytes));
 
     if (req_grad) {
-      mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd, 
kWriteTo>(s,
-          prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0);
+      mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd, kWriteTo, 
false>(
+        s, prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0);
       Assign(grad, mxnet::kWriteInplace, grad * alphabet_size);
     }
   }
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 4a19db7..c063e38 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -51,7 +51,7 @@ struct log_softmax_fwd {
 };
 
 
-template<typename OP, typename DType, int ndim>
+template<typename OP, bool negate, typename DType, int ndim>
 inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
                     Shape<ndim> shape, int axis, const DType temperature) {
   index_t M = shape[axis];
@@ -65,30 +65,37 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
   for (int i = 0; i < static_cast<int>(N); ++i) {
     index_t base = unravel_dot(i, sshape, stride);
 
-    DType mmax = in[base];
+    DType mmax = negate ? -in[base] : in[base];
+    DType val;
     for (index_t j = 1; j < M; ++j) {
-      if (mmax < in[base + j*sa]) mmax = in[base + j*sa];
+      val = negate ? -in[base + j*sa] : in[base + j*sa];
+      if (mmax < val) mmax = val;
     }
 
     DType sum = DType(0);
+    DType in_val;
     // By default temperature is 1.0, and only in reinforcement training
     // users would set it to other values.
     // Adding a branch here to save the CPU 'divide-by-1' computation at 
runtime
     if (temperature == 1.0) {
       for (index_t j = 0; j < M; ++j) {
-        sum += std::exp(in[base + j*sa] - mmax);
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        sum += std::exp(in_val - mmax);
       }
 
       for (index_t j = 0; j < M; ++j) {
-        out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum);
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        out[base + j*sa] = OP::Map(in_val - mmax, sum);
       }
     } else {
       for (index_t j = 0; j < M; ++j) {
-        sum += std::exp((in[base + j*sa] - mmax)/temperature);
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        sum += std::exp((in_val - mmax)/temperature);
       }
 
       for (index_t j = 0; j < M; ++j) {
-        out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum);
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum);
       }
     }
   }
@@ -111,7 +118,7 @@ struct log_softmax_bwd {
 };
 
 
-template<typename OP1, typename OP2, int Req, typename DType, int ndim>
+template<typename OP1, typename OP2, int Req, bool negate, typename DType, int 
ndim>
 inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
                         DType *igrad, Shape<ndim> shape, int axis,
                         const DType temperature) {
@@ -137,12 +144,16 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType 
*ograd,
     DType final_result;
     if (temperature == 1.0) {
       for (index_t j = 0; j < M; ++j) {
-        final_result = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum);
+        final_result = negate ?
+                       -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) :
+                       OP2::Map(ograd[base + j*sa], out[base + j*sa], sum);
         KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
       }
     } else {
       for (index_t j = 0; j < M; ++j) {
-        final_result = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / 
temperature;
+        final_result = negate ?
+                       -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / 
temperature :
+                       OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / 
temperature;
         KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
       }
     }
@@ -151,7 +162,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType 
*ograd,
 
 
 #ifdef __CUDACC__
-template<int x_bits, typename OP, typename DType, int ndim>
+template<int x_bits, typename OP, bool negate, typename DType, int ndim>
 __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int 
axis,
                                        Shape<ndim> sshape, Shape<ndim> stride,
                                        const double temperature) {
@@ -163,7 +174,7 @@ __global__ void softmax_compute_kernel(DType *in, DType 
*out, index_t M, int axi
 
   red::maximum::SetInitValue(smem[x]);
   for (index_t i = x; i < M; i += x_size) {
-    red::maximum::Reduce(smem[x], in[base + i*sa]);
+    red::maximum::Reduce(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
   }
   __syncthreads();
   cuda::Reduce1D<red::maximum, x_bits>(smem);
@@ -172,9 +183,11 @@ __global__ void softmax_compute_kernel(DType *in, DType 
*out, index_t M, int axi
   __syncthreads();
 
   red::sum::SetInitValue(smem[x]);
+  DType val;
   for (index_t i = x; i < M; i += x_size) {
-    red::sum::Reduce(smem[x], static_cast<DType>(expf((in[base + i*sa] - smax)/
-    static_cast<DType>(temperature))));
+    val = negate ? -in[base + i*sa]:in[base + i*sa];
+    red::sum::Reduce(
+      smem[x], static_cast<DType>(expf((val - smax) / 
static_cast<DType>(temperature))));
   }
   __syncthreads();
   cuda::Reduce1D<red::sum, x_bits>(smem);
@@ -183,11 +196,12 @@ __global__ void softmax_compute_kernel(DType *in, DType 
*out, index_t M, int axi
   __syncthreads();
 
   for (index_t i = x; i < M; i += x_size) {
-    out[base + i*sa] = OP::Map((in[base + i*sa] - 
smax)/static_cast<DType>(temperature), ssum);
+    val = negate ? -in[base + i*sa] : in[base + i*sa];
+    out[base + i*sa] = OP::Map((val - smax)/static_cast<DType>(temperature), 
ssum);
   }
 }
 
-template<typename OP, typename DType, int ndim>
+template<typename OP, bool negate, typename DType, int ndim>
 inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
                     Shape<ndim> shape, int axis, const double temperature) {
   const int x_bits = 7;
@@ -198,14 +212,14 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
   Shape<ndim> sshape = shape;
   sshape[axis] = 1;
 
-  softmax_compute_kernel<x_bits, OP, DType, ndim>
+  softmax_compute_kernel<x_bits, OP, negate, DType, ndim>
     <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
       in, out, M, axis, sshape, stride, temperature);
   MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
 }
 
 
-template<int x_bits, typename OP1, typename OP2, int Req, typename DType, int 
ndim>
+template<int x_bits, typename OP1, typename OP2, int Req, bool negate, 
typename DType, int ndim>
 __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
                                         index_t M, int axis, Shape<ndim> 
sshape,
                                         Shape<ndim> stride, const double 
temperature) {
@@ -228,13 +242,15 @@ __global__ void softmax_gradient_kernel(DType *out, DType 
*ograd, DType *igrad,
   DType final_result;
   for (index_t i = x; i < M; i += x_size) {
     final_result =
-      OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) / 
static_cast<DType>(temperature);
-    KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result);
+      negate ?
+      -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) :
+      OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
+    KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / 
static_cast<DType>(temperature));
   }
 }
 
 
-template<typename OP1, typename OP2, int Req, typename DType, int ndim>
+template<typename OP1, typename OP2, int Req, bool negate, typename DType, int 
ndim>
 inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
                         DType *igrad, Shape<ndim> shape, int axis,
                         const double temperature) {
@@ -246,7 +262,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType 
*ograd,
   Shape<ndim> sshape = shape;
   sshape[axis] = 1;
 
-  softmax_gradient_kernel<x_bits, OP1, OP2, Req, DType, ndim>
+  softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, DType, ndim>
     <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
       out, ograd, igrad, M, axis, sshape, stride, temperature);
   MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
@@ -267,7 +283,7 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
   }
 };
 
-template<typename xpu, typename OP>
+template<typename xpu, typename OP, bool negate = false>
 void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
                     const OpContext& ctx,
                     const std::vector<TBlob>& inputs,
@@ -283,19 +299,19 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
   TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
     if (shape.ndim() == 2) {
-      Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-              outputs[0].dptr<DType>(), shape.get<2>(), axis,
-              static_cast<DType>(temperature));
+      Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+                          outputs[0].dptr<DType>(), shape.get<2>(), axis,
+                          static_cast<DType>(temperature));
     } else {
-      Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-              outputs[0].dptr<DType>(), shape.get<3>(), axis,
-              static_cast<DType>(temperature));
+      Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+                          outputs[0].dptr<DType>(), shape.get<3>(), axis,
+                          static_cast<DType>(temperature));
     }
   });
 }
 
 
-template<typename xpu, typename OP1, typename OP2>
+template<typename xpu, typename OP1, typename OP2, bool negate = false>
 void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
                         const OpContext& ctx,
                         const std::vector<TBlob>& inputs,
@@ -311,13 +327,13 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
     MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
       if (shape.ndim() == 2) {
-        SoftmaxGrad<OP1, OP2, Req>(ctx.get_stream<xpu>(), 
inputs[1].dptr<DType>(),
-                                   inputs[0].dptr<DType>(), 
outputs[0].dptr<DType>(),
-                                   shape.get<2>(), axis, 
static_cast<DType>(temperature));
+        SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), 
inputs[1].dptr<DType>(),
+                                           inputs[0].dptr<DType>(), 
outputs[0].dptr<DType>(),
+                                           shape.get<2>(), axis, 
static_cast<DType>(temperature));
       } else {
-        SoftmaxGrad<OP1, OP2, Req>(ctx.get_stream<xpu>(), 
inputs[1].dptr<DType>(),
-                                   inputs[0].dptr<DType>(), 
outputs[0].dptr<DType>(),
-                                   shape.get<3>(), axis, 
static_cast<DType>(temperature));
+        SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), 
inputs[1].dptr<DType>(),
+                                           inputs[0].dptr<DType>(), 
outputs[0].dptr<DType>(),
+                                           shape.get<3>(), axis, 
static_cast<DType>(temperature));
       }
     });
   });
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 0fad3d6..88b7b5f 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -109,6 +109,45 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax)
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, 
op::mshadow_op::mul,
                                                         
mxnet_op::softmax_bwd>);
 
+MXNET_OPERATOR_REGISTER_UNARY(softmin)
+.describe(R"code(Applies the softmin function.
+
+The resulting array contains elements in the range (0,1) and the elements 
along the given axis sum
+up to 1.
+
+.. math::
+   softmin(\mathbf{z/t})_j = \frac{e^{-z_j/t}}{\sum_{k=1}^K e^{-z_k/t}}
+
+for :math:`j = 1, ..., K`
+
+t is the temperature parameter in softmax function. By default, t equals 1.0
+
+Example::
+
+  x = [[ 1.  2.  3.]
+       [ 3.  2.  1.]]
+
+  softmin(x,axis=0) = [[ 0.88079703,  0.5,  0.11920292],
+                       [ 0.11920292,  0.5,  0.88079703]]
+
+  softmin(x,axis=1) = [[ 0.66524094,  0.24472848,  0.09003057],
+                       [ 0.09003057,  0.24472848,  0.66524094]]
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<SoftmaxParam>)
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+    [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"output"};
+})
+.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, 
mxnet_op::softmax_fwd, true>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseOut{"_backward_softmin"})
+.add_arguments(SoftmaxParam::__FIELDS__());
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_softmin)
+.set_attr_parser(ParamParser<SoftmaxParam>)
+.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, 
op::mshadow_op::mul,
+                                                        mxnet_op::softmax_bwd, 
true>);
+
 MXNET_OPERATOR_REGISTER_UNARY(log_softmax)
 .describe(R"code(Computes the log softmax of the input.
 This is equivalent to computing softmax followed by log.
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index 8274642..254e726 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -35,6 +35,13 @@ NNVM_REGISTER_OP(_backward_softmax)
 .set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, 
op::mshadow_op::mul,
                                                         
mxnet_op::softmax_bwd>);
 
+NNVM_REGISTER_OP(softmin)
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, 
mxnet_op::softmax_fwd, true>);
+
+NNVM_REGISTER_OP(_backward_softmin)
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, 
op::mshadow_op::mul,
+                                                        mxnet_op::softmax_bwd, 
true>);
+
 NNVM_REGISTER_OP(log_softmax)
 .set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, 
mxnet_op::log_softmax_fwd>);
 
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 38c90e6..5bd88dd 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -267,11 +267,8 @@ def test_rnnrelu_dropout():
     out[0].wait_to_read()
 
 def np_softmax(x, axis=-1, temperature=1.0):
-    # fix for old numpy on Travis not supporting keepdims
-    # x = x - np.max(x, axis=-1, keepdims=True)
     x = x - np.max(x, axis=axis, keepdims=True)
     x = np.exp(x/temperature)
-    # x /= np.sum(x, axis=-1, keepdims=True)
     x /= np.sum(x, axis=axis, keepdims=True)
     return x
 
@@ -4535,6 +4532,27 @@ def test_where():
     test_invalid_shape()
     test_1d_cond()
 
+
+@with_seed()
+def test_softmin():
+    for ndim in range(1, 5):
+        for dtype in [np.float16, np.float32, np.float64]:
+            rtol, atol = (1e-2, 5e-3) if dtype is np.float16 else (1e-3, 1e-3)
+            shape = np.random.randint(1, 5, size=ndim)
+            axis = np.random.randint(-ndim, ndim)
+            data = np.random.uniform(-2, 2, size=shape).astype(dtype)
+            data = data / 10 if dtype is np.float16 else data
+            sym = mx.sym.softmin(axis=axis)
+            expected_fwd = np_softmax(-data, axis=axis)
+            expected_bwd = np.zeros(shape)
+            check_symbolic_forward(sym, [data], [expected_fwd], atol=atol, 
dtype=dtype)
+            for req in ['null', 'add', 'write']:
+                check_symbolic_backward(sym, [data], 
[np.ones(expected_fwd.shape)], [expected_bwd],
+                                        rtol=rtol, atol=atol, grad_req=req, 
dtype=dtype)
+            if dtype is not np.float16:
+                check_numeric_gradient(sym, [data], rtol=rtol, atol=atol, 
dtype=dtype)
+
+
 @with_seed()
 def test_new_softmax():
     for ndim in range(1, 5):

Reply via email to