This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4c82c71933 [flashinfer] Support directing JIT to FlashInfer
GroupedGemm kernels (#18325)
4c82c71933 is described below
commit 4c82c71933a5ed30e686a2d938b5963ef0715285
Author: Anrui(Henry) Liu <[email protected]>
AuthorDate: Sun Sep 21 23:37:44 2025 -0400
[flashinfer] Support directing JIT to FlashInfer GroupedGemm kernels
(#18325)
in tvm/python/tvm/relax/backend/cuda/flashinfer.py added a
`gen_grouped_gemm_module`
in tvm/tests/python/relax/test_group_gemm_flashinfer.py added
tests for different combinations of
- input and output types: ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"),
("float8_e4m3fn", "float8_e4m3fn", "float16"),
- scale granularity of m, n, k: (1, 128, 128),
- scale major mode: "MN", "K"
- mma_sm: 1, 2
- different batch sizes and m_sizes
---
python/tvm/relax/backend/cuda/flashinfer.py | 96 ++++-
tests/python/relax/test_group_gemm_flashinfer.py | 496 +++++++++++++++++++++++
2 files changed, 591 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/backend/cuda/flashinfer.py
b/python/tvm/relax/backend/cuda/flashinfer.py
index f1af2f3d15..4e0fc3e854 100644
--- a/python/tvm/relax/backend/cuda/flashinfer.py
+++ b/python/tvm/relax/backend/cuda/flashinfer.py
@@ -116,7 +116,7 @@ def _compile_flashinfer_kernels(
# Determine compute version
compute_version =
"".join(tvm.contrib.nvcc.get_target_compute_version(target).split("."))
- if compute_version in ["90"]:
+ if compute_version in ["90", "100"]:
compute_version += "a"
cuda_cflags += [
"-gencode",
@@ -488,3 +488,97 @@ def gen_sampling_module(target: Target, num_threads: int =
8):
object_files = _compile_flashinfer_kernels(uri, source_paths, target,
num_threads)
modules = _load_flashinfer_modules(object_files)
return modules
+
+
+def gen_grouped_gemm_module(
+ dtype_a: str,
+ dtype_b: str,
+ dtype_out: str,
+ scale_granularity_m: int,
+ scale_granularity_n: int,
+ scale_granularity_k: int,
+ scale_major_mode: str,
+ mma_sm: int,
+ target: Target,
+ num_threads: int = 8,
+) -> List[tvm.runtime.Module]:
+ """Generate a FlashInfer module for FP8 grouped GEMM.
+
+ Parameters
+ ----------
+ dtype_a : str
+ The data type of matrix A (e.g., "float8_e4m3fn").
+ dtype_b : str
+ The data type of matrix B (e.g., "float8_e4m3fn").
+ dtype_out : str
+ The data type of the output matrix (e.g., "bfloat16").
+ scale_granularity_m : int
+ The scaling granularity in the M dimension.
+ scale_granularity_n : int
+ The scaling granularity in the N dimension.
+ scale_granularity_k : int
+ The scaling granularity in the K dimension.
+ scale_major_mode : str
+ The scale storage mode ("K" or "MN").
+ mma_sm : int
+ The MMA scheduling mode (1 or 2).
+ target : Target
+ The target device to compile for.
+ num_threads : int
+ The number of threads to use for compilation.
+
+ Returns
+ -------
+ List[tvm.runtime.Module]
+ A list of compiled static library modules for FlashInfer FP8 grouped
GEMM kernels.
+
+ Note
+ _____
+ when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k),
m_indptr: (batch_size, )
+ requires all m in m_indptr to be multiple of 4
+ """
+ try:
+ from flashinfer.jit import ( # pylint: disable=import-outside-toplevel
+ gen_grouped_gemm_fp8_tvm_binding,
+ get_grouped_gemm_fp8_uri,
+ )
+ except ImportError:
+ raise ImportError(
+ "FlashInfer is not installed. Please follow instructions "
+ "in https://docs.flashinfer.ai to install FlashInfer."
+ )
+ try:
+ import torch # pylint: disable=import-outside-toplevel
+ except ImportError:
+ raise ImportError("PyTorch is not installed. Please install PyTorch to
use FlashInfer.")
+
+ torch_dtype_a = getattr(torch, dtype_a)
+ torch_dtype_b = getattr(torch, dtype_b)
+ torch_dtype_out = getattr(torch, dtype_out)
+
+ uri = get_grouped_gemm_fp8_uri(
+ dtype_a=torch_dtype_a,
+ dtype_b=torch_dtype_b,
+ dtype_out=torch_dtype_out,
+ scale_granularity_m=scale_granularity_m,
+ scale_granularity_n=scale_granularity_n,
+ scale_granularity_k=scale_granularity_k,
+ scale_major_mode=scale_major_mode,
+ mma_sm=mma_sm,
+ )
+
+ uri, source_paths = gen_grouped_gemm_fp8_tvm_binding(
+ uri=uri,
+ dtype_a=torch_dtype_a,
+ dtype_b=torch_dtype_b,
+ dtype_out=torch_dtype_out,
+ scale_granularity_m=scale_granularity_m,
+ scale_granularity_n=scale_granularity_n,
+ scale_granularity_k=scale_granularity_k,
+ scale_major_mode=scale_major_mode,
+ mma_sm=mma_sm,
+ )
+
+ object_files = _compile_flashinfer_kernels(uri, source_paths, target,
num_threads)
+ modules = _load_flashinfer_modules(object_files)
+ return modules
diff --git a/tests/python/relax/test_group_gemm_flashinfer.py
b/tests/python/relax/test_group_gemm_flashinfer.py
new file mode 100644
index 0000000000..8333e4b2d6
--- /dev/null
+++ b/tests/python/relax/test_group_gemm_flashinfer.py
@@ -0,0 +1,496 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test for FlashInfer GroupedGemm TVM integration"""
+
+import math
+import numpy as np
+import pytest
+import torch
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.contrib import utils
+from tvm.relax.backend.cuda import flashinfer
+
+DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024
+fp8_dtype = "float8_e4m3fn"
+
+
+###########################################
+################# Helpers #################
+###########################################
+def has_flashinfer():
+ """Check if FlashInfer is available"""
+ try:
+ from tvm.relax.backend.cuda import ( # pylint:
disable=import-outside-toplevel
+ flashinfer,
+ )
+
+ return True
+ except ImportError:
+ return False
+
+
+def has_cutlass():
+ """Check if CUTLASS is available for SM90+ operations"""
+ if not tvm.get_global_func("device_api.cuda", True):
+ return False
+ try:
+ import pynvml # pylint: disable=import-outside-toplevel
+
+ pynvml.nvmlInit()
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+ major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
+ return major >= 9 # SM90+
+ except:
+ return False
+
+
+def calc_diff(x: np.ndarray, y: np.ndarray):
+ denominator = (x * x + y * y).sum()
+ sim = 2 * (x * y).sum() / denominator
+ return 1 - sim
+
+
+def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode):
+ from einops import rearrange, reduce, repeat
+
+ """
+ Quantizes a 2D or 3D tensor to FP8.
+
+ Args:
+ x (torch.Tensor): The 2D or 3D input tensor.
+ scale_shape (tuple): The shape of the scale tensor.
+ tile_shape (tuple): The shape of the tiles.
+ scale_major_mode (str): The tiling order, "K" for row-major like,
+ or another value for column-major like.
+
+ Returns:
+ tuple: A tuple containing the quantized FP8 tensor and the
+ calculated float32 scales.
+ """
+ # 1. Assertions and Initial Setup
+ ndim = x.ndim
+ assert ndim == len(scale_shape) == len(tile_shape)
+
+ fp8_info = torch.finfo(torch.float8_e4m3fn)
+ fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32)
+
+ # 2. Tiling and Scale Calculation
+ if ndim == 2:
+ s0, s1 = scale_shape
+ t0, t1 = tile_shape
+ if scale_major_mode == "K":
+ # Tile x and find the max absolute value in each tile
+ x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0,
s1=s1)
+ abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1",
"max").clamp(1e-4)
+ x_scale = abs_max / fp8_amax
+ x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+
+ # Broadcast scales back to the original tensor shape
+ scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)",
t0=t0, t1=t1)
+ else:
+ # Handle column-major tiling
+ x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0,
s1=s1)
+ abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1",
"max").clamp(1e-4)
+ x_scale = abs_max / fp8_amax
+ x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+
+ # Permute scale axes before repeating to match layout
+ scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0")
+ scales_repeated = repeat(scales_permuted, "s1 s0 -> (s1 t0) (s0
t1)", t0=t0, t1=t1)
+
+ elif ndim == 3:
+ s0, s1, s2 = scale_shape
+ t0, t1, t2 = tile_shape
+ if scale_major_mode == "K":
+ # Tile x and find the max absolute value in each tile
+ x_tiled = rearrange(
+ x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0,
s1=s1, s2=s2
+ )
+ abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2",
"max").clamp(1e-4)
+ x_scale = abs_max / fp8_amax
+ x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+
+ # Broadcast scales back to the original tensor shape
+ scales_repeated = repeat(
+ x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1,
t2=t2
+ )
+ else:
+ # Handle layout where the last two axes are swapped
+ x_tiled = rearrange(
+ x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0,
s1=s1, s2=s2
+ )
+ abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2",
"max").clamp(1e-4)
+ x_scale = abs_max / fp8_amax
+ x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+ # Permute scale axes before repeating to match layout
+ scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1")
+ scales_repeated = repeat(
+ scales_permuted,
+ "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)",
+ t0=t0,
+ t1=t1,
+ t2=t2,
+ )
+ # 3. Final Quantization
+ # Divide the original tensor by the broadcasted scales
+ x_fp32 = x / (scales_repeated + 1e-8)
+
+ # Convert the result to the target FP8 format
+ x_fp8 = x_fp32.to(torch.float8_e4m3fn)
+
+ return x_fp8, x_scale
+
+
+def dequantize_fp8(x, x_scale, scale_major_mode):
+ from einops import rearrange
+
+ """
+ Quantizes a 2D or 3D tensor to FP8.
+
+ Args:
+ x (torch.Tensor): The 2D or 3D input tensor.
+ scale_shape (tuple): The shape of the scale tensor.
+ tile_shape (tuple): The shape of the tiles.
+ scale_major_mode (str): The tiling order, "K" for row-major like,
+ or another value for column-major like.
+
+ Returns:
+ tuple: A tuple containing the quantized FP8 tensor and the
+ calculated float32 scales.
+ """
+ # 1. Assertions and Initial Setup
+ ndim = x.ndim
+ assert ndim == len(x_scale.shape)
+
+ # 2. Tiling and Scale Calculation
+ if ndim == 2:
+ if scale_major_mode == "K":
+ s0, s1 = x_scale.shape
+ else:
+ s1, s0 = x_scale.shape
+ x = rearrange(x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1",
s0=s0, s1=s1)
+ if scale_major_mode == "K":
+ x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1")
+ else:
+ x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1")
+ out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)")
+ elif ndim == 3:
+ if scale_major_mode == "K":
+ s0, s1, s2 = x_scale.shape
+ else:
+ s0, s2, s1 = x_scale.shape
+ x = rearrange(
+ x.to(torch.float32),
+ "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2",
+ s0=s0,
+ s1=s1,
+ s2=s2,
+ )
+ if scale_major_mode == "K":
+ x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1")
+ else:
+ x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1")
+ out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2
t2)")
+
+ return out
+
+
+###########################################
+########### Refernce generation ###########
+###########################################
+def compute_reference_grouped_gemm(
+ a_fp32: torch.Tensor, # (total_m, k)
+ b_fp32: torch.Tensor, # (batch_size, n, k)
+ m_indptr: torch.Tensor,
+ dtype_out: str, # (total_m, n)
+):
+ """Compute reference result using PyTorch operations"""
+ """Compute reference result using original FP32 tensors"""
+
+ total_m, k = a_fp32.shape
+ batch_size, n, k2 = b_fp32.shape
+ assert k == k2
+
+ # Perform grouped GEMM computation directly on original FP32 data
+ results = []
+
+ for i in range(batch_size):
+ start_m = m_indptr[i].item()
+ end_m = m_indptr[i + 1].item()
+
+ # Extract group's portion of A
+ a_group = a_fp32[start_m:end_m, :] # [m_sizes[i], k]
+ b_group = b_fp32[i]
+
+ # Multiply with shared B matrix
+ result_group = torch.mm(a_group, b_group.T) # [m_sizes[i], n]
+ results.append(result_group)
+
+ result_fp32 = torch.cat(results, dim=0)
+
+ # Convert to output dtype
+ if dtype_out == "bfloat16":
+ result = result_fp32.to(torch.bfloat16)
+ elif dtype_out == "float16":
+ result = result_fp32.to(torch.float16)
+ else:
+ result = result_fp32
+
+ return result
+
+
+###########################################
+########### Test data generation ##########
+###########################################
+def generate_test_data(
+ m_sizes: list,
+ batch_size: int,
+ n: int,
+ k: int,
+ dtype_a: str,
+ dtype_b: str,
+ dtype_out: str,
+ scale_granularity_m: int,
+ scale_granularity_n: int,
+ scale_granularity_k: int,
+ scale_major_mode: str,
+ device: tvm.runtime.Device,
+):
+ """Generate test data for grouped GEMM operations"""
+ assert batch_size == len(
+ m_sizes
+ ), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})"
+
+ # print(f"Device object: {device}")
+ torch_device = torch.device(f"cuda:{device.index}")
+
+ cum_m = [0] + list(np.cumsum(m_sizes))
+ total_m = cum_m[-1]
+
+ # Generate input matrices A and B (where we assert of form fp8) random
data in fp32 first, then convert
+ assert dtype_a == "float8_e4m3fn"
+ a_fp32 = torch.randn(total_m, k, device=torch_device, dtype=torch.float32)
+
+ assert dtype_b == "float8_e4m3fn"
+ b_fp32 = torch.randn(batch_size, n, k, device=torch_device,
dtype=torch.float32) / math.sqrt(k)
+
+ if scale_major_mode == "K": # K mode:
+ scale_a_shape = (total_m // scale_granularity_m, k //
scale_granularity_k)
+ scale_b_shape = (batch_size, n // scale_granularity_n, k //
scale_granularity_k)
+
+ else: # MN mode
+ scale_a_shape = (k // scale_granularity_k, total_m //
scale_granularity_m)
+ scale_b_shape = (batch_size, k // scale_granularity_k, n //
scale_granularity_n)
+
+ tile_a_shape = (scale_granularity_m, scale_granularity_k)
+ tile_b_shape = (1, scale_granularity_n, scale_granularity_k)
+
+ # quantize A, B
+ a_quantized, scale_a = quantize_fp8(a_fp32, scale_a_shape, tile_a_shape,
scale_major_mode)
+ b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape,
scale_major_mode)
+
+ if dtype_a == "float8_e4m3fn":
+ a_tvm = tvm.runtime.tensor(
+ a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype),
device=device
+ )
+ else:
+ a_tvm = tvm.runtime.from_dlpack(a_quantized)
+
+ if dtype_b == "float8_e4m3fn":
+ b_tvm = tvm.runtime.tensor(
+ b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype),
device=device
+ )
+ else:
+ b_tvm = tvm.runtime.from_dlpack(b_quantized)
+
+ scale_a_tvm = tvm.runtime.from_dlpack(scale_a)
+ scale_b_tvm = tvm.runtime.from_dlpack(scale_b)
+
+ # Create m_indptr for grouped operation
+ m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32)
+ m_indptr_tvm = tvm.runtime.tensor(m_indptr.cpu().numpy(), device)
+
+ return {
+ "a": a_tvm,
+ "b": b_tvm,
+ "torch_a": a_fp32,
+ "torch_b": b_fp32,
+ "scale_a": scale_a_tvm,
+ "scale_b": scale_b_tvm,
+ "m_indptr": m_indptr_tvm,
+ "m_sizes": m_sizes,
+ "n": n,
+ "k": k,
+ "total_m": total_m,
+ "torch_scale_a": scale_a,
+ "torch_scale_b": scale_b,
+ "torch_m_indptr": m_indptr,
+ }
+
+
+###########################################
+############### Test driver ###############
+###########################################
[email protected](not has_flashinfer(), reason="FlashInfer not available")
[email protected](not has_cutlass(), reason="CUTLASS SM90+ not available")
[email protected](
+ "dtype_a,dtype_b,dtype_out",
+ [
+ ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"),
+ ("float8_e4m3fn", "float8_e4m3fn", "float16"),
+ ],
+)
[email protected](
+ "scale_granularity_m,scale_granularity_n,scale_granularity_k",
+ [
+ (1, 128, 128), # Row-wise A, block-wise B
+ ],
+)
[email protected]("scale_major_mode", ["K", "MN"])
[email protected]("mma_sm", [1, 2])
[email protected](
+ "test_case",
+ [
+ {"batch_size": 4, "m_sizes": [128, 256, 192, 320], "n": 512, "k":
1024},
+ {"batch_size": 2, "m_sizes": [64, 128], "n": 256, "k": 512},
+ {"batch_size": 3, "m_sizes": [256, 256, 128], "n": 768, "k": 768},
+ {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768},
+ ],
+)
+def test_grouped_gemm_correctness(
+ dtype_a,
+ dtype_b,
+ dtype_out,
+ scale_granularity_m,
+ scale_granularity_n,
+ scale_granularity_k,
+ scale_major_mode,
+ mma_sm,
+ test_case,
+):
+ """Test correctness of GroupedGemm operations"""
+ device = tvm.cuda(0)
+ target = tvm.target.Target.from_device(device)
+
+ def _load_module(name: str, static_modules):
+ """Helper function to load compiled modules."""
+ assert len(static_modules) > 0
+ if len(static_modules) == 1:
+ return static_modules[0]
+ static_mod = static_modules[0]
+ for mod in static_modules[1:]:
+ static_mod.import_module(mod)
+ temp = tvm.contrib.utils.tempdir()
+ mod_path = temp.relpath(f"{name}.so")
+ static_mod.export_library(mod_path)
+ return tvm.runtime.load_module(mod_path)
+
+ # Generate the module
+ modules = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(
+ dtype_a=dtype_a,
+ dtype_b=dtype_b,
+ dtype_out=dtype_out,
+ scale_granularity_m=scale_granularity_m,
+ scale_granularity_n=scale_granularity_n,
+ scale_granularity_k=scale_granularity_k,
+ scale_major_mode=scale_major_mode,
+ mma_sm=mma_sm,
+ target=target,
+ num_threads=4,
+ )
+
+ # Load the module
+ mod = _load_module("flashinfer_grouped_gemm", modules)
+ grouped_gemm_fn = mod["grouped_gemm_fp8_run"]
+
+ # Generate test data
+ test_data = generate_test_data(
+ batch_size=test_case["batch_size"],
+ m_sizes=test_case["m_sizes"],
+ n=test_case["n"],
+ k=test_case["k"],
+ dtype_a=dtype_a,
+ dtype_b=dtype_b,
+ dtype_out=dtype_out,
+ scale_granularity_m=scale_granularity_m,
+ scale_granularity_n=scale_granularity_n,
+ scale_granularity_k=scale_granularity_k,
+ scale_major_mode=scale_major_mode,
+ device=device,
+ )
+
+ # Prepare output buffer
+ output_shape = (test_data["total_m"], test_data["n"])
+ if dtype_out == "bfloat16":
+ output = tvm.runtime.empty(output_shape, dtype="bfloat16",
device=device)
+ elif dtype_out == "float16":
+ output = tvm.runtime.empty(output_shape, dtype="float16",
device=device)
+ else:
+ output = tvm.runtime.empty(output_shape, dtype="float32",
device=device)
+
+ # Create workspace buffers (required by the interface)
+ int_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,),
dtype="int32", device=device)
+ float_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,),
dtype="float32", device=device)
+
+ grouped_gemm_fn(
+ int_workspace, # int_workspace_buffer
+ float_workspace, # float_workspace_buffer
+ test_data["a"], # A
+ test_data["b"], # B
+ test_data["scale_a"], # SFA
+ test_data["scale_b"], # SFB
+ output, # D
+ test_data["m_indptr"], # m_indptr
+ test_data["n"], # n (scalar)
+ test_data["k"], # k (scalar)
+ None, # cuda_stream (use default stream)
+ )
+
+ # Compute reference result
+ reference = compute_reference_grouped_gemm(
+ test_data["torch_a"],
+ test_data["torch_b"],
+ test_data["torch_m_indptr"],
+ dtype_out,
+ )
+
+ # Convert TVM output to PyTorch for comparison
+ output_torch = torch.as_tensor(output, device=test_data["torch_a"].device)
+ output_torch
+
+ # Compare results with appropriate tolerance
+ if dtype_out == "bfloat16":
+ rtol, atol = 1e-2, 1e-2
+ elif dtype_out == "float16":
+ rtol, atol = 1e-3, 1e-3
+ else:
+ rtol, atol = 1e-4, 1e-4
+
+ # Check shapes match
+ assert (
+ output_torch.shape == reference.shape
+ ), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}"
+
+ diff = calc_diff(output_torch.cpu().double().numpy(),
reference.cpu().double().numpy())
+ assert diff < 1e-3, f"diff too large {diff}"
+
+
+if __name__ == "__main__":
+ tvm.testing.main()