BebDong opened a new issue #20067: URL: https://github.com/apache/incubator-mxnet/issues/20067
## Description (A clear and concise description of what the feature is.) - Cross-entropy loss with online hard example mining that hard to implement for multiple-gpu training by high-level Python API - Also, to be consistent with GluonCV at https://github.com/dmlc/gluon-cv/blob/master/gluoncv/loss.py#L456 - The following code does not apply to multi-gpu training. ```python from gluoncv import loss as gloss class OHEMCrossEntropyLoss(gloss.SoftmaxCrossEntropyLoss): """ OHEM cross-entropy loss. Only support a single GPU. Adapted from: https://github.com/PaddlePaddle/PaddleSeg/blob/release/v2.0/ paddleseg/models/losses/ohem_cross_entropy_loss.py """ def __init__(self, thresh=0.7, min_kept=10000, num_classes=21, height=None, width=None, crop_size=480, sparse_label=True, batch_axis=0, ignore_label=-1, size_average=True, **kwargs): super(OHEMCrossEntropyLoss, self).__init__(sparse_label, batch_axis, ignore_label, size_average, **kwargs) self._thresh = thresh self._min_kept = min_kept self._nclass = num_classes self._height = height if height is not None else crop_size self._width = width if width is not None else crop_size def hybrid_forward(self, F, logit, label): label = F.reshape(label, shape=(-1,)) valid_mask = (label != self._ignore_label) num_valid = F.sum(valid_mask) label = label * valid_mask prob = F.softmax(logit, axis=1) prob = F.reshape(F.transpose(prob, axes=(1, 0, 2, 3)), shape=(self._nclass, -1)) if self._min_kept < num_valid and num_valid > 0: # let the value which ignored greater than 1 prob = prob + (1 - valid_mask) prob = F.pick(prob, label, axis=0, keepdims=False) threshold = self._thresh if self._min_kept > 0: index = F.argsort(prob) threshold_index = index[min(len(index), self._min_kept) - 1] threshold_index = int(threshold_index.asnumpy()[0]) if prob[threshold_index] > self._thresh: threshold = prob[threshold_index] kept_mask = (prob < threshold) label = label * kept_mask valid_mask = valid_mask * kept_mask # make the invalid region as ignore label = label + (1 - valid_mask) * self._ignore_label label = F.reshape(label, shape=(-1, self._height, self._width)) return super(OHEMCrossEntropyLoss, self).hybrid_forward(F, logit, label) ``` ## References - A. Shrivastava, A. Gupta, and R. Girshick. Training region-based object detectors with online hard example mining. In CVPR, 2016. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
