wkcn commented on a change in pull request #16884: [Backport][v1.6.x] Fix the 
wrong result of sum, mean, argmin, argmax when inputs contain inf or nan
URL: https://github.com/apache/incubator-mxnet/pull/16884#discussion_r349416162
 
 

 ##########
 File path: src/operator/tensor/elemwise_unary_op.h
 ##########
 @@ -660,6 +660,134 @@ void AroundOpForward(const nnvm::NodeAttrs& attrs,
   }
 }
 
+struct NumpyNanToNumParam : public dmlc::Parameter<NumpyNanToNumParam> {
+  bool copy;
+  double nan;
+  dmlc::optional<double> posinf, neginf;
+  DMLC_DECLARE_PARAMETER(NumpyNanToNumParam) {
+    DMLC_DECLARE_FIELD(copy)
+    .set_default(true)
+    .describe("Whether to create a copy of `x` (True) or to replace values"
+              "in-place (False). The in-place operation only occurs if"
+              "casting to an array does not require a copy."
+              "Default is True.");
+    DMLC_DECLARE_FIELD(nan)
+    .set_default(0.0)
+    .describe("Value to be used to fill NaN values. If no value is passed"
+              "then NaN values will be replaced with 0.0.");
+    DMLC_DECLARE_FIELD(posinf)
+    .set_default(dmlc::optional<double>())
+    .describe("Value to be used to fill positive infinity values."
+              "If no value is passed then positive infinity values will be"
+              "replaced with a very large number.");
+    DMLC_DECLARE_FIELD(neginf)
+    .set_default(dmlc::optional<double>())
+    .describe("Value to be used to fill negative infinity values."
+              "If no value is passed then negative infinity values"
+              "will be replaced with a very small (or negative) number.");
+  }
+};
+
+template<int req>
+struct nan_to_num_forward {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  DType* out_data,
+                                  const DType* in_data,
+                                  const DType nan,
+                                  const DType posinf,
+                                  const DType neginf) {
+    DType val = in_data[i];
+    if (mshadow_op::IsNan<DType>(val))  val = nan;
+    if (val > 0 && mshadow_op::IsInf(val))  val = posinf;
 
 Review comment:
   Do I need to delete the modification in 
`src/operator/tensor/elemwise_unary_op.h`?
   The class `NumpyNanToNumParam` was written in other PR, but I only modify 
several lines (702~704, 761~763) in it.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to