This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cae6cb89b7 [IR] Add annotations to Call nodes (#19597)
cae6cb89b7 is described below

commit cae6cb89b732fd7f874b8fabc5fdba95edb41339
Author: Shushi Hong <[email protected]>
AuthorDate: Sun May 24 18:57:37 2026 -0400

    [IR] Add annotations to Call nodes (#19597)
    
    This PR adds annotation support to `tirx.Call` so downstream codegen
    users can attach call-level metadata and preserve it through TIRX
    transforms.
    
    What changed:
    - Add `CallNode::annotations` and expose it through reflection.
    - Add Python `tvm.tirx.Call(..., annotations=...)` support.
    - Preserve call annotations in C++ and Python expression mutators.
    - Preserve annotations across TIRX/arith passes that rebuild equivalent
    calls.
    - Print annotated calls as `Tx.Call(..., annotations={...})` and support
    script roundtrip.
    - Add regression coverage for annotated calls, mutator preservation,
    script roundtrip, and simplify preservation.
    
    This pr also cleans some stuff that #19596 didn't clean completely
---
 include/tvm/tirx/expr.h                          | 17 ++++-
 include/tvm/tirx/op.h                            | 28 ++++-----
 python/tvm/rpc/tracker.py                        |  4 +-
 python/tvm/tirx/expr.py                          | 18 +++++-
 python/tvm/tirx/expr_functor.py                  |  2 +-
 python/tvm/tirx/op.py                            | 29 ++++++---
 src/arith/ir_mutator_with_analyzer.cc            |  2 +-
 src/arith/rewrite_simplify.cc                    |  3 +-
 src/s_tir/transform/inject_software_pipeline.cc  |  6 +-
 src/tirx/ir/data_type_rewriter.cc                |  6 +-
 src/tirx/ir/expr.cc                              | 79 ++++++++++++++----------
 src/tirx/ir/expr_functor.cc                      |  2 +-
 src/tirx/ir/stmt.cc                              |  2 +-
 src/tirx/op/op.cc                                | 51 +++++++--------
 src/tirx/script/printer/expr.cc                  | 14 +++++
 src/tirx/transform/lower_warp_memory.cc          |  2 +-
 src/tirx/transform/storage_rewrite.cc            |  3 +-
 src/tirx/transform/tile_primitive_dispatch.cc    |  2 +-
 src/tirx/transform/unsupported_dtype_legalize.cc |  5 +-
 src/tirx/transform/vectorize_loop.cc             | 18 +++---
 tests/python/contrib/test_rpc_tracker.py         |  4 +-
 tests/python/tirx-base/test_tir_constructor.py   | 57 +++++++++++++++++
 22 files changed, 239 insertions(+), 115 deletions(-)

diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h
index db54984c82..68cfcbd361 100644
--- a/include/tvm/tirx/expr.h
+++ b/include/tvm/tirx/expr.h
@@ -731,9 +731,20 @@ class CallNode : public PrimExprNode {
   /*! \brief The arguments. */
   ffi::Array<PrimExpr> args;
 
+  /*!
+   * \brief Additional annotations about the call.
+   *
+   * These annotations can be used to carry target-specific metadata through
+   * TIRX transformations and codegen.
+   */
+  ffi::Map<ffi::String, ffi::Any> annotations;
+
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<CallNode>().def_ro("op", &CallNode::op).def_ro("args", 
&CallNode::args);
+    refl::ObjectDef<CallNode>()
+        .def_ro("op", &CallNode::op)
+        .def_ro("args", &CallNode::args)
+        .def_ro("annotations", &CallNode::annotations);
   }
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode);
 };
@@ -744,7 +755,9 @@ class CallNode : public PrimExprNode {
  */
 class Call : public PrimExpr {
  public:
-  TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span 
span = Span());
+  TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
+               ffi::Map<ffi::String, ffi::Any> annotations = 
ffi::Map<ffi::String, ffi::Any>(),
+               Span span = Span());
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
 };
diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h
index 9093c2c453..549aab4df8 100644
--- a/include/tvm/tirx/op.h
+++ b/include/tvm/tirx/op.h
@@ -736,19 +736,19 @@ inline void CheckMathUnaryOpInputDType(const char* 
op_name, DataType dtype) {
 }
 
 // Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType)     \
-  inline PrimExpr OpName(PrimExpr x, Span span = Span()) {               \
-    static const Op& op = Op::Get("tirx." #OpName);                      \
-    CheckInputDType(#OpName, x.dtype());                                 \
-    if (x.dtype().is_bfloat16()) {                                       \
-      DataType bf16_dtype = x.dtype();                                   \
-      DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes());             \
-      PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span);               \
-      PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, span); \
-      return tirx::Cast(bf16_dtype, {result_fp32}, span);                \
-    } else {                                                             \
-      return tirx::Call(x.dtype(), op, {x}, span);                       \
-    }                                                                    \
+#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType)         \
+  inline PrimExpr OpName(PrimExpr x, Span span = Span()) {                   \
+    static const Op& op = Op::Get("tirx." #OpName);                          \
+    CheckInputDType(#OpName, x.dtype());                                     \
+    if (x.dtype().is_bfloat16()) {                                           \
+      DataType bf16_dtype = x.dtype();                                       \
+      DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes());                 \
+      PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span);                   \
+      PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, {}, span); \
+      return tirx::Cast(bf16_dtype, {result_fp32}, span);                    \
+    } else {                                                                 \
+      return tirx::Call(x.dtype(), op, {x}, {}, span);                       \
+    }                                                                        \
   }
 
 #define TVM_DECLARE_INTRIN_UNARY(OpName) \
