piiswrong closed pull request #8761: Refactor image operators
URL: https://github.com/apache/incubator-mxnet/pull/8761
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/mshadow b/mshadow
index 2d7780c3f2..1e1f633a82 160000
--- a/mshadow
+++ b/mshadow
@@ -1 +1 @@
-Subproject commit 2d7780c3f2eefe4453fa419862d1b2089bedb8d5
+Subproject commit 1e1f633a82c1fec5718fd291e2da6149708635f6
diff --git a/python/mxnet/gluon/data/dataset.py 
b/python/mxnet/gluon/data/dataset.py
index 740a2a47c7..9b4d197906 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -18,7 +18,7 @@
 # coding: utf-8
 # pylint: disable=
 """Dataset container."""
-__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset', 'LabeledDataset',
+__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset',
            'RecordFileDataset']
 
 import os
diff --git a/python/mxnet/gluon/data/vision/datasets.py 
b/python/mxnet/gluon/data/vision/datasets.py
index 54da152b9f..24f66d6b4a 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -28,9 +28,9 @@
 import warnings
 import numpy as np
 
-from . import dataset
-from ..utils import download, check_sha1
-from ... import nd, image, recordio
+from .. import dataset
+from ...utils import download, check_sha1
+from .... import nd, image, recordio
 
 apache_repo_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'
 
diff --git a/python/mxnet/gluon/data/vision/transforms.py 
b/python/mxnet/gluon/data/vision/transforms.py
index fa7c0f2cba..e1deef631d 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -58,7 +58,7 @@ def __init__(self):
         super(ToTensor, self).__init__()
 
     def hybrid_forward(self, F, x):
-        return F.cast(x, 'float32').transpose((2, 0, 1))
+        return F.image.to_tensor(x)
 
 
 class Normalize(HybridBlock):
diff --git a/src/operator/image/image_aug_op.h 
b/src/operator/image/image_aug_op.h
deleted file mode 100644
index 40315ec85c..0000000000
--- a/src/operator/image/image_aug_op.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
-#define MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
-
-#include <mxnet/operator_util.h>
-#include <vector>
-#include <utility>
-#include <algorithm>
-#include "../mshadow_op.h"
-#include "../elemwise_op_common.h"
-#include "../mxnet_op.h"
-
-namespace mxnet {
-namespace op {
-
-struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
-  nnvm::Tuple<float> mean, std;
-  DMLC_DECLARE_PARAMETER(NormalizeParam) {
-    DMLC_DECLARE_FIELD(mean).set_default(nnvm::Tuple<float>({0.f}))
-      .describe("");
-    DMLC_DECLARE_FIELD(std).set_default(nnvm::Tuple<float>({1.f}))
-      .describe("");
-  }
-};
-
-
-void NormalizeCompute(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const std::vector<NDArray>& inputs,
-                      const std::vector<OpReqType>& req,
-                      const std::vector<NDArray>& outputs) {
-  using namespace mxnet_op;
-  const auto& params = dmlc::get<NormalizeParam>(attrs.parsed);
-  CHECK_NE(req[0], kAddTo);
-  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    auto num_channel = inputs[0].shape_[0];
-    auto size = inputs[0].Size(1, inputs[0].ndim());
-    nnvm::Tuple<DType> mean(params.mean.begin(), params.mean.end());
-    nnvm::Tuple<DType> std(params.std.begin(), params.std.end());
-    DType* src = inputs[0].dptr<DType>();
-    DType* dst = outputs[0].dptr<DType>();
-    for (int i = 0; i < num_channel; ++i) {
-      for (int j = 0; j < size; ++j, ++out, ++src) {
-        *out = (*src - mean[i]) / std[i];
-      }
-    }
-  });
-}
-
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
diff --git a/src/operator/image/image_common.h 
b/src/operator/image/image_common.h
deleted file mode 100644
index 3b6b8e3298..0000000000
--- a/src/operator/image/image_common.h
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*/
-
-/*!
-* \file image_common.h
-* \brief
-* \author
-*/
-#ifndef MXNET_OPERATOR_IMAGE_IMAGE_COMMON_H_
-#define MXNET_OPERATOR_IMAGE_IMAGE_COMMON_H_
-
-#include <mxnet/base.h>
-
-namespace mxnet {
-namespace op {
-
-/**
-* @brief convert TBlob to cv::Mat
-* @param input @see TBlob
-* @param hight
-* @param weight
-* @param channel
-* @return
-*/
-static cv::Mat mat_convert(TBlob input, int hight, int weight, int channel) {
-  cv::Mat m;
-  switch (input.type_flag_) {
-    case mshadow::kFloat32: {
-      typedef float DType;
-      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_32F, channel), 
input.dptr<DType>());
-    }
-    break;
-    case mshadow::kFloat64: {
-      typedef double DType;
-      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_64F, channel), 
input.dptr<DType>());
-    }
-    break;
-    case mshadow::kFloat16: {
-      typedef mshadow::half::half_t DType;
-      LOG(FATAL) << "not support type enum " << input.type_flag_;
-    }
-    break;
-    case mshadow::kUint8: {
-      typedef uint8_t DType;
-      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_8U, channel), 
input.dptr<DType>());
-    }
-    break;
-    case mshadow::kInt8: {
-      typedef int8_t DType;
-      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_8S, channel), 
input.dptr<DType>());
-    }
-    break;
-    case mshadow::kInt32: {
-      typedef int32_t DType;
-      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_32S, channel), 
input.dptr<DType>());
-    }
-    break;
-    case mshadow::kInt64: {
-      typedef int64_t DType;
-      LOG(FATAL) << "not support type enum " << input.type_flag_;
-    }
-    break;
-    default:
-      LOG(FATAL) << "Unknown type enum " << input.type_flag_;
-  }
-  return m;
-}
-}  // namespace op
-}  // namespace mxnet
-
-
-#endif  // MXNET_OPERATOR_IMAGE_IMAGE_COMMON_H_
-
diff --git a/src/operator/image/image_random-inl.h 
b/src/operator/image/image_random-inl.h
index 6f9cdc0e72..f823c8ce06 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -30,16 +30,11 @@
 #include <opencv2/opencv.hpp>
 #include <opencv2/core/mat.hpp>
 #include "../mxnet_op.h"
