piiswrong commented on a change in pull request #10931: [MXNET-349] [WIP] 
Histogram Operator
URL: https://github.com/apache/incubator-mxnet/pull/10931#discussion_r188037090
 
 

 ##########
 File path: src/operator/tensor/histogram.cc
 ##########
 @@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./histogram-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct ComputeBinKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const DType* in_data, int* 
bin_indices,
+                                  int bin_cnt, float width, float min, float 
max) {
+    DType data = in_data[i];
+    if (data >= min && data <= max) {
+      bin_indices[i] = mshadow_op::floor::Map((in_data[i] - min) / width);
+      bin_indices[i] = mshadow_op::minimum::Map(bin_cnt - 1, bin_indices[i]);
+    } else {
+      bin_indices[i] = -1;
+    }
+  }
+
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const DType* in_data, int* 
bin_indices,
+                                   const DType* bin_bounds, int num_bins) {
+    DType data = in_data[i];
+    int target_idx = -1;
+    if (data >= bin_bounds[0] && data <= bin_bounds[num_bins]) {
+      target_idx = 0;
+      while ((data - bin_bounds[target_idx]) >= 0) {
+        target_idx += 1;
+      }
+      target_idx = mshadow_op::minimum::Map(target_idx - 1, num_bins - 1);
+    }
+    bin_indices[i] = target_idx;
+  }
+};
+
+template<typename CType>
+void ComputeHistogram(const int* bin_indices, CType* out_data, size_t 
input_size) {
+  for (size_t i = 0; i < input_size; ++i) {
+    int target = bin_indices[i];
+    if (target >= 0) {
+      out_data[target] += 1;
+    }
+  }
+}
+
+template<typename cpu>
+void HistogramForwardImpl(mshadow::Stream<cpu>* s,
+                          const OpContext& ctx,
+                          const nnvm::NodeAttrs& attrs,
+                          const TBlob& in_data,
+                          const TBlob& bin_bounds,
+                          const TBlob& out_data,
+                          const TBlob& out_bins) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
+  const bool has_cnt = param.bin_cnt.has_value();
+  const bool has_range = param.range.has_value();
+  const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
+  CHECK(legal_param) << "width and range should both or neither be specified";
+
+  CHECK(!has_range || (has_range && (param.range.value().ndim() == 2U)));
+  CHECK(!has_range || (has_range && (param.range.value()[0] < 
param.range.value()[1])));
+
+  Tensor<cpu, 1, int> bin_indices =
+    ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);
+  const int bin_cnt = out_data.Size();
+  MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
+    if (has_cnt) {
+      float max = param.range.value()[1];
+      float min = param.range.value()[0];
+      float width = (max - min) / bin_cnt;
+      Kernel<ComputeBinKernel, cpu>::Launch(
+        s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_,
+        bin_cnt, width, min, max);
+      Kernel<FillBinBoundsKernel, cpu>::Launch(
+        s, bin_cnt+1, out_bins.dptr<DType>(), bin_cnt, min, max);
+    } else {
+      Kernel<ComputeBinKernel, cpu>::Launch(
+        s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_, 
bin_bounds.dptr<DType>(),
+        bin_cnt);
+      Kernel<op_with_req<mshadow_op::identity, kWriteTo>, cpu>::Launch(
+        s, bin_bounds.Size(), out_bins.dptr<DType>(), 
bin_bounds.dptr<DType>());
+    }
+  });
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, CType, {
+    Kernel<set_zero, cpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
+    ComputeHistogram(bin_indices.dptr_, out_data.dptr<CType>(), 
in_data.Size());
+  });
+}
+
+template<>
+void HistogramBackwardImpl<cpu>(const OpContext& ctx,
+                                const nnvm::NodeAttrs& attrs,
+                                const TBlob& out_grad,
+                                const TBlob& in_data,
+                                const TBlob& bin_bounds,
+                                const TBlob& out_data,
+                                const TBlob& in_grad) {
+  LOG(FATAL) << "Histogram Backward not implemented yet";
+}
+
+DMLC_REGISTER_PARAMETER(HistogramParam);
+
+NNVM_REGISTER_OP(_histogram)
+.describe(R"code(This operators implements the histogram function.
+
+Example::
+  x = [[0, 1], [2, 2], [3, 4]]
+  histo, bin_edges = histogram(data=x, bin_bounds=[], bin_cnt=5, range=(0,5))
+  histo = [1, 1, 2, 1, 1]
+  bin_edges = [0., 1., 2., 3., 4.]
+  histo, bin_edges = histogram(data=x, bin_bounds=[0., 2.1, 3.])
+  histo = [4, 1]
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<HistogramParam>)
+.set_num_inputs(2)
+.set_num_outputs(2)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "bins"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", HistogramOpShape)
+.set_attr<nnvm::FInferType>("FInferType", HistogramOpType)
+.set_attr<FCompute>("FCompute<cpu>", HistogramOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseInOut{"_backward_histogram"})
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int> >{};
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_argument("bins", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(HistogramParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_histogram)
 
 Review comment:
   Historgram op is not differentiable. Don't register backward op.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to