This is an automated email from the ASF dual-hosted git repository. anirudh2290 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new a2bafe1 [MXNET-912] Refactoring ctc loss operator (#12637) a2bafe1 is described below commit a2bafe177940e9966949c6f63cda421f10b22b19 Author: Lin Yuan <apefor...@gmail.com> AuthorDate: Mon Oct 8 18:36:58 2018 -0700 [MXNET-912] Refactoring ctc loss operator (#12637) * Implement ctc_loss as a normal operator * Update unit test * Update unit test and fix bug in backward * fix lint error * refactoring * Fix compilation error in CUDA * Fix CPU compilation error * Move ctc_include to nn folder and refactor * temporarily disable lint on 3rd party includes * move ctc_include to 3rdparty * remove contrib ctc_loss operator * revert a change by mistake * Fix a bug in kDevCPU * revert change by mistake * add alias to make it backward compatible * add unit test for backward compatibility * linting --- .../contrib => 3rdparty}/ctc_include/LICENSE | 0 .../ctc_include/contrib/moderngpu/LICENSE | 0 .../moderngpu/include/device/ctaloadbalance.cuh | 0 .../contrib/moderngpu/include/device/ctamerge.cuh | 0 .../contrib/moderngpu/include/device/ctascan.cuh | 0 .../contrib/moderngpu/include/device/ctasearch.cuh | 0 .../moderngpu/include/device/ctasegreduce.cuh | 0 .../moderngpu/include/device/ctasegscan.cuh | 0 .../moderngpu/include/device/ctasegsort.cuh | 0 .../moderngpu/include/device/ctasortedsearch.cuh | 0 .../moderngpu/include/device/devicetypes.cuh | 0 .../moderngpu/include/device/deviceutil.cuh | 0 .../moderngpu/include/device/intrinsics.cuh | 0 .../contrib/moderngpu/include/device/loadstore.cuh | 0 .../moderngpu/include/device/serialsets.cuh | 0 .../moderngpu/include/device/sortnetwork.cuh | 0 .../contrib/moderngpu/include/mgpudevice.cuh | 0 .../contrib/moderngpu/include/mgpuenums.h | 0 .../contrib/moderngpu/include/util/static.h | 0 .../ctc_include/detail/cpu_ctc.h | 0 .../ctc_include/detail/ctc_helper.h | 0 .../ctc_include/detail/gpu_ctc.h | 0 .../ctc_include/detail/gpu_ctc_kernels.h | 0 .../ctc_include/detail/hostdevice.h | 0 python/mxnet/gluon/loss.py | 8 +- src/operator/contrib/ctc_loss-inl.h | 591 --------------------- src/operator/nn/ctc_loss-inl.h | 397 ++++++++++++++ src/operator/{contrib => nn}/ctc_loss.cc | 87 +-- src/operator/{contrib => nn}/ctc_loss.cu | 27 +- tests/python/unittest/test_operator.py | 114 +++- 30 files changed, 573 insertions(+), 651 deletions(-) diff --git a/src/operator/contrib/ctc_include/LICENSE b/3rdparty/ctc_include/LICENSE similarity index 100% rename from src/operator/contrib/ctc_include/LICENSE rename to 3rdparty/ctc_include/LICENSE diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/LICENSE b/3rdparty/ctc_include/contrib/moderngpu/LICENSE similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/LICENSE rename to 3rdparty/ctc_include/contrib/moderngpu/LICENSE diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctascan.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctascan.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctascan.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctascan.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/loadstore.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/loadstore.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/loadstore.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/loadstore.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/serialsets.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/serialsets.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/serialsets.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/serialsets.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpudevice.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/mgpudevice.cuh similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpudevice.cuh rename to 3rdparty/ctc_include/contrib/moderngpu/include/mgpudevice.cuh diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpuenums.h b/3rdparty/ctc_include/contrib/moderngpu/include/mgpuenums.h similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpuenums.h rename to 3rdparty/ctc_include/contrib/moderngpu/include/mgpuenums.h diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/util/static.h b/3rdparty/ctc_include/contrib/moderngpu/include/util/static.h similarity index 100% rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/util/static.h rename to 3rdparty/ctc_include/contrib/moderngpu/include/util/static.h diff --git a/src/operator/contrib/ctc_include/detail/cpu_ctc.h b/3rdparty/ctc_include/detail/cpu_ctc.h similarity index 100% rename from src/operator/contrib/ctc_include/detail/cpu_ctc.h rename to 3rdparty/ctc_include/detail/cpu_ctc.h diff --git a/src/operator/contrib/ctc_include/detail/ctc_helper.h b/3rdparty/ctc_include/detail/ctc_helper.h similarity index 100% rename from src/operator/contrib/ctc_include/detail/ctc_helper.h rename to 3rdparty/ctc_include/detail/ctc_helper.h diff --git a/src/operator/contrib/ctc_include/detail/gpu_ctc.h b/3rdparty/ctc_include/detail/gpu_ctc.h similarity index 100% rename from src/operator/contrib/ctc_include/detail/gpu_ctc.h rename to 3rdparty/ctc_include/detail/gpu_ctc.h diff --git a/src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h b/3rdparty/ctc_include/detail/gpu_ctc_kernels.h similarity index 100% rename from src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h rename to 3rdparty/ctc_include/detail/gpu_ctc_kernels.h diff --git a/src/operator/contrib/ctc_include/detail/hostdevice.h b/3rdparty/ctc_include/detail/hostdevice.h similarity index 100% rename from src/operator/contrib/ctc_include/detail/hostdevice.h rename to 3rdparty/ctc_include/detail/hostdevice.h diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 2be4398..7e4d345 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -468,10 +468,10 @@ class CTCLoss(Loss): pred = F.swapaxes(pred, 0, 1) if self._batch_axis == 1: label = F.swapaxes(label, 0, 1) - loss = F.contrib.CTCLoss(pred, label, pred_lengths, label_lengths, - use_data_lengths=pred_lengths is not None, - use_label_lengths=label_lengths is not None, - blank_label='last') + loss = F.CTCLoss(pred, label, pred_lengths, label_lengths, + use_data_lengths=pred_lengths is not None, + use_label_lengths=label_lengths is not None, + blank_label='last') return _apply_weighting(F, loss, self._weight, sample_weight) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h deleted file mode 100644 index c8a8b26..0000000 --- a/src/operator/contrib/ctc_loss-inl.h +++ /dev/null @@ -1,591 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2016 by Contributors - * \file ctc_loss-inl.h - * \brief - * \author Sebastian Bodenstien -*/ - -#ifndef MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_ -#define MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_ - -#include <dmlc/logging.h> -#include <dmlc/parameter.h> -#include <mxnet/operator.h> -#include <algorithm> -#include <map> -#include <vector> -#include <string> -#include <utility> -#include <ctime> -#include <cstring> -#include <iostream> -#include "../operator_common.h" -#include "../sequence_op_common.h" -#include "../mshadow_op.h" -#include "../nn/sequence_mask-inl.h" - -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 -#define CUDNN_LABEL_LENGTH_LIMIT 256 -#include "../nn/softmax-inl.h" -#endif // CUDNN - -namespace mxnet { -namespace op { - -namespace ctc_loss { -enum CTCLossOpInputs { kData, kLabel }; -enum CTCLossOpOutputs { kOut, kGrad }; -enum CTCLossOpForwardResource { kTempSpace }; -} - -template <typename T> -inline void get_workspace_size(std::vector<int> *label_lengths, - std::vector<int> *data_lengths, - int alphabet_size, int minibatch, bool gpu, - size_t *size_bytes) { - // This is the max of all S and T for all examples in the minibatch. - int maxL = *std::max_element(label_lengths->data(), - label_lengths->data() + minibatch); - int maxT = *std::max_element(data_lengths->data(), - data_lengths->data() + minibatch); - - const int S = 2 * maxL + 1; - - *size_bytes = 0; - - if (gpu) { - // GPU storage - // nll_forward, nll_backward - *size_bytes += 2 * sizeof(T) * minibatch; - - // repeats - *size_bytes += sizeof(int) * minibatch; - - // label offsets - *size_bytes += sizeof(int) * minibatch; - - // utt_length - *size_bytes += sizeof(int) * minibatch; - - // label lengths - *size_bytes += sizeof(int) * minibatch; - - // labels without blanks - overallocate for now - *size_bytes += sizeof(int) * maxL * minibatch; - - // labels with blanks - *size_bytes += sizeof(int) * S * minibatch; - - // alphas - *size_bytes += sizeof(T) * S * maxT * minibatch; - - // denoms - *size_bytes += sizeof(T) * maxT * minibatch; - - // probs (since we will pass in activations) - *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch; - - } else { - // cpu can eventually replace all minibatch with - // max number of concurrent threads if memory is - // really tight - - // per minibatch memory - size_t per_minibatch_bytes = 0; - - // output - per_minibatch_bytes += sizeof(T) * alphabet_size; - - // alphas - per_minibatch_bytes += sizeof(T) * S * maxT; - - // betas - per_minibatch_bytes += sizeof(T) * S; - - // labels w/blanks, e_inc, s_inc - per_minibatch_bytes += 3 * sizeof(int) * S; - - *size_bytes = per_minibatch_bytes * minibatch; - - // probs - *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch; - } -} - -// Takes a tensor of labels, and interprets 0-elements at the end of the vector -// as padding. The tensor is packed into an std::vector without padding -// characters. The label sequence lengths are also inferred from the padding chars. -// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded. -template <typename DType, typename xpu> -inline bool LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels, - int padding_mask, - std::vector<int> *packed_labels, - std::vector<int> *label_lengths) { - int batch = labels.size(0); - int max_num_labels = labels.size(1); - bool exceed_limit = false; - - std::vector<int> cpu_labels(max_num_labels*batch); - mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D(); - IndexTensorToVector(flat_labels, &cpu_labels); - - for (int b = 0; b < batch; ++b) { - auto start = cpu_labels.data()+b*max_num_labels; - auto res = std::find(start, start+max_num_labels, padding_mask); - int len = std::distance(start, res); -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT; -#endif - std::copy(start, start + len, - std::back_inserter(*packed_labels)); - label_lengths->at(b) = len; - } - return exceed_limit; -} - -// Takes a tensor of labels, and a vector which specifies the actual length of each label -// The tensor is packed into an std::vector without padding characters. -// The label length vector is copied into an std::vector. -// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded. -template <typename DType, typename xpu> -inline bool PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels, - mshadow::Tensor<xpu, 1, DType> in_label_lengths, - std::vector<int> *packed_labels, - std::vector<int> *label_lengths) { - int batch = labels.size(0); - int max_num_labels = labels.size(1); - bool exceed_limit = false; - - IndexTensorToVector(in_label_lengths, label_lengths); - - std::vector<int> cpu_labels(max_num_labels*batch); - mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D(); - IndexTensorToVector(flat_labels, &cpu_labels); - - for (int b = 0; b < batch; ++b) { - auto start = cpu_labels.data()+b*max_num_labels; - int len = label_lengths->at(b); -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT; -#endif - std::copy(start, start + len, - std::back_inserter(*packed_labels)); - } - return exceed_limit; -} - -struct CTCLossParam : public dmlc::Parameter<CTCLossParam> { - bool use_data_lengths; - bool use_label_lengths; - int blank_label; - DMLC_DECLARE_PARAMETER(CTCLossParam) { - DMLC_DECLARE_FIELD(use_data_lengths).set_default(false) - .describe("Whether the data lenghts are decided by `data_lengths`. " - "If false, the lengths are equal to the max sequence length."); - DMLC_DECLARE_FIELD(use_label_lengths).set_default(false) - .describe("Whether the label lenghts are decided by " - "`label_lengths`, or derived from `padding_mask`. " - "If false, the lengths are derived from the " - "first occurrence of the value of `padding_mask`. " - "The value of `padding_mask` is ``0`` when first CTC label is reserved for blank, " - "and ``-1`` when last label is reserved for blank. See `blank_label`."); - DMLC_DECLARE_FIELD(blank_label) - .add_enum("first", 0) - .add_enum("last", 1) - .set_default(0) - .describe("Set the label that is reserved for blank label." - "If \"first\", 0-th label is reserved, and " - "label values for tokens in the vocabulary are " - "between ``1`` and ``alphabet_size-1``, and the padding mask is ``-1``. " - "If \"last\", last label value ``alphabet_size-1`` " - "is reserved for blank label instead, " - "and label values for tokens in the vocabulary are " - "between ``0`` and ``alphabet_size-2``, and the padding mask is ``0``."); - } -}; - -template <typename xpu> -class CTCLossOp : public Operator { - public: - explicit CTCLossOp(CTCLossParam p) { - this->param_ = p; - exceed_cudnn_limit = false; -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - CUDNN_CALL(cudnnCreateCTCLossDescriptor(&ctc_desc_)); - CUDNN_CALL(cudnnSetCTCLossDescriptor(ctc_desc_, CUDNN_DATA_FLOAT)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&prob_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&grad_desc_)); -#endif - } - - ~CTCLossOp() { -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - CUDNN_CALL(cudnnDestroyCTCLossDescriptor(ctc_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(prob_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(grad_desc_)); -#endif - } - - virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &out_data, - const std::vector<TBlob> &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 2U+param_.use_data_lengths+param_.use_label_lengths); - CHECK_EQ(out_data.size(), 2U); - exceed_cudnn_limit = false; - Stream<xpu> *s = ctx.get_stream<xpu>(); - - MSHADOW_TYPE_SWITCH(in_data[ctc_loss::kLabel].type_flag_, DType, { - Tensor<xpu, 3, real_t> data = - in_data[ctc_loss::kData].get<xpu, 3, real_t>(s); - Tensor<xpu, 2, DType> labels = - in_data[ctc_loss::kLabel].get<xpu, 2, DType>(s); - - Tensor<xpu, 1, real_t> costs = - out_data[ctc_loss::kOut].get<xpu, 1, real_t>(s); - Tensor<xpu, 3, real_t> grad = - out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s); - - int max_seq_len = data.size(0); - int batch_size = data.size(1); - int alphabet_size = data.size(2); - - // data_lengths - std::vector<int> data_lengths(batch_size, max_seq_len); - if (param_.use_data_lengths) { - int kInputLength = 2; - IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths); - } - - // label_lengths - std::vector<int> packed_labels; - std::vector<int> label_lengths(batch_size); - - if (param_.use_label_lengths) { - int kLabelLength = 2 + param_.use_data_lengths; - exceed_cudnn_limit = - PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, DType>(s), - &packed_labels, &label_lengths); - } else { - exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0 ? 0 : -1, - &packed_labels, &label_lengths); - } - - // CUDNN is disabled due to lack of support for input lengths - /* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */ - /* if (!exceed_cudnn_limit) { */ - /* cudnn_forward(ctx, s, data, costs, grad, */ - /* &data_lengths, &label_lengths, &packed_labels, */ - /* max_seq_len, batch_size, alphabet_size, */ - /* req[ctc_loss::kGrad] != mxnet::kNullOp); */ - /* } else { */ - /* baidu_forward(ctx, s, data, costs, grad, */ - /* &data_lengths, &label_lengths, &packed_labels, */ - /* batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);*/ - /* } */ - /* #else */ - - baidu_forward(ctx, s, data, costs, grad, - &data_lengths, &label_lengths, &packed_labels, - batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp); - - if (param_.use_data_lengths) { - // baidu warp CTC implementation sometimes includes undefined gradients - // for data outside of length mask. Setting to 0 to make it consistent - // with CPU implementation. - int kInputLength = 2; - mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s), - static_cast<real_t>(0)); - } - }); - } - - virtual void Backward(const OpContext &ctx, - const std::vector<TBlob> &out_grad, - const std::vector<TBlob> &in_data, - const std::vector<TBlob> &out_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &in_grad, - const std::vector<TBlob> &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - - Stream<xpu> *s = ctx.get_stream<xpu>(); - - Tensor<xpu, 3, real_t> data_grad = - in_grad[ctc_loss::kData].get<xpu, 3, real_t>(s); - Tensor<xpu, 1, real_t> output_grad = - out_grad[ctc_loss::kOut].get<xpu, 1, real_t>(s); - - Tensor<xpu, 3, real_t> data_grad_computed = - out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s); - - Assign(data_grad, req[ctc_loss::kData], - mshadow::expr::broadcast<1>(output_grad, data_grad.shape_) * data_grad_computed); - } - - private: - CTCLossParam param_; - bool exceed_cudnn_limit; - -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - cudnnDataType_t dtype_; - cudnnCTCLossDescriptor_t ctc_desc_; - cudnnTensorDescriptor_t prob_desc_, grad_desc_; - - inline virtual void cudnn_forward(const OpContext &ctx, - mshadow::Stream<xpu>* s, - mshadow::Tensor<xpu, 3, real_t> data, - mshadow::Tensor<xpu, 1, real_t> costs, - mshadow::Tensor<xpu, 3, real_t> grad, - std::vector<int>* data_lengths, - std::vector<int>* label_lengths, - std::vector<int>* packed_labels, - int max_seq_len, - int batch_size, - int alphabet_size, - bool req_grad) { - using namespace mshadow; - - // call cudnn to calculate ctc loss - dtype_ = CUDNN_DATA_FLOAT; - int dims[3], strides[3]; - size_t workspace_bytes; - int workspace_size; - dims[0] = max_seq_len; - dims[1] = batch_size; - dims[2] = alphabet_size; - strides[0] = batch_size*alphabet_size; - strides[1] = alphabet_size; - strides[2] = 1; - cudnnCTCLossAlgo_t ctc_algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC; - CUDNN_CALL(cudnnSetTensorNdDescriptor(prob_desc_, - dtype_, - 3, - dims, - strides)); - CUDNN_CALL(cudnnSetTensorNdDescriptor(grad_desc_, - dtype_, - 3, - dims, - strides)); - CUDNN_CALL(cudnnGetCTCLossWorkspaceSize(s->dnn_handle_, - prob_desc_, - req_grad?grad_desc_:NULL, - packed_labels->data(), - label_lengths->data(), - data_lengths->data(), - ctc_algo, - ctc_desc_, - &workspace_bytes)); - workspace_size = (workspace_bytes + sizeof(real_t) - 1)/sizeof(real_t); - - Tensor<xpu, 1, real_t> temp_space = - ctx.requested[ctc_loss::kTempSpace].get_space_typed<xpu, 1, real_t>( - mshadow::Shape1(workspace_size+data.shape_.FlatTo1D()[0]), s); - - Tensor<gpu, 1, real_t> work_space(temp_space.dptr_, - mshadow::Shape1(workspace_size), s); - Tensor<xpu, 3, real_t> prob(temp_space.dptr_+workspace_size, - data.shape_, s); - - // since the input is activation before softmax and cudnn ctc takes softmax - // apply softmax to inputs first. - mxnet_op::Softmax<mxnet_op::softmax_fwd, false>( - s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0); - - CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, - prob_desc_, - prob.dptr_, - packed_labels->data(), - label_lengths->data(), - data_lengths->data(), - costs.dptr_, - req_grad?grad_desc_:NULL, - req_grad?grad.dptr_:NULL, - ctc_algo, - ctc_desc_, - work_space.dptr_, - workspace_bytes)); - - if (req_grad) { - mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd, kWriteTo, false>( - s, prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0); - Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); - } - } -#endif // __CUDACC__ && CUDNN - - inline void baidu_forward(const OpContext &ctx, - mshadow::Stream<xpu>* s, - mshadow::Tensor<xpu, 3, real_t> data, - mshadow::Tensor<xpu, 1, real_t> costs, - mshadow::Tensor<xpu, 3, real_t> grad, - std::vector<int>* data_lengths, - std::vector<int>* label_lengths, - std::vector<int>* packed_labels, - int batch_size, - int alphabet_size, - bool req_grad) { - using namespace mshadow; - // allocate temporary workspace - size_t size_bytes; - bool gpu = data.kDevCPU ? false : true; - get_workspace_size<real_t>(label_lengths, data_lengths, alphabet_size, - batch_size, gpu, &size_bytes); - - // round-up so there are enough elems in memory - int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t); - Tensor<xpu, 1, real_t> workspace = - ctx.requested[ctc_loss::kTempSpace].get_space_typed<xpu, 1, real_t>( - Shape1(num_tmp_elems), s); - - compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(), - label_lengths->data(), data_lengths->data(), - workspace.dptr_, req_grad, - param_.blank_label == 0 ? 0 : (alphabet_size-1)); - } -}; // class CTCLossOp - -template <typename xpu> -Operator *CreateOp(CTCLossParam param, int dtype); - -#if DMLC_USE_CXX11 -class CTCLossProp : public OperatorProperty { - public: - int NumVisibleOutputs() const override { return 1; } - - int NumOutputs() const override { return 2; } - - std::vector<std::string> ListArguments() const override { - if (param_.use_data_lengths && param_.use_label_lengths) { - return {"data", "label", "data_lengths", "label_lengths"}; - } else if (param_.use_data_lengths) { - return {"data", "label", "data_lengths"}; - } else if (param_.use_label_lengths) { - return {"data", "label", "label_lengths"}; - } else { - return {"data", "label"}; - } - } - - std::vector<std::string> ListOutputs() const override { - return {"output", "grad"}; - } - - void Init(const std::vector<std::pair<std::string, std::string>> &kwargs) override { - param_.Init(kwargs); - } - - std::map<std::string, std::string> GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape, - std::vector<TShape> *aux_shape) const override { - using namespace mshadow; - index_t expected_inputs = 2+param_.use_data_lengths+param_.use_label_lengths; - CHECK_EQ(in_shape->size(), expected_inputs) - << "Expect " << expected_inputs << " inputs to the symbol."; - - const TShape &dshape = (*in_shape)[ctc_loss::kData]; - const TShape &lshape = (*in_shape)[ctc_loss::kLabel]; - CHECK_EQ(dshape.ndim(), 3U) << "The data array must be of rank 3."; - CHECK_EQ(lshape.ndim(), 2U) << "The labels array must be of rank 2."; - CHECK_EQ(dshape[1], lshape[0]) - << "The batch size for the labels and data arrays must be the same."; - if (param_.use_data_lengths) { - int kInputLength = 2; - const TShape &dlshape = (*in_shape)[kInputLength]; - CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector."; - CHECK_EQ(dlshape[0], dshape[1]) - << "The batch size for the data and data lengths must be the same."; - } - if (param_.use_label_lengths) { - int kLabelLength = 2+param_.use_data_lengths; - const TShape &llshape = (*in_shape)[kLabelLength]; - CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector."; - CHECK_EQ(llshape[0], lshape[0]) - << "The batch size for the labels and label lengths must be the same."; - } - - CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed " - "the maximum sequence length of the " - "data."; - - TShape oshape(1); - oshape[0] = dshape[1]; // batch size - out_shape->clear(); - out_shape->push_back(oshape); // forward output - out_shape->push_back(dshape); // grad output - return true; - } - - bool InferType(std::vector<int> *in_type, - std::vector<int> *out_type, - std::vector<int> *aux_type) const override { - CHECK_LE(in_type->size(), this->ListArguments().size()); - int dtype = (*in_type)[ctc_loss::kData]; - CHECK_NE(dtype, -1) << "Input data must have specified type"; - - out_type->clear(); - out_type->push_back(dtype); // forward output - out_type->push_back(dtype); // grad output - return true; - } - - OperatorProperty *Copy() const override { - auto ptr = new CTCLossProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { return "_contrib_CTCLoss"; } - - std::vector<ResourceRequest> ForwardResource( - const std::vector<TShape> &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - std::vector<int> DeclareBackwardDependency( - const std::vector<int> &out_grad, const std::vector<int> &in_data, - const std::vector<int> &out_data) const override { - return {out_grad[ctc_loss::kOut], out_data[ctc_loss::kGrad]}; - } - - Operator *CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not Implemented."; - return NULL; - } - - Operator *CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, - std::vector<int> *in_type) const override; - - private: - CTCLossParam param_; -}; // class CTCLossProp -#endif // DMLC_USE_CXX11 -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_ diff --git a/src/operator/nn/ctc_loss-inl.h b/src/operator/nn/ctc_loss-inl.h new file mode 100644 index 0000000..754cf84 --- /dev/null +++ b/src/operator/nn/ctc_loss-inl.h @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file ctc_loss-inl.h + * \brief CTC Loss operator +*/ + +#ifndef MXNET_OPERATOR_NN_CTC_LOSS_INL_H_ +#define MXNET_OPERATOR_NN_CTC_LOSS_INL_H_ + +#include <mxnet/operator_util.h> +#include <vector> +#include <algorithm> +#include <string> +#include "../mshadow_op.h" +#include "./sequence_mask-inl.h" +#include "../sequence_op_common.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +namespace ctc_loss { +enum CTCLossOpInputs { kData, kLabel }; +enum CTCLossOpOutputs { kOut, kGrad }; +} + +template <typename T> +inline void get_workspace_size(const std::vector<int> *label_lengths, + const std::vector<int> *data_lengths, + int alphabet_size, int minibatch, bool isGPU, + size_t *size_bytes) { + // This is the max of all S and T for all examples in the minibatch. + int maxL = *std::max_element(label_lengths->data(), + label_lengths->data() + minibatch); + int maxT = *std::max_element(data_lengths->data(), + data_lengths->data() + minibatch); + + const int S = 2 * maxL + 1; + + *size_bytes = 0; + + if (isGPU) { + // GPU storage + // nll_forward, nll_backward + *size_bytes += 2 * sizeof(T) * minibatch; + + // repeats + *size_bytes += sizeof(int) * minibatch; + + // label offsets + *size_bytes += sizeof(int) * minibatch; + + // utt_length + *size_bytes += sizeof(int) * minibatch; + + // label lengths + *size_bytes += sizeof(int) * minibatch; + + // labels without blanks - overallocate for now + *size_bytes += sizeof(int) * maxL * minibatch; + + // labels with blanks + *size_bytes += sizeof(int) * S * minibatch; + + // alphas + *size_bytes += sizeof(T) * S * maxT * minibatch; + + // denoms + *size_bytes += sizeof(T) * maxT * minibatch; + + // probs (since we will pass in activations) + *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch; + + } else { + // cpu can eventually replace all minibatch with + // max number of concurrent threads if memory is + // really tight + + // per minibatch memory + size_t per_minibatch_bytes = 0; + + // output + per_minibatch_bytes += sizeof(T) * alphabet_size; + + // alphas + per_minibatch_bytes += sizeof(T) * S * maxT; + + // betas + per_minibatch_bytes += sizeof(T) * S; + + // labels w/blanks, e_inc, s_inc + per_minibatch_bytes += 3 * sizeof(int) * S; + + *size_bytes = per_minibatch_bytes * minibatch; + + // probs + *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch; + } +} + +// Takes a tensor of labels, and interprets 0-elements at the end of the vector +// as padding. The tensor is packed into an std::vector without padding +// characters. The label sequence lengths are also inferred from the padding chars. +template <typename DType, typename xpu> +inline void LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels, + int padding_mask, + std::vector<int> *packed_labels, + std::vector<int> *label_lengths) { + int batch = labels.size(0); + int max_num_labels = labels.size(1); + + std::vector<int> cpu_labels(max_num_labels * batch); + mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D(); + IndexTensorToVector(flat_labels, &cpu_labels); + + for (int b = 0; b < batch; ++b) { + auto start = cpu_labels.data() + b * max_num_labels; + auto res = std::find(start, start + max_num_labels, padding_mask); + int len = std::distance(start, res); + std::copy(start, start + len, + std::back_inserter(*packed_labels)); + label_lengths->at(b) = len; + } +} + +// Takes a tensor of labels, and a vector which specifies the actual length of each label +// The tensor is packed into an std::vector without padding characters. +// The label length vector is copied into an std::vector. +template <typename DType, typename xpu> +inline void PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels, + mshadow::Tensor<xpu, 1, DType> in_label_lengths, + std::vector<int> *packed_labels, + std::vector<int> *label_lengths) { + int batch = labels.size(0); + int max_num_labels = labels.size(1); + + IndexTensorToVector(in_label_lengths, label_lengths); + + std::vector<int> cpu_labels(max_num_labels * batch); + mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D(); + IndexTensorToVector(flat_labels, &cpu_labels); + + for (int b = 0; b < batch; ++b) { + auto start = cpu_labels.data() + b * max_num_labels; + int len = label_lengths->at(b); + std::copy(start, start + len, + std::back_inserter(*packed_labels)); + } +} + +struct CTCLossOpParam : public dmlc::Parameter<CTCLossOpParam> { + bool use_data_lengths; + bool use_label_lengths; + int blank_label; + DMLC_DECLARE_PARAMETER(CTCLossOpParam) { + DMLC_DECLARE_FIELD(use_data_lengths).set_default(false) + .describe("Whether the data lenghts are decided by `data_lengths`. " + "If false, the lengths are equal to the max sequence length."); + DMLC_DECLARE_FIELD(use_label_lengths).set_default(false) + .describe("Whether the label lenghts are decided by " + "`label_lengths`, or derived from `padding_mask`. " + "If false, the lengths are derived from the " + "first occurrence of the value of `padding_mask`. " + "The value of `padding_mask` is ``0`` when first CTC label is reserved for blank, " + "and ``-1`` when last label is reserved for blank. See `blank_label`."); + DMLC_DECLARE_FIELD(blank_label) + .add_enum("first", 0) + .add_enum("last", 1) + .set_default(0) + .describe("Set the label that is reserved for blank label." + "If \"first\", 0-th label is reserved, and " + "label values for tokens in the vocabulary are " + "between ``1`` and ``alphabet_size-1``, and the padding mask is ``-1``. " + "If \"last\", last label value ``alphabet_size-1`` " + "is reserved for blank label instead, " + "and label values for tokens in the vocabulary are " + "between ``0`` and ``alphabet_size-2``, and the padding mask is ``0``."); + } +}; + +// By default, the inputs must include data array and label array +// if use_data_lengths parameter is set, user should also supply +// data_lengths array; if use_label_lengths parameter is set, user +// should also specify label_lengths array. +inline uint32_t CTCLossOpNumInputs(const NodeAttrs& attrs) { + const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed); + return 2U + param.use_data_lengths + param.use_label_lengths; +} + +inline bool CTCLossOpShape(const nnvm::NodeAttrs &attrs, + std::vector<TShape>* in_attrs, + std::vector<TShape>* out_attrs) { + const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed); + CHECK_EQ(in_attrs->size(), CTCLossOpNumInputs(attrs)); + CHECK_EQ(out_attrs->size(), 2U); + + const TShape &dshape = (*in_attrs)[ctc_loss::kData]; + const TShape &lshape = (*in_attrs)[ctc_loss::kLabel]; + CHECK_EQ(dshape.ndim(), 3U) << "The number of dimensions of data array must be 3."; + CHECK_EQ(lshape.ndim(), 2U) << "The number of dimensions of labels array must be 2."; + CHECK_EQ(dshape[1], lshape[0]) + << "The batch size for the labels and data arrays must be the same."; + + if (param.use_data_lengths) { + int kInputLength = 2; + const TShape &dlshape = (*in_attrs)[kInputLength]; + CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector."; + CHECK_EQ(dlshape[0], dshape[1]) + << "The batch size for the data and data lengths must be the same."; + } + if (param.use_label_lengths) { + int kLabelLength = 2 + param.use_data_lengths; + const TShape &llshape = (*in_attrs)[kLabelLength]; + CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector."; + CHECK_EQ(llshape[0], lshape[0]) + << "The batch size for the labels and label lengths must be the same."; + } + CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed " + "the maximum sequence length of the " + "data."; + + TShape oshape(1); + oshape[0] = dshape[1]; // batch size + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); // forward output + SHAPE_ASSIGN_CHECK(*out_attrs, 1, dshape); // grad output + return true; +} + +inline bool CTCLossOpType(const nnvm::NodeAttrs& attrs, + std::vector<int>* in_attrs, + std::vector<int>* out_attrs) { + CHECK_GE(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + int dtype = (*in_attrs)[ctc_loss::kData]; + CHECK_NE(dtype, -1) << "Input data must have specified type"; + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); // forward output + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0)); // grad output + return true; +} + +inline bool CTCLossOpStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int>* in_attrs, + std::vector<int>* out_attrs) { + CHECK_GE(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + const int in_stype = in_attrs->at(0); + bool dispatched = false; + if (!dispatched && in_stype == kDefaultStorage) { + // dns -> dns + dispatched = storage_type_assign(out_attrs, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + if (!dispatched) { + dispatched = dispatch_fallback(out_attrs, dispatch_mode); + } + return dispatched; +} + + +inline std::vector<std::string> CTCLossOpListInputNames(const NodeAttrs& attrs) { + const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed); + if (param.use_data_lengths && param.use_label_lengths) { + return {"data", "label", "data_lengths", "label_lengths"}; + } else if (param.use_data_lengths) { + return {"data", "label", "data_lengths"}; + } else if (param.use_label_lengths) { + return {"data", "label", "label_lengths"}; + } else { + return {"data", "label"}; + } +} + +template<typename xpu> +void CTCLossOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + + const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed); + CHECK_EQ(inputs.size(), CTCLossOpNumInputs(attrs)); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + + const TBlob& in_data = inputs[ctc_loss::kData]; + const TBlob& in_label = inputs[ctc_loss::kLabel]; + const TBlob& out_data = outputs[ctc_loss::kOut]; + const TBlob& out_grad = outputs[ctc_loss::kGrad]; + + Stream<xpu> *s = ctx.get_stream<xpu>(); + MSHADOW_TYPE_SWITCH(inputs[ctc_loss::kLabel].type_flag_, DType, { + Tensor<xpu, 3, real_t> data = in_data.get<xpu, 3, real_t>(s); + Tensor<xpu, 2, DType> labels = in_label.get<xpu, 2, DType>(s); + Tensor<xpu, 1, real_t> costs = out_data.get<xpu, 1, real_t>(s); + Tensor<xpu, 3, real_t> grad = out_grad.get<xpu, 3, real_t>(s); + + int max_seq_len = data.size(0); + int batch_size = data.size(1); + int alphabet_size = data.size(2); + + // data_lengths + std::vector<int> data_lengths(batch_size, max_seq_len); + if (param.use_data_lengths) { + int kInputLength = 2; + IndexTensorToVector(inputs[kInputLength].get<xpu, 1, real_t>(s), &data_lengths); + } + + // label_lengths + std::vector<int> packed_labels; + std::vector<int> label_lengths(batch_size); + + if (param.use_label_lengths) { + int kLabelLength = 2 + param.use_data_lengths; + PackLabelByLength(labels, inputs[kLabelLength].get<xpu, 1, DType>(s), + &packed_labels, &label_lengths); + } else { + LabelTensorToPackedVector(labels, param.blank_label == 0 ? 0 : -1, + &packed_labels, &label_lengths); + } + + size_t size_bytes; + get_workspace_size<real_t>(&label_lengths, &data_lengths, alphabet_size, + batch_size, data.kDevCPU ? false : true, &size_bytes); + + // round-up so there are enough elems in memory + int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t); + Tensor<xpu, 1, real_t> workspace = + ctx.requested[0].get_space_typed<xpu, 1, real_t>(Shape1(num_tmp_elems), s); + + compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels.data(), + label_lengths.data(), data_lengths.data(), + workspace.dptr_, req[ctc_loss::kGrad] != mxnet::kNullOp, + param.blank_label == 0 ? 0 : (alphabet_size - 1)); + + if (param.use_data_lengths) { + // baidu warp CTC implementation sometimes includes undefined gradients + // for data outside of length mask. Setting to 0 to make it consistent + // with CPU implementation. + int kInputLength = 2; + mxnet_op::SequenceMask(grad, inputs[kInputLength].get<xpu, 1, real_t>(s), + static_cast<real_t>(0)); + } + }); +} + +template<typename xpu> +void CTCLossOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + using namespace mxnet_op; + + Stream<xpu> *s = ctx.get_stream<xpu>(); + const TBlob& in_grad = outputs[0]; + const TBlob& out_grad = inputs[0]; + const TBlob& grad_computed = inputs[3]; // grad computed in the forward step + + Tensor<xpu, 3, real_t> igrad_data = in_grad.get<xpu, 3, real_t>(s); + Tensor<xpu, 1, real_t> ograd_data = out_grad.get<xpu, 1, real_t>(s); + Tensor<xpu, 3, real_t> computed_grad_data = grad_computed.get<xpu, 3, real_t>(s); + + Assign(igrad_data, req[0], + mshadow::expr::broadcast<1>(ograd_data, computed_grad_data.shape_) * computed_grad_data); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NN_CTC_LOSS_INL_H_ + diff --git a/src/operator/contrib/ctc_loss.cc b/src/operator/nn/ctc_loss.cc similarity index 64% rename from src/operator/contrib/ctc_loss.cc rename to src/operator/nn/ctc_loss.cc index 32e8e62..c381677 100644 --- a/src/operator/contrib/ctc_loss.cc +++ b/src/operator/nn/ctc_loss.cc @@ -18,26 +18,22 @@ */ /*! - * Copyright (c) 2015 by Contributors * \file ctc_loss.cc - * \brief - * \author Sebastian Bodenstein -*/ - + * \brief CPU Implementation of CTC Loss op + */ #include "./ctc_loss-inl.h" -#include "./ctc_include/detail/cpu_ctc.h" +#include "../../../3rdparty/ctc_include/detail/cpu_ctc.h" namespace mshadow { - template <typename DType> ctcStatus_t compute_ctc_cost(const Tensor<cpu, 3, DType> activations, DType *costs, DType *grads, int *labels, int *label_lengths, int *data_lengths, - void *workspace, int train, int blank_label) { + void *workspace, bool isTraining, int blank_label) { int minibatch = static_cast<int>(activations.size(1)); int alphabet_size = static_cast<int>(activations.size(2)); mxnet_warpctc::CpuCTC<DType> ctc(alphabet_size, minibatch, workspace, blank_label); - if (train) { + if (isTraining) { return ctc.cost_and_grad(activations.dptr_, grads, costs, labels, label_lengths, data_lengths); } else { @@ -45,32 +41,18 @@ ctcStatus_t compute_ctc_cost(const Tensor<cpu, 3, DType> activations, data_lengths); } } - } // namespace mshadow namespace mxnet { namespace op { -template <> -Operator *CreateOp<cpu>(CTCLossParam param, int dtype) { - return new CTCLossOp<cpu>(param); -} - -// DO_BIND_DISPATCH comes from operator_common.h -Operator *CTCLossProp::CreateOperatorEx(Context ctx, - std::vector<TShape> *in_shape, - std::vector<int> *in_type) const { - std::vector<TShape> out_shape, aux_shape; - std::vector<int> out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); -} - -DMLC_REGISTER_PARAMETER(CTCLossParam); -MXNET_REGISTER_OP_PROPERTY(_contrib_CTCLoss, CTCLossProp) - .describe(R"code(Connectionist Temporal Classification Loss. +DMLC_REGISTER_PARAMETER(CTCLossOpParam); +NNVM_REGISTER_OP(CTCLoss) +.add_alias("ctc_loss") +.add_alias("_contrib_CTCLoss") +.add_alias("_contrib_ctc_loss") +.describe(R"code(Connectionist Temporal Classification Loss. The shapes of the inputs and outputs: - **data**: `(sequence_length, batch_size, alphabet_size)` @@ -113,18 +95,41 @@ Sequence Data with Recurrent Neural Networks*, A. Graves *et al*. for more information on the definition and the algorithm. )code" ADD_FILELINE) - .add_argument("data", "NDArray-or-Symbol", "Input data to the ctc_loss op.") - .add_argument("label", "NDArray-or-Symbol", - "Ground-truth labels for the loss.") - .add_argument("data_lengths", "NDArray-or-Symbol", - "Lengths of data for each of the samples. Only required " - "when use_data_lengths is true.") - .add_argument("label_lengths", "NDArray-or-Symbol", - "Lengths of labels for each of the samples. Only required " - "when use_label_lengths is true.") - .add_arguments(CTCLossParam::__FIELDS__()); - -NNVM_REGISTER_OP(_contrib_CTCLoss).add_alias("_contrib_ctc_loss"); +.set_attr_parser(ParamParser<CTCLossOpParam>) +.set_num_inputs(CTCLossOpNumInputs) +.set_num_outputs(2) +.set_attr<nnvm::FListInputNames>("FListInputNames", CTCLossOpListInputNames) +.set_attr<nnvm::FListOutputNames>("FListOutputNAmes", + [](const NodeAttrs& attrs) { + return std::vector<std::string>{"out", "grad"}; + }) +.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + return 1; + }) +.set_attr<nnvm::FInferShape>("FInferShape", CTCLossOpShape) +.set_attr<nnvm::FInferType>("FInferType", CTCLossOpType) +.set_attr<FInferStorageType>("FInferStorageType", CTCLossOpStorageType) +.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) +.set_attr<FCompute>("FCompute<cpu>", CTCLossOpForward<cpu>) +.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_ctc_loss"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_argument("label", "NDArray-or-Symbol", "Ground-truth labels for the loss.") +.add_argument("data_lengths", "NDArray-or-Symbol", + "Lengths of data for each of the samples. Only required " + "when use_data_lengths is true.") +.add_argument("label_lengths", "NDArray-or-Symbol", + "Lengths of labels for each of the samples. Only required " + "when use_label_lengths is true.") +.add_arguments(CTCLossOpParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_ctc_loss) +.set_attr_parser(ParamParser<CTCLossOpParam>) +.set_num_inputs(1) +.set_num_outputs(CTCLossOpNumInputs) +.set_attr<nnvm::TIsBackward>("TIsBackward", true) +.set_attr<FCompute>("FCompute<cpu>", CTCLossOpBackward<cpu>); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/ctc_loss.cu b/src/operator/nn/ctc_loss.cu similarity index 81% rename from src/operator/contrib/ctc_loss.cu rename to src/operator/nn/ctc_loss.cu index 3f5f12c..a4491bf 100644 --- a/src/operator/contrib/ctc_loss.cu +++ b/src/operator/nn/ctc_loss.cu @@ -18,14 +18,13 @@ */ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2018 by Contributors * \file ctc_loss.cu - * \brief - * \author Sebastian Bodenstein -*/ -#include <algorithm> + * \brief GPU Implementation of ctc_loss op + */ + #include "./ctc_loss-inl.h" -#include "./ctc_include/detail/gpu_ctc.h" +#include "../../../3rdparty/ctc_include/detail/gpu_ctc.h" namespace mshadow { @@ -45,17 +44,19 @@ ctcStatus_t compute_ctc_cost(const Tensor<gpu, 3, DType> activations, return ctc.score_forward(activations.dptr_, costs, labels, label_lengths, input_lengths); } - } // namespace mshadow -//////////////////////////////////////////////////////////////////////////////// - namespace mxnet { namespace op { -template <> -Operator *CreateOp<gpu>(CTCLossParam param, int dtype) { - return new CTCLossOp<gpu>(param); -} + +NNVM_REGISTER_OP(CTCLoss) +.add_alias("ctc_loss") +.add_alias("_contrib_ctc_loss") +.add_alias("_contrib_CTCLoss") +.set_attr<FCompute>("FCompute<gpu>", CTCLossOpForward<gpu>); + +NNVM_REGISTER_OP(_backward_ctc_loss) +.set_attr<FCompute>("FCompute<gpu>", CTCLossOpBackward<gpu>); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b17562c..5332517 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4491,11 +4491,34 @@ def test_pick(): def check_ctc_loss(acts, labels, loss_truth): in_var = mx.sym.Variable('input') labels_var = mx.sym.Variable('labels') - ctc = mx.sym.contrib.ctc_loss(in_var, labels_var) + ctc = mx.sym.ctc_loss(in_var, labels_var) acts_nd = mx.nd.array(acts, ctx=default_context()) labels_nd = mx.nd.array(labels, ctx=default_context()) exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd]) + # test forward with grad calc + exe.forward(is_train=True) + outTest = exe.outputs[0] # test forward without grad calc + exe.forward(is_train=False) + outTrain = exe.outputs[0] + # make sure losses calculated with both modes are the same + assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy()) + + # test against ground truth, if available + if loss_truth is not None: + assert_almost_equal(outTest.asnumpy(), loss_truth) + # test grad + check_numeric_gradient(ctc, [acts, labels], grad_nodes=['input'], rtol=0.05, atol=1e-3) + +# check contrib operator for backward compatibility +def check_contrib_ctc_loss(acts, labels, loss_truth): + in_var = mx.sym.Variable('input') + labels_var = mx.sym.Variable('labels') + ctc = mx.sym.contrib.ctc_loss(in_var, labels_var) + acts_nd = mx.nd.array(acts, ctx=default_context()) + labels_nd = mx.nd.array(labels, ctx=default_context()) + exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd]) + # test forward with grad calc exe.forward(is_train=True) outTest = exe.outputs[0] # test forward without grad calc @@ -4503,6 +4526,7 @@ def check_ctc_loss(acts, labels, loss_truth): outTrain = exe.outputs[0] # make sure losses calculated with both modes are the same assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy()) + # test against ground truth, if available if loss_truth is not None: assert_almost_equal(outTest.asnumpy(), loss_truth) @@ -4520,6 +4544,8 @@ def test_ctc_loss(): labels = np.array([[2, 3, 0], [2, 3, 0]]) true_loss = np.array([4.04789, 4.04789], dtype=np.float32) # from Torch check_ctc_loss(acts, labels, true_loss) + check_contrib_ctc_loss(acts, labels, true_loss) + # Test 2: acts2 = np.array([ [[-5, -4, -3, -2, -1], [1.2, 3.4, 1.2, -0.1, -2.34]], @@ -4528,11 +4554,13 @@ def test_ctc_loss(): labels2 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.float32) true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch check_ctc_loss(acts2, labels2, true_loss) + check_contrib_ctc_loss(acts2, labels2, true_loss) # Test 3: check use integer type as label labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32) true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch check_ctc_loss(acts2, labels3, true_loss) + check_contrib_ctc_loss(acts2, labels3, true_loss) @with_seed() def test_ctc_loss_with_large_classes(): @@ -4550,7 +4578,7 @@ def test_ctc_loss_with_large_classes(): [1000, 2000, 3000, 4000, 0, 5000, 0, 0]], dtype=np.int32) nd_data = mx.nd.array(data) nd_label = mx.nd.array(label) - loss = mx.nd.contrib.ctc_loss(data=nd_data, label=nd_label) + loss = mx.nd.ctc_loss(data=nd_data, label=nd_label) expected_loss = np.array([688.02826, 145.34462]) assert_almost_equal(loss.asnumpy(), expected_loss) @@ -4624,6 +4652,85 @@ def test_ctc_loss_grad(): label = mx.nd.array(labels) data.attach_grad() with mx.autograd.record(): + l = mx.ndarray.CTCLoss(data, label, + use_data_lengths=True, + use_label_lengths=True, + data_lengths=mx.nd.array(seq_lens), + label_lengths=mx.nd.array(label_lens), + blank_label=blank_label) + l.backward() + assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5) + assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5) + + # check contrib operator for backward compatibility + def check_contrib_ctc_loss_grad(blank_label): # from tf + vocab_size = 5 + max_label_len = 5 + padding_mask = -1+ (blank_label=='first') + + targets_0 = [0, 1, 2, 1, 0] + loss_log_prob_0 = -3.34211 + input_prob_matrix_0 = np.asarray( + [[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], + [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436], + [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688], + [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533], + [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]], + dtype=np.float32) + gradient_log_prob_0 = np.asarray( + [[-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], + [0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436], + [0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688], + [0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533], + [-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]], + dtype=np.float32) + + targets_1 = [0, 1, 1, 0] + loss_log_prob_1 = -5.42262 + input_prob_matrix_1 = np.asarray( + [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508], + [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549], + [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456], + [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345], + [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]], + dtype=np.float32) + gradient_log_prob_1 = np.asarray( + [[-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508], + [0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549], + [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544], + [0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345], + [-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]], + dtype=np.float32) + + inputs = [ + np.vstack( + [input_prob_matrix_0[t, :], input_prob_matrix_1[t, :]]) + for t in range(5) + ] + 2 * [np.nan * np.ones((2, vocab_size+1), np.float32)] + inputs = np.log(np.asarray(inputs, dtype=np.float32)) + + grad_truth = np.array([ + np.vstack( + [gradient_log_prob_0[t, :], gradient_log_prob_1[t, :]]) + for t in range(5) + ] + 2 * [np.zeros((2, vocab_size+1), np.float32)]) + + if blank_label == 'first': + inputs = np.roll(inputs, 1, axis=2) + grad_truth = np.roll(grad_truth, 1, axis=2) + + labels = (np.asarray([x + [padding_mask]*(max_label_len-len(x)) + for x in [targets_0, targets_1]])+(blank_label == 'first')) + + seq_lens = np.array([5, 5], dtype=np.int32) + label_lens = np.array([5, 4], dtype=np.int32) + loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32) + + with default_context(): + data = mx.nd.array(inputs) + label = mx.nd.array(labels) + data.attach_grad() + with mx.autograd.record(): l = mx.contrib.ndarray.CTCLoss(data, label, use_data_lengths=True, use_label_lengths=True, @@ -4634,8 +4741,11 @@ def test_ctc_loss_grad(): assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5) assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5) + check_ctc_loss_grad('first') check_ctc_loss_grad('last') + check_contrib_ctc_loss_grad('first') + check_contrib_ctc_loss_grad('last') @with_seed()