This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3015acd7e6 [CUDA] Update FlashInfer JIT integration (#18353)
3015acd7e6 is described below
commit 3015acd7e678cd97e4334835e130c09b315ab1a7
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Oct 1 20:53:31 2025 -0400
[CUDA] Update FlashInfer JIT integration (#18353)
Following recent JIT refactor in FlashInfer that uses TVM FFI as
the JIT interface, this PR updates the JIT integration of FlashInfer
in TVM.
Major changes:
* we leverage FlashInfer's `JitSpec.build_and_load` to compile all
the JIT-generated source files, and remove the compilation logic
in TVM.
* for efficient tensor buffer management and efficient pointer
calculation, we enforced all `byte_offset` fields of auxiliary tensors
in KV cache to be zeros. The byte offset is now directly applied to
the data pointers.
* we also add a new parameter to FlashInfer JIT that controls whether
returning a linked shared library, or a list of compiled object paths.
For unit tests, returning a shared library is convenient and preferred,
while for cases such as MLC model compilation, object files are needed
to serialize the compiled model.
---
python/tvm/relax/backend/cuda/flashinfer.py | 481 ++++++---------------
python/tvm/relax/frontend/nn/llm/kv_cache.py | 21 +-
src/runtime/vm/attn_backend.cc | 11 +-
src/runtime/vm/attn_backend.h | 213 +++++++--
src/runtime/vm/attn_utils.h | 34 +-
src/runtime/vm/paged_kv_cache.cc | 2 +-
tests/python/relax/test_group_gemm_flashinfer.py | 39 +-
..._builtin_paged_attention_kv_cache_flashinfer.py | 71 ++-
...ltin_paged_attention_kv_cache_mla_flashinfer.py | 69 ++-
9 files changed, 396 insertions(+), 545 deletions(-)
diff --git a/python/tvm/relax/backend/cuda/flashinfer.py
b/python/tvm/relax/backend/cuda/flashinfer.py
index 4e0fc3e854..6b5b1293ff 100644
--- a/python/tvm/relax/backend/cuda/flashinfer.py
+++ b/python/tvm/relax/backend/cuda/flashinfer.py
@@ -16,203 +16,36 @@
# under the License.
"""FlashInfer JIT compilation module for CUDA backend"""
-import hashlib
-import json
-import os
-import subprocess
-from concurrent.futures import ThreadPoolExecutor
+import re
from pathlib import Path
from typing import List
-import tvm_ffi
-
import tvm
from tvm.target import Target
-def _compile_flashinfer_kernels(
- name: str, source_paths: List[Path], target: Target, num_threads: int
-) -> List[Path]:
- from flashinfer.jit.env import ( # pylint: disable=import-outside-toplevel
- CUTLASS_INCLUDE_DIRS,
- FLASHINFER_CSRC_DIR,
- FLASHINFER_INCLUDE_DIR,
- FLASHINFER_JIT_DIR,
- FLASHINFER_TVM_BINDING_DIR,
- )
-
- # ------------------------------------------------------------------------
- # Caching Flow: create build_directory and compute cache hash.
- # ------------------------------------------------------------------------
- build_directory = FLASHINFER_JIT_DIR / name
- build_directory.mkdir(parents=True, exist_ok=True)
-
- def get_object_file_path(src: Path) -> Path:
- obj_name = src.stem + ".o"
- obj_path = build_directory / obj_name
- return obj_path
-
- # Compute latest modification time among all source files
- latest_src_mtime = max(src.stat().st_mtime for src in source_paths)
+def _rename_exported_func_names(source_paths: List[Path], prefix: str):
+ """Rename the ffi-exported function names in the source files to the given
prefix."""
+ pattern =
re.compile(r"^(\s*TVM_FFI_DLL_EXPORT_TYPED_FUNC\()([A-Za-z0-9_]+)(,.*)$")
+ for source_path in source_paths:
+ if not source_path.name.endswith("_binding.cu"):
+ continue
- # Get modification time for the current file (the one that contains this
function)
- current_file_mtime = Path(__file__).stat().st_mtime
+ original_text = source_path.read_text(encoding="utf-8")
+ lines = original_text.splitlines(keepends=True)
+ updated = False
+ for idx, line in enumerate(lines):
+ line_body = line.rstrip("\r\n")
+ line_ending = line[len(line_body) :]
+ match = pattern.match(line_body)
+ if not match:
+ continue
+ new_body =
f"{match.group(1)}{prefix}_{match.group(2)}{match.group(3)}"
+ lines[idx] = new_body + line_ending
+ updated = True
- # Build the hash key from metadata
- hash_key = {
- "name": name,
- "target": str(target),
- "latest_src_mtime": latest_src_mtime,
- "current_file_mtime": current_file_mtime,
- }
-
- hash_value = hashlib.md5(
- json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8")
- ).hexdigest()
-
- # Check if a valid hash exists in the build directory
- hash_file = build_directory / "hash.md5"
- if hash_file.exists():
- with open(hash_file, "r") as f:
- cached_hash = f.read().strip()
- if cached_hash == hash_value:
- # Check that all object files exist
- object_files = []
- all_exist = True
- for src in source_paths:
- obj_path = get_object_file_path(src)
- if not obj_path.exists():
- all_exist = False
- break
- object_files.append(obj_path)
- if all_exist:
- return object_files
-
- # If we are here, cache is missing or outdated. Write the new hash and
compile the paths
- with open(hash_file, "w") as f:
- f.write(hash_value)
-
- # ------------------------------------------------------------------------
- # 1) Common CUDA compile flags
- # ------------------------------------------------------------------------
- cuda_cflags = [
- "-O3",
- "-std=c++17",
- "--threads",
- str(num_threads),
- "-g",
- "-use_fast_math",
- "--expt-relaxed-constexpr",
- # DMLC default
- "-DDMLC_USE_FOPEN64=0",
- "-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>",
- # Enable `-fPIC` for the host compiler
- "-Xcompiler=-fPIC",
- "-DFLASHINFER_ENABLE_F16",
- "-DFLASHINFER_ENABLE_BF16",
- "-DFLASHINFER_ENABLE_FP8_E4M3",
- "-DFLASHINFER_ENABLE_FP8_E5M2",
- ]
-
- # Determine compute version
- compute_version =
"".join(tvm.contrib.nvcc.get_target_compute_version(target).split("."))
- if compute_version in ["90", "100"]:
- compute_version += "a"
- cuda_cflags += [
- "-gencode",
- f"arch=compute_{compute_version},code=sm_{compute_version}",
- ]
-
- # ------------------------------------------------------------------------
- # 2) Include paths
- # ------------------------------------------------------------------------
- include_paths = [
- FLASHINFER_INCLUDE_DIR,
- FLASHINFER_CSRC_DIR,
- FLASHINFER_TVM_BINDING_DIR,
- ] + CUTLASS_INCLUDE_DIRS
-
- if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME",
None):
- # Respect TVM_SOURCE_DIR and TVM_HOME if they are set
- tvm_home = (
- os.environ["TVM_SOURCE_DIR"]
- if os.environ.get("TVM_SOURCE_DIR", None)
- else os.environ["TVM_HOME"]
- )
- include_paths += [
- Path(tvm_home).resolve() / "include",
- Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "include",
- Path(tvm_home).resolve() / "3rdparty" / "tvm-ffi" / "3rdparty" /
"dlpack" / "include",
- Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
- ]
- else:
- # If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM
package path
- tvm_package_path = Path(tvm.__file__).resolve().parent
- if (tvm_package_path / "include").exists():
- # The package is installed from pip.
- tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent
- include_paths += [
- tvm_package_path / "include",
- tvm_package_path / "3rdparty" / "dmlc-core" / "include",
- tvm_ffi_package_path / "include",
- ]
- elif (tvm_package_path.parent.parent / "include").exists():
- # The package is installed from source.
- include_paths += [
- tvm_package_path.parent.parent / "include",
- tvm_package_path.parent.parent / "3rdparty" / "tvm-ffi" /
"include",
- tvm_package_path.parent.parent
- / "3rdparty"
- / "tvm-ffi"
- / "3rdparty"
- / "dlpack"
- / "include",
- tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" /
"include",
- ]
- else:
- # warning: TVM is not installed in the system.
- print(
- "Warning: Include path for TVM cannot be found. "
- "FlashInfer kernel compilation may fail due to missing
headers."
- )
-
- # ------------------------------------------------------------------------
- # 3) Function to compile a single source file
- # ------------------------------------------------------------------------
- def compile_single_source(src: Path) -> Path:
- # Derive the .o filename from the source filename
- obj_path = get_object_file_path(src)
-
- # Construct the command
- cmd = (
- ["nvcc"]
- + cuda_cflags
- + [f"-I{inc_path}" for inc_path in include_paths]
- + ["-c", "-o", str(obj_path), str(src)]
- )
-
- proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
- out, err = proc.communicate()
- if proc.returncode != 0:
- raise RuntimeError(
- f"FlashInfer JIT compilation failed for {src}\n"
- f"Command: {' '.join(cmd)}\n"
- f"stdout:\n{out.decode('utf-8')}\n"
- f"stderr:\n{err.decode('utf-8')}"
- )
- return obj_path
-
- # ------------------------------------------------------------------------
- # 4) Compile each source in parallel using ThreadPoolExecutor
- # ------------------------------------------------------------------------
- object_files = []
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
- futures = [executor.submit(compile_single_source, src) for src in
source_paths]
- for f in futures:
- object_files.append(f.result()) # Will raise if there's a
compilation error
-
- # Return list of generated object files for any further linking steps
- return object_files
+ if updated:
+ source_path.write_text("".join(lines), encoding="utf-8")
def _load_flashinfer_modules(object_files: List[Path]) ->
List[tvm.runtime.Module]:
@@ -228,9 +61,8 @@ def gen_flashinfer_prefill_module(
dtype_o: str,
qk_head_dim: int,
v_head_dim: int,
- target: Target,
- enable_inline_rope: bool = True,
- num_threads: int = 8,
+ enable_inline_rope: bool,
+ return_static_libs: bool = False,
) -> List[tvm.runtime.Module]:
"""Generate a FlashInfer module for prefill.
@@ -246,12 +78,12 @@ def gen_flashinfer_prefill_module(
The head dimension of the query and key tensors.
v_head_dim : int
The head dimension of the value tensor.
- target : Target
- The target device to compile for.
enable_inline_rope : bool
Whether to enable inline rotary positional embedding.
- num_threads : int
- The number of threads to use for compilation.
+ return_static_libs : bool
+ Whether to return static library modules instead of compiled modules.
+ When it is False, it returns the loaded shared library that links all
the object files.
+ When it is True, it returns the static libraries of each compiled
object files.
Returns
-------
@@ -259,7 +91,7 @@ def gen_flashinfer_prefill_module(
"""
try:
from flashinfer.jit import ( # pylint: disable=import-outside-toplevel
- gen_customize_batch_prefill_tvm_binding,
+ gen_customize_batch_prefill_module,
)
except ImportError:
raise ImportError(
@@ -289,32 +121,33 @@ def gen_flashinfer_prefill_module(
if backend == "fa2"
else "#include <flashinfer/attention/hopper/variants.cuh>"
)
- jit_args = {
- "backend": backend,
- "uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_"
+ jit_spec = gen_customize_batch_prefill_module(
+ backend=backend,
+ uri=f"batch_prefill_tvm_dtype_q_{dtype_q}_"
+ f"dtype_kv_{dtype_kv}_"
+ f"dtype_o_{dtype_o}_"
+ f"qk_head_dim_{qk_head_dim}_"
+ f"v_head_dim_{v_head_dim}_"
+ f"enable_inline_rope_{enable_inline_rope}",
- "dtype_q": torch_dtype_q,
- "dtype_kv": torch_dtype_kv,
- "dtype_o": torch_dtype_o,
- "idtype": torch.int32,
- "head_dim_qk": qk_head_dim,
- "head_dim_vo": v_head_dim,
- "additional_tensor_names": [],
- "additional_tensor_dtypes": [],
- "additional_scalar_names": ["sm_scale", "rope_rcp_scale",
"rope_rcp_theta"],
- "additional_scalar_dtypes": ["double", "double", "double"],
- "variant_name": variant_name,
- "variant_decl": variant_decl,
- "enable_inline_rope": enable_inline_rope,
- }
- uri, source_paths = gen_customize_batch_prefill_tvm_binding(**jit_args)
- object_files = _compile_flashinfer_kernels(uri, source_paths, target,
num_threads)
- modules = _load_flashinfer_modules(object_files)
- return modules
+ dtype_q=torch_dtype_q,
+ dtype_kv=torch_dtype_kv,
+ dtype_o=torch_dtype_o,
+ idtype=torch.int32,
+ head_dim_qk=qk_head_dim,
+ head_dim_vo=v_head_dim,
+ pos_encoding_mode=int(enable_inline_rope),
+ additional_tensor_names=[],
+ additional_tensor_dtypes=[],
+ additional_scalar_names=["sm_scale", "rope_rcp_scale",
"rope_rcp_theta"],
+ additional_scalar_dtypes=["double", "double", "double"],
+ variant_name=variant_name,
+ variant_decl=variant_decl,
+ )
+ _rename_exported_func_names(jit_spec.sources, "batch_prefill")
+ if return_static_libs:
+ jit_spec.build(verbose=False)
+ return _load_flashinfer_modules(jit_spec.get_object_paths())
+ return [jit_spec.build_and_load()]
def gen_flashinfer_decode_module(
@@ -323,8 +156,8 @@ def gen_flashinfer_decode_module(
dtype_o: str,
qk_head_dim: int,
v_head_dim: int,
- target: Target,
- num_threads: int = 8,
+ enable_inline_rope: bool,
+ return_static_libs: bool = False,
) -> List[tvm.runtime.Module]:
"""Generate a FlashInfer module for decode.
@@ -340,10 +173,12 @@ def gen_flashinfer_decode_module(
The head dimension of the query and key tensors.
v_head_dim : int
The head dimension of the value tensor.
- target : Target
- The target device to compile for.
- num_threads : int
- The number of threads to use for compilation.
+ enable_inline_rope : bool
+ Whether to enable inline rotary positional embedding.
+ return_static_libs : bool
+ Whether to return static library modules instead of compiled modules.
+ When it is False, it returns the loaded shared library that links all
the object files.
+ When it is True, it returns the static libraries of each compiled
object files.
Returns
-------
@@ -351,7 +186,7 @@ def gen_flashinfer_decode_module(
"""
try:
from flashinfer.jit import ( # pylint: disable=import-outside-toplevel
- gen_customize_batch_decode_tvm_binding,
+ gen_customize_batch_decode_module,
)
except ImportError:
raise ImportError(
@@ -366,29 +201,32 @@ def gen_flashinfer_decode_module(
torch_dtype_q = getattr(torch, dtype_q)
torch_dtype_kv = getattr(torch, dtype_kv)
torch_dtype_o = getattr(torch, dtype_o)
- jit_args = {
- "uri": f"batch_decode_tvm_dtype_q_{dtype_q}_"
+ jit_spec = gen_customize_batch_decode_module(
+ uri=f"batch_decode_tvm_dtype_q_{dtype_q}_"
+ f"dtype_kv_{dtype_kv}_"
+ f"dtype_o_{dtype_o}_"
+ f"qk_head_dim_{qk_head_dim}_"
- + f"v_head_dim_{v_head_dim}",
- "dtype_q": torch_dtype_q,
- "dtype_kv": torch_dtype_kv,
- "dtype_o": torch_dtype_o,
- "idtype": torch.int32,
- "head_dim_qk": qk_head_dim,
- "head_dim_vo": v_head_dim,
- "additional_tensor_names": [],
- "additional_tensor_dtypes": [],
- "additional_scalar_names": ["sm_scale", "rope_rcp_scale",
"rope_rcp_theta"],
- "additional_scalar_dtypes": ["double", "double", "double"],
- "variant_name": "DefaultAttention<false, false, false, false>",
- "variant_decl": "#include <flashinfer/attention/variants.cuh>",
- }
- uri, source_paths = gen_customize_batch_decode_tvm_binding(**jit_args)
- object_files = _compile_flashinfer_kernels(uri, source_paths, target,
num_threads)
- modules = _load_flashinfer_modules(object_files)
- return modules
+ + f"v_head_dim_{v_head_dim}_"
+ + f"enable_inline_rope_{enable_inline_rope}",
+ dtype_q=torch_dtype_q,
+ dtype_kv=torch_dtype_kv,
+ dtype_o=torch_dtype_o,
+ idtype=torch.int32,
+ head_dim_qk=qk_head_dim,
+ head_dim_vo=v_head_dim,
+ pos_encoding_mode=int(enable_inline_rope),
+ additional_tensor_names=[],
+ additional_tensor_dtypes=[],
+ additional_scalar_names=["sm_scale", "rope_rcp_scale",
"rope_rcp_theta"],
+ additional_scalar_dtypes=["double", "double", "double"],
+ variant_name="DefaultAttention<false, false, false, false>",
+ variant_decl="#include <flashinfer/attention/variants.cuh>",
+ )
+ _rename_exported_func_names(jit_spec.sources, "batch_decode")
+ if return_static_libs:
+ jit_spec.build(verbose=False)
+ return _load_flashinfer_modules(jit_spec.get_object_paths())
+ return [jit_spec.build_and_load()]
def gen_flashinfer_mla_module(
@@ -397,8 +235,7 @@ def gen_flashinfer_mla_module(
dtype_o: str,
head_dim_ckv: int,
head_dim_kpe: int,
- target: Target,
- num_threads: int = 8,
+ return_static_libs: bool = False,
) -> List[tvm.runtime.Module]:
"""Generate a FlashInfer module for MLA.
@@ -418,6 +255,10 @@ def gen_flashinfer_mla_module(
The target device to compile for.
num_threads : int
The number of threads to use for compilation.
+ return_static_libs : bool
+ Whether to return static library modules instead of compiled modules.
+ When it is False, it returns the loaded shared library that links all
the object files.
+ When it is True, it returns the static libraries of each compiled
object files.
Returns
-------
@@ -425,7 +266,7 @@ def gen_flashinfer_mla_module(
"""
try:
from flashinfer.jit import ( # pylint: disable=import-outside-toplevel
- gen_batch_mla_tvm_binding,
+ gen_batch_mla_module,
)
except ImportError:
raise ImportError(
@@ -440,92 +281,36 @@ def gen_flashinfer_mla_module(
torch_dtype_q = getattr(torch, dtype_q)
torch_dtype_kv = getattr(torch, dtype_kv)
torch_dtype_o = getattr(torch, dtype_o)
- jit_args = {
- "uri": f"batch_mla_tvm_dtype_q_{dtype_q}_"
- + f"dtype_kv_{dtype_kv}_"
- + f"dtype_o_{dtype_o}_"
- + f"head_dim_ckv_{head_dim_ckv}_"
- + f"head_dim_kpe_{head_dim_kpe}",
- "dtype_q": torch_dtype_q,
- "dtype_kv": torch_dtype_kv,
- "dtype_o": torch_dtype_o,
- "dtype_idx": torch.int32,
- "head_dim_ckv": head_dim_ckv,
- "head_dim_kpe": head_dim_kpe,
- }
- uri, source_paths = gen_batch_mla_tvm_binding(**jit_args)
- object_files = _compile_flashinfer_kernels(uri, source_paths, target,
num_threads)
- modules = _load_flashinfer_modules(object_files)
- return modules
-
-
-def gen_sampling_module(target: Target, num_threads: int = 8):
- """
- Generate a FlashInfer module for sampling kernels.
-
- Parameters
- ----------
- target : Target
- The target device for which the module will be compiled.
- num_threads : int, optional
- The number of threads to use during compilation (default is 8).
-
- Returns
- -------
- List[tvm.runtime.Module]
- A list of compiled static library modules for the FlashInfer sampling
kernels.
- """
- try:
- from flashinfer.jit import ( # pylint: disable=import-outside-toplevel
- gen_sampling_tvm_binding,
- )
- except ImportError:
- raise ImportError(
- "FlashInfer is not installed. Please follow instructions "
- "in https://docs.flashinfer.ai to install FlashInfer."
- )
- uri, source_paths = gen_sampling_tvm_binding(uri="sampling")
- object_files = _compile_flashinfer_kernels(uri, source_paths, target,
num_threads)
- modules = _load_flashinfer_modules(object_files)
- return modules
+ jit_spec = gen_batch_mla_module(
+ backend="fa2",
+ dtype_q=torch_dtype_q,
+ dtype_kv=torch_dtype_kv,
+ dtype_o=torch_dtype_o,
+ dtype_idx=torch.int32,
+ head_dim_ckv=head_dim_ckv,
+ head_dim_kpe=head_dim_kpe,
+ use_profiler=False,
+ )
+ _rename_exported_func_names(jit_spec.sources, "batch_mla")
+ if return_static_libs:
+ jit_spec.build(verbose=False)
+ return _load_flashinfer_modules(jit_spec.get_object_paths())
+ return [jit_spec.build_and_load()]
def gen_grouped_gemm_module(
- dtype_a: str,
- dtype_b: str,
- dtype_out: str,
- scale_granularity_m: int,
- scale_granularity_n: int,
- scale_granularity_k: int,
- scale_major_mode: str,
- mma_sm: int,
- target: Target,
- num_threads: int = 8,
+ target: Target, return_static_libs: bool = False
) -> List[tvm.runtime.Module]:
"""Generate a FlashInfer module for FP8 grouped GEMM.
Parameters
----------
- dtype_a : str
- The data type of matrix A (e.g., "float8_e4m3fn").
- dtype_b : str
- The data type of matrix B (e.g., "float8_e4m3fn").
- dtype_out : str
- The data type of the output matrix (e.g., "bfloat16").
- scale_granularity_m : int
- The scaling granularity in the M dimension.
- scale_granularity_n : int
- The scaling granularity in the N dimension.
- scale_granularity_k : int
- The scaling granularity in the K dimension.
- scale_major_mode : str
- The scale storage mode ("K" or "MN").
- mma_sm : int
- The MMA scheduling mode (1 or 2).
target : Target
The target device to compile for.
- num_threads : int
- The number of threads to use for compilation.
+ return_static_libs : bool
+ Whether to return static library modules instead of compiled modules.
+ When it is False, it returns the loaded shared library that links all
the object files.
+ When it is True, it returns the static libraries of each compiled
object files.
Returns
-------
@@ -537,48 +322,24 @@ def gen_grouped_gemm_module(
when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k),
m_indptr: (batch_size, )
requires all m in m_indptr to be multiple of 4
"""
+ # NOTE: This function is still under development,
+ # and we currently only support SM100 grouped gemm
try:
- from flashinfer.jit import ( # pylint: disable=import-outside-toplevel
- gen_grouped_gemm_fp8_tvm_binding,
- get_grouped_gemm_fp8_uri,
+ from flashinfer.gemm import ( # pylint:
disable=import-outside-toplevel
+ gen_gemm_sm100_module,
)
except ImportError:
raise ImportError(
"FlashInfer is not installed. Please follow instructions "
"in https://docs.flashinfer.ai to install FlashInfer."
)
- try:
- import torch # pylint: disable=import-outside-toplevel
- except ImportError:
- raise ImportError("PyTorch is not installed. Please install PyTorch to
use FlashInfer.")
-
- torch_dtype_a = getattr(torch, dtype_a)
- torch_dtype_b = getattr(torch, dtype_b)
- torch_dtype_out = getattr(torch, dtype_out)
-
- uri = get_grouped_gemm_fp8_uri(
- dtype_a=torch_dtype_a,
- dtype_b=torch_dtype_b,
- dtype_out=torch_dtype_out,
- scale_granularity_m=scale_granularity_m,
- scale_granularity_n=scale_granularity_n,
- scale_granularity_k=scale_granularity_k,
- scale_major_mode=scale_major_mode,
- mma_sm=mma_sm,
- )
- uri, source_paths = gen_grouped_gemm_fp8_tvm_binding(
- uri=uri,
- dtype_a=torch_dtype_a,
- dtype_b=torch_dtype_b,
- dtype_out=torch_dtype_out,
- scale_granularity_m=scale_granularity_m,
- scale_granularity_n=scale_granularity_n,
- scale_granularity_k=scale_granularity_k,
- scale_major_mode=scale_major_mode,
- mma_sm=mma_sm,
- )
-
- object_files = _compile_flashinfer_kernels(uri, source_paths, target,
num_threads)
- modules = _load_flashinfer_modules(object_files)
- return modules
+ compute_version =
"".join(tvm.contrib.nvcc.get_target_compute_version(target).split("."))
+ if compute_version == "100":
+ jit_spec = gen_gemm_sm100_module()
+ else:
+ raise ValueError(f"Unsupported compute version: {compute_version}")
+ if return_static_libs:
+ jit_spec.build(verbose=False)
+ return _load_flashinfer_modules(jit_spec.get_object_paths())
+ return [jit_spec.build_and_load()]
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index e6e171da99..e94d5c4295 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -371,8 +371,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
enable_disaggregation : bool
Whether to enable disaggregation in the KV cache.
"""
- if rope_mode == RopeMode.INLINE:
- assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not
support partial rotary dim."
+ assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support
inline mode."
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else
attn_kind
if attn_kind_single == "mha_sliding":
@@ -383,8 +382,8 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
dtype_o=dtype,
qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else
mla_original_qk_head_dim),
v_head_dim=(v_head_dim if attn_kind_single == "mha" else
mla_original_v_head_dim),
- target=target,
- enable_inline_rope=rope_mode == RopeMode.INLINE,
+ enable_inline_rope=False,
+ return_static_libs=True,
)
flashinfer_decode_mods = (
rx.backend.cuda.flashinfer.gen_flashinfer_decode_module(
@@ -393,7 +392,8 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
dtype_o=dtype,
qk_head_dim=qk_head_dim,
v_head_dim=v_head_dim,
- target=target,
+ enable_inline_rope=False,
+ return_static_libs=True,
)
if attn_kind_single == "mha"
else []
@@ -405,7 +405,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
dtype_o=dtype,
head_dim_ckv=v_head_dim,
head_dim_kpe=qk_head_dim - v_head_dim,
- target=target,
+ return_static_libs=True,
)
if attn_kind_single == "mla"
else []
@@ -417,8 +417,8 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
bb = rx.BlockBuilder.current()
mha_functions = (
[
- rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_prefill_with_paged_kv_cache_run"),
rx.ExternFunc("batch_prefill_with_kv_cache_plan")]),
- rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_decode_with_paged_kv_cache_run"),
rx.ExternFunc("batch_decode_with_paged_kv_cache_plan")]),
+ rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_prefill_paged_run"), rx.ExternFunc("batch_prefill_plan")]),
+ rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_decode_run"), rx.ExternFunc("batch_decode_plan")]),
rx.Tuple([rx.StringImm("tir"),
bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads,
qk_head_dim, dtype, True, rope_scaling, target),
"tir_attention_prefill_sliding_window")]),
rx.Tuple([rx.StringImm("tir"),
bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads,
qk_head_dim, dtype, True, rope_scaling, target),
"tir_attention_decode_sliding_window")]),
rx.Tuple([rx.StringImm("tir"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads,
num_attention_heads, qk_head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
@@ -427,7 +427,8 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
if attn_kind_single == "mha"
else [rx.Tuple([]) for _ in range(6)]
)
- mla_function = rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_mla_paged_attention_run"),
rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla"
else [])
+ ragged_prefill_function = rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_prefill_ragged_run"),
rx.ExternFunc("batch_prefill_plan")]) if attn_kind_single == "mha" else
rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_prefill_ragged_run"), rx.ExternFunc("batch_prefill_plan"),
rx.PrimValue(mla_original_qk_head_dim), rx.PrimValue(mla_original_v_head_dim)])
+ mla_function = rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_mla_run"), rx.ExternFunc("batch_mla_plan")] if
attn_kind_single == "mla" else [])
attn_merge_functions = [
bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim,
dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
]
@@ -463,7 +464,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
rx.op.zeros((), dtype),
bb.add_func(_kv_cache_transpose_append(num_key_value_heads,
qk_head_dim, dtype), "kv_cache_transpose_append"),
bb.add_func(_kv_cache_transpose_append_mla(qk_head_dim, dtype),
"kv_cache_transpose_append_mla"),
- rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_prefill_with_ragged_kv_cache_run"),
rx.ExternFunc("batch_prefill_with_kv_cache_plan")]),
+ ragged_prefill_function,
*mha_functions,
mla_function,
rx.Tuple(attn_merge_functions),
diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc
index 3b37d9810b..13e151ecd2 100644
--- a/src/runtime/vm/attn_backend.cc
+++ b/src/runtime/vm/attn_backend.cc
@@ -59,11 +59,18 @@ std::unique_ptr<RaggedPrefillFunc>
ConvertRaggedPrefillFunc(ffi::Array<ffi::Any>
return std::make_unique<TIRRaggedPrefillFunc>(std::move(attn_func),
attn_kind);
}
if (backend_name == "flashinfer") {
- CHECK_EQ(args.size(), 3);
+ CHECK(args.size() == 3 || args.size() == 5);
ffi::Function attn_func = args[1].cast<ffi::Function>();
ffi::Function plan_func = args[2].cast<ffi::Function>();
+ int64_t qk_head_dim_override = -1;
+ int64_t v_head_dim_override = -1;
+ if (args.size() == 5) {
+ qk_head_dim_override = args[3].cast<int64_t>();
+ v_head_dim_override = args[4].cast<int64_t>();
+ }
return std::make_unique<FlashInferRaggedPrefillFunc>(std::move(attn_func),
std::move(plan_func),
- attn_kind);
+ attn_kind,
qk_head_dim_override,
+ v_head_dim_override);
}
LOG(FATAL) << "Cannot reach here";
throw;
diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h
index ea5f49c6c0..1fd22a97ab 100644
--- a/src/runtime/vm/attn_backend.h
+++ b/src/runtime/vm/attn_backend.h
@@ -27,6 +27,7 @@
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/function.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/int_tuple.h>
#include <tvm/runtime/logging.h>
@@ -57,6 +58,22 @@ class AttnBackendFunc {
virtual ~AttnBackendFunc() = default;
protected:
+ // helper allocator class for creating strided view of a Tensor
+ // that applies byte offset to the original data pointer
+ class ViewBasedAlloc {
+ public:
+ explicit ViewBasedAlloc(Tensor source) : source_(source) {}
+ void AllocData(DLTensor* tensor, int64_t* strides, int64_t
extra_byte_offset) {
+ tensor->data = static_cast<char*>(source_->data) + extra_byte_offset;
+ tensor->strides = strides;
+ }
+
+ void FreeData(DLTensor* tensor) {}
+
+ private:
+ Tensor source_;
+ };
+
ffi::Function attn_func_;
public:
@@ -133,16 +150,34 @@ class FlashInferPagedPrefillFunc : public
PagedPrefillFunc {
Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double
rotary_scale,
double rotary_theta, double sm_scale, Tensor attn_output, Tensor
attn_lse,
TVMStreamHandle compute_stream) final {
+ Device device = q->device;
+ TVMStreamHandle original_stream =
DeviceAPI::Get(device)->GetCurrentStream(device);
+ DeviceAPI::Get(device)->SetStream(device, compute_stream);
auto [float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
plan_info_vec] = cached_buffers_[depth];
double rope_rcp_scale = 1 / rotary_scale;
double rope_rcp_theta = 1 / rotary_theta;
- attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages, qo_indptr,
- page_indptr, page_indices, length_info, q_rope_position,
k_rope_pos_offset,
- attn_output, attn_lse,
/*mask_mode_code=*/static_cast<int64_t>(causal),
- /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode ==
RoPEMode::kInline),
- /*layout(HND)=*/1, /*window_left=*/-1, sm_scale,
/*rope_rcp_scale=*/rope_rcp_scale,
- /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+
+ ICHECK_EQ(pages.ndim(), 5);
+ int H = pages->shape[2];
+ int N = pages->shape[3];
+ int D = pages->shape[4];
+ CHECK(pages.IsContiguous());
+ std::vector<int64_t> pages_k_v_shape = {pages->shape[0], H, N, D};
+ std::vector<int64_t> pages_k_v_strides = {2 * H * N * D, N * D, D, 1};
+ Tensor pages_k =
+ Tensor::FromNDAlloc(ViewBasedAlloc(pages),
ffi::Shape(pages_k_v_shape), pages->dtype,
+ pages->device, pages_k_v_strides.data(),
pages->byte_offset);
+ Tensor pages_v = Tensor::FromNDAlloc(
+ ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype,
pages->device,
+ pages_k_v_strides.data(), pages->byte_offset + (H * N * D) *
pages.DataType().bytes());
+
+ attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages_k, pages_v,
+ qo_indptr, page_indptr, page_indices, length_info, attn_output,
attn_lse,
+ /*mask_mode_code=*/static_cast<int64_t>(causal),
/*layout(HND)=*/1,
+ /*window_left=*/-1, /*enable_pdl=*/false, sm_scale,
+ /*rope_rcp_scale=*/rope_rcp_scale,
/*rope_rcp_theta=*/rope_rcp_theta);
+ DeviceAPI::Get(device)->SetStream(device, original_stream);
}
void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor
page_indptr,
@@ -150,9 +185,43 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc
{
Tensor attn_output, Tensor attn_lse, TVMStreamHandle
compute_stream) final {
auto [float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
plan_info_vec] = cached_buffers_[depth];
- attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages, page_indices,
- attn_output, attn_lse,
/*mask_mode_code=*/static_cast<int64_t>(causal),
- /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1],
sm_scale, compute_stream);
+ Device device = q->device;
+ TVMStreamHandle original_stream =
DeviceAPI::Get(device)->GetCurrentStream(device);
+ DeviceAPI::Get(device)->SetStream(device, compute_stream);
+ ICHECK_NE(qk_head_dim_, -1);
+ ICHECK_NE(v_head_dim_, -1);
+ int64_t H = q->shape[1];
+ int64_t page_size = pages->shape[1];
+ int64_t rope_head_dim = qk_head_dim_ - v_head_dim_;
+ int64_t nope_head_dim = q->shape[2] - rope_head_dim;
+
+ // Split q into q_nope and q_pe
+ CHECK(q.IsContiguous());
+ std::vector<int64_t> q_nope_shape = {q->shape[0], H, nope_head_dim};
+ std::vector<int64_t> q_pe_shape = {q->shape[0], H, rope_head_dim};
+ std::vector<int64_t> q_strides = {H * q->shape[2], q->shape[2], 1};
+ Tensor q_nope = Tensor::FromNDAlloc(ViewBasedAlloc(q),
ffi::Shape(q_nope_shape), q->dtype,
+ q->device, q_strides.data(),
q->byte_offset);
+ Tensor q_pe = Tensor::FromNDAlloc(ViewBasedAlloc(q),
ffi::Shape(q_pe_shape), q->dtype,
+ q->device, q_strides.data(),
+ q->byte_offset + nope_head_dim *
q.DataType().bytes());
+ // Split pages into kv_nope and kv_pe
+ CHECK(pages.IsContiguous());
+ std::vector<int64_t> kv_nope_shape = {pages->shape[0], page_size,
nope_head_dim};
+ std::vector<int64_t> kv_pe_shape = {pages->shape[0], page_size,
rope_head_dim};
+ std::vector<int64_t> kv_strides = {page_size * pages->shape[2],
pages->shape[2], 1};
+ Tensor kv_nope =
+ Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(kv_nope_shape),
pages->dtype,
+ pages->device, kv_strides.data(),
pages->byte_offset);
+ Tensor kv_pe = Tensor::FromNDAlloc(
+ ViewBasedAlloc(pages), ffi::Shape(kv_pe_shape), pages->dtype,
pages->device,
+ kv_strides.data(), pages->byte_offset + nope_head_dim *
pages.DataType().bytes());
+
+ attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec,
q_nope, q_pe, kv_nope,
+ kv_pe, page_indices, attn_output, attn_lse,
+ /*mask_mode_code=*/static_cast<int64_t>(causal),
+ /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1],
sm_scale);
+ DeviceAPI::Get(device)->SetStream(device, original_stream);
}
void BeginForward(int depth, Tensor float_workspace_buffer, Tensor
int_workspace_buffer,
@@ -161,31 +230,37 @@ class FlashInferPagedPrefillFunc : public
PagedPrefillFunc {
int64_t batch_size, int64_t total_qo_len, int64_t
page_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t
qk_head_dim,
int64_t v_head_dim, bool causal, TVMStreamHandle
copy_stream) final {
- std::vector<int64_t> kv_len;
- kv_len.reserve(batch_size);
+ Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32),
Device{kDLCPU, 0});
+ int32_t* kv_len_arr_data = static_cast<int32_t*>(kv_len_arr.data_ptr());
for (int i = 0; i < static_cast<int>(batch_size); ++i) {
- kv_len.push_back((*page_indptr)[i + 1] != (*page_indptr)[i]
- ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) *
page_size +
- (*last_page_len)[i]
- : 0);
+ kv_len_arr_data[i] =
+ (*page_indptr)[i + 1] != (*page_indptr)[i]
+ ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size +
(*last_page_len)[i]
+ : 0;
}
- IntTuple plan_info_vec;
+ qk_head_dim_ = qk_head_dim;
+ v_head_dim_ = v_head_dim;
+ ffi::Array<int64_t> plan_info_vec;
+ Device device = float_workspace_buffer->device;
+ TVMStreamHandle original_stream =
DeviceAPI::Get(device)->GetCurrentStream(device);
+ DeviceAPI::Get(device)->SetStream(device, copy_stream);
if (attn_kind == AttnKind::kMHA) {
// Todo(tvm-team): enable cuda graph
plan_info_vec =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
- qo_indptr->as_tensor(), page_indptr->as_tensor(),
IntTuple(std::move(kv_len)),
- total_qo_len, batch_size, num_qo_heads, num_kv_heads,
page_size,
+ qo_indptr->as_tensor(), page_indptr->as_tensor(),
kv_len_arr, total_qo_len,
+ batch_size, num_qo_heads, num_kv_heads, page_size,
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim,
causal,
- /*window_left=*/-1, copy_stream)
- .cast<IntTuple>();
+ /*window_left=*/-1, /*fixed_split_size=*/-1,
/*disable_split_kv=*/false)
+ .cast<ffi::Array<int64_t>>();
} else if (attn_kind == AttnKind::kMLA) {
plan_info_vec =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
- qo_indptr->as_tensor(), page_indptr->as_tensor(),
IntTuple(std::move(kv_len)),
- num_qo_heads, v_head_dim, causal, copy_stream)
- .cast<IntTuple>();
+ qo_indptr->as_tensor(), page_indptr->as_tensor(),
kv_len_arr, num_qo_heads,
+ v_head_dim, causal)
+ .cast<ffi::Array<int64_t>>();
}
+ DeviceAPI::Get(device)->SetStream(device, original_stream);
if (cached_buffers_.size() <= static_cast<size_t>(depth)) {
cached_buffers_.resize(depth + 1);
@@ -196,8 +271,10 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc
{
}
private:
+ int64_t qk_head_dim_ = -1;
+ int64_t v_head_dim_ = -1;
ffi::Function plan_func_;
- std::vector<std::tuple<Tensor, Tensor, Tensor, IntTuple>> cached_buffers_;
+ std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Array<int64_t>>>
cached_buffers_;
};
/*! \brief The ragged prefill attention function base class. */
@@ -244,23 +321,30 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc {
class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc {
public:
explicit FlashInferRaggedPrefillFunc(ffi::Function attn_func, ffi::Function
plan_func,
- AttnKind attn_kind)
+ AttnKind attn_kind, int64_t
qk_head_dim_override,
+ int64_t v_head_dim_override)
: RaggedPrefillFunc(std::move(attn_func), attn_kind,
AttnBackendKind::kFlashInfer),
+ qk_head_dim_override_(qk_head_dim_override),
+ v_head_dim_override_(v_head_dim_override),
plan_func_(std::move(plan_func)) {}
void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr,
Tensor q_rope_position,
Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double
rotary_scale,
double rotary_theta, double sm_scale, Tensor attn_output, Tensor
attn_lse,
TVMStreamHandle compute_stream) final {
+ Device device = q->device;
+ TVMStreamHandle original_stream =
DeviceAPI::Get(device)->GetCurrentStream(device);
+ DeviceAPI::Get(device)->SetStream(device, compute_stream);
double rope_rcp_scale = 1 / rotary_scale;
double rope_rcp_theta = 1 / rotary_theta;
attn_func_(float_workspace_buffer_, int_workspace_buffer_, plan_info_vec_,
q, k, v, qo_indptr,
- kv_indptr, q_rope_position, k_rope_pos_offset, attn_output,
attn_lse,
+ kv_indptr, attn_output, attn_lse,
/*mask_mode_code=*/static_cast<int64_t>(causal),
- /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode ==
RoPEMode::kInline),
- /*layout(NHD)=*/0, /*window_left=*/-1, sm_scale,
+ /*layout(NHD)=*/0, /*window_left=*/-1,
+ /*enable_pdl=*/false, sm_scale,
/*rope_rcp_scale=*/rope_rcp_scale,
- /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+ /*rope_rcp_theta=*/rope_rcp_theta);
+ DeviceAPI::Get(device)->SetStream(device, original_stream);
}
void BeginForward(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
@@ -268,30 +352,42 @@ class FlashInferRaggedPrefillFunc : public
RaggedPrefillFunc {
HostMemoryVector* kv_indptr, int64_t batch_size, int64_t
total_qo_len,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t
qk_head_dim,
int64_t v_head_dim, bool causal, TVMStreamHandle
copy_stream) final {
- std::vector<int64_t> kv_len;
- kv_len.reserve(batch_size);
+ Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32),
Device{kDLCPU, 0});
+ int32_t* kv_len_arr_data = static_cast<int32_t*>(kv_len_arr.data_ptr());
for (int i = 0; i < static_cast<int>(batch_size); ++i) {
- kv_len.push_back((*kv_indptr)[i + 1] - (*kv_indptr)[i]);
+ kv_len_arr_data[i] = (*kv_indptr)[i + 1] - (*kv_indptr)[i];
+ }
+ if (qk_head_dim_override_ != -1) {
+ qk_head_dim = qk_head_dim_override_;
+ }
+ if (v_head_dim_override_ != -1) {
+ v_head_dim = v_head_dim_override_;
}
// Todo(tvm-team): enable cuda graph
float_workspace_buffer_ = float_workspace_buffer;
int_workspace_buffer_ = int_workspace_buffer;
page_locked_int_workspace_buffer_ = page_locked_int_workspace_buffer;
+ Device device = float_workspace_buffer->device;
+ TVMStreamHandle original_stream =
DeviceAPI::Get(device)->GetCurrentStream(device);
+ DeviceAPI::Get(device)->SetStream(device, copy_stream);
plan_info_vec_ =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
- qo_indptr->as_tensor(), kv_indptr->as_tensor(),
IntTuple(std::move(kv_len)),
- total_qo_len, batch_size, num_qo_heads, num_kv_heads,
/*page_size=*/1,
+ qo_indptr->as_tensor(), kv_indptr->as_tensor(), kv_len_arr,
total_qo_len,
+ batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1,
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim,
causal,
- /*window_left=*/-1, copy_stream)
- .cast<IntTuple>();
+ /*window_left=*/-1, /*fixed_split_size=*/-1,
/*disable_split_kv=*/false)
+ .cast<ffi::Array<int64_t>>();
+ DeviceAPI::Get(device)->SetStream(device, original_stream);
}
private:
+ int64_t qk_head_dim_override_;
+ int64_t v_head_dim_override_;
ffi::Function plan_func_;
Tensor float_workspace_buffer_;
Tensor int_workspace_buffer_;
Tensor page_locked_int_workspace_buffer_;
- IntTuple plan_info_vec_;
+ ffi::Array<int64_t> plan_info_vec_;
};
/*! \brief The paged decode attention function base class. */
@@ -359,15 +455,33 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
Tensor length_info, Tensor k_rope_pos_offset, Tensor
q_rope_position, RoPEMode rope_mode,
double rotary_scale, double rotary_theta, double sm_scale, Tensor
attn_output,
Tensor attn_lse, TVMStreamHandle compute_stream) final {
+ Device device = q->device;
+ TVMStreamHandle original_stream =
DeviceAPI::Get(device)->GetCurrentStream(device);
+ DeviceAPI::Get(device)->SetStream(device, compute_stream);
auto [float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
plan_info_vec] = cached_buffers_[depth];
double rope_rcp_scale = 1 / rotary_scale;
double rope_rcp_theta = 1 / rotary_theta;
- attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages, page_indptr,
- page_indices, length_info, q_rope_position, k_rope_pos_offset,
attn_output, attn_lse,
- /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode ==
RoPEMode::kInline),
- /*layout(HND)=*/1, /*window_left=*/-1, sm_scale,
/*rope_rcp_scale=*/rope_rcp_scale,
- /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+
+ ICHECK_EQ(pages.ndim(), 5);
+ int H = pages->shape[2];
+ int N = pages->shape[3];
+ int D = pages->shape[4];
+ CHECK(pages.IsContiguous());
+ std::vector<int64_t> pages_k_v_shape = {pages->shape[0], H, N, D};
+ std::vector<int64_t> pages_k_v_strides = {2 * H * N * D, N * D, D, 1};
+ Tensor pages_k =
+ Tensor::FromNDAlloc(ViewBasedAlloc(pages),
ffi::Shape(pages_k_v_shape), pages->dtype,
+ pages->device, pages_k_v_strides.data(),
pages->byte_offset);
+ Tensor pages_v = Tensor::FromNDAlloc(
+ ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype,
pages->device,
+ pages_k_v_strides.data(), pages->byte_offset + (H * N * D) *
pages.DataType().bytes());
+
+ attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages_k, pages_v,
+ page_indptr, page_indices, length_info, attn_output, attn_lse,
+ /*layout(HND)=*/1, /*window_left=*/-1, /*enable_pdl=*/false,
sm_scale,
+ /*rope_rcp_scale=*/rope_rcp_scale,
/*rope_rcp_theta=*/rope_rcp_theta);
+ DeviceAPI::Get(device)->SetStream(device, original_stream);
}
void BeginForward(int depth, Tensor float_workspace_buffer, Tensor
int_workspace_buffer,
@@ -377,13 +491,18 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype,
TVMStreamHandle copy_stream) final {
// Todo(tvm-team): enable cuda graph
- IntTuple plan_info_vec =
+ Tensor empty_qkv_data = Tensor::Empty({1}, q_dtype, Device{kDLCPU, 0});
+ Device device = float_workspace_buffer->device;
+ TVMStreamHandle original_stream =
DeviceAPI::Get(device)->GetCurrentStream(device);
+ DeviceAPI::Get(device)->SetStream(device, copy_stream);
+ ffi::Array<int64_t> plan_info_vec =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
page_indptr->as_tensor(), batch_size, num_qo_heads,
num_kv_heads, page_size,
/*enable_cuda_graph=*/false,
- static_cast<int64_t>(rope_mode == RoPEMode::kInline),
- /*window_left=*/-1, qk_head_dim, v_head_dim, q_dtype,
kv_dtype, copy_stream)
- .cast<IntTuple>();
+ /*window_left=*/-1, /*logits_soft_cap=*/0.0, qk_head_dim,
v_head_dim,
+ empty_qkv_data, empty_qkv_data)
+ .cast<ffi::Array<int64_t>>();
+ DeviceAPI::Get(device)->SetStream(device, original_stream);
if (cached_buffers_.size() <= static_cast<size_t>(depth)) {
cached_buffers_.resize(depth + 1);
@@ -395,7 +514,7 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
private:
ffi::Function plan_func_;
- std::vector<std::tuple<Tensor, Tensor, Tensor, IntTuple>> cached_buffers_;
+ std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Array<int64_t>>>
cached_buffers_;
};
/*! \brief The paged prefill with tree mask attention function base class. */
diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h
index 09557a8f0a..1c695a10e2 100644
--- a/src/runtime/vm/attn_utils.h
+++ b/src/runtime/vm/attn_utils.h
@@ -860,8 +860,9 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
sliding_window_offset->data(), n_elem * elem_byte_size_);
std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_
+ 2 * n_elem,
sink_size->data(), n_elem * elem_byte_size_);
- Tensor view = merged_attn_aux_data_device_.CreateView(
- {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_);
+ Tensor view =
+ Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_),
ffi::Shape({3, n_elem}),
+ dtype_aux_, device_, attn_aux_data_copy_offset_ *
elem_byte_size_);
attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem);
return view;
}
@@ -895,8 +896,9 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
src_data->data(), n_elem * elem_byte_size_);
std::memcpy(merged_compact_kv_aux_data_host_.data() +
compact_kv_aux_data_copy_offset_ + n_elem,
dst_data->data(), n_elem * elem_byte_size_);
- Tensor view = merged_compact_kv_aux_data_device_.CreateView(
- {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ *
elem_byte_size_);
+ Tensor view =
Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_),
+ ffi::Shape({2, n_elem}), dtype_aux_,
device_,
+ compact_kv_aux_data_copy_offset_ *
elem_byte_size_);
compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem);
return view;
}
@@ -919,6 +921,20 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
}
private:
+ // helper allocator class that applies byte offset to the original data
pointer
+ class ViewHelper {
+ public:
+ explicit ViewHelper(Tensor source) : source_(source) {}
+ void AllocData(DLTensor* tensor, int64_t extra_byte_offset) {
+ tensor->data = static_cast<char*>(source_->data) + extra_byte_offset;
+ }
+
+ void FreeData(DLTensor* tensor) {}
+
+ private:
+ Tensor source_;
+ };
+
/*!
* \brief Calculate the start element offsets of the auxiliary arrays in the
local cache.
* \return Return the local cache size (total number of elements in the
local cache).
@@ -990,8 +1006,9 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
int64_t n_elem = data->size();
std::memcpy(merged_attn_aux_data_host_.data() +
attn_aux_data_copy_offset_, data->data(),
n_elem * elem_byte_size_);
- Tensor view = merged_attn_aux_data_device_.CreateView(
- {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_);
+ Tensor view =
+ Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_),
ffi::Shape({n_elem}),
+ dtype_aux_, device_, attn_aux_data_copy_offset_ *
elem_byte_size_);
attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
return view;
}
@@ -1000,8 +1017,9 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
int64_t n_elem = data->size();
std::memcpy(merged_compact_kv_aux_data_host_.data() +
compact_kv_aux_data_copy_offset_,
data->data(), n_elem * elem_byte_size_);
- Tensor view = merged_compact_kv_aux_data_device_.CreateView(
- {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ *
elem_byte_size_);
+ Tensor view =
Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_),
+ ffi::Shape({n_elem}), dtype_aux_,
device_,
+ compact_kv_aux_data_copy_offset_ *
elem_byte_size_);
compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
return view;
}
diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc
index 0f3f568661..4fb3cd69d6 100644
--- a/src/runtime/vm/paged_kv_cache.cc
+++ b/src/runtime/vm/paged_kv_cache.cc
@@ -2052,7 +2052,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
temp_float_attn_workspace_, temp_int_attn_workspace_[0],
temp_int_pinned_attn_workspace_[0],
&cur_append_lengths_indptr_host_,
&cur_append_lengths_indptr_host_, cur_batch_size_,
- cur_append_lengths_indptr_host_.back(), num_qo_heads_,
num_kv_heads_, qk_head_dim_,
+ cur_append_lengths_indptr_host_.back(), num_qo_heads_,
num_qo_heads_, qk_head_dim_,
v_head_dim_, /*causal=*/true, copy_stream_);
}
}
diff --git a/tests/python/relax/test_group_gemm_flashinfer.py
b/tests/python/relax/test_group_gemm_flashinfer.py
index 8333e4b2d6..da6fdacebd 100644
--- a/tests/python/relax/test_group_gemm_flashinfer.py
+++ b/tests/python/relax/test_group_gemm_flashinfer.py
@@ -18,14 +18,14 @@
"""Test for FlashInfer GroupedGemm TVM integration"""
import math
+
import numpy as np
import pytest
import torch
+
import tvm
import tvm.testing
from tvm import relax
-from tvm.contrib import utils
-from tvm.relax.backend.cuda import flashinfer
DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024
fp8_dtype = "float8_e4m3fn"
@@ -389,36 +389,11 @@ def test_grouped_gemm_correctness(
device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
- def _load_module(name: str, static_modules):
- """Helper function to load compiled modules."""
- assert len(static_modules) > 0
- if len(static_modules) == 1:
- return static_modules[0]
- static_mod = static_modules[0]
- for mod in static_modules[1:]:
- static_mod.import_module(mod)
- temp = tvm.contrib.utils.tempdir()
- mod_path = temp.relpath(f"{name}.so")
- static_mod.export_library(mod_path)
- return tvm.runtime.load_module(mod_path)
-
# Generate the module
- modules = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(
- dtype_a=dtype_a,
- dtype_b=dtype_b,
- dtype_out=dtype_out,
- scale_granularity_m=scale_granularity_m,
- scale_granularity_n=scale_granularity_n,
- scale_granularity_k=scale_granularity_k,
- scale_major_mode=scale_major_mode,
- mma_sm=mma_sm,
- target=target,
- num_threads=4,
- )
+ mod =
relax.backend.cuda.flashinfer.gen_grouped_gemm_module(target=target)[0]
# Load the module
- mod = _load_module("flashinfer_grouped_gemm", modules)
- grouped_gemm_fn = mod["grouped_gemm_fp8_run"]
+ grouped_gemm_fn = mod["group_gemm_fp8_nt_groupwise"]
# Generate test data
test_data = generate_test_data(
@@ -460,7 +435,11 @@ def test_grouped_gemm_correctness(
test_data["m_indptr"], # m_indptr
test_data["n"], # n (scalar)
test_data["k"], # k (scalar)
- None, # cuda_stream (use default stream)
+ scale_granularity_m,
+ scale_granularity_n,
+ scale_granularity_k,
+ scale_major_mode,
+ mma_sm,
)
# Compute reference result
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index dd29140e9b..4aae9dec59 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -23,7 +23,6 @@ import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import relax
-from tvm.contrib import utils
from tvm.relax.frontend.nn.llm.kv_cache import (
AttnKind,
RopeMode,
@@ -78,7 +77,7 @@ fcopy_cache = None
fcompact_copy = None
-def set_global_func():
+def set_global_func(rope_mode: RopeMode):
global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fpopn
global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv,
fdebug_get_kv
global fattention_prefill, fattention_decode, fattention_prefill_ragged
@@ -98,48 +97,30 @@ def set_global_func():
)
fdebug_get_kv =
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
- def load_module(name: str, static_modules: List[tvm.runtime.Module]):
- assert len(static_modules) > 0
- if len(static_modules) == 1:
- return static_modules[0]
- static_mod = static_modules[0]
- for mod in static_modules[1:]:
- static_mod.import_module(mod)
- temp = utils.tempdir()
- mod_path = temp.relpath(f"{name}.so")
- static_mod.export_library(mod_path)
- return tvm.runtime.load_module(mod_path)
-
target = tvm.target.Target.from_device(device)
- flashinfer_prefill_mod = load_module(
- "flashinfer_prefill",
- relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
- dtype_q=dtype,
- dtype_kv=dtype,
- dtype_o=dtype,
- qk_head_dim=head_dim,
- v_head_dim=head_dim,
- target=target,
- ),
- )
- flashinfer_decode_mod = load_module(
- "flashinfer_decode",
- relax.backend.cuda.flashinfer.gen_flashinfer_decode_module(
- dtype_q=dtype,
- dtype_kv=dtype,
- dtype_o=dtype,
- qk_head_dim=head_dim,
- v_head_dim=head_dim,
- target=target,
- ),
- )
-
- fattention_prefill =
flashinfer_prefill_mod["batch_prefill_with_paged_kv_cache_run"]
- fattention_prefill_plan =
flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"]
- fattention_prefill_ragged =
flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"]
- fattention_prefill_ragged_plan =
flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"]
- fattention_decode =
flashinfer_decode_mod["batch_decode_with_paged_kv_cache_run"]
- fattention_decode_plan =
flashinfer_decode_mod["batch_decode_with_paged_kv_cache_plan"]
+ flashinfer_prefill_mod =
relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
+ dtype_q=dtype,
+ dtype_kv=dtype,
+ dtype_o=dtype,
+ qk_head_dim=head_dim,
+ v_head_dim=head_dim,
+ enable_inline_rope=rope_mode == RopeMode.INLINE,
+ )[0]
+ flashinfer_decode_mod =
relax.backend.cuda.flashinfer.gen_flashinfer_decode_module(
+ dtype_q=dtype,
+ dtype_kv=dtype,
+ dtype_o=dtype,
+ qk_head_dim=head_dim,
+ v_head_dim=head_dim,
+ enable_inline_rope=rope_mode == RopeMode.INLINE,
+ )[0]
+
+ fattention_prefill = flashinfer_prefill_mod["batch_prefill_paged_run"]
+ fattention_prefill_plan = flashinfer_prefill_mod["batch_prefill_plan"]
+ fattention_prefill_ragged =
flashinfer_prefill_mod["batch_prefill_ragged_run"]
+ fattention_prefill_ragged_plan =
flashinfer_prefill_mod["batch_prefill_plan"]
+ fattention_decode = flashinfer_decode_mod["batch_decode_run"]
+ fattention_decode_plan = flashinfer_decode_mod["batch_decode_plan"]
builts = []
for tir_func in [
@@ -560,8 +541,8 @@ def
test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
if __name__ == "__main__":
- set_global_func()
- for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
+ for rope_mode in [RopeMode.NONE, RopeMode.NORMAL]:
+ set_global_func(rope_mode)
cache = create_kv_cache(rope_mode)
test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode))
test_paged_attention_kv_cache_remove_sequence((cache, rope_mode))
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index e3de4944fe..cd76f9ce20 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -25,7 +25,6 @@ import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import relax
-from tvm.contrib import utils
from tvm.relax.frontend.nn.llm.kv_cache import (
AttnKind,
RopeMode,
@@ -115,47 +114,27 @@ def set_global_func(dtype):
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
fdebug_get_kv =
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv_mla")
- def load_module(name: str, static_modules: List[tvm.runtime.Module]):
- assert len(static_modules) > 0
- if len(static_modules) == 1:
- return static_modules[0]
- static_mod = static_modules[0]
- for mod in static_modules[1:]:
- static_mod.import_module(mod)
- temp = utils.tempdir()
- mod_path = temp.relpath(f"{name}.so")
- static_mod.export_library(mod_path)
- return tvm.runtime.load_module(mod_path)
-
target = tvm.target.Target.from_device(device)
- flashinfer_prefill_mod = load_module(
- "flashinfer_prefill",
- relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
- dtype_q=dtype,
- dtype_kv=dtype,
- dtype_o=dtype,
- qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
- v_head_dim=v_head_dim,
- target=target,
- enable_inline_rope=False,
- ),
- )
- flashinfer_mla_mod = load_module(
- "flashinfer_mla",
- relax.backend.cuda.flashinfer.gen_flashinfer_mla_module(
- dtype_q=dtype,
- dtype_kv=dtype,
- dtype_o=dtype,
- head_dim_ckv=kv_lora_rank,
- head_dim_kpe=qk_rope_head_dim,
- target=target,
- ),
- )
-
- fattn_prefill_ragged =
flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"]
- fattn_prefill_ragged_plan =
flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"]
- fmla_prefill = flashinfer_mla_mod["batch_mla_paged_attention_run"]
- fmla_prefill_plan = flashinfer_mla_mod["batch_mla_paged_attention_plan"]
+ flashinfer_prefill_mod =
relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
+ dtype_q=dtype,
+ dtype_kv=dtype,
+ dtype_o=dtype,
+ qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
+ v_head_dim=v_head_dim,
+ enable_inline_rope=False,
+ )[0]
+ flashinfer_mla_mod =
relax.backend.cuda.flashinfer.gen_flashinfer_mla_module(
+ dtype_q=dtype,
+ dtype_kv=dtype,
+ dtype_o=dtype,
+ head_dim_ckv=kv_lora_rank,
+ head_dim_kpe=qk_rope_head_dim,
+ )[0]
+
+ fattn_prefill_ragged = flashinfer_prefill_mod["batch_prefill_ragged_run"]
+ fattn_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_plan"]
+ fmla_prefill = flashinfer_mla_mod["batch_mla_run"]
+ fmla_prefill_plan = flashinfer_mla_mod["batch_mla_plan"]
builts = []
for tir_func in [
@@ -221,7 +200,13 @@ def create_kv_cache(dtype):
tvm.runtime.empty((), dtype, device=device),
None, # f_transpose_append_mha
ftranspose_append,
- ["flashinfer", fattn_prefill_ragged, fattn_prefill_ragged_plan], #
fattn_prefill_ragged
+ [
+ "flashinfer",
+ fattn_prefill_ragged,
+ fattn_prefill_ragged_plan,
+ qk_nope_head_dim + qk_rope_head_dim,
+ v_head_dim,
+ ], # fattn_prefill_ragged
[], # fattn_prefill
[], # fattn_decode
[], # fattn_prefill_sliding_window