@@ -786,7 +786,7 @@ TVM_DECLARE_INTRIN_UNARY(clz);
 #define TVM_DECLARE_INTRIN_BINARY(OpName)                              \
   inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
     static const Op& op = Op::Get("tirx." #OpName);                    \
-    return tirx::Call(x.dtype(), op, {x, y}, span);                    \
+    return tirx::Call(x.dtype(), op, {x, y}, {}, span);                \
   }
 
 TVM_DECLARE_INTRIN_BINARY(atan2);
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index 0714c64fc9..1af2a26985 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -248,9 +248,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
                 try:
                     self.call_handler(json.loads(msg))
                 except Exception:  # pylint: disable=broad-except
-                    logger.warning(
-                        "Error handling message from %s", self.name(), 
exc_info=True
-                    )
+                    logger.warning("Error handling message from %s", 
self.name(), exc_info=True)
                     self.close()
                     return
             else:
diff --git a/python/tvm/tirx/expr.py b/python/tvm/tirx/expr.py
index e3c341c4e9..bf04b3ae7a 100644
--- a/python/tvm/tirx/expr.py
+++ b/python/tvm/tirx/expr.py
@@ -1302,13 +1302,22 @@ class Call(PrimExprWithOp):
 
     span : Optional[Span]
         The location of this expression in the source code.
+
+    annotations : Optional[dict]
+        Additional metadata attached to the call.
     """
 
     op: Op
     args: list[PrimExpr]
+    annotations: dict
 
     def __init__(
-        self, dtype: str, op: Op | str, args: list[PrimExpr], span: Span | 
None = None
+        self,
+        dtype: str,
+        op: Op | str,
+        args: list[PrimExpr],
+        annotations: dict | None = None,
+        span: Span | None = None,
     ) -> None:
         if isinstance(op, str):
             if not op.startswith("tirx."):
@@ -1321,7 +1330,12 @@ class Call(PrimExprWithOp):
                     % op
                 )
             op = Op.get(op)
-        self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, 
span)  # type: ignore
+        if annotations:
+            self.__init_handle_by_constructor__(  # type: ignore
+                _ffi_api.CallWithAnnotations, dtype, op, args, annotations, 
span
+            )
+        else:
+            self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, 
args, span)  # type: ignore
 
 
 @tvm_ffi.register_object("tirx.Let")
diff --git a/python/tvm/tirx/expr_functor.py b/python/tvm/tirx/expr_functor.py
index e89ed19c1e..b09606602a 100644
--- a/python/tvm/tirx/expr_functor.py
+++ b/python/tvm/tirx/expr_functor.py
@@ -495,7 +495,7 @@ class ExprMutator(ExprFunctor):
         if all(old_arg is new_arg for old_arg, new_arg in zip(op.args, args)):
             return op
         else:
-            return tvm.tirx.Call(op.dtype, op.op, args)
+            return tvm.tirx.Call(op.dtype, op.op, args, 
annotations=op.annotations, span=op.span)
 
     def _mutate_binary_op(self, op_cls, op):
         """Helper to mutate binary operators."""
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 924bec91dc..2f227195b3 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -62,8 +62,10 @@ tir = tirx  # alias for backward compat with upstream 
tir.convert() calls
 
 def _pack_buffer(buf, span=None):
     """Build intrinsics that packs the buffer."""
-    shape = Call("handle", "tirx.tvm_stack_make_shape", buf.shape, span)
-    strides = Call("handle", "tirx.tvm_stack_make_shape", buf.strides, span) 
if buf.strides else 0
+    shape = Call("handle", "tirx.tvm_stack_make_shape", buf.shape, span=span)
+    strides = (
+        Call("handle", "tirx.tvm_stack_make_shape", buf.strides, span=span) if 
buf.strides else 0
+    )
     pack_args = [
         buf.data,
         shape,
@@ -72,7 +74,7 @@ def _pack_buffer(buf, span=None):
         const(0, dtype=buf.dtype),
         buf.elem_offset,
     ]
-    return Call("handle", Op.get("tirx.tvm_stack_make_array"), pack_args, span)
+    return Call("handle", Op.get("tirx.tvm_stack_make_array"), pack_args, 
span=span)
 
 
 def call_packed_lowered(*args, span=None):
@@ -101,7 +103,7 @@ def call_packed_lowered(*args, span=None):
     te.extern : Create tensor with extern function call.
     """
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
-    return Call("int32", Op.get("tirx.tvm_call_packed_lowered"), call_args, 
span)
+    return Call("int32", Op.get("tirx.tvm_call_packed_lowered"), call_args, 
span=span)
 
 
 def call_cpacked_lowered(*args, span=None):
