This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 291c04770a [TIR] Fix Bug in VectorizeLoop  (#17039)
291c04770a is described below

commit 291c04770a079254d812007c191ae6923857312c
Author: Charlie Ruan <53290280+charliefr...@users.noreply.github.com>
AuthorDate: Thu May 30 01:02:52 2024 -0700

    [TIR] Fix Bug in VectorizeLoop  (#17039)
    
    * [TIR] Fix Bug in VectorizeLoop
    
    This PR fixes a bug in vectorize loop introduced
    related to recent change.
    
    The visit to condition can write need scalarize to true
    then the followup visit to then case can trigger an ICHECK.
    
    The visit to let value can also write need scalarize flag
    in which case we need to immediately scalarize.
    
    * Add unit test
    
    ---------
    
    Co-authored-by: tqchen <tianqi.tc...@gmail.com>
---
 src/tir/transforms/vectorize_loop.cc               | 14 +++++++++--
 .../tir-transform/test_tir_transform_vectorize.py  | 27 ++++++++++++++++++++--
 2 files changed, 37 insertions(+), 4 deletions(-)

diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index aa62d58505..63569f342a 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -676,12 +676,16 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
   Stmt VisitStmt_(const IfThenElseNode* op) final {
     ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
     PrimExpr condition = this->VisitExpr(op->condition);
+    // need scalarize can be marked as true during visit of condition
+    bool cond_need_scalarize = false;
+    std::swap(cond_need_scalarize, need_scalarize_);
+    // temp clear need_scalarize flag, so VisitStmt
+    // won't trigger an ICHECK eror
     Stmt then_case = this->VisitStmt(op->then_case);
     Optional<Stmt> else_case = NullOpt;
     if (op->else_case) {
       else_case = this->VisitStmt(op->else_case.value());
     }
-
     // Check if we can rewrite the condition with predicated buffers
     if (EnableBufferLevelPredication(target_) &&
         condition.dtype().is_scalable_or_fixed_length_vector() && 
!else_case.defined()) {
@@ -693,7 +697,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       }
     }
 
-    if (condition.dtype().is_scalable_or_fixed_length_vector()) {
+    if (cond_need_scalarize || 
condition.dtype().is_scalable_or_fixed_length_vector()) {
       return Scalarize(GetRef<Stmt>(op));
     }
     if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
@@ -710,6 +714,12 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
   // LetStmt
   Stmt VisitStmt_(const LetStmtNode* op) final {
     PrimExpr value = this->VisitExpr(op->value);
+    // if visit of value triggers need scalarize
+    // we need to scalarize the let
+    if (need_scalarize_) {
+      need_scalarize_ = false;
+      Scalarize(GetRef<Stmt>(op));
+    }
     ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is 
binded twice";
     let_binding_[op->var] = value;
 
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py 
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index e02c227b05..7523cab549 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -14,13 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
+
 import tvm
 import tvm.testing
 from tvm import te
 from tvm.script import ir as I
 from tvm.script import tir as T
-import pytest
-
 
 simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu")
 sve_target = tvm.target.Target("llvm -device=arm_cpu 
-mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve")
@@ -312,6 +312,29 @@ def test_vectorize_if_then_else_vector(extent, target):
         tvm.ir.assert_structural_equal(mod, After)
 
 
+def test_vectorize_let_if_then_else():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main():
+            for i in T.vectorized(4):
+                if i < 2:
+                    result: T.int32 = T.if_then_else(i < 1, 1, 2)
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main():
+            for i_s in range(4):
+                if i_s < 2:
+                    result: T.int32 = T.if_then_else(i_s < 1, 1, 2)
+                    T.evaluate(0)
+
+    with tvm.target.Target(simple_target):
+        mod = tvm.tir.transform.VectorizeLoop()(Before)
+        tvm.ir.assert_structural_equal(mod, After)
+
+
 def test_vectorize_while_fail():
     """A while loop inside a vectorized loop should fail."""
 

Reply via email to