This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b5fda2d93a [TIR] Ramp and Broadcast lanes fixed to int32 dtype (#16795)
b5fda2d93a is described below

commit b5fda2d93ab91b08753b23eff92916f097bd2620
Author: Anirudh Sundar Subramaniam <quic_sanir...@quicinc.com>
AuthorDate: Mon Apr 1 17:37:11 2024 +0530

    [TIR] Ramp and Broadcast lanes fixed to int32 dtype (#16795)
    
    * [TIR] Ramp and Broadcast lanes fixed to int32 dtype
    
    When Ramp and Broadcast nodes are created with fixed length lanes,
    they're fixed to int32 dtype since DLDataType always supports only
    uint16 lanes.
    
    * Add test cases for int64 type lanes
    
    * Update test case with int64 iterators
---
 src/tir/ir/expr.cc                                |  8 ++++++--
 tests/python/arith/test_arith_rewrite_simplify.py | 15 +++++++++++++++
 tests/python/tir-base/test_tir_nodes.py           | 10 ++++++++++
 3 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index c2baad2096..90dad72039 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -449,16 +449,18 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr 
lanes, Span span) {
     int lanes = static_cast<int>(lanes_as_int->value);
     ICHECK_GT(lanes, 1);
     node->dtype = base.dtype().with_lanes(lanes);
+    // Stick to int32 lanes for fixed length vectors
+    node->lanes = lanes;
   } else { /* scalable vector */
     std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
     ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;
 
     node->dtype = 
base.dtype().with_scalable_vscale_factor(vscale_factor.value());
     lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), 
vscale_factor.value());
+    node->lanes = lanes;
   }
   node->base = base;
   node->stride = stride;
-  node->lanes = lanes;
   node->span = std::move(span);
   data_ = std::move(node);
 }
@@ -481,15 +483,17 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span 
span) {
     int lanes = static_cast<int>(lanes_int->value);
     ICHECK_GT(lanes, 1);
     node->dtype = value.dtype().with_lanes(lanes);
+    // Stick to int32 lanes for fixed length vectors
+    node->lanes = lanes;
   } else { /* scalable vector */
     std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
     ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;
 
     node->dtype = 
value.dtype().with_scalable_vscale_factor(vscale_factor.value());
     lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), 
vscale_factor.value());
+    node->lanes = lanes;
   }
   node->value = std::move(value);
-  node->lanes = lanes;
   node->span = std::move(span);
   data_ = node;
 }
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index 8645e5b26a..9cc44aa6a2 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -75,6 +75,7 @@ class BaseCompare:
 
 class TestVector(BaseCompare):
     x, y, z = te.var("x"), te.var("y"), te.var("z")
+    x64 = te.var("x", dtype="int64")
     vx = te.var("vx", dtype="int32x2")
     vc = te.var("vc", dtype="uint1")
     test_case = tvm.testing.parameter(
@@ -88,6 +89,20 @@ class TestVector(BaseCompare):
         ),
         TestCase(y.astype("int32x2") + x.astype("int32x2"), (y + 
x).astype("int32x2")),
         TestCase(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)),
+        # int64 lanes
+        TestCase(
+            tvm.tir.Broadcast(x, 4) + tvm.tir.Ramp(0, 1, 
tvm.tir.IntImm(dtype="int64", value=4)),
+            tvm.tir.Ramp(x, 1, 4),
+        ),
+        TestCase(
+            tvm.tir.Broadcast(x, tvm.tir.IntImm(dtype="int64", value=4)) + 
tvm.tir.Ramp(0, 1, 4),
+            tvm.tir.Ramp(x, 1, 4),
+        ),
+        # int64 iterators with int32 lanes
+        TestCase(
+            tvm.tir.Broadcast(x64, 4) + 
tvm.tir.Ramp(tvm.tir.IntImm(dtype="int64", value=0), 1, 4),
+            tvm.tir.Ramp(x64, 1, 4),
+        ),
         TestCase(
             tvm.tir.Broadcast(0, tir.vscale() * 8) + y, tvm.tir.Broadcast(y, 
tir.vscale() * 8)
         ),
diff --git a/tests/python/tir-base/test_tir_nodes.py 
b/tests/python/tir-base/test_tir_nodes.py
index 60f8278ec2..31a1317e68 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -409,6 +409,16 @@ def _create_broadcast(lanes):
     return tvm.tir.Broadcast(0, lanes)
 
 
+@pytest.mark.parametrize("lanes", [(tvm.tir.IntImm(dtype="int64", value=11))])
+@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
+def test_lane_types(lanes, node_func):
+    def _check_dtype(node):
+        assert node.lanes.dtype == "int32"
+        assert node.lanes == 11
+
+    _check_dtype(node_func(lanes))
+
+
 @pytest.mark.parametrize("lanes", [(11 * tvm.tir.vscale()), (tvm.tir.vscale() 
* 11)])
 @pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
 def test_scalable_vec(lanes, node_func):

Reply via email to