AndrewZhaoLuo commented on a change in pull request #8816:
URL: https://github.com/apache/tvm/pull/8816#discussion_r698710720



##########
File path: include/tvm/relay/attrs/reduce.h
##########
@@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
   }
 };
 
+/*! \brief Attributes for Reduce operators which reduce by finding a single 
element. E.g. argmin */
+struct OneElementReduceAttrs : public tvm::AttrsNode<OneElementReduceAttrs> {

Review comment:
       Done

##########
File path: src/relay/op/tensor/reduce.cc
##########
@@ -269,29 +290,46 @@ inline std::vector<IndexExpr> ReduceShapeImpl(const 
std::vector<IndexExpr>& in_s
   }
 }
 
-/*!
- * \brief ArgReduceRel Output type and shape relation evaluation function.
- * \param num_inputs Number of input types in the args.
- * \param attrs The additional attributes of the operator.
- * \param reporter The reporter to report solution to.
- * \return false if This relation cannot be resolved. true if this relation 
has been resolved.
- */
-bool ArgReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
-                  const TypeReporter& reporter) {
+template <class T>
+bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                      const TypeReporter& reporter) {
   ICHECK_EQ(types.size(), 2);
   const auto* data = types[0].as<TensorTypeNode>();
   if (data == nullptr) return false;
   ICHECK(static_cast<int>(data->shape.size()) != 0);
   std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
 
-  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+  const T* param = attrs.as<T>();
   ICHECK(param != nullptr);
 
   // assign output type and shape
   auto oshape = ReduceShapeImpl(in_shape, param, reporter);
   reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
   return true;
 }
+/*!
+ * \brief ArgReduceRel Output type and shape relation evaluation function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved. true if this relation 
has been resolved.
+ */
+bool ArgReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  return GenericReduceRel<ReduceAttrs>(types, num_inputs, attrs, reporter);
+}
+
+/*!
+ * \brief SingleElementArgReduceRel Output type and shape relation evaluation 
function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved. true if this relation 
has been resolved.
+ */
+bool SingleElementArgReduceRel(const Array<Type>& types, int num_inputs, const 
Attrs& attrs,
+                               const TypeReporter& reporter) {
+  return GenericReduceRel<OneElementReduceAttrs>(types, num_inputs, attrs, 
reporter);
+}

Review comment:
       Done

##########
File path: include/tvm/topi/reduction.h
##########
@@ -442,35 +481,49 @@ inline Tensor max(const Tensor& data, const 
Array<Integer>& axis, bool keepdims
  * left in the result as dimensions with size one. This enables the result
  * to broadcast correctly against the input array.
  * \param atleast1d Whether the output need to be atleast1d.
+ * \param select_last_index Whether to select the last index if the minimum 
element
+ * appears multiple times, else select the first index.
  *
  * \return A Tensor whose op member is the argmin operation
  */
 inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool 
keepdims = false,
-                     bool atleast1d = false) {
-  auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
-    Array<PrimExpr> result;
-    result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0]));  // 
idx
-    result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1]));  // 
val
-    return result;
-  };
-  auto fidentity = [](std::vector<DataType> types) {
-    Array<PrimExpr> result;
-    result.push_back(tvm::tir::make_const(types[0], -1));  // idx
-    result.push_back(tvm::max_value(types[1]));            // val
-    return result;
-  };
-  auto func = MakeCommReducer(fcombine, fidentity, "argmin");
-  return CommReduceIdx(data, axis, func, keepdims, atleast1d);
+                     bool atleast1d = false, bool select_last_index = false) {
+  auto reducer = MakeArgminReducer(select_last_index);
+  return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
 }
 
-inline FCommReduce MakeArgmaxReducer() {
-  auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
+inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
+  // Create a Commutative Reducer with a comparison operation, and method to 
get the initial value.
+  auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
     Array<PrimExpr> result;
-    result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0]));  // 
idx
-    result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1]));  // 
val
+
+    // Casting to avoid operator ambiguity
+    PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
+    PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
+    PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
+    PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
+
+    // These variables compare the actual values of the array
+    auto is_bigger = lhs_val > rhs_val;
+    auto is_same = lhs_val == rhs_val;
+
+    // This checks if the indices are correct for the reduction. E.g. for 
select_last_index
+    // it gives precedence for later indices of the same element and 
precedence for sooner
+    // indices if not select_last_index;
+    PrimExpr proper_index;
+    if (select_last_index) {
+      proper_index = lhs_idx > rhs_idx;
+    } else {
+      proper_index = lhs_idx < rhs_idx;
+    }
+
+    PrimExpr update_index = is_bigger || (is_same && proper_index);
+    result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0]));  // idx
+    result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1]));     // val
+    LOG(WARNING) << result;

Review comment:
       Good catch, done

##########
File path: src/relay/op/tensor/reduce.cc
##########
@@ -207,9 +208,29 @@ Array<te::Tensor> ReduceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
       return {topi::identity(inputs[0])};
     }
   }
+
   return {f(inputs[0], axes, param->keepdims, false)};
 }
 
+template <typename F>
+Array<te::Tensor> OneElementReduceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,

Review comment:
       Hmm, I've changed it to ArgReduceCompute to be consistent with the other 
naming suggested by Masa.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to