This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new ba8a9d1 support softmin operator with unit test (#12306) ba8a9d1 is described below commit ba8a9d13e1b549d061f1933c463cfad5e7bdd7aa Author: Hao Jin <haoj...@users.noreply.github.com> AuthorDate: Wed Aug 29 10:35:44 2018 -0700 support softmin operator with unit test (#12306) --- src/operator/contrib/ctc_loss-inl.h | 7 +-- src/operator/nn/softmax-inl.h | 88 ++++++++++++++++++++-------------- src/operator/nn/softmax.cc | 39 +++++++++++++++ src/operator/nn/softmax.cu | 7 +++ tests/python/unittest/test_operator.py | 24 ++++++++-- 5 files changed, 123 insertions(+), 42 deletions(-) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index 72209ae..9380be4 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -409,7 +409,8 @@ class CTCLossOp : public Operator { // since the input is activation before softmax and cudnn ctc takes softmax // apply softmax to inputs first. - mxnet_op::Softmax<mxnet_op::softmax_fwd>(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0); + mxnet_op::Softmax<mxnet_op::softmax_fwd, false>( + s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0); CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, @@ -426,8 +427,8 @@ class CTCLossOp : public Operator { workspace_bytes)); if (req_grad) { - mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd, kWriteTo>(s, - prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0); + mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd, kWriteTo, false>( + s, prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0); Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); } } diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 4a19db7..c063e38 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -51,7 +51,7 @@ struct log_softmax_fwd { }; -template<typename OP, typename DType, int ndim> +template<typename OP, bool negate, typename DType, int ndim> inline void Softmax(Stream<cpu> *s, DType *in, DType *out, Shape<ndim> shape, int axis, const DType temperature) { index_t M = shape[axis]; @@ -65,30 +65,37 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out, for (int i = 0; i < static_cast<int>(N); ++i) { index_t base = unravel_dot(i, sshape, stride); - DType mmax = in[base]; + DType mmax = negate ? -in[base] : in[base]; + DType val; for (index_t j = 1; j < M; ++j) { - if (mmax < in[base + j*sa]) mmax = in[base + j*sa]; + val = negate ? -in[base + j*sa] : in[base + j*sa]; + if (mmax < val) mmax = val; } DType sum = DType(0); + DType in_val; // By default temperature is 1.0, and only in reinforcement training // users would set it to other values. // Adding a branch here to save the CPU 'divide-by-1' computation at runtime if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) { - sum += std::exp(in[base + j*sa] - mmax); + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp(in_val - mmax); } for (index_t j = 0; j < M; ++j) { - out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map(in_val - mmax, sum); } } else { for (index_t j = 0; j < M; ++j) { - sum += std::exp((in[base + j*sa] - mmax)/temperature); + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + sum += std::exp((in_val - mmax)/temperature); } for (index_t j = 0; j < M; ++j) { - out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); + in_val = negate ? -in[base + j*sa] : in[base + j*sa]; + out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum); } } } @@ -111,7 +118,7 @@ struct log_softmax_bwd { }; -template<typename OP1, typename OP2, int Req, typename DType, int ndim> +template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim> inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd, DType *igrad, Shape<ndim> shape, int axis, const DType temperature) { @@ -137,12 +144,16 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd, DType final_result; if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) { - final_result = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); } } else { for (index_t j = 0; j < M; ++j) { - final_result = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); } } @@ -151,7 +162,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd, #ifdef __CUDACC__ -template<int x_bits, typename OP, typename DType, int ndim> +template<int x_bits, typename OP, bool negate, typename DType, int ndim> __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, Shape<ndim> sshape, Shape<ndim> stride, const double temperature) { @@ -163,7 +174,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi red::maximum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::maximum::Reduce(smem[x], in[base + i*sa]); + red::maximum::Reduce(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]); } __syncthreads(); cuda::Reduce1D<red::maximum, x_bits>(smem); @@ -172,9 +183,11 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi __syncthreads(); red::sum::SetInitValue(smem[x]); + DType val; for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], static_cast<DType>(expf((in[base + i*sa] - smax)/ - static_cast<DType>(temperature)))); + val = negate ? -in[base + i*sa]:in[base + i*sa]; + red::sum::Reduce( + smem[x], static_cast<DType>(expf((val - smax) / static_cast<DType>(temperature)))); } __syncthreads(); cuda::Reduce1D<red::sum, x_bits>(smem); @@ -183,11 +196,12 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi __syncthreads(); for (index_t i = x; i < M; i += x_size) { - out[base + i*sa] = OP::Map((in[base + i*sa] - smax)/static_cast<DType>(temperature), ssum); + val = negate ? -in[base + i*sa] : in[base + i*sa]; + out[base + i*sa] = OP::Map((val - smax)/static_cast<DType>(temperature), ssum); } } -template<typename OP, typename DType, int ndim> +template<typename OP, bool negate, typename DType, int ndim> inline void Softmax(Stream<gpu> *s, DType *in, DType *out, Shape<ndim> shape, int axis, const double temperature) { const int x_bits = 7; @@ -198,14 +212,14 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out, Shape<ndim> sshape = shape; sshape[axis] = 1; - softmax_compute_kernel<x_bits, OP, DType, ndim> + softmax_compute_kernel<x_bits, OP, negate, DType, ndim> <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>( in, out, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } -template<int x_bits, typename OP1, typename OP2, int Req, typename DType, int ndim> +template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim> __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, index_t M, int axis, Shape<ndim> sshape, Shape<ndim> stride, const double temperature) { @@ -228,13 +242,15 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, DType final_result; for (index_t i = x; i < M; i += x_size) { final_result = - OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) / static_cast<DType>(temperature); - KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result); + negate ? + -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) : + OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); + KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast<DType>(temperature)); } } -template<typename OP1, typename OP2, int Req, typename DType, int ndim> +template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim> inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd, DType *igrad, Shape<ndim> shape, int axis, const double temperature) { @@ -246,7 +262,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd, Shape<ndim> sshape = shape; sshape[axis] = 1; - softmax_gradient_kernel<x_bits, OP1, OP2, Req, DType, ndim> + softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, DType, ndim> <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>( out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); @@ -267,7 +283,7 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> { } }; -template<typename xpu, typename OP> +template<typename xpu, typename OP, bool negate = false> void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, @@ -283,19 +299,19 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { - Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), - outputs[0].dptr<DType>(), shape.get<2>(), axis, - static_cast<DType>(temperature)); + Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), + outputs[0].dptr<DType>(), shape.get<2>(), axis, + static_cast<DType>(temperature)); } else { - Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), - outputs[0].dptr<DType>(), shape.get<3>(), axis, - static_cast<DType>(temperature)); + Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), + outputs[0].dptr<DType>(), shape.get<3>(), axis, + static_cast<DType>(temperature)); } }); } -template<typename xpu, typename OP1, typename OP2> +template<typename xpu, typename OP1, typename OP2, bool negate = false> void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, @@ -311,13 +327,13 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (shape.ndim() == 2) { - SoftmaxGrad<OP1, OP2, Req>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(), - inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), - shape.get<2>(), axis, static_cast<DType>(temperature)); + SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(), + inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), + shape.get<2>(), axis, static_cast<DType>(temperature)); } else { - SoftmaxGrad<OP1, OP2, Req>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(), - inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), - shape.get<3>(), axis, static_cast<DType>(temperature)); + SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(), + inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), + shape.get<3>(), axis, static_cast<DType>(temperature)); } }); }); diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 0fad3d6..88b7b5f 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -109,6 +109,45 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax) .set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul, mxnet_op::softmax_bwd>); +MXNET_OPERATOR_REGISTER_UNARY(softmin) +.describe(R"code(Applies the softmin function. + +The resulting array contains elements in the range (0,1) and the elements along the given axis sum +up to 1. + +.. math:: + softmin(\mathbf{z/t})_j = \frac{e^{-z_j/t}}{\sum_{k=1}^K e^{-z_k/t}} + +for :math:`j = 1, ..., K` + +t is the temperature parameter in softmax function. By default, t equals 1.0 + +Example:: + + x = [[ 1. 2. 3.] + [ 3. 2. 1.]] + + softmin(x,axis=0) = [[ 0.88079703, 0.5, 0.11920292], + [ 0.11920292, 0.5, 0.88079703]] + + softmin(x,axis=1) = [[ 0.66524094, 0.24472848, 0.09003057], + [ 0.09003057, 0.24472848, 0.66524094]] + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser<SoftmaxParam>) +.set_attr<nnvm::FListOutputNames>("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector<std::string>{"output"}; +}) +.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd, true>) +.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_softmin"}) +.add_arguments(SoftmaxParam::__FIELDS__()); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_softmin) +.set_attr_parser(ParamParser<SoftmaxParam>) +.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul, + mxnet_op::softmax_bwd, true>); + MXNET_OPERATOR_REGISTER_UNARY(log_softmax) .describe(R"code(Computes the log softmax of the input. This is equivalent to computing softmax followed by log. diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu index 8274642..254e726 100644 --- a/src/operator/nn/softmax.cu +++ b/src/operator/nn/softmax.cu @@ -35,6 +35,13 @@ NNVM_REGISTER_OP(_backward_softmax) .set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, op::mshadow_op::mul, mxnet_op::softmax_bwd>); +NNVM_REGISTER_OP(softmin) +.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, mxnet_op::softmax_fwd, true>); + +NNVM_REGISTER_OP(_backward_softmin) +.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, op::mshadow_op::mul, + mxnet_op::softmax_bwd, true>); + NNVM_REGISTER_OP(log_softmax) .set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, mxnet_op::log_softmax_fwd>); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 38c90e6..5bd88dd 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -267,11 +267,8 @@ def test_rnnrelu_dropout(): out[0].wait_to_read() def np_softmax(x, axis=-1, temperature=1.0): - # fix for old numpy on Travis not supporting keepdims - # x = x - np.max(x, axis=-1, keepdims=True) x = x - np.max(x, axis=axis, keepdims=True) x = np.exp(x/temperature) - # x /= np.sum(x, axis=-1, keepdims=True) x /= np.sum(x, axis=axis, keepdims=True) return x @@ -4535,6 +4532,27 @@ def test_where(): test_invalid_shape() test_1d_cond() + +@with_seed() +def test_softmin(): + for ndim in range(1, 5): + for dtype in [np.float16, np.float32, np.float64]: + rtol, atol = (1e-2, 5e-3) if dtype is np.float16 else (1e-3, 1e-3) + shape = np.random.randint(1, 5, size=ndim) + axis = np.random.randint(-ndim, ndim) + data = np.random.uniform(-2, 2, size=shape).astype(dtype) + data = data / 10 if dtype is np.float16 else data + sym = mx.sym.softmin(axis=axis) + expected_fwd = np_softmax(-data, axis=axis) + expected_bwd = np.zeros(shape) + check_symbolic_forward(sym, [data], [expected_fwd], atol=atol, dtype=dtype) + for req in ['null', 'add', 'write']: + check_symbolic_backward(sym, [data], [np.ones(expected_fwd.shape)], [expected_bwd], + rtol=rtol, atol=atol, grad_req=req, dtype=dtype) + if dtype is not np.float16: + check_numeric_gradient(sym, [data], rtol=rtol, atol=atol, dtype=dtype) + + @with_seed() def test_new_softmax(): for ndim in range(1, 5):