This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3015acd7e6 [CUDA] Update FlashInfer JIT integration (#18353)
3015acd7e6 is described below

commit 3015acd7e678cd97e4334835e130c09b315ab1a7
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Oct 1 20:53:31 2025 -0400

    [CUDA] Update FlashInfer JIT integration (#18353)
    
    Following recent JIT refactor in FlashInfer that uses TVM FFI as
    the JIT interface, this PR updates the JIT integration of FlashInfer
    in TVM.
    
    Major changes:
    * we leverage FlashInfer's `JitSpec.build_and_load` to compile all
    the JIT-generated source files, and remove the compilation logic
    in TVM.
    * for efficient tensor buffer management and efficient pointer
    calculation, we enforced all `byte_offset` fields of auxiliary tensors
    in KV cache to be zeros. The byte offset is now directly applied to
    the data pointers.
    * we also add a new parameter to FlashInfer JIT that controls whether
    returning a linked shared library, or a list of compiled object paths.
    For unit tests, returning a shared library is convenient and preferred,
    while for cases such as MLC model compilation, object files are needed
    to serialize the compiled model.
---
 python/tvm/relax/backend/cuda/flashinfer.py        | 481 ++++++---------------
 python/tvm/relax/frontend/nn/llm/kv_cache.py       |  21 +-
 src/runtime/vm/attn_backend.cc                     |  11 +-
 src/runtime/vm/attn_backend.h                      | 213 +++++++--
 src/runtime/vm/attn_utils.h                        |  34 +-
 src/runtime/vm/paged_kv_cache.cc                   |   2 +-
 tests/python/relax/test_group_gemm_flashinfer.py   |  39 +-
 ..._builtin_paged_attention_kv_cache_flashinfer.py |  71 ++-
 ...ltin_paged_attention_kv_cache_mla_flashinfer.py |  69 ++-
 9 files changed, 396 insertions(+), 545 deletions(-)

diff --git a/python/tvm/relax/backend/cuda/flashinfer.py 
b/python/tvm/relax/backend/cuda/flashinfer.py
index 4e0fc3e854..6b5b1293ff 100644
--- a/python/tvm/relax/backend/cuda/flashinfer.py
+++ b/python/tvm/relax/backend/cuda/flashinfer.py
@@ -16,203 +16,36 @@
 # under the License.
 
 """FlashInfer JIT compilation module for CUDA backend"""
-import hashlib
-import json
-import os
-import subprocess
-from concurrent.futures import ThreadPoolExecutor
+import re
 from pathlib import Path
 from typing import List
 
-import tvm_ffi
-
 import tvm
 from tvm.target import Target
 
 
-def _compile_flashinfer_kernels(
-    name: str, source_paths: List[Path], target: Target, num_threads: int
-) -> List[Path]:
-    from flashinfer.jit.env import (  # pylint: disable=import-outside-toplevel
-        CUTLASS_INCLUDE_DIRS,
-        FLASHINFER_CSRC_DIR,
-        FLASHINFER_INCLUDE_DIR,
-        FLASHINFER_JIT_DIR,
-        FLASHINFER_TVM_BINDING_DIR,
-    )
-
-    # ------------------------------------------------------------------------
-    # Caching Flow: create build_directory and compute cache hash.
-    # ------------------------------------------------------------------------
-    build_directory = FLASHINFER_JIT_DIR / name
-    build_directory.mkdir(parents=True, exist_ok=True)
-
-    def get_object_file_path(src: Path) -> Path:
-        obj_name = src.stem + ".o"
-        obj_path = build_directory / obj_name
-        return obj_path
-
-    # Compute latest modification time among all source files
-    latest_src_mtime = max(src.stat().st_mtime for src in source_paths)
+def _rename_exported_func_names(source_paths: List[Path], prefix: str):
+    """Rename the ffi-exported function names in the source files to the given 
prefix."""
+    pattern = 
re.compile(r"^(\s*TVM_FFI_DLL_EXPORT_TYPED_FUNC\()([A-Za-z0-9_]+)(,.*)$")
+    for source_path in source_paths:
+        if not source_path.name.endswith("_binding.cu"):
+            continue
 
-    # Get modification time for the current file (the one that contains this 
function)
-    current_file_mtime = Path(__file__).stat().st_mtime
+        original_text = source_path.read_text(encoding="utf-8")
+        lines = original_text.splitlines(keepends=True)
+        updated = False
+        for idx, line in enumerate(lines):
+            line_body = line.rstrip("\r\n")
+            line_ending = line[len(line_body) :]
+            match = pattern.match(line_body)
+            if not match:
+                continue
+            new_body = 
f"{match.group(1)}{prefix}_{match.group(2)}{match.group(3)}"
+            lines[idx] = new_body + line_ending
+            updated = True
 
-    # Build the hash key from metadata
-    hash_key = {
-        "name": name,
-        "target": str(target),
-        "latest_src_mtime": latest_src_mtime,
-        "current_file_mtime": current_file_mtime,
-    }
-
-    hash_value = hashlib.md5(
-        json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8")
-    ).hexdigest()
-
-    # Check if a valid hash exists in the build directory
-    hash_file = build_directory / "hash.md5"
-    if hash_file.exists():
-        with open(hash_file, "r") as f:
-            cached_hash = f.read().strip()
-        if cached_hash == hash_value:
-            # Check that all object files exist
-            object_files = []
-            all_exist = True
-            for src in source_paths:
-                obj_path = get_object_file_path(src)
-                if not obj_path.exists():
-                    all_exist = False
-                    break
-                object_files.append(obj_path)
-            if all_exist:
-                return object_files
-
-    # If we are here, cache is missing or outdated. Write the new hash and 
compile the paths
-    with open(hash_file, "w") as f:
-        f.write(hash_value)
-
-    # ------------------------------------------------------------------------
-    # 1) Common CUDA compile flags
-    # ------------------------------------------------------------------------
-    cuda_cflags = [
-        "-O3",
-        "-std=c++17",
-        "--threads",
-        str(num_threads),
-        "-g",
-        "-use_fast_math",
-        "--expt-relaxed-constexpr",
-        # DMLC default
-        "-DDMLC_USE_FOPEN64=0",
-        "-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>",
-        # Enable `-fPIC` for the host compiler
-        "-Xcompiler=-fPIC",
-        "-DFLASHINFER_ENABLE_F16",
-        "-DFLASHINFER_ENABLE_BF16",
-        "-DFLASHINFER_ENABLE_FP8_E4M3",
-        "-DFLASHINFER_ENABLE_FP8_E5M2",
-    ]
-
-    # Determine compute version
-    compute_version = 
"".join(tvm.contrib.nvcc.get_target_compute_version(target).split("."))
-    if compute_version in ["90", "100"]:
-        compute_version += "a"
-    cuda_cflags += [
-        "-gencode",
-        f"arch=compute_{compute_version},code=sm_{compute_version}",
-    ]
-
-    # ------------------------------------------------------------------------
-    # 2) Include paths
-    # ------------------------------------------------------------------------
-    include_paths = [
-        FLASHINFER_INCLUDE_DIR,
-        FLASHINFER_CSRC_DIR,
-        FLASHINFER_TVM_BINDING_DIR,
-    ] + CUTLASS_INCLUDE_DIRS
-
-    if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME", 
None):
-        # Respect TVM_SOURCE_DIR and TVM_HOME if they are set
-        tvm_home = (
-            os.environ["TVM_SOURCE_DIR"]
-            if os.environ.get("TVM_SOURCE_DIR", None)
-            else os.environ["TVM_HOME"]
-        )
-        include_paths += [
-            Path(tvm_home).resolve() / "include",
-            Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "include",
-            Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "3rdparty" / 
"dlpack" / "include",
-            Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
-        ]
-    else:
-        # If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM 
package path
-        tvm_package_path = Path(tvm.__file__).resolve().parent
-        if (tvm_package_path / "include").exists():
-            # The package is installed from pip.
-            tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent
-            include_paths += [
-                tvm_package_path / "include",
-                tvm_package_path / "3rdparty" / "dmlc-core" / "include",
-                tvm_ffi_package_path / "include",
-            ]
-        elif (tvm_package_path.parent.parent / "include").exists():
-            # The package is installed from source.
-            include_paths += [
-                tvm_package_path.parent.parent / "include",
-                tvm_package_path.parent.parent / "3rdparty" / "tvm-ffi" / 
"include",
-                tvm_package_path.parent.parent
-                / "3rdparty"
-                / "tvm-ffi"
-                / "3rdparty"
-                / "dlpack"
-                / "include",
-                tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / 
"include",
-            ]
-        else:
-            # warning: TVM is not installed in the system.
-            print(
-                "Warning: Include path for TVM cannot be found. "
-                "FlashInfer kernel compilation may fail due to missing 
headers."
-            )
-
-    # ------------------------------------------------------------------------
-    # 3) Function to compile a single source file
-    # ------------------------------------------------------------------------
-    def compile_single_source(src: Path) -> Path:
-        # Derive the .o filename from the source filename
-        obj_path = get_object_file_path(src)
-
-        # Construct the command
-        cmd = (
-            ["nvcc"]
-            + cuda_cflags
-            + [f"-I{inc_path}" for inc_path in include_paths]
-            + ["-c", "-o", str(obj_path), str(src)]
-        )
-
-        proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, 
stderr=subprocess.PIPE)
-        out, err = proc.communicate()
-        if proc.returncode != 0:
-            raise RuntimeError(
-                f"FlashInfer JIT compilation failed for {src}\n"
-                f"Command: {' '.join(cmd)}\n"
-                f"stdout:\n{out.decode('utf-8')}\n"
-                f"stderr:\n{err.decode('utf-8')}"
-            )
-        return obj_path
-
-    # ------------------------------------------------------------------------
-    # 4) Compile each source in parallel using ThreadPoolExecutor
-    # ------------------------------------------------------------------------
-    object_files = []
-    with ThreadPoolExecutor(max_workers=num_threads) as executor:
-        futures = [executor.submit(compile_single_source, src) for src in 
source_paths]
-        for f in futures:
-            object_files.append(f.result())  # Will raise if there's a 
compilation error
-
-    # Return list of generated object files for any further linking steps
-    return object_files
+        if updated:
+            source_path.write_text("".join(lines), encoding="utf-8")
 
 
 def _load_flashinfer_modules(object_files: List[Path]) -> 