-#include "image_common.h"
-#include "../../operator/operator_common.h"
+#include "../operator_common.h"
 
 namespace mxnet {
 namespace op {
 
-
-enum ImageRandomResource { kRandom };
-
-template<typename xpu>
 static void RandomFlip(const nnvm::NodeAttrs &attrs,
                        const OpContext &ctx,
                        const std::vector<TBlob> &inputs,
@@ -73,37 +68,25 @@ inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-template<typename xpu>
 static void ToTensor(const nnvm::NodeAttrs &attrs,
                      const OpContext &ctx,
                      const std::vector<TBlob> &inputs,
                      const std::vector<OpReqType> &req,
                      const std::vector<TBlob> &outputs) {
-  auto input = inputs[0];
-  auto output = outputs[0];
-
-  int height = input.shape_[0];
-  int weight = input.shape_[1];
-  int channel = input.shape_[2];
-
-  typedef float   DstDType;
-  typedef uint8_t SrcDType;
-
   CHECK_EQ(req[0], kWriteTo)
     << "`to_tensor` does not support inplace";
 
-  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-    auto input_3d =  input.get<xpu, 3, SrcDType>(s);
-    auto output_3d = output.get<xpu, 3, DstDType>(s);
-    for (int h = 0; h < height; ++h) {
-      for (int w = 0; w < weight; ++w) {
-        for (int c = 0; c < channel; ++c) {
-          Assign(output_3d[c][h][w], Req, DstDType(input_3d[h][w][c] / 255.0));
-        }
-      }
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  int channel = inputs[0].shape_[2];
+
+  float* output = outputs[0].dptr<float>();
+  uint8_t* input = inputs[0].dptr<uint8_t>();
+
+  for (int l = 0; l < length; ++l) {
+    for (int c = 0; c < channel; ++c) {
+      output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
     }
-  });
+  }
 }
 
 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
@@ -117,93 +100,47 @@ struct NormalizeParam : public 
dmlc::Parameter<NormalizeParam> {
   }
 };
 
-struct normalize {
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
-                                  const OpReqType req,
-                                  const int nchannel, const int size,
-                                  const float *mean, const float *std) {
-    int c = 0;
-    switch (nchannel) {
-      case 1:
-        break;
-      case 3:
-        if (i < size) {
-          c = 0;
-        } else if (i < (size << 1)) {
-          c = 1;
-        } else {
-          c = 2;
-        }
-        break;
-      default:
-        LOG(FATAL) << "not support channel" << nchannel;
-    }
-    float m = (mean ? mean[c] : 0);
-    KERNEL_ASSIGN(out[i], req, static_cast<DType>((in[i] - m) / std[c]));
-  }
-};
 
