This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 685355e2c7 [Relax] Minor updates for NN frontend (#16558)
685355e2c7 is described below

commit 685355e2c7f4ae98342dadb6b4b6119066d8c305
Author: Siyuan Feng <hzfen...@sjtu.edu.cn>
AuthorDate: Wed Feb 14 02:34:18 2024 +0800

    [Relax] Minor updates for NN frontend (#16558)
    
    * [Relax] Minor updates for NN frontend
    
    This PR includes two changes:
    
    - expose GroupNorm in NN frontend
    - remove `tir_vars` from `tensor_ir_op` if not necessary. Not the
      current implementation is correct, but it would convert into
      `call_tir_dyn` even if there is an empty `tir_vars` list.
    
    * lint
---
 python/tvm/relax/frontend/nn/__init__.py  |  1 +
 python/tvm/relax/frontend/nn/op.py        |  3 ++
 tests/python/relax/test_frontend_nn_op.py | 50 +++++++++++++++++++++++++++----
 3 files changed, 48 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/__init__.py 
b/python/tvm/relax/frontend/nn/__init__.py
index 61d1001ea8..a8200d8dd6 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -25,6 +25,7 @@ from .modules import (
     Conv1D,
     ConvTranspose1D,
     Embedding,
+    GroupNorm,
     IOEffect,
     KVCache,
     LayerNorm,
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index fbca48f0ee..b6c34ca265 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1623,6 +1623,9 @@ def tensor_ir_op(
     bb = BlockBuilder.current()
     global_var = bb.add_func(func, name_hint)
 
+    if len(tir_vars) == 0:
+        tir_vars = None
+
     return wrap_nested(
         bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, 
tir_vars=tir_vars)),
         name=name_hint,
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index c74e06490f..650d8ace30 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -441,9 +441,10 @@ def test_timestep_embedding():
             get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = 
R.astype(
                 lv11, dtype="float32"
             )
-            gv1: R.Tuple(
-                R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)
-            ) = get_timestep_embedding, (_io,)
+            gv1: R.Tuple(R.Tensor((3, 10), dtype="float32"), 
R.Tuple(R.Object)) = (
+                get_timestep_embedding,
+                (_io,),
+            )
             R.output(gv1)
         return gv1
 
@@ -470,9 +471,10 @@ def test_scaled_dot_product_attention():
             scaled_dot_product_attention: R.Tensor(
                 (1, 32, 32, 32), dtype="float32"
             ) = R.nn.attention(query, key, value, scale=None, causal_mask=None)
-            gv1: R.Tuple(
-                R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)
-            ) = scaled_dot_product_attention, (_io,)
+            gv1: R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), 
R.Tuple(R.Object)) = (
+                scaled_dot_product_attention,
+                (_io,),
+            )
             R.output(gv1)
         return gv1
 
@@ -724,6 +726,42 @@ def test_tensor_ir_inplace_op():
     tvm.ir.assert_structural_equal(irmodule, Expected)
 
 
+def test_tensor_ir_op_no_tir_var():
+    @T.prim_func(private=True)
+    def tir_func(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), 
"float32")):
+        T.evaluate(0)
+
+    class Model(Module):
+        def test(self, A: Tensor):
+            tensor_expr_op_out = op.tensor_ir_op(
+                tir_func,
+                "tir_func",
+                args=[A],
+                out=[Tensor.placeholder((16, 16), "float32")],
+            )
+            return tensor_expr_op_out
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def tir_func(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), 
"float32")):
+            T.evaluate(0)
+
+        @R.function
+        def test(A: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 
16), dtype="float32"))
+                gv: R.Tensor((16, 16), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    m = Model()
+    irmodule, _ = m.export_tvm(spec={"test": {"A": spec.Tensor([16, 16], 
"float32")}})
+    tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
 def test_extern():
     class Model(Module):
         def test(self, q: Tensor, k: Tensor, v: Tensor):

Reply via email to