@@ -127,7 +129,7 @@ def call_cpacked_lowered(*args, span=None):
     te.extern : Create tensor with extern function call.
     """
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
-    return Call("int32", Op.get("tirx.tvm_call_cpacked_lowered"), call_args, 
span)
+    return Call("int32", Op.get("tirx.tvm_call_cpacked_lowered"), call_args, 
span=span)
 
 
 def call_packed(*args, span=None):
@@ -158,7 +160,7 @@ def call_packed(*args, span=None):
     te.extern : Create tensor with extern function call.
     """
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
-    return Call("int32", Op.get("tirx.tvm_call_packed"), call_args, span)
+    return Call("int32", Op.get("tirx.tvm_call_packed"), call_args, span=span)
 
 
 def call_cpacked(*args, span=None):
@@ -185,10 +187,10 @@ def call_cpacked(*args, span=None):
     te.extern : Create tensor with extern function call.
     """
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
-    return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span)
+    return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span=span)
 
 
-def call_intrin(dtype, func_name, *args, span=None):
+def call_intrin(dtype, func_name, *args, annotations=None, span=None):
     """Build expression by calling an intrinsic function.
 
     Intrinsics can be overloaded with multiple data types via
@@ -205,6 +207,9 @@ def call_intrin(dtype, func_name, *args, span=None):
     args : list
         Positional arguments.
 
+    annotations : Optional[Dict[str, Object]]
+        Additional annotations about the call.
+
     span : Optional[Span]
         The location of this operator in the source code.
 
@@ -213,7 +218,11 @@ def call_intrin(dtype, func_name, *args, span=None):
     call : PrimExpr
         The call expression.
     """
-    return Call(dtype, func_name, args, span)
+    if annotations is not None:
+        annotations = {
+            k: const(v) if isinstance(v, int | bool) else v for k, v in 
annotations.items()
+        }
+    return Call(dtype, func_name, args, annotations=annotations, span=span)
 
 
 def call_pure_extern(dtype, func_name, *args, span=None):
@@ -238,7 +247,7 @@ def call_pure_extern(dtype, func_name, *args, span=None):
     call : PrimExpr
         The call expression.
     """
-    return Call(dtype, Op.get("tirx.call_pure_extern"), [func_name, *args], 
span)
+    return Call(dtype, Op.get("tirx.call_pure_extern"), [func_name, *args], 
span=span)
 
 
 def call_extern(dtype, func_name, *args, span=None):
diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index 1d35da952f..e902d32aba 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -310,7 +310,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* 
op) {
         false_value.same_as(op->args[2])) {
       return ffi::GetRef<PrimExpr>(op);
     } else {
-      return Call(op->dtype, op->op, {cond, true_value, false_value});
+      return Call(op->dtype, op->op, {cond, true_value, false_value}, 
op->annotations, op->span);
     }
   }
   return StmtExprMutator::VisitExpr_(op);
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index c9fec7f599..1765a6b04a 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2343,7 +2343,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
CallNode* op) {
       // Only check constant cases to avoid recursion
       if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
           analyzer_->CanProve(inner_else_expr == else_expr)) {
-        return if_then_else(cond && inner_cond, inner_then_expr, else_expr);
+        return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr, 
else_expr},
+                    op->annotations, op->span);
       }
     }
   }
diff --git a/src/s_tir/transform/inject_software_pipeline.cc 
b/src/s_tir/transform/inject_software_pipeline.cc
index 1512644052..717b9b7dc8 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -119,7 +119,7 @@ class PipelineOpaqueAccessRewriter {
         ffi::Array<PrimExpr> new_args = call->args;
         const Buffer& new_buffer = (*it).second;
         new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, 
call->args[4]));
-        return Call(call->dtype, call->op, new_args, call->span);
+        return Call(call->dtype, call->op, new_args, call->annotations, 
call->span);
       }
     } else if (call->op.same_as(mma_sync)) {
       ffi::Array<PrimExpr> new_args = call->args;
@@ -133,7 +133,7 @@ class PipelineOpaqueAccessRewriter {
           new_args.Set(i * 2 + 1, new_index);
         }
       }
-      return Call(call->dtype, call->op, new_args, call->span);
+      return Call(call->dtype, call->op, new_args, call->annotations, 
call->span);
     } else if (call->op.same_as(access_ptr)) {
       return RewriteBufferAccess(call, {1});
     } else if (call->op.same_as(ptx_mma)) {
@@ -196,7 +196,7 @@ class PipelineOpaqueAccessRewriter {
         new_args.Set(i + 1, new_index);
       }
     }
-    return Call(call->dtype, call->op, new_args, call->span);
+    return Call(call->dtype, call->op, new_args, call->annotations, 
call->span);
   }
 
   const ffi::Map<Var, Buffer>& buffer_data_to_buffer_;
diff --git a/src/tirx/ir/data_type_rewriter.cc 
b/src/tirx/ir/data_type_rewriter.cc
index 901d18e5c4..6fab0e3e09 100644
--- a/src/tirx/ir/data_type_rewriter.cc
+++ b/src/tirx/ir/data_type_rewriter.cc
@@ -248,7 +248,8 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
   } else if (op->op.same_as(builtin_pow_)) {
     return pow(op->args[0], op->args[1]);
   } else if (op->op.same_as(builtin::if_then_else())) {
-    return if_then_else(op->args[0], op->args[1], op->args[2]);
+    return Call(op->dtype, op->op, {op->args[0], op->args[1], op->args[2]}, 
op->annotations,
+                op->span);
   } else if (op->op.same_as(Op::Get("tirx.clz"))) {
     DataType before_dtype = before->args[0]->dtype;
     DataType after_dtype = op->args[0]->dtype;
@@ -564,7 +565,8 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* 
op) {
     is_condition_ = true;
     PrimExpr cond = VisitExpr(op->args[0]);
     is_condition_ = is_condition;
-    return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2]));
+    return Call(op->dtype, op->op, {cond, VisitExpr(op->args[1]), 
VisitExpr(op->args[2])},
+                op->annotations, op->span);
   }
   return Parent::VisitExpr_(op);
 }
diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc
index 3248c009b4..5faa7c24bd 100644
--- a/src/tirx/ir/expr.cc
+++ b/src/tirx/ir/expr.cc
@@ -590,7 +590,39 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 }
 
 // Call
-Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span) 
{
+using CallArg = ffi::Variant<ffi::String, DLDataType, IterVar, BufferRegion, 
PrimExpr>;
+
+static ffi::Array<PrimExpr> ConvertCallArgs(ffi::Array<CallArg> args) {
+  ffi::Array<PrimExpr> prim_expr_args;
+  for (const auto& it : args) {
+    if (auto opt_str = it.as<ffi::String>()) {
+      prim_expr_args.push_back(StringImm(opt_str.value()));
+    } else if (auto opt_dtype = it.as<DLDataType>()) {
+      
prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value())));
+    } else if (const auto* iter_var = it.as<IterVarNode>()) {
+      prim_expr_args.push_back(iter_var->var);
+    } else if (const auto* br = it.as<BufferRegionNode>()) {
+      ffi::Array<PrimExpr> indices;
+      for (Range r : br->region) {
+        if (is_one(r->extent)) {
+          indices.push_back(r->min);
+        } else if (r->extent.as<IntImmNode>()) {
+          indices.push_back(tirx::Ramp(r->min, make_const(r->min->dtype, 1), 
r->extent));
+        } else {
+          TVM_FFI_THROW(ValueError)
+              << "Cannot convert to BufferLoad: " << 
ffi::GetRef<BufferRegion>(br);
+        }
+      }
+      prim_expr_args.push_back(BufferLoad(br->buffer, indices));
+    } else {
+      prim_expr_args.push_back(Downcast<PrimExpr>(it));
+    }
+  }
+  return prim_expr_args;
+}
+
+Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
+           ffi::Map<ffi::String, ffi::Any> annotations, Span span) {
   for (size_t i = 0; i < args.size(); ++i) {
     TVM_FFI_ICHECK(args[i].defined()) << "arg " << i << " is not defined()";
   }
@@ -599,44 +631,25 @@ Call::Call(DataType dtype, RelaxExpr op, 
ffi::Array<PrimExpr> args, Span span) {
   node->dtype = dtype;
   node->op = std::move(op);
   node->args = std::move(args);
+  node->annotations = std::move(annotations);
   node->span = std::move(span);
   data_ = std::move(node);
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def(
-      "tirx.Call",
-      [](ffi::Optional<DataType> dtype, RelaxExpr op,
-         ffi::Array<ffi::Variant<ffi::String, DLDataType, IterVar, 
BufferRegion, PrimExpr>> args,
-         Span span) {
-        ffi::Array<PrimExpr> prim_expr_args;
-        for (const auto& it : args) {
-          if (auto opt_str = it.as<ffi::String>()) {
-            prim_expr_args.push_back(StringImm(opt_str.value()));
-          } else if (auto opt_dtype = it.as<DLDataType>()) {
-            
prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value())));
-          } else if (const auto* iter_var = it.as<IterVarNode>()) {
-            prim_expr_args.push_back(iter_var->var);
-          } else if (const auto* br = it.as<BufferRegionNode>()) {
-            ffi::Array<PrimExpr> indices;
-            for (Range r : br->region) {
-              if (is_one(r->extent)) {
-                indices.push_back(r->min);
-              } else if (r->extent.as<IntImmNode>()) {
-                indices.push_back(tirx::Ramp(r->min, make_const(r->min->dtype, 
1), r->extent));
-              } else {
-                TVM_FFI_THROW(ValueError)
-                    << "Cannot convert to BufferLoad: " << 
ffi::GetRef<BufferRegion>(br);
-              }
-            }
-            prim_expr_args.push_back(BufferLoad(br->buffer, indices));
-          } else {
-            prim_expr_args.push_back(Downcast<PrimExpr>(it));
-          }
-        }
-        return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, 
span);
-      });
+  refl::GlobalDef()
+      .def("tirx.Call",
+           [](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg> 
args, Span span) {
+             return Call(dtype.value_or(DataType::Void()), op, 
ConvertCallArgs(args),
+                         ffi::Map<ffi::String, ffi::Any>(), span);
+           })
+      .def("tirx.CallWithAnnotations",
+           [](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg> 
args,
+              ffi::Optional<ffi::Map<ffi::String, ffi::Any>> annotations, Span 
span) {
+             return Call(dtype.value_or(DataType::Void()), op, 
ConvertCallArgs(args),
+                         annotations.value_or(ffi::Map<ffi::String, 
ffi::Any>()), span);
+           });
 }
 
 // Shuffle