-static void NormalizeCheckParam(const nnvm::Tuple<float> &mean,
-                                const nnvm::Tuple<float> &std,
-                                const int nchannel) {
-  CHECK(mean.ndim() == 1 || mean.ndim() == 3)
-    << "Mean must be in dimension 1 or 3.";
-  CHECK(std.ndim() == 1 || std.ndim() == 3)
-    << "Standard deviations must be in dimension 1 or 3.";
-  CHECK(nchannel == 1 || nchannel == 3) << "Image channel must be 1 or 3.";
-  CHECK_EQ(mean.ndim(), nchannel)
-    << "Mean dimension does not agree with image channel.";
-  CHECK_EQ(std.ndim(), nchannel)
-    << "Standard deviations dimension does not agree with image channel.";
-  for (uint32_t c = 0; c < std.ndim(); ++c) {
-    CHECK(std[c] > 0) << "Invalid standard deviation " << std[c];
-  }
+inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
+                          std::vector<TShape> *in_attrs,
+                          std::vector<TShape> *out_attrs) {
+  const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
+  const auto& dshape = (*in_attrs)[0];
+  if (!dshape.ndim()) return false;
+  CHECK_EQ(dshape.ndim(), 3)
+      << "Input must have 3 dimensions";
+
+  auto nchannels = dshape[0];
+  CHECK(param.mean.ndim() == 1 || param.mean.ndim() == nchannels)
+      << "mean must have either 1 or " << nchannels << " elements";
+  CHECK(param.std.ndim() == 1 || param.std.ndim() == nchannels)
+      << "std must have either 1 or " << nchannels << " elements";
+
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
 }
 
-template<typename xpu>
+
 static void Normalize(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &outputs) {
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
-  auto mean = param.mean;
-  auto std = param.std;
-
-  int nchannel = inputs[0].shape_[0];
-  NormalizeCheckParam(mean, std, nchannel);
-
-  int size = inputs[0].Size() / nchannel;
-  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      mxnet_op::Kernel<normalize, xpu>::Launch(
-        s, inputs[0].Size(), outputs[0].dptr<DType>(), inputs[0].dptr<DType>(),
-        Req, nchannel, size, mean.begin(), std.begin());
-    });
-  });
-}
 
-template<typename xpu>
-static void NormalizeBackward(const nnvm::NodeAttrs &attrs,
-                              const OpContext &ctx,
-                              const std::vector<TBlob> &inputs,
-                              const std::vector<OpReqType> &req,
-                              const std::vector<TBlob> &outputs) {
-  const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
-  int nchannel = inputs[0].shape_[0];
-
-  NormalizeCheckParam(param.mean, param.std, nchannel);
-
-  int size = inputs[0].Size() / nchannel;
-  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      mxnet_op::Kernel<normalize, xpu>::Launch(
-        s, inputs[0].Size(), outputs[0].dptr<DType>(), inputs[0].dptr<DType>(),
-        Req, nchannel, size, nullptr, param.std.begin());
-      });
+  int nchannels = inputs[0].shape_[0];
+  int length = inputs[0].shape_[1] * inputs[0].shape_[2];
+
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* input = inputs[0].dptr<DType>();
+    DType* output = outputs[0].dptr<DType>();
+
+    for (int i = 0; i < nchannels; ++i) {
+      DType mean = param.mean[param.mean.ndim() > 1 ? i : 0];
+      DType std = param.std[param.std.ndim() > 1 ? i : 0];
+      for (int j = 0; j < length; ++j) {
+        output[i*length + j] = (input[i*length + j] - mean) / std;
+      }
+    }
   });
 }
 
