This is an automated email from the ASF dual-hosted git repository. lunderberg pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 9c7aaace43 [TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973) 9c7aaace43 is described below commit 9c7aaace4355c67403be563de3059d34fb8e29f5 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Mon Jul 18 11:03:54 2022 -0500 [TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973) * [TIR] Moved PrimExpr operator overload from op.h to expr.h If a compilation unit includes `<tvm/ir/expr.h>`, but does not include `<tvm/tir/op.h>`, the operator overloads for `ObjectRef` are declared, but the operator overloads for `PrimExpr` are not. In this case, any use of `expr_a == expr_b` would use `ObjectRef`'s implementation and compare reference equality of the two expressions, rather than returning a `PrimExpr` that represents the comparison. By having the operator overloads in the `<tvm/ir/expr.h>` header file, directly adjacent to the `PrimExpr` declaration, the correct overload must be available whenever the `PrimExpr` can be used. Even though this would only impact `operator==`, `operator!=`, and `operator<`, the three operators defined for `ObjectRef`, this PR moves all operator overloads to `expr.h` for consistency. The named version of the operators (e.g. `tvm::add`) do not have overloaded variants, and so they are intentionally kept in `<tvm/tir/op.h>`. * Explicitly convert TVMRetValue to bool in target.cc Needed to avoid ambiguity between `TVMRetValue -> bool` conversion and `TVMRetValue -> int -> PrimExpr` conversion. * Used vector/unordered_set to track BufferInfoExtractor::call_order_ Use of `std::set<Call>` had ambiguity between `operator<` by `PrimExpr` or by `ObjectRef`. The comment for `call_order_` implied that the previous usage of `std::set<Call>` was intended to have a de-duplicated list in the order of occurrence. However, the `std::set` was ordered by `ObjectRef::operator<`, not by insertion order. Switching to using a `vector` for ordering and `unordered_set` for de-duplication resolves this issue, and also removes the use of `operator<`. * Remove C-style cast to fix lint error --- include/tvm/ir/expr.h | 214 +++++++++++++++++++++++++++ include/tvm/tir/op.h | 195 ------------------------ src/target/target.cc | 9 +- src/tir/usmp/analysis/extract_buffer_info.cc | 11 +- 4 files changed, 228 insertions(+), 201 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b2cfc295b6..5e358ed50e 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -133,6 +133,220 @@ class PrimExpr : public BaseExpr { TVM_DLL static PrimExpr FromObject_(ObjectRef ref); }; +/*! + * \brief add operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b); + +/*! + * \brief subtraction operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b); + +/*! + * \brief negation. + * + * \param a input. + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator-(PrimExpr a); + +/*! + * \brief multiplication operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b); + +/*! + * \brief division operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b); + +/*! + * \brief left shift operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b); + +/*! + * \brief right shift operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b); + +/*! + * \brief greater + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b); + +/*! + * \brief greater_equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b); + +/*! + * \brief less + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b); + +/*! + * \brief less_equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b); + +/*! + * \brief equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b); + +/*! + * \brief not_equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b); + +/*! + * \brief and + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note This operator does eager constant folding. + */ +TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b); + +/*! + * \brief or + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note This operator does eager constant folding. + */ +TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b); + +/*! + * \brief not + * + * \param a left operand + * \return The result expression. + * \note This operator does eager constant folding. + */ +TVM_DLL PrimExpr operator!(PrimExpr a); + +/*! + * \brief take bitwise and of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b); + +/*! + * \brief take bitwise or of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b); + +/*! + * \brief take bitwise xor of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b); + +/*! + * \brief take bitwise negation of two values + * + * \param a the input expression. + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr operator~(PrimExpr a); + /*! * \brief Base node of all non-primitive expressions. * diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 34935aec61..7236c6a611 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -42,7 +42,6 @@ namespace tvm { // Most common operators can be overloaded by argument type(PrimExpr). // So we put them under the root namespace. -// It is also necessary to overload operators for PrimExpr. // // We put more developer oriented APIs -- make_const and is_const under tir // as they are more specific to the tir namespace. @@ -143,16 +142,6 @@ TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span * index types(int32, int64) when possible. */ TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief add operator - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b); /*! * \brief subtraction operator * @@ -164,16 +153,6 @@ TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief subtraction operator - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b); /*! * \brief negation. * @@ -184,15 +163,6 @@ TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span()); -/*! - * \brief negation. - * - * \param a input. - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator-(PrimExpr a); /*! * \brief multiplication operator * @@ -204,26 +174,6 @@ TVM_DLL PrimExpr operator-(PrimExpr a); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief multiplication operator - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b); -/*! - * \brief division operator - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b); /*! * \brief left shift operator * @@ -235,16 +185,6 @@ TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief left shift operator - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b); /*! * \brief right shift operator * @@ -256,16 +196,6 @@ TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief right shift operator - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b); /*! * \brief greater * @@ -277,16 +207,6 @@ TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief greater - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b); /*! * \brief greater_equal * @@ -298,16 +218,6 @@ TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief greater_equal - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b); /*! * \brief less * @@ -319,16 +229,6 @@ TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief less - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b); /*! * \brief less_equal * @@ -340,16 +240,6 @@ TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief less_equal - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b); /*! * \brief equal * @@ -361,16 +251,6 @@ TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief equal - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b); /*! * \brief not_equal * @@ -382,16 +262,6 @@ TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief not_equal - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b); /*! * \brief and * @@ -402,15 +272,6 @@ TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b); * \note This operator does eager constant folding. */ TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief and - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note This operator does eager constant folding. - */ -TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b); /*! * \brief or * @@ -421,15 +282,6 @@ TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b); * \note This operator does eager constant folding. */ TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief or - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note This operator does eager constant folding. - */ -TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b); /*! * \brief not * @@ -439,14 +291,6 @@ TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b); * \note This operator does eager constant folding. */ TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span()); -/*! - * \brief not - * - * \param a left operand - * \return The result expression. - * \note This operator does eager constant folding. - */ -TVM_DLL PrimExpr operator!(PrimExpr a); /*! * \brief compute division in C semantics. * @@ -601,16 +445,6 @@ TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span()); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief take bitwise and of two values - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b); /*! * \brief take bitwise or of two values * @@ -622,16 +456,6 @@ TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief take bitwise or of two values - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b); /*! * \brief take bitwise xor of two values * @@ -643,16 +467,6 @@ TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span()); -/*! - * \brief take bitwise xor of two values - * - * \param a left operand - * \param b right operand - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b); /*! * \brief take bitwise negation of two values * @@ -663,15 +477,6 @@ TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span()); -/*! - * \brief take bitwise negation of two values - * - * \param a the input expression. - * \return The result expression. - * \note this function does eager constant folding for - * index types(int32, int64) when possible. - */ -TVM_DLL PrimExpr operator~(PrimExpr a); /*! * \brief Conditional expression. * diff --git a/src/target/target.cc b/src/target/target.cc index 07b347f098..01f9bfaeec 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -847,10 +847,11 @@ std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id, TVMRetValue ret; api->GetAttr(device, runtime::kExist, &ret); - if (!ret) { - ICHECK(ret) << "Requested reading the parameters for " << target->kind->name - << " from device_id " << device_id << ", but device_id " << device_id - << " doesn't exist. Using default target parameters."; + bool device_exists = ret; + if (!device_exists) { + ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name + << " from device_id " << device_id << ", but device_id " << device_id + << " doesn't exist. Using default target parameters."; return output; } diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index ba8f6aa911..74d428f6dd 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -92,7 +92,11 @@ class BufferInfoExtractor : public StmtExprVisitor { /*! * \brief Records the order of calls in the main for stability. */ - std::set<Call> call_order_; + std::vector<Call> call_order_; + /*! + * \brief Lookup to avoid adding duplicates to `call_order_`. + */ + std::unordered_set<Call, ObjectPtrHash, ObjectPtrEqual> call_order_contents_; /*! * \brief Records first access in-terms of Stmts to each buffer per call * @@ -469,7 +473,10 @@ void BufferInfoExtractor::VisitPrimFunc(const PrimFunc& func, const Call& call) scope_stack_.top().allocate_nodes, scope_stack_.top().allocate_const_nodes, scope_stack_.top().initial_stmt_of_the_nested_loops}; - call_order_.insert(call); + if (call_order_contents_.count(call) == 0) { + call_order_contents_.insert(call); + call_order_.push_back(call); + } scope_stack_.push(si); this->VisitStmt(func->body); scope_stack_.pop();