piiswrong closed pull request #8678: [WIP]hue URL: https://github.com/apache/incubator-mxnet/pull/8678
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 5c552b2073..3bee84321b 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -25,9 +25,12 @@ #ifndef MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_ #define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_ + #include <mxnet/base.h> #include <algorithm> #include <vector> +#include <cmath> +#include <limits> #include <algorithm> #include <utility> #include "../mxnet_op.h" @@ -337,11 +340,160 @@ static void RandomSaturation(const nnvm::NodeAttrs &attrs, } } +struct RandomHueParam : public dmlc::Parameter<RandomHueParam> { + float max_hue; + DMLC_DECLARE_PARAMETER(RandomHueParam) { + DMLC_DECLARE_FIELD(max_hue) + .set_default(0.0) + .describe("Max Hue."); + } +}; + +template <typename DType> static +void RGB2HLSConvert(const DType src_r, + const DType src_g, + const DType src_b, + DType *dst_h, + DType *dst_l, + DType *dst_s + ) { + DType b = src_b, g = src_g, r = src_r; + DType h = 0.f, s = 0.f, l; + DType vmin; + DType vmax; + DType diff; + + vmax = vmin = r; + vmax = fmax(vmax, g); + vmax = fmax(vmax, b); + vmin = fmin(vmin, g); + vmin = fmin(vmin, b); + + diff = vmax - vmin; + l = (vmax + vmin) * 0.5f; + + if (diff > std::numeric_limits<DType>::epsilon()) { + s = (l < 0.5f) * diff / (vmax + vmin); + s += (l >= 0.5f) * diff / (2.0f - vmax - vmin); + + diff = 60.f / diff; + + h = (vmax == r) * (g - b) * diff; + h += (vmax != r && vmax == g) * ((b - r) * diff + 120.f); + h += (vmax != r && vmax != g) * ((r - g) * diff + 240.f); + h += (h < 0.f) * 360.f; + } + + *dst_h = h; + *dst_l = l; + *dst_s = s; +} + + +static int c_HlsSectorData[6][3] = { + { 1, 3, 0 }, + { 1, 0, 2 }, + { 3, 0, 1 }, + { 0, 2, 1 }, + { 0, 1, 3 }, + { 2, 1, 0 } +}; + +template <typename DType> static void HLS2RGBConvert(const DType src_h, + const DType src_l, + const DType src_s, + DType *dst_r, + DType *dst_g, + DType *dst_b) { + + + float h = src_h, l = src_l, s = src_s; + float b = l, g = l, r = l; + + if (s != 0) { + float p2 = (l <= 0.5f) * l * (1 + s); + p2 += (l > 0.5f) * (l + s - l * s); + float p1 = 2 * l - p2; + + if (h < 0) { + do { h += 6; } while (h < 0); + } else if (h >= 6) { + do { h -= 6; } while (h >= 6); + } + + int sector = static_cast<int>(h); + + h -= sector; + + float tab[4]; + tab[0] = p2; + tab[1] = p1; + tab[2] = p1 + (p2 - p1) * (1 - h); + tab[3] = p1 + (p2 - p1) * h; + + b = tab[c_HlsSectorData[sector][0]]; + g = tab[c_HlsSectorData[sector][1]]; + r = tab[c_HlsSectorData[sector][2]]; + } + + *dst_b = b; + *dst_g = g; + *dst_r = r; +} + +template<typename xpu, typename DType> +static void RandomHueKernal(const TBlob &input, + const TBlob &output, + Stream<xpu> *s, + int hight, + int weight, + DType alpha) { + auto input_3d = input.get<xpu, 3, DType>(s); + auto output_3d = output.get<xpu, 3, DType>(s); + for (int h_index = 0; h_index < hight; ++h_index) { + for (int w_index = 0; w_index < weight; ++w_index) { + DType h; + DType l; + DType s; + RGB2HLSConvert(input_3d[0][h_index][w_index], + input_3d[1][h_index][w_index], + input_3d[2][h_index][w_index], + &h, &l, &s); + h += alpha; + h = std::max(DType(0), std::min(DType(180), h)); + + HLS2RGBConvert( + h, l, s, + &output_3d[0][h_index][w_index], + &output_3d[1][h_index][w_index], + &output_3d[2][h_index][w_index]); + } + } +} + +template<typename xpu> static void RandomHue(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector<TBlob> &inputs, const std::vector<OpReqType> &req, const std::vector<TBlob> &outputs) { + using namespace mshadow; + auto input = inputs[0]; + auto output = outputs[0]; + int channel = input.shape_[0]; + int hight = input.shape_[1]; + int weight = input.shape_[2]; + Stream<xpu> *s = ctx.get_stream<xpu>(); + Random<xpu> *prnd = ctx.requested[kRandom].get_random<xpu, real_t>(s); + + const RandomHueParam ¶m = nnvm::get<RandomHueParam>(attrs.parsed); + float alpha = std::uniform_real_distribution<float>( + -param.max_hue, param.max_hue)(prnd->GetRndEngine()); + auto output_float = output.get<xpu, 3, float>(s); + + MSHADOW_TYPE_SWITCH(input.type_flag_, DType, { + RandomHueKernal<xpu, DType>(input, output, s, hight, weight, alpha); + }); } static void RandomColorJitter(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 4184382ab5..29edeedeaa 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -136,6 +136,22 @@ NNVM_REGISTER_OP(_image_random_saturation) .add_argument("data", "NDArray-or-Symbol", "The input.") .add_arguments(RandomSaturationParam::__FIELDS__()); +DMLC_REGISTER_PARAMETER(RandomHueParam); +NNVM_REGISTER_OP(_image_random_hue) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser<RandomHueParam>) +.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ResourceRequest::kRandom}; +}) +.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) +.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FCompute>("FCompute<cpu>", RandomHue<cpu>) +.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" }) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(RandomHueParam::__FIELDS__()); + DMLC_REGISTER_PARAMETER(AdjustLightingParam); NNVM_REGISTER_OP(_image_adjust_lighting) .describe(R"code(Adjust the lighting level of the input. Follow the AlexNet style.)code" ADD_FILELINE) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services