@@ -211,99 +148,83 @@ struct RandomBrightnessParam : public 
dmlc::Parameter<RandomBrightnessParam> {
   float max_brightness;
   DMLC_DECLARE_PARAMETER(RandomBrightnessParam) {
     DMLC_DECLARE_FIELD(max_brightness)
-    .set_default(0.0)
+    .set_lower_bound(0.0)
     .describe("Max Brightness.");
   }
 };
 
-template<typename xpu>
 static void RandomBrightness(const nnvm::NodeAttrs &attrs,
                              const OpContext &ctx,
                              const std::vector<TBlob> &inputs,
                              const std::vector<OpReqType> &req,
                              const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  auto input = inputs[0];
-  auto output = outputs[0];
-  int channel = input.shape_[0];
-  int height = input.shape_[1];
-  int weight = input.shape_[2];
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  Random<xpu> *prnd = ctx.requested[kRandom].get_random<xpu, real_t>(s);
-
   const RandomBrightnessParam &param = 
nnvm::get<RandomBrightnessParam>(attrs.parsed);
+
+  int length = inputs[0].Size();
+
+  uint8_t* output = outputs[0].dptr<uint8_t>();
+  uint8_t* input = inputs[0].dptr<uint8_t>();
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
   float alpha_b = 1.0 + std::uniform_real_distribution<float>(
-    -param.max_brightness, param.max_brightness)(prnd->GetRndEngine());
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      mxnet_op::Kernel<mxnet_op::op_with_req<mshadow::op::mul, Req>, 
xpu>::Launch(
-        s, inputs[0].Size(), outputs[0].dptr<DType>(), 
inputs[0].dptr<DType>(), DType(alpha_b));
-    });
-  });
+      -param.max_brightness, param.max_brightness)(prnd->GetRndEngine());
+
+  for (int l = 0; l < length; ++l) {
+    float val = static_cast<float>(input[l]) * alpha_b;
+    val = std::min(std::max(val, 0.f), 255.f);
+    output[l] = static_cast<uint8_t>(val);
+  }
 }
 
+
 struct RandomContrastParam : public dmlc::Parameter<RandomContrastParam> {
   float max_contrast;
   DMLC_DECLARE_PARAMETER(RandomContrastParam) {
     DMLC_DECLARE_FIELD(max_contrast)
-    .set_default(0.0)
+    .set_lower_bound(0.0)
     .describe("Max Contrast.");
   }
 };
 
-/*! \brief mul_add operator */
-struct mul_add {
-  /*! \brief map a, b, c to result using defined operation */
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType a, DType b, DType c) {
-    return a * b + c;
-  }
-};
 