List[tvm.runtime.Module]:
@@ -228,9 +61,8 @@ def gen_flashinfer_prefill_module(
     dtype_o: str,
     qk_head_dim: int,
     v_head_dim: int,
-    target: Target,
-    enable_inline_rope: bool = True,
-    num_threads: int = 8,
+    enable_inline_rope: bool,
+    return_static_libs: bool = False,
 ) -> List[tvm.runtime.Module]:
     """Generate a FlashInfer module for prefill.
 
@@ -246,12 +78,12 @@ def gen_flashinfer_prefill_module(
         The head dimension of the query and key tensors.
     v_head_dim : int
         The head dimension of the value tensor.
-    target : Target
-        The target device to compile for.
     enable_inline_rope : bool
         Whether to enable inline rotary positional embedding.
-    num_threads : int
-        The number of threads to use for compilation.
+    return_static_libs : bool
+        Whether to return static library modules instead of compiled modules.
+        When it is False, it returns the loaded shared library that links all 
the object files.
+        When it is True, it returns the static libraries of each compiled 
object files.
 
     Returns
     -------
@@ -259,7 +91,7 @@ def gen_flashinfer_prefill_module(
     """
     try:
         from flashinfer.jit import (  # pylint: disable=import-outside-toplevel
-            gen_customize_batch_prefill_tvm_binding,
+            gen_customize_batch_prefill_module,
         )
     except ImportError:
         raise ImportError(
@@ -289,32 +121,33 @@ def gen_flashinfer_prefill_module(
         if backend == "fa2"
         else "#include <flashinfer/attention/hopper/variants.cuh>"
     )
-    jit_args = {
-        "backend": backend,
-        "uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_"
+    jit_spec = gen_customize_batch_prefill_module(
+        backend=backend,
+        uri=f"batch_prefill_tvm_dtype_q_{dtype_q}_"
         + f"dtype_kv_{dtype_kv}_"
         + f"dtype_o_{dtype_o}_"
         + f"qk_head_dim_{qk_head_dim}_"
         + f"v_head_dim_{v_head_dim}_"
         + f"enable_inline_rope_{enable_inline_rope}",
-        "dtype_q": torch_dtype_q,
-        "dtype_kv": torch_dtype_kv,
-        "dtype_o": torch_dtype_o,
-        "idtype": torch.int32,
-        "head_dim_qk": qk_head_dim,
-        "head_dim_vo": v_head_dim,
-        "additional_tensor_names": [],
-        "additional_tensor_dtypes": [],
-        "additional_scalar_names": ["sm_scale", "rope_rcp_scale", 
"rope_rcp_theta"],
-        "additional_scalar_dtypes": ["double", "double", "double"],
-        "variant_name": variant_name,
-        "variant_decl": variant_decl,
-        "enable_inline_rope": enable_inline_rope,
-    }
-    uri, source_paths = gen_customize_batch_prefill_tvm_binding(**jit_args)
-    object_files = _compile_flashinfer_kernels(uri, source_paths, target, 
num_threads)
-    modules = _load_flashinfer_modules(object_files)
-    return modules
+        dtype_q=torch_dtype_q,
+        dtype_kv=torch_dtype_kv,
+        dtype_o=torch_dtype_o,
+        idtype=torch.int32,
+        head_dim_qk=qk_head_dim,
+        head_dim_vo=v_head_dim,
+        pos_encoding_mode=int(enable_inline_rope),
+        additional_tensor_names=[],
+        additional_tensor_dtypes=[],
+        additional_scalar_names=["sm_scale", "rope_rcp_scale", 
"rope_rcp_theta"],
+        additional_scalar_dtypes=["double", "double", "double"],
+        variant_name=variant_name,
+        variant_decl=variant_decl,
+    )
+    _rename_exported_func_names(jit_spec.sources, "batch_prefill")
+    if return_static_libs:
+        jit_spec.build(verbose=False)
+        return _load_flashinfer_modules(jit_spec.get_object_paths())
+    return [jit_spec.build_and_load()]
 
 
 def gen_flashinfer_decode_module(
@@ -323,8 +156,8 @@ def gen_flashinfer_decode_module(
     dtype_o: str,
     qk_head_dim: int,
     v_head_dim: int,
-    target: Target,
-    num_threads: int = 8,
+    enable_inline_rope: bool,
+    return_static_libs: bool = False,
 ) -> List[tvm.runtime.Module]:
     """Generate a FlashInfer module for decode.
 
@@ -340,10 +173,12 @@ def gen_flashinfer_decode_module(
         The head dimension of the query and key tensors.
     v_head_dim : int
         The head dimension of the value tensor.
-    target : Target
-        The target device to compile for.
-    num_threads : int
-        The number of threads to use for compilation.
+    enable_inline_rope : bool
+        Whether to enable inline rotary positional embedding.
+    return_static_libs : bool
+        Whether to return static library modules instead of compiled modules.
+        When it is False, it returns the loaded shared library that links all 
the object files.
+        When it is True, it returns the static libraries of each compiled 
object files.
 
     Returns
     -------
@@ -351,7 +186,7 @@ def gen_flashinfer_decode_module(
     """
     try:
         from flashinfer.jit import (  # pylint: disable=import-outside-toplevel
-            gen_customize_batch_decode_tvm_binding,
+            gen_customize_batch_decode_module,
         )
     except ImportError:
         raise ImportError(
@@ -366,29 +201,32 @@ def gen_flashinfer_decode_module(
     torch_dtype_q = getattr(torch, dtype_q)
     torch_dtype_kv = getattr(torch, dtype_kv)
     torch_dtype_o = getattr(torch, dtype_o)
-    jit_args = {
-        "uri": f"batch_decode_tvm_dtype_q_{dtype_q}_"
+    jit_spec = gen_customize_batch_decode_module(
+        uri=f"batch_decode_tvm_dtype_q_{dtype_q}_"
         + f"dtype_kv_{dtype_kv}_"
         + f"dtype_o_{dtype_o}_"
         + f"qk_head_dim_{qk_head_dim}_"
-        + f"v_head_dim_{v_head_dim}",
-        "dtype_q": torch_dtype_q,
-        "dtype_kv": torch_dtype_kv,
-        "dtype_o": torch_dtype_o,
-        "idtype": torch.int32,
-        "head_dim_qk": qk_head_dim,
-        "head_dim_vo": v_head_dim,
-        "additional_tensor_names": [],
-        "additional_tensor_dtypes": [],
-        "additional_scalar_names": ["sm_scale", "rope_rcp_scale", 
"rope_rcp_theta"],
-        "additional_scalar_dtypes": ["double", "double", "double"],
-        "variant_name": "DefaultAttention<false, false, false, false>",
-        "variant_decl": "#include <flashinfer/attention/variants.cuh>",
-    }
-    uri, source_paths = gen_customize_batch_decode_tvm_binding(**jit_args)
-    object_files = _compile_flashinfer_kernels(uri, source_paths, target, 
num_threads)
-    modules = _load_flashinfer_modules(object_files)
-    return modules
+        + f"v_head_dim_{v_head_dim}_"
+        + f"enable_inline_rope_{enable_inline_rope}",
+        dtype_q=torch_dtype_q,
+        dtype_kv=torch_dtype_kv,
+        dtype_o=torch_dtype_o,
+        idtype=torch.int32,
+        head_dim_qk=qk_head_dim,
+        head_dim_vo=v_head_dim,
+        pos_encoding_mode=int(enable_inline_rope),
+        additional_tensor_names=[],
+        additional_tensor_dtypes=[],
+        additional_scalar_names=["sm_scale", "rope_rcp_scale", 
"rope_rcp_theta"],
+        additional_scalar_dtypes=["double", "double", "double"],
+        variant_name="DefaultAttention<false, false, false, false>",
+        variant_decl="#include <flashinfer/attention/variants.cuh>",
+    )
+    _rename_exported_func_names(jit_spec.sources, "batch_decode")
+    if return_static_libs:
+        jit_spec.build(verbose=False)
+        return _load_flashinfer_modules(jit_spec.get_object_paths())
+    return [jit_spec.build_and_load()]
 
 
 def gen_flashinfer_mla_module(
@@ -397,8 +235,7 @@ def gen_flashinfer_mla_module(
     dtype_o: str,
     head_dim_ckv: int,
     head_dim_kpe: int,
-    target: Target,
-    num_threads: int = 8,
+    return_static_libs: bool = False,
 ) -> List[tvm.runtime.Module]:
     """Generate a FlashInfer module for MLA.
 
@@ -418,6 +255,10 @@ def gen_flashinfer_mla_module(
         The target device to compile for.
     num_threads : int
         The number of threads to use for compilation.
+    return_static_libs : bool
+        Whether to return static library modules instead of compiled modules.
+        When it is False, it returns the loaded shared library that links all 
the object files.
+        When it is True, it returns the static libraries of each compiled 
object files.
 
     Returns
     -------
@@ -425,7 +266,7 @@ def gen_flashinfer_mla_module(
     """
     try:
         from flashinfer.jit import (  # pylint: disable=import-outside-toplevel
-            gen_batch_mla_tvm_binding,
+            gen_batch_mla_module,
         )
     except ImportError:
         raise ImportError(
@@ -440,92 +281,36 @@ def gen_flashinfer_mla_module(
     torch_dtype_q = getattr(torch, dtype_q)
     torch_dtype_kv = getattr(torch, dtype_kv)
     torch_dtype_o = getattr(torch, dtype_o)
-    jit_args = {
-        "uri": f"batch_mla_tvm_dtype_q_{dtype_q}_"
-        + f"dtype_kv_{dtype_kv}_"
-        + f"dtype_o_{dtype_o}_"
-        + f"head_dim_ckv_{head_dim_ckv}_"
-        + f"head_dim_kpe_{head_dim_kpe}",
-        "dtype_q": torch_dtype_q,
-        "dtype_kv": torch_dtype_kv,
-        "dtype_o": torch_dtype_o,
-        "dtype_idx": torch.int32,
-        "head_dim_ckv": head_dim_ckv,
-        "head_dim_kpe": head_dim_kpe,
-    }
-    uri, source_paths = gen_batch_mla_tvm_binding(**jit_args)
-    object_files = _compile_flashinfer_kernels(uri, source_paths, target, 
num_threads)
-    modules = _load_flashinfer_modules(object_files)
-    return modules
-
-
-def gen_sampling_module(target: Target, num_threads: int = 8):
-    """
-    Generate a FlashInfer module for sampling kernels.
-
-    Parameters
-    ----------
-    target : Target
-        The target device for which the module will be compiled.
-    num_threads : int, optional
-        The number of threads to use during compilation (default is 8).
-
-    Returns
-    -------
-    List[tvm.runtime.Module]
-        A list of compiled static library modules for the FlashInfer sampling 
kernels.
-    """
-    try:
-        from flashinfer.jit import (  # pylint: disable=import-outside-toplevel
-            gen_sampling_tvm_binding,
-        )
-    except ImportError:
-        raise ImportError(
-            "FlashInfer is not installed. Please follow instructions "
-            "in https://docs.flashinfer.ai to install FlashInfer."
-        )
-    uri, source_paths = gen_sampling_tvm_binding(uri="sampling")
-    object_files = _compile_flashinfer_kernels(uri, source_paths, target, 
num_threads)
-    modules = _load_flashinfer_modules(object_files)
-    return modules
+    jit_spec = gen_batch_mla_module(
+        backend="fa2",
+        dtype_q=torch_dtype_q,
+        dtype_kv=torch_dtype_kv,
+        dtype_o=torch_dtype_o,
+        dtype_idx=torch.int32,
+        head_dim_ckv=head_dim_ckv,
+        head_dim_kpe=head_dim_kpe,
+        use_profiler=False,
+    )
+    _rename_exported_func_names(jit_spec.sources, "batch_mla")
+    if return_static_libs:
+        jit_spec.build(verbose=False)
+        return _load_flashinfer_modules(jit_spec.get_object_paths())
+    return [jit_spec.build_and_load()]
 
 
 def gen_grouped_gemm_module(
-    dtype_a: str,
-    dtype_b: str,
-    dtype_out: str,
-    scale_granularity_m: int,
-    scale_granularity_n: int,
-    scale_granularity_k: int,
-    scale_major_mode: str,
-    mma_sm: int,
-    target: Target,
-    num_threads: int = 8,
+    target: Target, return_static_libs: bool = False
 ) -> List[tvm.runtime.Module]:
     """Generate a FlashInfer module for FP8 grouped GEMM.
 
     Parameters
     ----------
-    dtype_a : str
-        The data type of matrix A (e.g., "float8_e4m3fn").
-    dtype_b : str
-        The data type of matrix B (e.g., "float8_e4m3fn").
-    dtype_out : str
-        The data type of the output matrix (e.g., "bfloat16").
-    scale_granularity_m : int
-        The scaling granularity in the M dimension.
-    scale_granularity_n : int
-        The scaling granularity in the N dimension.
-    scale_granularity_k : int
-        The scaling granularity in the K dimension.
-    scale_major_mode : str
-        The scale storage mode ("K" or "MN").
-    mma_sm : int
-        The MMA scheduling mode (1 or 2).
     target : Target
         The target device to compile for.
-    num_threads : int
-        The number of threads to use for compilation.
+    return_static_libs : bool
+        Whether to return static library modules instead of compiled modules.
+        When it is False, it returns the loaded shared library that links all 
the object files.
+        When it is True, it returns the static libraries of each compiled 
object files.
 
     Returns
     -------
@@ -537,48 +322,24 @@ def gen_grouped_gemm_module(
     when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k), 
m_indptr: (batch_size, )
     requires all m in m_indptr to be multiple of 4
     """
+    # NOTE: This function is still under development,
+    # and we currently only support SM100 grouped gemm
     try:
-        from flashinfer.jit import (  # pylint: disable=import-outside-toplevel
-            gen_grouped_gemm_fp8_tvm_binding,
-            get_grouped_gemm_fp8_uri,
+        from flashinfer.gemm import (  # pylint: 
disable=import-outside-toplevel
+            gen_gemm_sm100_module,
         )
     except ImportError:
         raise ImportError(
             "FlashInfer is not installed. Please follow instructions "
             "in https://docs.flashinfer.ai to install FlashInfer."
         )
-    try:
-        import torch  # pylint: disable=import-outside-toplevel
-    except ImportError:
-        raise ImportError("PyTorch is not installed. Please install PyTorch to 
use FlashInfer.")
-
-    torch_dtype_a = getattr(torch, dtype_a)
-    torch_dtype_b = getattr(torch, dtype_b)
-    torch_dtype_out = getattr(torch, dtype_out)
-
-    uri = get_grouped_gemm_fp8_uri(
-        dtype_a=torch_dtype_a,
-        dtype_b=torch_dtype_b,
-        dtype_out=torch_dtype_out,
-        scale_granularity_m=scale_granularity_m,
-        scale_granularity_n=scale_granularity_n,
-        scale_granularity_k=scale_granularity_k,
-        scale_major_mode=scale_major_mode,
-        mma_sm=mma_sm,
-    )
 
-    uri, source_paths = gen_grouped_gemm_fp8_tvm_binding(
-        uri=uri,
-        dtype_a=torch_dtype_a,
-        dtype_b=torch_dtype_b,
-        dtype_out=torch_dtype_out,
-        scale_granularity_m=scale_granularity_m,
-        scale_granularity_n=scale_granularity_n,
-        scale_granularity_k=scale_granularity_k,
-        scale_major_mode=scale_major_mode,
-        mma_sm=mma_sm,
-    )
-
-    object_files = _compile_flashinfer_kernels(uri, source_paths, target, 
num_threads)
-    modules = _load_flashinfer_modules(object_files)
-    return modules
+    compute_version = 
"".join(tvm.contrib.nvcc.get_target_compute_version(target).split("."))
+    if compute_version == "100":
+        jit_spec = gen_gemm_sm100_module()
+    else:
+        raise ValueError(f"Unsupported compute version: {compute_version}")
+    if return_static_libs:
+        jit_spec.build(verbose=False)
+        return _load_flashinfer_modules(jit_spec.get_object_paths())
+    return [jit_spec.build_and_load()]
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index e6e171da99..e94d5c4295 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -371,8 +371,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
         enable_disaggregation : bool
             Whether to enable disaggregation in the KV cache.
         """
-        if rope_mode == RopeMode.INLINE:
-            assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not 
support partial rotary dim."
+        assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support 
inline mode."
 
         attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else 
attn_kind
         if attn_kind_single == "mha_sliding":
@@ -383,8 +382,8 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             dtype_o=dtype,
             qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else 
mla_original_qk_head_dim),
             v_head_dim=(v_head_dim if attn_kind_single == "mha" else 
mla_original_v_head_dim),
-            target=target,
-            enable_inline_rope=rope_mode == RopeMode.INLINE,
+            enable_inline_rope=False,
+            return_static_libs=True,
         )
         flashinfer_decode_mods = (
             rx.backend.cuda.flashinfer.gen_flashinfer_decode_module(
@@ -393,7 +392,8 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
                 dtype_o=dtype,
                 qk_head_dim=qk_head_dim,
                 v_head_dim=v_head_dim,
-                target=target,
+                enable_inline_rope=False,
+                return_static_libs=True,
             )
             if attn_kind_single == "mha"
             else []
@@ -405,7 +405,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
                 dtype_o=dtype,
                 head_dim_ckv=v_head_dim,
                 head_dim_kpe=qk_head_dim - v_head_dim,
-                target=target,
+                return_static_libs=True,
             )
             if attn_kind_single == "mla"
             else []
@@ -417,8 +417,8 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
         bb = rx.BlockBuilder.current()
         mha_functions = (
             [
-                rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_prefill_with_paged_kv_cache_run"), 
rx.ExternFunc("batch_prefill_with_kv_cache_plan")]),
-                rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_decode_with_paged_kv_cache_run"), 
rx.ExternFunc("batch_decode_with_paged_kv_cache_plan")]),
+                rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_prefill_paged_run"), rx.ExternFunc("batch_prefill_plan")]),
+                rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_decode_run"), rx.ExternFunc("batch_decode_plan")]),
                 rx.Tuple([rx.StringImm("tir"), 
bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, 
qk_head_dim, dtype, True, rope_scaling, target), 
"tir_attention_prefill_sliding_window")]),
                 rx.Tuple([rx.StringImm("tir"), 
bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, 
qk_head_dim, dtype, True, rope_scaling, target), 
"tir_attention_decode_sliding_window")]),
                 rx.Tuple([rx.StringImm("tir"), 
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, 
num_attention_heads, qk_head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
@@ -427,7 +427,8 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             if attn_kind_single == "mha"
             else [rx.Tuple([]) for _ in range(6)]
         )
-        mla_function = rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_mla_paged_attention_run"), 
rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" 
else [])
+        ragged_prefill_function = rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_prefill_ragged_run"), 
rx.ExternFunc("batch_prefill_plan")]) if attn_kind_single == "mha" else 
rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_prefill_ragged_run"), rx.ExternFunc("batch_prefill_plan"), 
rx.PrimValue(mla_original_qk_head_dim), rx.PrimValue(mla_original_v_head_dim)])
+        mla_function = rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_mla_run"), rx.ExternFunc("batch_mla_plan")] if 
attn_kind_single == "mla" else [])
         attn_merge_functions = [
             bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, 
dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
         ]
@@ -463,7 +464,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             rx.op.zeros((), dtype),
             bb.add_func(_kv_cache_transpose_append(num_key_value_heads, 
qk_head_dim, dtype), "kv_cache_transpose_append"),
             bb.add_func(_kv_cache_transpose_append_mla(qk_head_dim, dtype), 
"kv_cache_transpose_append_mla"),
-            rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_prefill_with_ragged_kv_cache_run"), 
rx.ExternFunc("batch_prefill_with_kv_cache_plan")]),
+            ragged_prefill_function,
             *mha_functions,
             mla_function,
             rx.Tuple(attn_merge_functions),
diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc
index 3b37d9810b..13e151ecd2 100644
--- a/src/runtime/vm/attn_backend.cc
+++ b/src/runtime/vm/attn_backend.cc
@@ -59,11 +59,18 @@ std::unique_ptr<RaggedPrefillFunc> 
ConvertRaggedPrefillFunc(ffi::Array<ffi::Any>
     return std::make_unique<TIRRaggedPrefillFunc>(std::move(attn_func), 
attn_kind);
   }
   if (backend_name == "flashinfer") {
-    CHECK_EQ(args.size(), 3);
+    CHECK(args.size() == 3 || args.size() == 5);
     ffi::Function attn_func = args[1].cast<ffi::Function>();
     ffi::Function plan_func = args[2].cast<ffi::Function>();
+    int64_t qk_head_dim_override = -1;
+    int64_t v_head_dim_override = -1;
+    if (args.size() == 5) {
+      qk_head_dim_override = args[3].cast<int64_t>();
+      v_head_dim_override = args[4].cast<int64_t>();
+    }
     return std::make_unique<FlashInferRaggedPrefillFunc>(std::move(attn_func), 
std::move(plan_func),
-                                                         attn_kind);
+                                                         attn_kind, 
qk_head_dim_override,
+                                                         v_head_dim_override);
   }
   LOG(FATAL) << "Cannot reach here";
   throw;
diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h
index ea5f49c6c0..1fd22a97ab 100644
--- a/src/runtime/vm/attn_backend.h
+++ b/src/runtime/vm/attn_backend.h
@@ -27,6 +27,7 @@
 
 #include <tvm/ffi/container/array.h>
 #include <tvm/ffi/function.h>
+#include <tvm/runtime/device_api.h>
 #include <tvm/runtime/int_tuple.h>
 #include <tvm/runtime/logging.h>
 
@@ -57,6 +58,22 @@ class AttnBackendFunc {
   virtual ~AttnBackendFunc() = default;
 
  protected:
+  // helper allocator class for creating strided view of a Tensor
+  // that applies byte offset to the original data pointer
+  class ViewBasedAlloc {
+   public:
+    explicit ViewBasedAlloc(Tensor source) : source_(source) {}
+    void AllocData(DLTensor* tensor, int64_t* strides, int64_t 
extra_byte_offset) {
+      tensor->data = static_cast<char*>(source_->data) + extra_byte_offset;
+      tensor->strides = strides;
+    }
+
+    void FreeData(DLTensor* tensor) {}
+
+   private:
+    Tensor source_;
+  };
+
   ffi::Function attn_func_;
 
  public:
@@ -133,16 +150,34 @@ class FlashInferPagedPrefillFunc : public 
PagedPrefillFunc {
            Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double 
rotary_scale,
            double rotary_theta, double sm_scale, Tensor attn_output, Tensor 
attn_lse,
            TVMStreamHandle compute_stream) final {
+    Device device = q->device;
+    TVMStreamHandle original_stream = 
DeviceAPI::Get(device)->GetCurrentStream(device);
+    DeviceAPI::Get(device)->SetStream(device, compute_stream);
     auto [float_workspace_buffer, int_workspace_buffer, 
page_locked_int_workspace_buffer,
           plan_info_vec] = cached_buffers_[depth];
     double rope_rcp_scale = 1 / rotary_scale;
     double rope_rcp_theta = 1 / rotary_theta;
-    attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, 
pages, qo_indptr,
-               page_indptr, page_indices, length_info, q_rope_position, 
k_rope_pos_offset,
-               attn_output, attn_lse, 
/*mask_mode_code=*/static_cast<int64_t>(causal),
-               /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode == 
RoPEMode::kInline),
-               /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, 
/*rope_rcp_scale=*/rope_rcp_scale,
-               /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+
+    ICHECK_EQ(pages.ndim(), 5);
+    int H = pages->shape[2];
+    int N = pages->shape[3];
+    int D = pages->shape[4];
+    CHECK(pages.IsContiguous());
+    std::vector<int64_t> pages_k_v_shape = {pages->shape[0], H, N, D};
+    std::vector<int64_t> pages_k_v_strides = {2 * H * N * D, N * D, D, 1};
+    Tensor pages_k =
+        Tensor::FromNDAlloc(ViewBasedAlloc(pages), 
ffi::Shape(pages_k_v_shape), pages->dtype,
+                            pages->device, pages_k_v_strides.data(), 
pages->byte_offset);
+    Tensor pages_v = Tensor::FromNDAlloc(
+        ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, 
pages->device,
+        pages_k_v_strides.data(), pages->byte_offset + (H * N * D) * 
pages.DataType().bytes());
+
+    attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, 
pages_k, pages_v,
+               qo_indptr, page_indptr, page_indices, length_info, attn_output, 
attn_lse,
+               /*mask_mode_code=*/static_cast<int64_t>(causal), 
/*layout(HND)=*/1,
+               /*window_left=*/-1, /*enable_pdl=*/false, sm_scale,
+               /*rope_rcp_scale=*/rope_rcp_scale, 
/*rope_rcp_theta=*/rope_rcp_theta);
+    DeviceAPI::Get(device)->SetStream(device, original_stream);
   }
 
   void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor 
page_indptr,
@@ -150,9 +185,43 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc 
{
            Tensor attn_output, Tensor attn_lse, TVMStreamHandle 
compute_stream) final {
     auto [float_workspace_buffer, int_workspace_buffer, 
page_locked_int_workspace_buffer,
           plan_info_vec] = cached_buffers_[depth];
-    attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, 
pages, page_indices,
-               attn_output, attn_lse, 
/*mask_mode_code=*/static_cast<int64_t>(causal),
-               /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1], 
sm_scale, compute_stream);
+    Device device = q->device;
+    TVMStreamHandle original_stream = 
DeviceAPI::Get(device)->GetCurrentStream(device);
+    DeviceAPI::Get(device)->SetStream(device, compute_stream);
+    ICHECK_NE(qk_head_dim_, -1);
+    ICHECK_NE(v_head_dim_, -1);
+    int64_t H = q->shape[1];
+    int64_t page_size = pages->shape[1];
+    int64_t rope_head_dim = qk_head_dim_ - v_head_dim_;
+    int64_t nope_head_dim = q->shape[2] - rope_head_dim;
+
+    // Split q into q_nope and q_pe
+    CHECK(q.IsContiguous());
+    std::vector<int64_t> q_nope_shape = {q->shape[0], H, nope_head_dim};
+    std::vector<int64_t> q_pe_shape = {q->shape[0], H, rope_head_dim};
+    std::vector<int64_t> q_strides = {H * q->shape[2], q->shape[2], 1};
+    Tensor q_nope = Tensor::FromNDAlloc(ViewBasedAlloc(q), 
ffi::Shape(q_nope_shape), q->dtype,
+                                        q->device, q_strides.data(), 
q->byte_offset);
+    Tensor q_pe = Tensor::FromNDAlloc(ViewBasedAlloc(q), 
ffi::Shape(q_pe_shape), q->dtype,
+                                      q->device, q_strides.data(),
+                                      q->byte_offset + nope_head_dim * 
q.DataType().bytes());
+    // Split pages into kv_nope and kv_pe
+    CHECK(pages.IsContiguous());
+    std::vector<int64_t> kv_nope_shape = {pages->shape[0], page_size, 
nope_head_dim};
+    std::vector<int64_t> kv_pe_shape = {pages->shape[0], page_size, 
rope_head_dim};
+    std::vector<int64_t> kv_strides = {page_size * pages->shape[2], 
pages->shape[2], 1};
+    Tensor kv_nope =
+        Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(kv_nope_shape), 
pages->dtype,
+                            pages->device, kv_strides.data(), 
pages->byte_offset);
+    Tensor kv_pe = Tensor::FromNDAlloc(
+        ViewBasedAlloc(pages), ffi::Shape(kv_pe_shape), pages->dtype, 
pages->device,
+        kv_strides.data(), pages->byte_offset + nope_head_dim * 
pages.DataType().bytes());
+
+    attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, 
q_nope, q_pe, kv_nope,
+               kv_pe, page_indices, attn_output, attn_lse,
+               /*mask_mode_code=*/static_cast<int64_t>(causal),
+               /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1], 
sm_scale);
+    DeviceAPI::Get(device)->SetStream(device, original_stream);
   }
 
   void BeginForward(int depth, Tensor float_workspace_buffer, Tensor 
int_workspace_buffer,
@@ -161,31 +230,37 @@ class FlashInferPagedPrefillFunc : public 
PagedPrefillFunc {
                     int64_t batch_size, int64_t total_qo_len, int64_t 
page_size,
                     int64_t num_qo_heads, int64_t num_kv_heads, int64_t 
qk_head_dim,
                     int64_t v_head_dim, bool causal, TVMStreamHandle 
copy_stream) final {
-    std::vector<int64_t> kv_len;
-    kv_len.reserve(batch_size);
+    Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32), 
Device{kDLCPU, 0});
+    int32_t* kv_len_arr_data = static_cast<int32_t*>(kv_len_arr.data_ptr());
     for (int i = 0; i < static_cast<int>(batch_size); ++i) {
-      kv_len.push_back((*page_indptr)[i + 1] != (*page_indptr)[i]
-                           ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * 
page_size +
-                                 (*last_page_len)[i]
-                           : 0);
+      kv_len_arr_data[i] =
+          (*page_indptr)[i + 1] != (*page_indptr)[i]
+              ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size + 
(*last_page_len)[i]
+              : 0;
     }
-    IntTuple plan_info_vec;
+    qk_head_dim_ = qk_head_dim;
+    v_head_dim_ = v_head_dim;
+    ffi::Array<int64_t> plan_info_vec;
+    Device device = float_workspace_buffer->device;
+    TVMStreamHandle original_stream = 
DeviceAPI::Get(device)->GetCurrentStream(device);
+    DeviceAPI::Get(device)->SetStream(device, copy_stream);
     if (attn_kind == AttnKind::kMHA) {
       // Todo(tvm-team): enable cuda graph
       plan_info_vec =
           plan_func_(float_workspace_buffer, int_workspace_buffer, 
page_locked_int_workspace_buffer,
-                     qo_indptr->as_tensor(), page_indptr->as_tensor(), 
IntTuple(std::move(kv_len)),
-                     total_qo_len, batch_size, num_qo_heads, num_kv_heads, 
page_size,
+                     qo_indptr->as_tensor(), page_indptr->as_tensor(), 
kv_len_arr, total_qo_len,
+                     batch_size, num_qo_heads, num_kv_heads, page_size,
                      /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, 
causal,
-                     /*window_left=*/-1, copy_stream)
-              .cast<IntTuple>();
+                     /*window_left=*/-1, /*fixed_split_size=*/-1, 
/*disable_split_kv=*/false)
+              .cast<ffi::Array<int64_t>>();
     } else if (attn_kind == AttnKind::kMLA) {
       plan_info_vec =
           plan_func_(float_workspace_buffer, int_workspace_buffer, 
page_locked_int_workspace_buffer,
-                     qo_indptr->as_tensor(), page_indptr->as_tensor(), 
IntTuple(std::move(kv_len)),
-                     num_qo_heads, v_head_dim, causal, copy_stream)
-              .cast<IntTuple>();
+                     qo_indptr->as_tensor(), page_indptr->as_tensor(), 
kv_len_arr, num_qo_heads,
+                     v_head_dim, causal)
+              .cast<ffi::Array<int64_t>>();
     }
+    DeviceAPI::Get(device)->SetStream(device, original_stream);
 
     if (cached_buffers_.size() <= static_cast<size_t>(depth)) {
       cached_buffers_.resize(depth + 1);
@@ -196,8 +271,10 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc 
{
   }
 
  private:
+  int64_t qk_head_dim_ = -1;
+  int64_t v_head_dim_ = -1;
   ffi::Function plan_func_;
-  std::vector<std::tuple<Tensor, Tensor, Tensor, IntTuple>> cached_buffers_;
+  std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Array<int64_t>>> 
cached_buffers_;
 };
 
 /*! \brief The ragged prefill attention function base class. */
@@ -244,23 +321,30 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc {
 class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc {
  public:
   explicit FlashInferRaggedPrefillFunc(ffi::Function attn_func, ffi::Function 
plan_func,
-                                       AttnKind attn_kind)
+                                       AttnKind attn_kind, int64_t 
qk_head_dim_override,
+                                       int64_t v_head_dim_override)
       : RaggedPrefillFunc(std::move(attn_func), attn_kind, 
AttnBackendKind::kFlashInfer),
+        qk_head_dim_override_(qk_head_dim_override),
+        v_head_dim_override_(v_head_dim_override),
         plan_func_(std::move(plan_func)) {}
 
   void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, 
Tensor q_rope_position,
            Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double 
rotary_scale,
            double rotary_theta, double sm_scale, Tensor attn_output, Tensor 
attn_lse,
            TVMStreamHandle compute_stream) final {
+    Device device = q->device;
+    TVMStreamHandle original_stream = 
DeviceAPI::Get(device)->GetCurrentStream(device);
+    DeviceAPI::Get(device)->SetStream(device, compute_stream);
     double rope_rcp_scale = 1 / rotary_scale;
     double rope_rcp_theta = 1 / rotary_theta;
     attn_func_(float_workspace_buffer_, int_workspace_buffer_, plan_info_vec_, 
q, k, v, qo_indptr,
-               kv_indptr, q_rope_position, k_rope_pos_offset, attn_output, 
attn_lse,
+               kv_indptr, attn_output, attn_lse,
                /*mask_mode_code=*/static_cast<int64_t>(causal),
-               /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode == 
RoPEMode::kInline),
-               /*layout(NHD)=*/0, /*window_left=*/-1, sm_scale,
+               /*layout(NHD)=*/0, /*window_left=*/-1,
+               /*enable_pdl=*/false, sm_scale,
                /*rope_rcp_scale=*/rope_rcp_scale,
-               /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+               /*rope_rcp_theta=*/rope_rcp_theta);
+    DeviceAPI::Get(device)->SetStream(device, original_stream);
   }
 
   void BeginForward(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
@@ -268,30 +352,42 @@ class FlashInferRaggedPrefillFunc : public 
RaggedPrefillFunc {
                     HostMemoryVector* kv_indptr, int64_t batch_size, int64_t 
total_qo_len,
                     int64_t num_qo_heads, int64_t num_kv_heads, int64_t 
qk_head_dim,
                     int64_t v_head_dim, bool causal, TVMStreamHandle 
copy_stream) final {
-    std::vector<int64_t> kv_len;
-    kv_len.reserve(batch_size);
+    Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32), 
Device{kDLCPU, 0});
+    int32_t* kv_len_arr_data = static_cast<int32_t*>(kv_len_arr.data_ptr());
     for (int i = 0; i < static_cast<int>(batch_size); ++i) {
-      kv_len.push_back((*kv_indptr)[i + 1] - (*kv_indptr)[i]);
+      kv_len_arr_data[i] = (*kv_indptr)[i + 1] - (*kv_indptr)[i];
+    }
+    if (qk_head_dim_override_ != -1) {
+      qk_head_dim = qk_head_dim_override_;
+    }
+    if (v_head_dim_override_ != -1) {
+      v_head_dim = v_head_dim_override_;
     }
     // Todo(tvm-team): enable cuda graph
     float_workspace_buffer_ = float_workspace_buffer;
     int_workspace_buffer_ = int_workspace_buffer;
     page_locked_int_workspace_buffer_ = page_locked_int_workspace_buffer;
+    Device device = float_workspace_buffer->device;
+    TVMStreamHandle original_stream = 
DeviceAPI::Get(device)->GetCurrentStream(device);
+    DeviceAPI::Get(device)->SetStream(device, copy_stream);
     plan_info_vec_ =
         plan_func_(float_workspace_buffer, int_workspace_buffer, 
page_locked_int_workspace_buffer,
-                   qo_indptr->as_tensor(), kv_indptr->as_tensor(), 
IntTuple(std::move(kv_len)),
-                   total_qo_len, batch_size, num_qo_heads, num_kv_heads, 
/*page_size=*/1,
+                   qo_indptr->as_tensor(), kv_indptr->as_tensor(), kv_len_arr, 
total_qo_len,
+                   batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1,
                    /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, 
causal,
-                   /*window_left=*/-1, copy_stream)
-            .cast<IntTuple>();
+                   /*window_left=*/-1, /*fixed_split_size=*/-1, 
/*disable_split_kv=*/false)
+            .cast<ffi::Array<int64_t>>();
+    DeviceAPI::Get(device)->SetStream(device, original_stream);
   }
 
  private:
+  int64_t qk_head_dim_override_;
+  int64_t v_head_dim_override_;
   ffi::Function plan_func_;
   Tensor float_workspace_buffer_;
   Tensor int_workspace_buffer_;
   Tensor page_locked_int_workspace_buffer_;
-  IntTuple plan_info_vec_;
+  ffi::Array<int64_t> plan_info_vec_;
 };
 
 /*! \brief The paged decode attention function base class. */
@@ -359,15 +455,33 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
            Tensor length_info, Tensor k_rope_pos_offset, Tensor 
q_rope_position, RoPEMode rope_mode,
            double rotary_scale, double rotary_theta, double sm_scale, Tensor 
attn_output,
            Tensor attn_lse, TVMStreamHandle compute_stream) final {
+    Device device = q->device;
+    TVMStreamHandle original_stream = 
DeviceAPI::Get(device)->GetCurrentStream(device);
+    DeviceAPI::Get(device)->SetStream(device, compute_stream);
     auto [float_workspace_buffer, int_workspace_buffer, 
page_locked_int_workspace_buffer,
           plan_info_vec] = cached_buffers_[depth];
     double rope_rcp_scale = 1 / rotary_scale;
     double rope_rcp_theta = 1 / rotary_theta;
-    attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, 
pages, page_indptr,
-               page_indices, length_info, q_rope_position, k_rope_pos_offset, 
attn_output, attn_lse,
-               /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode == 
RoPEMode::kInline),
-               /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, 
/*rope_rcp_scale=*/rope_rcp_scale,
-               /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+
+    ICHECK_EQ(pages.ndim(), 5);
+    int H = pages->shape[2];
+    int N = pages->shape[3];
+    int D = pages->shape[4];
+    CHECK(pages.IsContiguous());
+    std::vector<int64_t> pages_k_v_shape = {pages->shape[0], H, N, D};
+    std::vector<int64_t> pages_k_v_strides = {2 * H * N * D, N * D, D, 1};
+    Tensor pages_k =
+        Tensor::FromNDAlloc(ViewBasedAlloc(pages), 
ffi::Shape(pages_k_v_shape), pages->dtype,
+                            pages->device, pages_k_v_strides.data(), 
pages->byte_offset);
+    Tensor pages_v = Tensor::FromNDAlloc(
+        ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, 
pages->device,
+        pages_k_v_strides.data(), pages->byte_offset + (H * N * D) * 
pages.DataType().bytes());
+
+    attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, 
pages_k, pages_v,
+               page_indptr, page_indices, length_info, attn_output, attn_lse,
+               /*layout(HND)=*/1, /*window_left=*/-1, /*enable_pdl=*/false, 
sm_scale,
+               /*rope_rcp_scale=*/rope_rcp_scale, 
/*rope_rcp_theta=*/rope_rcp_theta);
+    DeviceAPI::Get(device)->SetStream(device, original_stream);
   }
 
   void BeginForward(int depth, Tensor float_workspace_buffer, Tensor 
int_workspace_buffer,
@@ -377,13 +491,18 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
                     RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype,
                     TVMStreamHandle copy_stream) final {
     // Todo(tvm-team): enable cuda graph
-    IntTuple plan_info_vec =
+    Tensor empty_qkv_data = Tensor::Empty({1}, q_dtype, Device{kDLCPU, 0});
+    Device device = float_workspace_buffer->device;
+    TVMStreamHandle original_stream = 
DeviceAPI::Get(device)->GetCurrentStream(device);
+    DeviceAPI::Get(device)->SetStream(device, copy_stream);
+    ffi::Array<int64_t> plan_info_vec =
         plan_func_(float_workspace_buffer, int_workspace_buffer, 
page_locked_int_workspace_buffer,
                    page_indptr->as_tensor(), batch_size, num_qo_heads, 
num_kv_heads, page_size,
                    /*enable_cuda_graph=*/false,
-                   static_cast<int64_t>(rope_mode == RoPEMode::kInline),
-                   /*window_left=*/-1, qk_head_dim, v_head_dim, q_dtype, 
kv_dtype, copy_stream)
-            .cast<IntTuple>();
+                   /*window_left=*/-1, /*logits_soft_cap=*/0.0, qk_head_dim, 
v_head_dim,
+                   empty_qkv_data, empty_qkv_data)
+            .cast<ffi::Array<int64_t>>();
+    DeviceAPI::Get(device)->SetStream(device, original_stream);
 
     if (cached_buffers_.size() <= static_cast<size_t>(depth)) {
       cached_buffers_.resize(depth + 1);
@@ -395,7 +514,7 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
 
  private:
   ffi::Function plan_func_;
-  std::vector<std::tuple<Tensor, Tensor, Tensor, IntTuple>> cached_buffers_;
+  std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Array<int64_t>>> 
cached_buffers_;
 };
 
 /*! \brief The paged prefill with tree mask attention function base class. */
diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h
index 09557a8f0a..1c695a10e2 100644
--- a/src/runtime/vm/attn_utils.h
+++ b/src/runtime/vm/attn_utils.h
@@ -860,8 +860,9 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
                 sliding_window_offset->data(), n_elem * elem_byte_size_);
     std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ 
+ 2 * n_elem,
                 sink_size->data(), n_elem * elem_byte_size_);
-    Tensor view = merged_attn_aux_data_device_.CreateView(
-        {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_);
+    Tensor view =
+        Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), 
ffi::Shape({3, n_elem}),
+                            dtype_aux_, device_, attn_aux_data_copy_offset_ * 
elem_byte_size_);
     attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem);
     return view;
   }
@@ -895,8 +896,9 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
                 src_data->data(), n_elem * elem_byte_size_);
     std::memcpy(merged_compact_kv_aux_data_host_.data() + 
compact_kv_aux_data_copy_offset_ + n_elem,
                 dst_data->data(), n_elem * elem_byte_size_);
-    Tensor view = merged_compact_kv_aux_data_device_.CreateView(
-        {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * 
elem_byte_size_);
+    Tensor view = 
Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_),
+                                      ffi::Shape({2, n_elem}), dtype_aux_, 
device_,
+                                      compact_kv_aux_data_copy_offset_ * 
elem_byte_size_);
     compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem);
     return view;
   }
@@ -919,6 +921,20 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
   }
 
  private:
+  // helper allocator class that applies byte offset to the original data 
pointer
+  class ViewHelper {
+   public:
+    explicit ViewHelper(Tensor source) : source_(source) {}
+    void AllocData(DLTensor* tensor, int64_t extra_byte_offset) {
+      tensor->data = static_cast<char*>(source_->data) + extra_byte_offset;
+    }
+
+    void FreeData(DLTensor* tensor) {}
+
+   private:
+    Tensor source_;
+  };
+
   /*!
    * \brief Calculate the start element offsets of the auxiliary arrays in the 
local cache.
    * \return Return the local cache size (total number of elements in the 
local cache).
@@ -990,8 +1006,9 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     int64_t n_elem = data->size();
     std::memcpy(merged_attn_aux_data_host_.data() + 
attn_aux_data_copy_offset_, data->data(),
                 n_elem * elem_byte_size_);
-    Tensor view = merged_attn_aux_data_device_.CreateView(
-        {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_);
+    Tensor view =
+        Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), 
ffi::Shape({n_elem}),
+                            dtype_aux_, device_, attn_aux_data_copy_offset_ * 
elem_byte_size_);
     attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
     return view;
   }
@@ -1000,8 +1017,9 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     int64_t n_elem = data->size();
     std::memcpy(merged_compact_kv_aux_data_host_.data() + 
compact_kv_aux_data_copy_offset_,
                 data->data(), n_elem * elem_byte_size_);
-    Tensor view = merged_compact_kv_aux_data_device_.CreateView(
-        {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * 
elem_byte_size_);
+    Tensor view = 
Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_),
+                                      ffi::Shape({n_elem}), dtype_aux_, 
device_,
+                                      compact_kv_aux_data_copy_offset_ * 
elem_byte_size_);
     compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
     return view;
   }
diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc
index 0f3f568661..4fb3cd69d6 100644
--- a/src/runtime/vm/paged_kv_cache.cc
+++ b/src/runtime/vm/paged_kv_cache.cc
@@ -2052,7 +2052,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
               temp_float_attn_workspace_, temp_int_attn_workspace_[0],
               temp_int_pinned_attn_workspace_[0], 
&cur_append_lengths_indptr_host_,
               &cur_append_lengths_indptr_host_, cur_batch_size_,
-              cur_append_lengths_indptr_host_.back(), num_qo_heads_, 
num_kv_heads_, qk_head_dim_,
+              cur_append_lengths_indptr_host_.back(), num_qo_heads_, 
num_qo_heads_, qk_head_dim_,
               v_head_dim_, /*causal=*/true, copy_stream_);
         }
       }
diff --git a/tests/python/relax/test_group_gemm_flashinfer.py 
b/tests/python/relax/test_group_gemm_flashinfer.py
index 8333e4b2d6..da6fdacebd 100644
--- a/tests/python/relax/test_group_gemm_flashinfer.py
+++ b/tests/python/relax/test_group_gemm_flashinfer.py
@@ -18,14 +18,14 @@
 """Test for FlashInfer GroupedGemm TVM integration"""
 
 import math
+
 import numpy as np
 import pytest
 import torch
+
 import tvm
 import tvm.testing
 from tvm import relax
-from tvm.contrib import utils
-from tvm.relax.backend.cuda import flashinfer
 
 DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024
 fp8_dtype = "float8_e4m3fn"
@@ -389,36 +389,11 @@ def test_grouped_gemm_correctness(
     device = tvm.cuda(0)
     target = tvm.target.Target.from_device(device)
 
-    def _load_module(name: str, static_modules):
-        """Helper function to load compiled modules."""
-        assert len(static_modules) > 0
-        if len(static_modules) == 1:
-            return static_modules[0]
-        static_mod = static_modules[0]
-        for mod in static_modules[1:]:
-            static_mod.import_module(mod)
-        temp = tvm.contrib.utils.tempdir()
-        mod_path = temp.relpath(f"{name}.so")
-        static_mod.export_library(mod_path)
-        return tvm.runtime.load_module(mod_path)
-
     # Generate the module
-    modules = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(
-        dtype_a=dtype_a,
-        dtype_b=dtype_b,
-        dtype_out=dtype_out,
-        scale_granularity_m=scale_granularity_m,
-        scale_granularity_n=scale_granularity_n,
-        scale_granularity_k=scale_granularity_k,
-        scale_major_mode=scale_major_mode,
-        mma_sm=mma_sm,
-        target=target,
-        num_threads=4,
-    )
+    mod = 
relax.backend.cuda.flashinfer.gen_grouped_gemm_module(target=target)[0]
 
     # Load the module
-    mod = _load_module("flashinfer_grouped_gemm", modules)
-    grouped_gemm_fn = mod["grouped_gemm_fp8_run"]
+    grouped_gemm_fn = mod["group_gemm_fp8_nt_groupwise"]
 
     # Generate test data
     test_data = generate_test_data(
@@ -460,7 +435,11 @@ def test_grouped_gemm_correctness(
         test_data["m_indptr"],  # m_indptr
         test_data["n"],  # n (scalar)
         test_data["k"],  # k (scalar)
-        None,  # cuda_stream (use default stream)
+        scale_granularity_m,
+        scale_granularity_n,
+        scale_granularity_k,
+        scale_major_mode,
+        mma_sm,
     )
 
     # Compute reference result
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index dd29140e9b..4aae9dec59 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -23,7 +23,6 @@ import tvm
 import tvm.testing
 from tvm import dlight as dl
 from tvm import relax
-from tvm.contrib import utils
 from tvm.relax.frontend.nn.llm.kv_cache import (
     AttnKind,
     RopeMode,
@@ -78,7 +77,7 @@ fcopy_cache = None
 fcompact_copy = None
 
 
-def set_global_func():
+def set_global_func(rope_mode: RopeMode):
     global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fpopn
     global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, 
fdebug_get_kv
     global fattention_prefill, fattention_decode, fattention_prefill_ragged
@@ -98,48 +97,30 @@ def set_global_func():
     )
     fdebug_get_kv = 
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
 
-    def load_module(name: str, static_modules: List[tvm.runtime.Module]):
-        assert len(static_modules) > 0
-        if len(static_modules) == 1:
-            return static_modules[0]
-        static_mod = static_modules[0]
-        for mod in static_modules[1:]:
-            static_mod.import_module(mod)
-        temp = utils.tempdir()
-        mod_path = temp.relpath(f"{name}.so")
-        static_mod.export_library(mod_path)
-        return tvm.runtime.load_module(mod_path)
-
     target = tvm.target.Target.from_device(device)
-    flashinfer_prefill_mod = load_module(
-        "flashinfer_prefill",
-        relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
-            dtype_q=dtype,
-            dtype_kv=dtype,
-            dtype_o=dtype,
-            qk_head_dim=head_dim,
-            v_head_dim=head_dim,
-            target=target,
-        ),
-    )
-    flashinfer_decode_mod = load_module(
-        "flashinfer_decode",
-        relax.backend.cuda.flashinfer.gen_flashinfer_decode_module(
-            dtype_q=dtype,
-            dtype_kv=dtype,
-            dtype_o=dtype,
-            qk_head_dim=head_dim,
-            v_head_dim=head_dim,
-            target=target,
-        ),
-    )
-
-    fattention_prefill = 
flashinfer_prefill_mod["batch_prefill_with_paged_kv_cache_run"]
-    fattention_prefill_plan = 
flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"]
-    fattention_prefill_ragged = 
flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"]
-    fattention_prefill_ragged_plan = 
flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"]
-    fattention_decode = 
flashinfer_decode_mod["batch_decode_with_paged_kv_cache_run"]
-    fattention_decode_plan = 
flashinfer_decode_mod["batch_decode_with_paged_kv_cache_plan"]
+    flashinfer_prefill_mod = 
relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
+        dtype_q=dtype,
+        dtype_kv=dtype,
+        dtype_o=dtype,
+        qk_head_dim=head_dim,
+        v_head_dim=head_dim,
+        enable_inline_rope=rope_mode == RopeMode.INLINE,
+    )[0]
+    flashinfer_decode_mod = 
relax.backend.cuda.flashinfer.gen_flashinfer_decode_module(
+        dtype_q=dtype,
+        dtype_kv=dtype,
+        dtype_o=dtype,
+        qk_head_dim=head_dim,
+        v_head_dim=head_dim,
+        enable_inline_rope=rope_mode == RopeMode.INLINE,
+    )[0]
+
+    fattention_prefill = flashinfer_prefill_mod["batch_prefill_paged_run"]
+    fattention_prefill_plan = flashinfer_prefill_mod["batch_prefill_plan"]
+    fattention_prefill_ragged = 
flashinfer_prefill_mod["batch_prefill_ragged_run"]
+    fattention_prefill_ragged_plan = 
flashinfer_prefill_mod["batch_prefill_plan"]
+    fattention_decode = flashinfer_decode_mod["batch_decode_run"]
+    fattention_decode_plan = flashinfer_decode_mod["batch_decode_plan"]
 
     builts = []
     for tir_func in [
@@ -560,8 +541,8 @@ def 
test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
 
 
 if __name__ == "__main__":
-    set_global_func()
-    for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
+    for rope_mode in [RopeMode.NONE, RopeMode.NORMAL]:
+        set_global_func(rope_mode)
         cache = create_kv_cache(rope_mode)
         test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode))
         test_paged_attention_kv_cache_remove_sequence((cache, rope_mode))
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index e3de4944fe..cd76f9ce20 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -25,7 +25,6 @@ import tvm
 import tvm.testing
 from tvm import dlight as dl
 from tvm import relax
-from tvm.contrib import utils
 from tvm.relax.frontend.nn.llm.kv_cache import (
     AttnKind,
     RopeMode,
@@ -115,47 +114,27 @@ def set_global_func(dtype):
     fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
     fdebug_get_kv = 
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv_mla")
 
-    def load_module(name: str, static_modules: List[tvm.runtime.Module]):
-        assert len(static_modules) > 0
-        if len(static_modules) == 1:
-            return static_modules[0]
-        static_mod = static_modules[0]
-        for mod in static_modules[1:]:
-            static_mod.import_module(mod)
-        temp = utils.tempdir()
-        mod_path = temp.relpath(f"{name}.so")
-        static_mod.export_library(mod_path)
-        return tvm.runtime.load_module(mod_path)
-
     target = tvm.target.Target.from_device(device)
-    flashinfer_prefill_mod = load_module(
-        "flashinfer_prefill",
-        relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
-            dtype_q=dtype,
-            dtype_kv=dtype,
-            dtype_o=dtype,
-            qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
-            v_head_dim=v_head_dim,
-            target=target,
-            enable_inline_rope=False,
-        ),
-    )
-    flashinfer_mla_mod = load_module(
-        "flashinfer_mla",
-        relax.backend.cuda.flashinfer.gen_flashinfer_mla_module(
-            dtype_q=dtype,
-            dtype_kv=dtype,
-            dtype_o=dtype,
-            head_dim_ckv=kv_lora_rank,
-            head_dim_kpe=qk_rope_head_dim,
-            target=target,
-        ),
-    )
-
-    fattn_prefill_ragged = 
flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"]
-    fattn_prefill_ragged_plan = 
flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"]
-    fmla_prefill = flashinfer_mla_mod["batch_mla_paged_attention_run"]
-    fmla_prefill_plan = flashinfer_mla_mod["batch_mla_paged_attention_plan"]
+    flashinfer_prefill_mod = 
relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
+        dtype_q=dtype,
+        dtype_kv=dtype,
+        dtype_o=dtype,
+        qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
+        v_head_dim=v_head_dim,
+        enable_inline_rope=False,
+    )[0]
+    flashinfer_mla_mod = 
relax.backend.cuda.flashinfer.gen_flashinfer_mla_module(
+        dtype_q=dtype,
+        dtype_kv=dtype,
+        dtype_o=dtype,
+        head_dim_ckv=kv_lora_rank,
+        head_dim_kpe=qk_rope_head_dim,
+    )[0]
+
+    fattn_prefill_ragged = flashinfer_prefill_mod["batch_prefill_ragged_run"]
+    fattn_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_plan"]
+    fmla_prefill = flashinfer_mla_mod["batch_mla_run"]
+    fmla_prefill_plan = flashinfer_mla_mod["batch_mla_plan"]
 
     builts = []
     for tir_func in [
@@ -221,7 +200,13 @@ def create_kv_cache(dtype):
         tvm.runtime.empty((), dtype, device=device),
         None,  # f_transpose_append_mha
         ftranspose_append,
-        ["flashinfer", fattn_prefill_ragged, fattn_prefill_ragged_plan],  # 
fattn_prefill_ragged
+        [
+            "flashinfer",
+            fattn_prefill_ragged,
+            fattn_prefill_ragged_plan,
+            qk_nope_head_dim + qk_rope_head_dim,
+            v_head_dim,
+        ],  # fattn_prefill_ragged
         [],  # fattn_prefill
         [],  # fattn_decode
         [],  # fattn_prefill_sliding_window

Reply via email to