This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 885fc27390 [TVMScript][TIR] Pretty print TIR LLVM function name 
(#15953)
885fc27390 is described below

commit 885fc27390b5d0b902cfc17049363a2c68e2ac80
Author: Balint Cristian <cristian.bal...@gmail.com>
AuthorDate: Wed Oct 25 11:37:35 2023 +0300

    [TVMScript][TIR] Pretty print TIR LLVM function name (#15953)
    
    This allows printing of the LLVM function real name in TIR printer.
    Prior to this a counter-intuitive T.int32() value was printed instead of 
the real name.
    Changes
    
    Before: T.call_llvm_pure_intrin("int32x4", T.uint32(62), T.uint32(0))
    After: T.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))
    
    This is part of #15918 .
---
 python/tvm/tir/op.py                               |  9 ++++----
 src/script/printer/tir/expr.cc                     | 25 ++++++++++++++++++++++
 tests/python/unittest/test_tir_ops.py              |  7 ++++++
 .../python/unittest/test_tvmscript_printer_tir.py  |  7 ++++++
 4 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 905d14296d..d7df2a4bb6 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -16,7 +16,6 @@
 # under the License.
 # pylint: disable=redefined-builtin, invalid-name
 """Operators used in TIR expression."""
-import warnings
 from typing import Any, Optional
 
 import tvm._ffi
@@ -251,7 +250,7 @@ def call_llvm_intrin(dtype, name, *args, span=None):
        The name of the llvm intrinsic function.
 
     args : list
-       Poistional arguments.
+       Positional arguments.
 
     span : Optional[Span]
         The location of this operator in the source code.
@@ -271,7 +270,7 @@ def call_llvm_intrin(dtype, name, *args, span=None):
     else:
         llvm_id = name
     if llvm_id == 0:
-        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back 
to 0")
+        raise ValueError(f"Unknown llvm intrinsic function {name}")
     return call_intrin(
         dtype,
         Op.get("tir.call_llvm_intrin"),
@@ -293,7 +292,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
        The name of the llvm intrinsic function.
 
     args : list
-       Poistional arguments.
+       Positional arguments.
 
     span : Optional[Span]
         The location of this operator in the source code.
@@ -313,7 +312,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     else:
         llvm_id = name
     if llvm_id == 0:
-        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back 
to 0")
+        raise ValueError(f"Unknown llvm intrinsic function {name}")
     return call_intrin(
         dtype,
         Op.get("tir.call_llvm_pure_intrin"),
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index 8de142f861..e25b074401 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -250,6 +250,31 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           dtype_print_location =
               
static_cast<tir::ScriptDtypePrintLocation>(dtype_locations[op].IntValue());
         }
+        if (name == "call_llvm_pure_intrin" || name == "call_llvm_intrin") {
+          int n_args = call->args.size();
+          int64_t id = call->args[0].as<IntImmNode>()->value;
+          auto f_llvm_lookup_intrinsic_name =
+              tvm::runtime::Registry::Get("target.llvm_get_intrinsic_name");
+
+          Array<ExprDoc> args;
+          args.reserve(n_args + 1);
+          if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) {
+            args.push_back(LiteralDoc::DataType(call->dtype, 
call_p->Attr("dtype")));
+          }
+
+          for (int i = 0; i < n_args; ++i) {
+            if ((i == 0) && (f_llvm_lookup_intrinsic_name)) {
+              String name = (*f_llvm_lookup_intrinsic_name)(id);
+              args.push_back(LiteralDoc::Str(name.c_str(), 
call_p->Attr("args")->ArrayIndex(i)));
+            } else {
+              args.push_back(d->AsDoc<ExprDoc>(call->args[i], 
call_p->Attr("args")->ArrayIndex(i)));
+            }
+          }
+          if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) {
+            args.push_back(LiteralDoc::DataType(call->dtype, 
call_p->Attr("dtype")));
+          }
+          return prefix->Call(args);
+        }
       } else if (call->op.as<GlobalVarNode>()) {
         prefix = d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
       } else {
diff --git a/tests/python/unittest/test_tir_ops.py 
b/tests/python/unittest/test_tir_ops.py
index 21981d1f0b..8cffe8171a 100644
--- a/tests/python/unittest/test_tir_ops.py
+++ b/tests/python/unittest/test_tir_ops.py
@@ -234,5 +234,12 @@ def test_comm_reducer(num_args):
     assert tvm.tir.max(*range(num_args)) == num_args - 1
 
 
+def test_llvm_intrin():
+    with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function 
llvm.dummy"):
+        a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy", 0)
+    with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function 
llvm.dummy"):
+        a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy", 0)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py 
b/tests/python/unittest/test_tvmscript_printer_tir.py
index 70d56e6903..0636a79334 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -504,6 +504,13 @@ T.Cast("float64", a)
     )
 
 
+def test_llvm_intrin_imm():
+    a = tir.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))
+    _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing", 
T.uint32(0))')
+    a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))
+    _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing", 
T.uint32(0))')
+
+
 def test_binary_arith():
     a = tir.Var("a", "int32")
     b = tir.Var("b", "int32")

Reply via email to