diff --git a/src/tirx/ir/expr_functor.cc b/src/tirx/ir/expr_functor.cc
index dc9913060e..921ce45623 100644
--- a/src/tirx/ir/expr_functor.cc
+++ b/src/tirx/ir/expr_functor.cc
@@ -155,7 +155,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
   if (args.same_as(op->args)) {
     return ffi::GetRef<PrimExpr>(op);
   } else {
-    return Call(op->dtype, op->op, args);
+    return Call(op->dtype, op->op, args, op->annotations, op->span);
   }
 }
 
diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc
index 1a9abe6ca8..1943135926 100644
--- a/src/tirx/ir/stmt.cc
+++ b/src/tirx/ir/stmt.cc
@@ -674,7 +674,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 PrimExpr TypeAnnotation(DataType dtype, Span span) {
   static auto op = Op::Get("tirx.type_annotation");
-  return tirx::Call(dtype, op, {}, span);
+  return tirx::Call(dtype, op, {}, {}, span);
 }
 
 TVM_TIRX_REGISTER_OP("type_annotation")
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index c2772ad69f..bc500a54cc 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -128,14 +128,14 @@ Type GetTypeFromRuntimeDataType(const DataType& dtype) {
 PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) {
   return tirx::Call(
       t, tirx::builtin::large_uint_imm(),
-      {make_const(DataType::UInt(32), low, span), 
make_const(DataType::UInt(32), high, span)},
+      {make_const(DataType::UInt(32), low, span), 
make_const(DataType::UInt(32), high, span)}, {},
       span);
 }
 
 // Q-multiplication
 PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span 
span) {
   return tirx::Call(DataType::Int(32, x.dtype().lanes()), 
tirx::builtin::q_multiply_shift(),
-                    {x, y, q, s}, span);
+                    {x, y, q, s}, {}, span);
 }
 
 void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) {  // NOLINT(*)
@@ -263,19 +263,19 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, 
Span span) {  // NOLINT(*)
 
 PrimExpr ret(PrimExpr value, Span span) {
   TVM_FFI_ICHECK(value.defined());
-  return tirx::Call(value.dtype(), tirx::builtin::ret(), {value}, span);
+  return tirx::Call(value.dtype(), tirx::builtin::ret(), {value}, {}, span);
 }
 
 PrimExpr thread_return(Span span) {
-  return tirx::Call(DataType::Void(), tirx::builtin::thread_return(), {}, 
span);
+  return tirx::Call(DataType::Void(), tirx::builtin::thread_return(), {}, {}, 
span);
 }
 
 PrimExpr continue_loop(Span span) {
-  return tirx::Call(DataType::Void(), tirx::builtin::continue_loop(), {}, 
span);
+  return tirx::Call(DataType::Void(), tirx::builtin::continue_loop(), {}, {}, 
span);
 }
 
 PrimExpr break_loop(Span span) {
-  return tirx::Call(DataType::Void(), tirx::builtin::break_loop(), {}, span);
+  return tirx::Call(DataType::Void(), tirx::builtin::break_loop(), {}, {}, 
span);
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
@@ -512,7 +512,7 @@ PrimExpr reinterpret(const DataType& t, PrimExpr value, 
Span span) {
                     value.dtype().bytes() * value.dtype().lanes() == t.bytes() 
* t.lanes()))
         << "Reinterpret requires size match " << t << " vs " << value.dtype();
   }
-  return tirx::Call(t, tirx::builtin::reinterpret(), {value}, span);
+  return tirx::Call(t, tirx::builtin::reinterpret(), {value}, {}, span);
 }
 
 // operator+
@@ -654,13 +654,13 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, 
PrimExpr false_value,
   }
 
   return tirx::Call(true_value.dtype(), tirx::builtin::if_then_else(),
-                    {cond, true_value, false_value}, span);
+                    {cond, true_value, false_value}, {}, span);
 }
 
 // likely
 PrimExpr likely(PrimExpr cond, Span span) {
   if (is_const_int(cond)) return cond;
-  return tirx::Call(cond.dtype(), tirx::builtin::likely(), {cond}, span);
+  return tirx::Call(cond.dtype(), tirx::builtin::likely(), {cond}, {}, span);
 }
 
 // operator>
@@ -786,7 +786,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
     }
   });
 
-  return tirx::Call(a.dtype(), tirx::builtin::shift_right(), {a, b}, span);
+  return tirx::Call(a.dtype(), tirx::builtin::shift_right(), {a, b}, {}, span);
 }
 
 // shift left
@@ -805,7 +805,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
       if (pb->value == 0) return a;
     }
   });
-  return tirx::Call(a.dtype(), tirx::builtin::shift_left(), {a, b}, span);
+  return tirx::Call(a.dtype(), tirx::builtin::shift_left(), {a, b}, {}, span);
 }
 
 // bitwise and
@@ -817,7 +817,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, (pa->value & pb->value), span);
   });
-  return tirx::Call(a.dtype(), tirx::builtin::bitwise_and(), {a, b}, span);
+  return tirx::Call(a.dtype(), tirx::builtin::bitwise_and(), {a, b}, {}, span);
 }
 
 // bitwise_or
@@ -829,7 +829,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, (pa->value | pb->value), span);
   });
