ekalda commented on code in PR #16782:
URL: https://github.com/apache/tvm/pull/16782#discussion_r1551560120


##########
tests/python/tir-transform/test_tir_transform_vectorize.py:
##########
@@ -64,28 +61,86 @@ def test_vectorize_vector():
     assert isinstance(stmt.body.value, tvm.tir.Broadcast)
 
 
-def test_vectorize_with_if():
-    n = te.var("n")
-    x = te.var("x")
-    ib = tvm.tir.ir_builder.create()
-    A = ib.pointer("float32", name="A")
-    with ib.for_range(0, 4, kind="vectorize") as i:
-        with ib.if_scope(x < n):
-            A[i] = A[i] + 1
-        with ib.else_scope():
-            with ib.if_scope(i < n):
-                A[i] = 2.0
-    stmt = ib.get()
+def test_vectorize_vector_scalable_error():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            for j in T.vectorized(T.vscale() * 4):
+                A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4)
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
-    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+    with pytest.raises(tvm.error.InternalError):
+        tvm.tir.transform.VectorizeLoop()(Module)
+
+
+def test_vectorize_vector_scalable_error2():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32xvscalex4")):
+            for j in T.vectorized(4):
+                A[j] = T.Broadcast(T.float32(1), T.vscale() * 4)
+
+    with pytest.raises(tvm.error.InternalError):

Review Comment:
   Done



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -37,19 +37,36 @@
 namespace tvm {
 namespace tir {
 
-// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
-  if (e.dtype().lanes() == lanes) return e;
+inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
+  if (is_scalable) {
+    return Mul(Call(DataType::Int(32), builtin::vscale(), {}), 
lanes_or_vscale_factor);
+  } else {
+    return lanes_or_vscale_factor;
+  }
+}
+
+inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
+  // Check if e is already in the expected form
+  if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
+      e.dtype().is_scalable_vector() == is_scalable)
+    return e;
+
   if (const BroadcastNode* op = e.as<BroadcastNode>()) {
-    ICHECK(!e.dtype().is_scalable_vector());
-    int broadcast_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
-    if (lanes % broadcast_lanes == 0) {
-      return Broadcast(op->value, lanes);
+    ICHECK(op->dtype.is_scalable_vector() == is_scalable)
+        << "Can't broadcast between scalable and fixed length vectors.";
+    int e_lanes = is_scalable ? op->dtype.vscale_factor() : op->dtype.lanes();

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to