This is an automated email from the ASF dual-hosted git repository. ekalda pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 291c04770a [TIR] Fix Bug in VectorizeLoop (#17039) 291c04770a is described below commit 291c04770a079254d812007c191ae6923857312c Author: Charlie Ruan <53290280+charliefr...@users.noreply.github.com> AuthorDate: Thu May 30 01:02:52 2024 -0700 [TIR] Fix Bug in VectorizeLoop (#17039) * [TIR] Fix Bug in VectorizeLoop This PR fixes a bug in vectorize loop introduced related to recent change. The visit to condition can write need scalarize to true then the followup visit to then case can trigger an ICHECK. The visit to let value can also write need scalarize flag in which case we need to immediately scalarize. * Add unit test --------- Co-authored-by: tqchen <tianqi.tc...@gmail.com> --- src/tir/transforms/vectorize_loop.cc | 14 +++++++++-- .../tir-transform/test_tir_transform_vectorize.py | 27 ++++++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index aa62d58505..63569f342a 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -676,12 +676,16 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp Stmt VisitStmt_(const IfThenElseNode* op) final { ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); + // need scalarize can be marked as true during visit of condition + bool cond_need_scalarize = false; + std::swap(cond_need_scalarize, need_scalarize_); + // temp clear need_scalarize flag, so VisitStmt + // won't trigger an ICHECK eror Stmt then_case = this->VisitStmt(op->then_case); Optional<Stmt> else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } - // Check if we can rewrite the condition with predicated buffers if (EnableBufferLevelPredication(target_) && condition.dtype().is_scalable_or_fixed_length_vector() && !else_case.defined()) { @@ -693,7 +697,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp } } - if (condition.dtype().is_scalable_or_fixed_length_vector()) { + if (cond_need_scalarize || condition.dtype().is_scalable_or_fixed_length_vector()) { return Scalarize(GetRef<Stmt>(op)); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && @@ -710,6 +714,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp // LetStmt Stmt VisitStmt_(const LetStmtNode* op) final { PrimExpr value = this->VisitExpr(op->value); + // if visit of value triggers need scalarize + // we need to scalarize the let + if (need_scalarize_) { + need_scalarize_ = false; + Scalarize(GetRef<Stmt>(op)); + } ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index e02c227b05..7523cab549 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm import te from tvm.script import ir as I from tvm.script import tir as T -import pytest - simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu") sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve") @@ -312,6 +312,29 @@ def test_vectorize_if_then_else_vector(extent, target): tvm.ir.assert_structural_equal(mod, After) +def test_vectorize_let_if_then_else(): + @I.ir_module + class Before: + @T.prim_func + def main(): + for i in T.vectorized(4): + if i < 2: + result: T.int32 = T.if_then_else(i < 1, 1, 2) + + @I.ir_module + class After: + @T.prim_func + def main(): + for i_s in range(4): + if i_s < 2: + result: T.int32 = T.if_then_else(i_s < 1, 1, 2) + T.evaluate(0) + + with tvm.target.Target(simple_target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + def test_vectorize_while_fail(): """A while loop inside a vectorized loop should fail."""