-  return tirx::Call(a.dtype(), tirx::builtin::bitwise_or(), {a, b}, span);
+  return tirx::Call(a.dtype(), tirx::builtin::bitwise_or(), {a, b}, {}, span);
 }
 
 // bitwise_xor
@@ -841,7 +841,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value), span);
   });
-  return tirx::Call(a.dtype(), tirx::builtin::bitwise_xor(), {a, b}, span);
+  return tirx::Call(a.dtype(), tirx::builtin::bitwise_xor(), {a, b}, {}, span);
 }
 
 // bitwise_not
@@ -849,7 +849,7 @@ PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
 
 PrimExpr bitwise_neg(PrimExpr a, Span span) {
   type_check_int_or_bool_args(a, "~ operator (bitwise NOT)");
-  return tirx::Call(a.dtype(), tirx::builtin::bitwise_not(), {a}, span);
+  return tirx::Call(a.dtype(), tirx::builtin::bitwise_not(), {a}, {}, span);
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
@@ -889,7 +889,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) {
   }
 
   static auto op = Op::Get("tirx.pow");
-  return tirx::Call(x.dtype(), op, {x, y}, span);
+  return tirx::Call(x.dtype(), op, {x, y}, {}, span);
 }
 
 
TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr<TVectorizable>("TVectorizable", 
true);
@@ -910,7 +910,7 @@ PrimExpr abs(PrimExpr x, Span span) {
       return FloatImm(x.dtype(), std::fabs(fx->value), fx->span);
     }
     static auto op = Op::Get("tirx.fabs");
-    return tirx::Call(x.dtype(), op, {x}, span);
+    return tirx::Call(x.dtype(), op, {x}, {}, span);
   } else if (x.dtype().is_uint()) {
     return x;
   } else {
@@ -935,9 +935,10 @@ PrimExpr isnan(PrimExpr x, Span span) {
     }
     static auto op = Op::Get("tirx.isnan");
     if (x.dtype().bits() == 16) {
-      return tirx::Call(t, op, {cast(DataType::Float(32, t.lanes()), 
std::move(x), span)}, span);
+      return tirx::Call(t, op, {cast(DataType::Float(32, t.lanes()), 
std::move(x), span)}, {},
+                        span);
     } else {
-      return tirx::Call(t, op, {x}, span);
+      return tirx::Call(t, op, {x}, {}, span);
     }
   } else {
     TVM_FFI_THROW(InternalError) << "Data type " << x.dtype()
@@ -1025,7 +1026,7 @@ PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) {
   BinaryOpMatchTypes(x, y, span);
   TVM_FFI_ICHECK(x.dtype().is_float()) << "fmod only applies to float";
   static auto op = Op::Get("tirx.fmod");
-  return tirx::Call(x.dtype(), op, {x, y}, span);
+  return tirx::Call(x.dtype(), op, {x, y}, {}, span);
 }
 
 TVM_TIR_REGISTER_PURE_UNARY_OP("fmod");
@@ -1039,7 +1040,7 @@ PrimExpr floor(PrimExpr x, Span span) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span);
   static auto op = Op::Get("tirx.floor");
-  return tirx::Call(x.dtype(), op, {x}, span);
+  return tirx::Call(x.dtype(), op, {x}, {}, span);
 }
 
 
TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr<TVectorizable>("TVectorizable",
 true);
@@ -1053,7 +1054,7 @@ PrimExpr ceil(PrimExpr x, Span span) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span);
   static auto op = Op::Get("tirx.ceil");
-  return tirx::Call(x.dtype(), op, {x}, span);
+  return tirx::Call(x.dtype(), op, {x}, {}, span);
 }
 
 
TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr<TVectorizable>("TVectorizable", 
true);
@@ -1067,7 +1068,7 @@ PrimExpr round(PrimExpr x, Span span) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span);
   static auto op = Op::Get("tirx.round");
-  return tirx::Call(x.dtype(), op, {x}, span);
+  return tirx::Call(x.dtype(), op, {x}, {}, span);
 }
 
 
TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr<TVectorizable>("TVectorizable",
 true);
@@ -1081,7 +1082,7 @@ PrimExpr nearbyint(PrimExpr x, Span span) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span);
   static auto op = Op::Get("tirx.nearbyint");
-  return tirx::Call(x.dtype(), op, {x}, span);
+  return tirx::Call(x.dtype(), op, {x}, {}, span);
 }
 
 TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint");
@@ -1098,7 +1099,7 @@ PrimExpr trunc(PrimExpr x, Span span) {
                     fx->span);
   }
   static auto op = Op::Get("tirx.trunc");
-  return tirx::Call(x.dtype(), op, {x}, span);
+  return tirx::Call(x.dtype(), op, {x}, {}, span);
 }
 
 
TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr<TVectorizable>("TVectorizable",
 true);
diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc
index 4b852cd4fa..2c9f5daed3 100644
--- a/src/tirx/script/printer/expr.cc
+++ b/src/tirx/script/printer/expr.cc
@@ -258,6 +258,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tirx::Call>("", [](tirx::Call call, AccessPath call_p, 
IRDocsifier d) -> Doc {
+      if (!call->annotations.empty()) {
+        ffi::Array<ExprDoc> call_args;
+        int n_args = call->args.size();
+        call_args.reserve(n_args);
+        for (int i = 0; i < n_args; ++i) {
+          call_args.push_back(d->AsDoc<ExprDoc>(call->args[i], 
call_p->Attr("args")->ArrayItem(i)));
+        }
+        ExprDoc op_doc = call->op.as<Op>()
+                             ? 
LiteralDoc::Str(call->op.as<Op>().value()->name, call_p->Attr("op"))
+                             : d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
+        return TIR(d, "Call")->Call(
+            {LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")), op_doc, 
ListDoc(call_args)},
+            {"annotations"}, {d->AsDoc<DictDoc>(call->annotations, 
call_p->Attr("annotations"))});
+      }
       static const OpAttrMap<tirx::TScriptPrinterName>& op_names =
           Op::GetAttrMap<tirx::TScriptPrinterName>("TScriptPrinterName");
       static const OpAttrMap<tirx::TScriptDtypePrintLocation> dtype_locations =
diff --git a/src/tirx/transform/lower_warp_memory.cc 
b/src/tirx/transform/lower_warp_memory.cc
index ed98c5dfe6..2c3d84fad6 100644
--- a/src/tirx/transform/lower_warp_memory.cc
+++ b/src/tirx/transform/lower_warp_memory.cc
@@ -291,7 +291,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
         new_args.Set(i + 1, local_index);
       }
     }
-    return Call(op->dtype, op->op, new_args);
+    return Call(op->dtype, op->op, new_args, op->annotations, op->span);
   }
 
   PrimExpr VisitExpr_(const CallNode* op) override {
diff --git a/src/tirx/transform/storage_rewrite.cc 
b/src/tirx/transform/storage_rewrite.cc
index da31b2f9f5..c0c4243bd3 100644
--- a/src/tirx/transform/storage_rewrite.cc
+++ b/src/tirx/transform/storage_rewrite.cc
@@ -496,7 +496,8 @@ class StoragePlanRewriter : public StmtExprMutator {
       if (se->bits_offset != 0) {
         offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + 
offset;
       }
-      return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, 
extent, op->args[4]});
+      return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, 
extent, op->args[4]},
+                  op->annotations, op->span);
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
diff --git a/src/tirx/transform/tile_primitive_dispatch.cc 
b/src/tirx/transform/tile_primitive_dispatch.cc
index 70509bd3e0..fbc7786d92 100644
--- a/src/tirx/transform/tile_primitive_dispatch.cc
+++ b/src/tirx/transform/tile_primitive_dispatch.cc
@@ -1160,7 +1160,7 @@ class TilePrimitiveDispatcher : public StmtExprMutator {
         args.push_back(new_arg);
       }
       if (changed) {
-        return tirx::Call(call->dtype, call->op, args, call->span);
+        return tirx::Call(call->dtype, call->op, args, call->annotations, 
call->span);
       }
     }
     return pred;
diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc 
b/src/tirx/transform/unsupported_dtype_legalize.cc
index 15f5876075..c2934c2d86 100644
--- a/src/tirx/transform/unsupported_dtype_legalize.cc
+++ b/src/tirx/transform/unsupported_dtype_legalize.cc
@@ -238,12 +238,13 @@ class ComputeLegalizer : public StmtExprMutator {
     auto fmutate = [this](const PrimExpr& e) { return 
PromoteToTarget(this->VisitExpr(e)); };
     ffi::Array<PrimExpr> args = op->args.Map(fmutate);
     if (MatchDType(op->dtype)) {
-      return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args);
+      return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args, 
op->annotations,
+                  op->span);
     }
     if (args.same_as(op->args)) {
       return ffi::GetRef<PrimExpr>(op);
     } else {
-      return Call(op->dtype, op->op, args);
+      return Call(op->dtype, op->op, args, op->annotations, op->span);
     }
   }
 
diff --git a/src/tirx/transform/vectorize_loop.cc 
b/src/tirx/transform/vectorize_loop.cc
index cdf0bddf4d..da90338956 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -491,9 +491,10 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       t = BroadcastTo(t, lanes, is_scalable);
       f = BroadcastTo(f, lanes, is_scalable);
       if (is_scalable) {
-        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{cond, t, f});
+        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{cond, t, f},
+                    op->annotations, op->span);
       } else {
-        return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
+        return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, 
op->annotations, op->span);
       }
     }
   }
@@ -506,13 +507,14 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     } else {
       int lanes = value.dtype().get_lanes_or_vscale_factor();
       if (value.dtype().is_scalable_vector()) {
-        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{value});
+        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{value}, op->annotations,
+                    op->span);
       } else {
         int new_lanes = (op->dtype != DataType::Float4E2M1FN() &&
                          op->args[0].dtype() != DataType::Float4E2M1FN())
                             ? (value.dtype().bits() * value.dtype().lanes()) / 
op->dtype.bits()
                             : value.dtype().lanes();
-        return Call(op->dtype.with_lanes(new_lanes), op->op, {value});
+        return Call(op->dtype.with_lanes(new_lanes), op->op, {value}, 
op->annotations, op->span);
       }
     }
   }
@@ -534,7 +536,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       auto new_args = op->args;
       new_args.pop_back();
       new_args.push_back(fcd[0]);
