AndrewZhaoLuo commented on a change in pull request #8816:
URL: https://github.com/apache/tvm/pull/8816#discussion_r698717310
##########
File path: src/relay/op/tensor/reduce.cc
##########
@@ -269,29 +290,46 @@ inline std::vector<IndexExpr> ReduceShapeImpl(const
std::vector<IndexExpr>& in_s
}
}
-/*!
- * \brief ArgReduceRel Output type and shape relation evaluation function.
- * \param num_inputs Number of input types in the args.
- * \param attrs The additional attributes of the operator.
- * \param reporter The reporter to report solution to.
- * \return false if This relation cannot be resolved. true if this relation
has been resolved.
- */
-bool ArgReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
+template <class T>
+bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
ICHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
- const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+ const T* param = attrs.as<T>();
ICHECK(param != nullptr);
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
return true;
}
+/*!
+ * \brief ArgReduceRel Output type and shape relation evaluation function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved. true if this relation
has been resolved.
+ */
+bool ArgReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ return GenericReduceRel<ReduceAttrs>(types, num_inputs, attrs, reporter);
+}
+
+/*!
+ * \brief SingleElementArgReduceRel Output type and shape relation evaluation
function.
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved. true if this relation
has been resolved.
+ */
+bool SingleElementArgReduceRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
+ const TypeReporter& reporter) {
+ return GenericReduceRel<OneElementReduceAttrs>(types, num_inputs, attrs,
reporter);
+}
Review comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]