This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ee6e26f2cb [Unity][Op] Avoid indices in TIR matmul being 0 in 
legalization (#14701)
ee6e26f2cb is described below

commit ee6e26f2cb1159ab6a5cfc74b96b37b464b19797
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Apr 22 13:42:37 2023 -0400

    [Unity][Op] Avoid indices in TIR matmul being 0 in legalization (#14701)
    
    This PR changes a behavior of the legalization of matmul, so that we
    do not use 0 as indices in the generated TIR in certain case.
    
    Since the matmul op supports broadcasting and batching, previously when
    legalizing a matmul op, for the broadcasting dimensions, we will emit
    indices "0" for those broadcasting dimensions with length 1. For
    example, the code below is a TIR produced by matmul legalization.
    ```python
    @T.prim_func
    def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(5)), 
"float32"), B: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(7)), 
"float32"), matmul_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4), 
T.int64(7)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(4), 
T.int64(7), T.int64(5)):
            with T.block("matmul"):
                v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, 
i2, i3, k])
                T.reads(A[T.int64(0), T.int64(0), v_i2, v_k], B[T.int64(0), 
T.int64(0), v_k, v_i3])
                T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3])
                with T.init():
                    matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
                matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, 
v_i3] + A[T.int64(0), T.int64(0), v_i2, v_k] * B[T.int64(0), T.int64(0), v_k, 
v_i3]
    ```
    You can see us using `T.int64(0)` to index the first dim of `A` and `B`.
    
    However, when both `A` and `B` have length 1 at that dimension, it is
    more canonical to use a variable as the index, as this is more
    acceptable and detectable by analysis functions generally.
    
    Therefore, this PR updates the behavior, so that we will emit variable
    as indices when both sides have length 1, just as the example above.
    We have a unit test to demonstrate the effect after changing.
---
 .../relax/transform/legalize_ops/linear_algebra.py | 14 ++++++---
 ..._transform_legalize_ops_index_linear_algebra.py | 35 ++++++++++++++++++++++
 2 files changed, 45 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py 
b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
index a5e708dce7..7cc75bab1c 100644
--- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
+++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
@@ -59,10 +59,16 @@ def _matmul(bb: BlockBuilder, call: Call) -> Expr:
                 for i in range(offset, len(output_shape) - (2 - a_prepended - 
b_appended)):
                     a_dim = a_shape[i if is_a_larger else i - offset]
                     b_dim = b_shape[i if not is_a_larger else i - offset]
-                    a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1
-                    b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1
-                    a_indices.append(0 if a_dim_is_one else idx_spatial[i])
-                    b_indices.append(0 if b_dim_is_one else idx_spatial[i])
+                    dim_equal = a_dim == b_dim
+                    if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0:
+                        a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim 
== 1
+                        b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim 
== 1
+                        a_indices.append(0 if a_dim_is_one else idx_spatial[i])
+                        b_indices.append(0 if b_dim_is_one else idx_spatial[i])
+                    else:
+                        a_indices.append(idx_spatial[i])
+                        b_indices.append(idx_spatial[i])
+
                 if not a_prepended:
                     a_indices.append(idx_spatial[-2 + b_appended])
                 a_indices.append(idx_reduce)
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index d7c0b54af2..85ade3f140 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -886,6 +886,41 @@ def test_matmul_4_5_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_matmul_batching_dim_1():
+    # fmt: off
+    @tvm.script.ir_module
+    class Matmul:
+        @R.function
+        def main(x: R.Tensor((1, 1, 4, 5), "float32"), y: R.Tensor((1, 1, 5, 
7), "float32")) -> R.Tensor((1, 1, 4, 7), "float32"):
+            gv: R.Tensor((1, 1, 4, 7), "float32") = R.matmul(x, y, 
out_dtype="float32")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), 
T.int64(5)), "float32"), B: T.Buffer((T.int64(1), T.int64(1), T.int64(5), 
T.int64(7)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(1), 
T.int64(4), T.int64(7)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), 
T.int64(4), T.int64(7), T.int64(5)):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, 
i1, i2, i3, k])
+                    T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
+                    T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3])
+                    with T.init():
+                        matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
+                    matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, 
v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
+
+        @R.function
+        def main(x: R.Tensor((1, 1, 4, 5), dtype="float32"), y: R.Tensor((1, 
1, 5, 7), dtype="float32")) -> R.Tensor((1, 1, 4, 7), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((1, 1, 4, 
7), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Matmul)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_einsum():
     # fmt: off
     @I.ir_module

Reply via email to