-      return Call(op->dtype.with_lanes(lane), op->op, new_args);
+      return Call(op->dtype.with_lanes(lane), op->op, new_args, 
op->annotations, op->span);
     } else if (op->op.same_as(builtin::texture2d_store())) {
       int lane = 0;
       // Vectorize the value to store
@@ -549,7 +551,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
           << "Expected Data to be Written equal to Texture Store length";
       ffi::Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
                                     op->args[3], op->args[4], 
mutated_value[0]};
-      return Call(op->dtype.with_lanes(lane), op->op, new_args);
+      return Call(op->dtype.with_lanes(lane), op->op, new_args, 
op->annotations, op->span);
     } else if (op->op.same_as(builtin::reinterpret())) {
       return MutateReinterpretExpr_(op);
     }
@@ -571,7 +573,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       if (op->args.same_as(new_args)) {
         return ffi::GetRef<PrimExpr>(op);
       } else {
-        return Call(op->dtype, op->op, new_args);
+        return Call(op->dtype, op->op, new_args, op->annotations, op->span);
       }
     } else {
       int lane = 0;
@@ -597,7 +599,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       if (op->args.same_as(new_args)) {
         return ffi::GetRef<PrimExpr>(op);
       } else {
-        return Call(op->dtype.with_lanes(lane), op->op, new_args);
+        return Call(op->dtype.with_lanes(lane), op->op, new_args, 
op->annotations, op->span);
       }
     }
   }
diff --git a/tests/python/contrib/test_rpc_tracker.py 
b/tests/python/contrib/test_rpc_tracker.py
index 486d5abce4..a5351b62a6 100644
--- a/tests/python/contrib/test_rpc_tracker.py
+++ b/tests/python/contrib/test_rpc_tracker.py
@@ -139,9 +139,7 @@ def check_tracker_rejects_oversized_msg_size():
                     break
                 time.sleep(0.05)
             else:
-                raise AssertionError(
-                    "tracker did not close connection after oversized msg_size"
-                )
+                raise AssertionError("tracker did not close connection after 
oversized msg_size")
         finally:
             tserver.terminate()
     except ImportError:
diff --git a/tests/python/tirx-base/test_tir_constructor.py 
b/tests/python/tirx-base/test_tir_constructor.py
index 16f85f9625..00cd63fa85 100644
--- a/tests/python/tirx-base/test_tir_constructor.py
+++ b/tests/python/tirx-base/test_tir_constructor.py
@@ -19,6 +19,19 @@ import pytest
 
 import tvm
 from tvm import te, topi
+from tvm.tirx.expr_functor import ExprMutator
+
+
+class ReplaceVar(ExprMutator):
+    def __init__(self, old_var, new_var):
+        super().__init__()
+        self.old_var = old_var
+        self.new_var = new_var
+
+    def visit_var_(self, op):
+        if op.same_as(self.old_var):
+            return self.new_var
+        return op
 
 
 def test_expr_constructor():
@@ -120,6 +133,50 @@ def test_expr_constructor():
     assert x.dtype == "float32"
     assert x.op.name == "tirx.call_extern"
     assert x.args[1] == a
+    assert len(x.annotations) == 0
+
+    annotated_arg = tvm.tirx.Var("annotated_arg", "float32")
+    x_with_annotations = tvm.tirx.Call(
+        "float32",
+        "tirx.call_extern",
+        [tvm.tirx.StringImm("xyz"), annotated_arg],
+        annotations={"disable_tma": True},
+    )
+    assert bool(x_with_annotations.annotations["disable_tma"])
+    assert not tvm.ir.structural_equal(x, x_with_annotations)
+    script = tvm.tirx.Evaluate(x_with_annotations).script()
+    assert "annotations" in script
+    assert "disable_tma" in script
+    func = tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(x_with_annotations))
+    assert tvm.script.from_source(func.script()).script() == func.script()
+
+    y = tvm.tirx.Var("y", "float32")
+    mutated = ReplaceVar(annotated_arg, y)(x_with_annotations)
+    assert bool(mutated.annotations["disable_tma"])
+    assert mutated.args[1].same_as(y)
+
+    x_from_intrin = tvm.tirx.call_intrin(
+        "float32", "tirx.call_extern", tvm.tirx.StringImm("xyz"), 
annotations={"disable_tma": True}
+    )
+    assert int(x_from_intrin.annotations["disable_tma"]) == 1
+
+    cond0 = tvm.tirx.Var("cond0", "bool")
+    cond1 = tvm.tirx.Var("cond1", "bool")
+    inner_if = tvm.tirx.Call(
+        "int32",
+        "tirx.if_then_else",
+        [cond1, tvm.tirx.IntImm("int32", 1), tvm.tirx.IntImm("int32", 0)],
+    )
+    outer_if = tvm.tirx.Call(
+        "int32",
+        "tirx.if_then_else",
+        [cond0, inner_if, tvm.tirx.IntImm("int32", 0)],
+        annotations={"keep": True},
+    )
+    simplified = tvm.tirx.transform.Simplify()(
+        tvm.IRModule({"main": tvm.tirx.PrimFunc([], 
tvm.tirx.Evaluate(outer_if))})
+    )["main"].body.value
+    assert bool(simplified.annotations["keep"])
 
     v = tvm.tirx.Var("aa", "int32")
     x = tvm.tirx.Let(v, 1, v)


Reply via email to