gemini-code-assist[bot] commented on code in PR #18771:
URL: https://github.com/apache/tvm/pull/18771#discussion_r2801975951
##########
src/s_tir/transform/profile_instrumentation.cc:
##########
@@ -25,14 +25,15 @@
// and can be used to capture profiling information such as processor cycles.
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
namespace lwp {
TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_disable_func_prof", Bool);
Review Comment:

The pass configuration options are still using the `tir.` prefix (e.g.,
`tir.lwp_disable_func_prof`). Since this pass has been moved to the `s_tir`
namespace, these should be updated to use the `s_tir.` prefix for consistency
(e.g., `s_tir.lwp_disable_func_prof`). This should be applied to all
`TVM_REGISTER_PASS_CONFIG_OPTION` calls in this file, and the corresponding
usages in tests should be updated as well.
```suggestion
TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_disable_func_prof", Bool);
```
##########
python/tvm/s_tir/transform/transform.py:
##########
@@ -253,3 +253,174 @@ def InjectDoubleBuffer():
The result pass
"""
return _ffi_api.InjectDoubleBuffer() # type: ignore
+
+
+def HoistIfThenElse(variant=None):
+ """Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
+
+ Parameters
+ ----------
+ variant : Optional[String]
+ The variant of the pass.
+ variant can have any one of following values ["basic", None(Default)].
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ if variant == "basic":
+ return _ffi_api.HoistIfThenElseBasic() # type: ignore
+ elif variant is None:
+ return _ffi_api.HoistIfThenElse() # type: ignore
+ else:
+ raise ValueError("wrong variant of HoistIfThenElse, " + variant)
+
+
+def RenormalizeSplitPattern():
+ """Renormalize the split pattern from floordiv(floormod()) to
floormod(floordiv())
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RenormalizeSplitPattern() # type: ignore
+
+
+def RewriteUnsafeSelect():
+ """Detect and rewrite unsafe select that contains memory access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RewriteUnsafeSelect() # type: ignore
+
+
+def InstrumentBoundCheckers():
+ """Instruments bound checkers.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentBoundCheckers() # type: ignore
+
+
+def InjectPTXLDG32(enable_inject_ptx_intrin=True):
+ """Inject ptx.ldg.32 intrinsics.
+
+ Parameters
+ ----------
+ enable_inject_ptx_intrin : bool
+ If True, inject ptx.ldg.32 intrinsics.
+ """
+ return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore
+
+
+def InstrumentProfileIntrinsics():
+ """Insert intrinsic calls to instrument function and loop level profiling.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentProfileIntrinsics() # type: ignore
+
+
+def VerifyVTCMLimit(limit=None):
+ """Verify if the size of the allocated vtcm memory satisfies the limit.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VerifyVTCMLimit(limit) # type: ignore
+
+
+def LowerVtcmAlloc():
+ """Lower vtcm allocation.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerVtcmAlloc() # type: ignore
+
+
+def ThreadSync(storage_scope):
+ """Insert sync between parallel read/write of shared buffers.
+
+ Parameters
+ ----------
+ storage_scope: str
+ The target storage scope.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.ThreadSync(storage_scope) # type: ignore
+
+
+def InferFragment():
+ """Infer the TensorCore fragment infomation using tensor intrinsics.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InferFragment() # type: ignore
+
+
+def LowerThreadAllreduce():
+ """Lower cross thread alleduce.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerThreadAllreduce() # type: ignore
+
+
+def LowerAsyncDMA():
+ """Lower async DMA to DMA.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerAsyncDMA() # type: ignore
+
+
+def InjectPTXAsyncCopy():
+ """Rewrite global to shared memory copy on CUDA with asyncronous copy.
Review Comment:

There's a typo in the docstring. `asyncronous` should be `asynchronous`.
```python
"""Rewrite global to shared memory copy on CUDA with asynchronous copy.
```
##########
tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py:
##########
@@ -306,13 +306,13 @@ def func(data: T.handle("float32")):
mod = tvm.IRModule.from_expr(func)
stmt = mod["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
Review Comment:

The pass configuration key should be updated from `"tir.HoistIfThenElse"` to
`"s_tir.HoistIfThenElse"` to match the refactoring in
`src/s_tir/transform/hoist_expression.cc` where
`TVM_REGISTER_PASS_CONFIG_OPTION` was updated. Without this change, the test
may not be using the intended configuration. This issue occurs multiple times
in this test file.
```suggestion
config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
```
##########
python/tvm/s_tir/transform/transform.py:
##########
@@ -253,3 +253,174 @@ def InjectDoubleBuffer():
The result pass
"""
return _ffi_api.InjectDoubleBuffer() # type: ignore
+
+
+def HoistIfThenElse(variant=None):
+ """Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
+
+ Parameters
+ ----------
+ variant : Optional[String]
+ The variant of the pass.
+ variant can have any one of following values ["basic", None(Default)].
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ if variant == "basic":
+ return _ffi_api.HoistIfThenElseBasic() # type: ignore
+ elif variant is None:
+ return _ffi_api.HoistIfThenElse() # type: ignore
+ else:
+ raise ValueError("wrong variant of HoistIfThenElse, " + variant)
+
+
+def RenormalizeSplitPattern():
+ """Renormalize the split pattern from floordiv(floormod()) to
floormod(floordiv())
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RenormalizeSplitPattern() # type: ignore
+
+
+def RewriteUnsafeSelect():
+ """Detect and rewrite unsafe select that contains memory access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RewriteUnsafeSelect() # type: ignore
+
+
+def InstrumentBoundCheckers():
+ """Instruments bound checkers.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentBoundCheckers() # type: ignore
+
+
+def InjectPTXLDG32(enable_inject_ptx_intrin=True):
+ """Inject ptx.ldg.32 intrinsics.
+
+ Parameters
+ ----------
+ enable_inject_ptx_intrin : bool
+ If True, inject ptx.ldg.32 intrinsics.
+ """
+ return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore
+
+
+def InstrumentProfileIntrinsics():
+ """Insert intrinsic calls to instrument function and loop level profiling.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentProfileIntrinsics() # type: ignore
+
+
+def VerifyVTCMLimit(limit=None):
+ """Verify if the size of the allocated vtcm memory satisfies the limit.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VerifyVTCMLimit(limit) # type: ignore
+
+
+def LowerVtcmAlloc():
+ """Lower vtcm allocation.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerVtcmAlloc() # type: ignore
+
+
+def ThreadSync(storage_scope):
+ """Insert sync between parallel read/write of shared buffers.
+
+ Parameters
+ ----------
+ storage_scope: str
+ The target storage scope.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.ThreadSync(storage_scope) # type: ignore
+
+
+def InferFragment():
+ """Infer the TensorCore fragment infomation using tensor intrinsics.
Review Comment:

There's a typo in the docstring. `infomation` should be `information`.
```python
"""Infer the TensorCore fragment information using tensor intrinsics.
```
##########
python/tvm/s_tir/transform/transform.py:
##########
@@ -253,3 +253,174 @@ def InjectDoubleBuffer():
The result pass
"""
return _ffi_api.InjectDoubleBuffer() # type: ignore
+
+
+def HoistIfThenElse(variant=None):
+ """Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
+
+ Parameters
+ ----------
+ variant : Optional[String]
+ The variant of the pass.
+ variant can have any one of following values ["basic", None(Default)].
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ if variant == "basic":
+ return _ffi_api.HoistIfThenElseBasic() # type: ignore
+ elif variant is None:
+ return _ffi_api.HoistIfThenElse() # type: ignore
+ else:
+ raise ValueError("wrong variant of HoistIfThenElse, " + variant)
+
+
+def RenormalizeSplitPattern():
+ """Renormalize the split pattern from floordiv(floormod()) to
floormod(floordiv())
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RenormalizeSplitPattern() # type: ignore
+
+
+def RewriteUnsafeSelect():
+ """Detect and rewrite unsafe select that contains memory access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RewriteUnsafeSelect() # type: ignore
+
+
+def InstrumentBoundCheckers():
+ """Instruments bound checkers.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentBoundCheckers() # type: ignore
+
+
+def InjectPTXLDG32(enable_inject_ptx_intrin=True):
+ """Inject ptx.ldg.32 intrinsics.
+
+ Parameters
+ ----------
+ enable_inject_ptx_intrin : bool
+ If True, inject ptx.ldg.32 intrinsics.
+ """
+ return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore
+
+
+def InstrumentProfileIntrinsics():
+ """Insert intrinsic calls to instrument function and loop level profiling.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentProfileIntrinsics() # type: ignore
+
+
+def VerifyVTCMLimit(limit=None):
+ """Verify if the size of the allocated vtcm memory satisfies the limit.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VerifyVTCMLimit(limit) # type: ignore
Review Comment:

The parameter name `limit` is misleading. The underlying C++ pass
`VerifyVTCMLimit` takes an optional `tvm.target.Target`, not a numerical limit.
The limit is then derived from the target's "vtcm-capacity" attribute. To avoid
confusion, please rename the parameter to `default_target` and update the
docstring to clarify its purpose.
```python
def VerifyVTCMLimit(default_target=None):
"""Verify if the size of the allocated vtcm memory satisfies the limit.
The limit is determined from the "vtcm-capacity" attribute of the target.
Parameters
----------
default_target : Optional[tvm.target.Target]
The default target to use if a PrimFunc does not have a target
attribute.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.VerifyVTCMLimit(default_target) # type: ignore
```
##########
python/tvm/s_tir/transform/transform.py:
##########
@@ -253,3 +253,174 @@ def InjectDoubleBuffer():
The result pass
"""
return _ffi_api.InjectDoubleBuffer() # type: ignore
+
+
+def HoistIfThenElse(variant=None):
+ """Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
+
+ Parameters
+ ----------
+ variant : Optional[String]
+ The variant of the pass.
+ variant can have any one of following values ["basic", None(Default)].
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ if variant == "basic":
+ return _ffi_api.HoistIfThenElseBasic() # type: ignore
+ elif variant is None:
+ return _ffi_api.HoistIfThenElse() # type: ignore
+ else:
+ raise ValueError("wrong variant of HoistIfThenElse, " + variant)
+
+
+def RenormalizeSplitPattern():
+ """Renormalize the split pattern from floordiv(floormod()) to
floormod(floordiv())
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RenormalizeSplitPattern() # type: ignore
+
+
+def RewriteUnsafeSelect():
+ """Detect and rewrite unsafe select that contains memory access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RewriteUnsafeSelect() # type: ignore
+
+
+def InstrumentBoundCheckers():
+ """Instruments bound checkers.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentBoundCheckers() # type: ignore
+
+
+def InjectPTXLDG32(enable_inject_ptx_intrin=True):
+ """Inject ptx.ldg.32 intrinsics.
+
+ Parameters
+ ----------
+ enable_inject_ptx_intrin : bool
+ If True, inject ptx.ldg.32 intrinsics.
+ """
+ return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore
+
+
+def InstrumentProfileIntrinsics():
+ """Insert intrinsic calls to instrument function and loop level profiling.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentProfileIntrinsics() # type: ignore
+
+
+def VerifyVTCMLimit(limit=None):
+ """Verify if the size of the allocated vtcm memory satisfies the limit.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VerifyVTCMLimit(limit) # type: ignore
+
+
+def LowerVtcmAlloc():
+ """Lower vtcm allocation.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerVtcmAlloc() # type: ignore
+
+
+def ThreadSync(storage_scope):
+ """Insert sync between parallel read/write of shared buffers.
+
+ Parameters
+ ----------
+ storage_scope: str
+ The target storage scope.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.ThreadSync(storage_scope) # type: ignore
+
+
+def InferFragment():
+ """Infer the TensorCore fragment infomation using tensor intrinsics.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InferFragment() # type: ignore
+
+
+def LowerThreadAllreduce():
+ """Lower cross thread alleduce.
Review Comment:

There's a typo in the docstring. `alleduce` should be `allreduce`.
```python
"""Lower cross thread allreduce.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]