-template<typename xpu>
 static void RandomContrast(const nnvm::NodeAttrs &attrs,
                            const OpContext &ctx,
                            const std::vector<TBlob> &inputs,
                            const std::vector<OpReqType> &req,
                            const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  auto input = inputs[0];
-  auto output = outputs[0];
-  int channel = input.shape_[0];
-  int height = input.shape_[1];
-  int weight = input.shape_[2];
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  Random<xpu> *prnd = ctx.requested[kRandom].get_random<xpu, real_t>(s);
+  static const float coef[] = { 0.299f, 0.587f, 0.114f };
+  const RandomContrastParam &param = 
nnvm::get<RandomContrastParam>(attrs.parsed);
 
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  int nchannels = inputs[0].shape_[2];
 
-  const RandomContrastParam &param = 
nnvm::get<RandomContrastParam>(attrs.parsed);
+  uint8_t* output = outputs[0].dptr<uint8_t>();
+  uint8_t* input = inputs[0].dptr<uint8_t>();
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
   float alpha_c = 1.0 + std::uniform_real_distribution<float>(
     -param.max_contrast, param.max_contrast)(prnd->GetRndEngine());
 
-  const float R2YF = 0.299f;
-  const float G2YF = 0.587f;
-  const float B2YF = 0.114f;
-  static const float coeffs0[] = { R2YF, G2YF, B2YF };
-
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    auto input_3d = input.get<xpu, 3, DType>(s);
-    DType sum = (DType)0.0;
-    for (int c = 0; c < channel; ++c) {
-      for (int h = 0; h < height; ++h) {
-        for (int w = 0; w < weight; ++w) {
-          sum += input_3d[c][h][w] * coeffs0[c];
-        }
-      }
+  float sum = 0.f;
+  if (nchannels > 1) {
+    for (int l = 0; l < length; ++l) {
+      for (int c = 0; c < nchannels; ++c) sum += input[l*nchannels + c] * 
coef[c];
     }
-    float gray_mean = sum / static_cast<float>(height * weight);
-    float beta = (1 - alpha_c) * gray_mean;
-
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      mxnet_op::Kernel<mxnet_op::op_with_req<mul_add, Req>, xpu>::Launch(
-        s, inputs[0].Size(), outputs[0].dptr<DType>(),
-        inputs[0].dptr<DType>(), DType(alpha_c), DType(beta));
-    });
-  });
+  } else {
+    for (int l = 0; l < length; ++l) sum += input[l];
+  }
+  float gray_mean = sum / static_cast<float>(length);
+  float beta = (1 - alpha_c) * gray_mean;
+
+  for (int l = 0; l < length * nchannels; ++l) {
+    float val = input[l] * alpha_c + beta;
+    val = std::min(std::max(val, 0.f), 255.f);
+    output[l] = static_cast<uint8_t>(val);
+  }
 }
 
 struct RandomSaturationParam : public dmlc::Parameter<RandomSaturationParam> {
@@ -315,55 +236,46 @@ struct RandomSaturationParam : public 
dmlc::Parameter<RandomSaturationParam> {
   }
 };
 
-template<typename xpu>
 static void RandomSaturation(const nnvm::NodeAttrs &attrs,
                              const OpContext &ctx,
                              const std::vector<TBlob> &inputs,
                              const std::vector<OpReqType> &req,
                              const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  auto input = inputs[0];
-  auto output = outputs[0];
-  int channel = input.shape_[0];
-  int height = input.shape_[1];
-  int weight = input.shape_[2];
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  Random<xpu> *prnd = ctx.requested[kRandom].get_random<xpu, real_t>(s);
   const RandomSaturationParam &param = 
nnvm::get<RandomSaturationParam>(attrs.parsed);
-  float alpha_s = 1.0 + std::uniform_real_distribution<float>(
+  static const float coef[] = { 0.299f, 0.587f, 0.114f };
+
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  int nchannels = inputs[0].shape_[2];
+
+  uint8_t* output = outputs[0].dptr<uint8_t>();
+  uint8_t* input = inputs[0].dptr<uint8_t>();
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+  float alpha_s = 1.f + std::uniform_real_distribution<float>(
     -param.max_saturation, param.max_saturation)(prnd->GetRndEngine());
-  float alpha_o = 1 - alpha_s;
-  const float R2YF = 0.299f;
-  const float G2YF = 0.587f;
-  const float B2YF = 0.114f;
-  static const float coeffs0[] = { R2YF, G2YF, B2YF };
+  float alpha_o = 1.f - alpha_s;
 
+  if (nchannels == 1) {
+    for (int l = 0; l < length * nchannels; ++l) output[l] = input[l];
+    return;
+  }
 
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      auto input_3d =  input.get<xpu, 3, DType>(s);
-      auto output_3d = output.get<xpu, 3, DType>(s);
-      switch (channel) {
-        case 1:
-          Assign(output_3d, Req, input_3d)
-          break;
-        case 3:
-          for (int h = 0; h < height; ++h) {
-            for (int w = 0; w < weight; ++w) {
-              float gray =
-                input_3d[0][h][w] * R2YF + input_3d[1][h][w] * G2YF + 
input_3d[2][h][w] * B2YF;
-              Assign(output_3d[0][h][w], Req, DType(gray * alpha_s + 
input_3d[0][h][w] * alpha_o))
-            }
-          }
-          break;
-        default:
-          LOG(FATAL) << "not support channel" << channel;
-      }
-    });
-  });
+  for (int l = 0; l < length; ++l) {
+    float gray = 0.f;
+    for (int c = 0; c < nchannels; ++c) {
+      gray = input[l*nchannels + c] * coef[c];
+    }
+    gray *= alpha_o;
+    for (int c = 0; c < nchannels; ++c) {
+      float val = gray + input[l*nchannels + c] * alpha_s;
+      val = std::min(std::max(val, 0.f), 255.f);
+      output[l*nchannels + c] = static_cast<uint8_t>(val);
+    }
+  }
 }
 
