This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch fix-tirx-cuda-sm-guards in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 7b1c7124aee60dfc3d5a482d4399dbcbd0ca4f59 Author: spectrometerHBH <[email protected]> AuthorDate: Tue Jun 9 14:07:04 2026 -0400 [Tests] Guard TIRX CUDA tests by compute capability --- tests/python/tirx/codegen/test_codegen_cuda.py | 5 +++++ tests/python/tirx/codegen/test_codegen_dsmem.py | 2 ++ tests/python/tirx/codegen/test_codegen_hopper.py | 1 + .../tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py | 3 +++ .../operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py | 7 +++++++ .../operator/tile_primitive/cuda/gemm_async/test_gemm_async.py | 10 ++++++++++ .../operator/tile_primitive/cuda/reduction/test_reduction.py | 1 + 7 files changed, 29 insertions(+) diff --git a/tests/python/tirx/codegen/test_codegen_cuda.py b/tests/python/tirx/codegen/test_codegen_cuda.py index f253d6d375..cd5baaac33 100644 --- a/tests/python/tirx/codegen/test_codegen_cuda.py +++ b/tests/python/tirx/codegen/test_codegen_cuda.py @@ -87,6 +87,7 @@ def test_serial_pragma_unroll_codegen(): assert "break;" in src [email protected]_cuda_compute_version(9) def test_cluster_cta_id_codegen_uses_coordinate_sregs(): @T.prim_func def main(A: T.Buffer((1,), "int32")): @@ -160,6 +161,7 @@ def test_ptx_ld_acquire_and_volatile_codegen(): assert "ld.volatile.global.u64" in src [email protected]_cuda_compute_version(10) def test_megamoe_extracted_intrinsics_codegen(): @T.prim_func def main( @@ -265,6 +267,7 @@ def test_megamoe_extracted_intrinsics_codegen(): assert snippet in src [email protected]_cuda_compute_version(9) def test_ptx_cp_async_bulk_non_tma_form_codegen(): @T.prim_func def main( @@ -304,6 +307,7 @@ def test_tensor_map_param_codegen(): assert "((unsigned long long)(&(A_map)))" in src [email protected]_cuda_compute_version(9) def test_tma_cache_policy_operand_codegen(): @T.prim_func def main(Cache: T.Buffer((1,), "uint64")): @@ -537,6 +541,7 @@ def test_warp_shuffle_xor_sync(): @pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256]) @pytest.mark.parametrize("predicate", [-1, T.int32(0), T.int32(1)]) @pytest.mark.parametrize("fill_mode", ["", "zero"]) [email protected]_cuda_compute_version(9) def test_ptx_cp_async(cp_size, cache_hint, prefetch_size, predicate, fill_mode): if fill_mode != "" and predicate == -1: return diff --git a/tests/python/tirx/codegen/test_codegen_dsmem.py b/tests/python/tirx/codegen/test_codegen_dsmem.py index d538be571f..ed4f1e7e18 100644 --- a/tests/python/tirx/codegen/test_codegen_dsmem.py +++ b/tests/python/tirx/codegen/test_codegen_dsmem.py @@ -30,6 +30,7 @@ def _get_source(func: tvm.tirx.PrimFunc) -> str: return src [email protected]_cuda_compute_version(9) def test_ptx_cp_async_bulk_s2c_codegen(): """Test that T.ptx.cp_async.bulk.s2c emits the correct PTX instruction.""" @@ -58,6 +59,7 @@ def test_ptx_cp_async_bulk_s2c_codegen(): assert "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes" in src [email protected]_cuda_compute_version(9) def test_ptx_cp_async_bulk_s2c_codegen_address_conversion(): """Test that the codegen correctly converts addresses to shared space.""" diff --git a/tests/python/tirx/codegen/test_codegen_hopper.py b/tests/python/tirx/codegen/test_codegen_hopper.py index 8f14dfc3c2..90b1921503 100644 --- a/tests/python/tirx/codegen/test_codegen_hopper.py +++ b/tests/python/tirx/codegen/test_codegen_hopper.py @@ -139,6 +139,7 @@ def test_stmatrix_sync_aligned(trans): @pytest.mark.parametrize("trans", [False, True]) @pytest.mark.parametrize("num", [1, 2, 4]) [email protected]_cuda_compute_version(9) def test_ptx_stmatrix(trans, num): # fmt: off @T.prim_func diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py index 0f910a4376..af180e15cc 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py @@ -30,6 +30,7 @@ from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg @pytest.mark.parametrize("dtype", ["float16", "float32"]) @pytest.mark.parametrize("width_32b", [4, 8, 16, 32]) [email protected]_cuda_compute_version(10) def test_copy_tmem2reg_async(dtype, width_32b): """Test async tmem<->local copy using copy_async instead of copy. @@ -135,6 +136,7 @@ def test_copy_tmem2reg_async(dtype, width_32b): @pytest.mark.parametrize("dtype", ["uint8", "float16", "float32"]) @pytest.mark.parametrize("width_32b", [2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("offset_32b", [0, 3, 10]) [email protected]_cuda_compute_version(10) def test_copy_tmem2reg(dtype, width_32b, offset_32b): def next_power_of_2(x): if x <= 1: @@ -227,6 +229,7 @@ def test_copy_tmem2reg(dtype, width_32b, offset_32b): @pytest.mark.parametrize("dtype", ["float16", "float32"]) @pytest.mark.parametrize("width_32b", [4, 8, 16, 32]) @pytest.mark.parametrize("local_offset_32b", [0, 2, 4]) [email protected]_cuda_compute_version(10) def test_copy_tmem2reg_sliced_local(dtype, width_32b, local_offset_32b): """tmem<->local copy with a sliced local buffer region.""" diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py index 4209359460..eab1b83d89 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py @@ -155,6 +155,7 @@ def _expected_reg_value_16b( @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32]) # subset; full reps below @pytest.mark.parametrize("dtype", ["float32"]) [email protected]_cuda_compute_version(10) def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype): """Bit-exact verification of ``tcgen05.<shape>.x<rep>.b32`` load.""" if rep not in _SHAPE_REPS[shape]: @@ -170,6 +171,7 @@ def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype): ("16x128b", 64), ], ) [email protected]_cuda_compute_version(10) def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep): """High-rep entries that aren't in the parametrize-cross above.""" _run_load_test(shape, rep, "float32") @@ -178,6 +180,7 @@ def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep): @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) [email protected]_cuda_compute_version(10) def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype): """Self-consistent round-trip for 16-bit pack::16b path. @@ -204,6 +207,7 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype): @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"]) @pytest.mark.parametrize("rep", [1, 2, 4]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) [email protected]_cuda_compute_version(10) def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, dtype): if rep not in _SHAPE_REPS[shape]: pytest.skip(f"rep {rep} not valid for {shape}") @@ -217,6 +221,7 @@ def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, dtype): @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"]) @pytest.mark.parametrize("rep", [1, 2, 4]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) [email protected]_cuda_compute_version(10) def test_tcgen05_16xnb_roundtrip_16b_layout_F(shape, rep, dtype): if rep not in _SHAPE_REPS[shape]: pytest.skip(f"rep {rep} not valid for {shape}") @@ -642,6 +647,7 @@ def _run_load_test(shape: str, rep: int, dtype: str): @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 4, 16]) @pytest.mark.parametrize("dtype", ["float32"]) [email protected]_cuda_compute_version(10) def test_tcgen05_st_16xnb_store(shape, rep, dtype): """Round-trip test: write the M=64 fragment via .<shape>.x<rep>.st then read via the standard .32x32b path; verify the host-known fragment data ends up @@ -807,6 +813,7 @@ def test_tcgen05_st_16xnb_store(shape, rep, dtype): ("16x256b", 64, 64), # .16x256b.x8 fp32 ], ) [email protected]_cuda_compute_version(10) def test_alloc_tcgen05_frag_wrapper_compiles(shape, frag_rows, K_cols): """Ensure T.alloc_tcgen05_ldst_frag yields a buffer that ``T.copy_async`` accepts and lowers to the correct tcgen05 atom for each supported instr_shape.""" diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py index 8c32bbe048..359bbbe171 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py @@ -179,6 +179,7 @@ def pack_sf_fp8_uint32(sf_uint8, n_total=128): ) ], ) [email protected]_cuda_compute_version(10) def test_gemm_tcgen05_cta_group_1(task): ( (C_shape, C_dtype, C_region), @@ -293,6 +294,7 @@ def test_gemm_tcgen05_cta_group_1(task): np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) [email protected]_cuda_compute_version(10) def test_gemm_tcgen05_cta_group_1_layout_f_m64(): """M=64 MMA with C operand allocated as Layout F (datapath="F"). @@ -417,6 +419,7 @@ def test_gemm_tcgen05_cta_group_1_layout_f_m64(): ) ], ) [email protected]_cuda_compute_version(10) def test_gemm_tcgen05_cta_group_2(task): ( (C_shape, C_dtype, C_region), @@ -545,6 +548,7 @@ def test_gemm_tcgen05_cta_group_2(task): np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) [email protected]_cuda_compute_version(10) def test_gemm_tcgen05_cta_group_2_layout_b(): """Test cta_group=2 with Layout B (2x2 datapath, M=128 total, 64 per CTA). @@ -689,6 +693,7 @@ def test_gemm_tcgen05_cta_group_2_layout_b(): ) ], ) [email protected]_cuda_compute_version(10) def test_gemm_block_scaled_fp8_cta_group_1(task): """Test block-scaled fp8 GEMM with cta_group=1 using gemm_async op. @@ -882,6 +887,7 @@ def test_gemm_block_scaled_fp8_cta_group_1(task): ) ], ) [email protected]_cuda_compute_version(10) def test_gemm_block_scaled_fp8_cta_group_2(task): """Test block-scaled fp8 GEMM with cta_group=2 using gemm_async op. @@ -1090,6 +1096,7 @@ def test_gemm_block_scaled_fp8_cta_group_2(task): @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") [email protected]_cuda_compute_version(10) def test_gemm_block_scaled_nvfp4_cta_group_1(): """Test block-scaled nvfp4 GEMM with cta_group=1. @@ -1259,6 +1266,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_1(): @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") [email protected]_cuda_compute_version(10) def test_gemm_block_scaled_nvfp4_cta_group_2(): """Test block-scaled nvfp4 GEMM with cta_group=2. @@ -1463,6 +1471,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_2(): @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") [email protected]_cuda_compute_version(10) def test_gemm_block_scaled_fp8_sf_id(): """Test sf_id auto-derivation from layout for fp8 block-scaled MMA. @@ -1809,6 +1818,7 @@ def test_gemm_block_scaled_fp8_sf_id(): "transA_kmajor_smem", ], ) [email protected]_cuda_compute_version(10) def test_gemm_tcgen05_arbitrary_tiles(task): """Test arbitrary tile decomposition for tcgen05 gemm_async. diff --git a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py index 0474ad2dc4..92077fa449 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py @@ -687,6 +687,7 @@ def test_reduction_local_optimized_3input_maxmin(reduction_len, op_type, accum): @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 9, 17, 63, 65, 100]) @pytest.mark.parametrize("accum", [False, True]) [email protected]_cuda_compute_version(10) def test_reduction_local_optimized_packed_add_sum(reduction_len, accum): """Test thread-level sum reduction using packed add with add.f32x2 PTX instruction.""" dev = tvm.cuda(0)