-template<typename xpu>
 static void RandomHue(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
@@ -371,7 +283,6 @@ static void RandomHue(const nnvm::NodeAttrs &attrs,
                       const std::vector<TBlob> &outputs) {
 }
 
-template<typename xpu>
 static void RandomColorJitter(const nnvm::NodeAttrs &attrs,
                               const OpContext &ctx,
                               const std::vector<TBlob> &inputs,
@@ -379,7 +290,6 @@ static void RandomColorJitter(const nnvm::NodeAttrs &attrs,
                               const std::vector<TBlob> &outputs) {
 }
 
-template<typename xpu>
 static void RandomLighting(const nnvm::NodeAttrs &attrs,
                            const OpContext &ctx,
                            const std::vector<TBlob> &inputs,
diff --git a/src/operator/image/image_random.cc 
b/src/operator/image/image_random.cc
index e32a6777c2..7ff73284a0 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -40,10 +40,11 @@ NNVM_REGISTER_OP(_image_to_tensor)
 })
 .set_attr<nnvm::FInferShape>("FInferShape", ToTensorShape)
 .set_attr<nnvm::FInferType>("FInferType", ToTensorType)
-.set_attr<FCompute>("FCompute<cpu>", ToTensor<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", ToTensor)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.");
 
+
 DMLC_REGISTER_PARAMETER(NormalizeParam);
 NNVM_REGISTER_OP(_image_normalize)
 .describe(R"code()code" ADD_FILELINE)
@@ -56,25 +57,14 @@ NNVM_REGISTER_OP(_image_normalize)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
-[](const NodeAttrs& attrs){
-  return std::vector<std::pair<int, int> >{{0, 0}};
-})
-.set_attr<FCompute>("FCompute<cpu>", Normalize<cpu>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ 
"_image_backward_normalize" })
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", Normalize)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(NormalizeParam::__FIELDS__());
 
-NNVM_REGISTER_OP(_image_backward_normalize)
-.describe(R"code()code" ADD_FILELINE)
-.set_num_inputs(1)
-.set_num_outputs(1)
-.set_attr_parser(ParamParser<NormalizeParam>)
-.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<nnvm::FInplaceOption>("FInplaceOption",
-[](const NodeAttrs& attrs){
-  return std::vector<std::pair<int, int> >{{0, 0}};
-})
-.set_attr<FCompute>("FCompute<cpu>", NormalizeBackward<cpu>);
 
 DMLC_REGISTER_PARAMETER(RandomBrightnessParam);
 NNVM_REGISTER_OP(_image_random_brightness)
@@ -87,7 +77,11 @@ NNVM_REGISTER_OP(_image_random_brightness)
 })
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCompute>("FCompute<cpu>", RandomBrightness<cpu>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", RandomBrightness)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(RandomBrightnessParam::__FIELDS__());
@@ -103,7 +97,11 @@ NNVM_REGISTER_OP(_image_random_contrast)
 })
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCompute>("FCompute<cpu>", RandomContrast<cpu>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", RandomContrast)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(RandomContrastParam::__FIELDS__());
@@ -119,7 +117,11 @@ NNVM_REGISTER_OP(_image_random_saturation)
 })
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCompute>("FCompute<cpu>", RandomSaturation<cpu>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", RandomSaturation)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(RandomSaturationParam::__FIELDS__());


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to