This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c82c71933 [flashinfer] Support directing JIT to FlashInfer 
GroupedGemm kernels (#18325)
4c82c71933 is described below

commit 4c82c71933a5ed30e686a2d938b5963ef0715285
Author: Anrui(Henry) Liu <[email protected]>
AuthorDate: Sun Sep 21 23:37:44 2025 -0400

    [flashinfer] Support directing JIT to FlashInfer GroupedGemm kernels 
(#18325)
    
    in tvm/python/tvm/relax/backend/cuda/flashinfer.py added a
    `gen_grouped_gemm_module`
    in tvm/tests/python/relax/test_group_gemm_flashinfer.py added
    tests for different combinations of
    - input and output types: ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"),
    ("float8_e4m3fn", "float8_e4m3fn", "float16"),
    - scale granularity of m, n, k: (1, 128, 128),
    - scale major mode: "MN", "K"
    - mma_sm: 1, 2
    - different batch sizes and m_sizes
---
 python/tvm/relax/backend/cuda/flashinfer.py      |  96 ++++-
 tests/python/relax/test_group_gemm_flashinfer.py | 496 +++++++++++++++++++++++
 2 files changed, 591 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/backend/cuda/flashinfer.py 
b/python/tvm/relax/backend/cuda/flashinfer.py
index f1af2f3d15..4e0fc3e854 100644
--- a/python/tvm/relax/backend/cuda/flashinfer.py
+++ b/python/tvm/relax/backend/cuda/flashinfer.py
@@ -116,7 +116,7 @@ def _compile_flashinfer_kernels(
 
     # Determine compute version
     compute_version = 
"".join(tvm.contrib.nvcc.get_target_compute_version(target).split("."))
-    if compute_version in ["90"]:
+    if compute_version in ["90", "100"]:
         compute_version += "a"
     cuda_cflags += [
         "-gencode",
@@ -488,3 +488,97 @@ def gen_sampling_module(target: Target, num_threads: int = 
8):
     object_files = _compile_flashinfer_kernels(uri, source_paths, target, 
num_threads)
     modules = _load_flashinfer_modules(object_files)
     return modules
+
+
+def gen_grouped_gemm_module(
+    dtype_a: str,
+    dtype_b: str,
+    dtype_out: str,
+    scale_granularity_m: int,
+    scale_granularity_n: int,
+    scale_granularity_k: int,
+    scale_major_mode: str,
+    mma_sm: int,
+    target: Target,
+    num_threads: int = 8,
+) -> List[tvm.runtime.Module]:
+    """Generate a FlashInfer module for FP8 grouped GEMM.
+
+    Parameters
+    ----------
+    dtype_a : str
+        The data type of matrix A (e.g., "float8_e4m3fn").
+    dtype_b : str
+        The data type of matrix B (e.g., "float8_e4m3fn").
+    dtype_out : str
+        The data type of the output matrix (e.g., "bfloat16").
+    scale_granularity_m : int
+        The scaling granularity in the M dimension.
+    scale_granularity_n : int
+        The scaling granularity in the N dimension.
+    scale_granularity_k : int
+        The scaling granularity in the K dimension.
+    scale_major_mode : str
+        The scale storage mode ("K" or "MN").
+    mma_sm : int
+        The MMA scheduling mode (1 or 2).
+    target : Target
+        The target device to compile for.
+    num_threads : int
+        The number of threads to use for compilation.
+
+    Returns
+    -------
+    List[tvm.runtime.Module]
+        A list of compiled static library modules for FlashInfer FP8 grouped 
GEMM kernels.
+
+    Note
+    _____
+    when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k), 
m_indptr: (batch_size, )
+    requires all m in m_indptr to be multiple of 4
+    """
+    try:
+        from flashinfer.jit import (  # pylint: disable=import-outside-toplevel
+            gen_grouped_gemm_fp8_tvm_binding,
+            get_grouped_gemm_fp8_uri,
+        )
+    except ImportError:
+        raise ImportError(
+            "FlashInfer is not installed. Please follow instructions "
+            "in https://docs.flashinfer.ai to install FlashInfer."
+        )
+    try:
+        import torch  # pylint: disable=import-outside-toplevel
+    except ImportError:
+        raise ImportError("PyTorch is not installed. Please install PyTorch to 
use FlashInfer.")
+
+    torch_dtype_a = getattr(torch, dtype_a)
+    torch_dtype_b = getattr(torch, dtype_b)
+    torch_dtype_out = getattr(torch, dtype_out)
+
+    uri = get_grouped_gemm_fp8_uri(
+        dtype_a=torch_dtype_a,
+        dtype_b=torch_dtype_b,
+        dtype_out=torch_dtype_out,
+        scale_granularity_m=scale_granularity_m,
+        scale_granularity_n=scale_granularity_n,
+        scale_granularity_k=scale_granularity_k,
+        scale_major_mode=scale_major_mode,
+        mma_sm=mma_sm,
+    )
+
+    uri, source_paths = gen_grouped_gemm_fp8_tvm_binding(
+        uri=uri,
+        dtype_a=torch_dtype_a,
+        dtype_b=torch_dtype_b,
+        dtype_out=torch_dtype_out,
+        scale_granularity_m=scale_granularity_m,
+        scale_granularity_n=scale_granularity_n,
+        scale_granularity_k=scale_granularity_k,
+        scale_major_mode=scale_major_mode,
+        mma_sm=mma_sm,
+    )
+
+    object_files = _compile_flashinfer_kernels(uri, source_paths, target, 
num_threads)
+    modules = _load_flashinfer_modules(object_files)
+    return modules
diff --git a/tests/python/relax/test_group_gemm_flashinfer.py 
b/tests/python/relax/test_group_gemm_flashinfer.py
new file mode 100644
index 0000000000..8333e4b2d6
--- /dev/null
+++ b/tests/python/relax/test_group_gemm_flashinfer.py
@@ -0,0 +1,496 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test for FlashInfer GroupedGemm TVM integration"""
+
+import math
+import numpy as np
+import pytest
+import torch
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.contrib import utils
+from tvm.relax.backend.cuda import flashinfer
+
+DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024
+fp8_dtype = "float8_e4m3fn"
+
+
+###########################################
+################# Helpers #################
+###########################################
+def has_flashinfer():
+    """Check if FlashInfer is available"""
+    try:
+        from tvm.relax.backend.cuda import (  # pylint: 
disable=import-outside-toplevel
+            flashinfer,
+        )
+
+        return True
+    except ImportError:
+        return False
+
+
+def has_cutlass():
+    """Check if CUTLASS is available for SM90+ operations"""
+    if not tvm.get_global_func("device_api.cuda", True):
+        return False
+    try:
+        import pynvml  # pylint: disable=import-outside-toplevel
+
+        pynvml.nvmlInit()
+        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+        major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
+        return major >= 9  # SM90+
+    except:
+        return False
+
+
+def calc_diff(x: np.ndarray, y: np.ndarray):
+    denominator = (x * x + y * y).sum()
+    sim = 2 * (x * y).sum() / denominator
+    return 1 - sim
+
+
+def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode):
+    from einops import rearrange, reduce, repeat
+
+    """
+    Quantizes a 2D or 3D tensor to FP8.
+
+    Args:
+        x (torch.Tensor): The 2D or 3D input tensor.
+        scale_shape (tuple): The shape of the scale tensor.
+        tile_shape (tuple): The shape of the tiles.
+        scale_major_mode (str): The tiling order, "K" for row-major like,
+                                or another value for column-major like.
+
+    Returns:
+        tuple: A tuple containing the quantized FP8 tensor and the
+               calculated float32 scales.
+    """
+    # 1. Assertions and Initial Setup
+    ndim = x.ndim
+    assert ndim == len(scale_shape) == len(tile_shape)
+
+    fp8_info = torch.finfo(torch.float8_e4m3fn)
+    fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32)
+
+    # 2. Tiling and Scale Calculation
+    if ndim == 2:
+        s0, s1 = scale_shape
+        t0, t1 = tile_shape
+        if scale_major_mode == "K":
+            # Tile x and find the max absolute value in each tile
+            x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, 
s1=s1)
+            abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", 
"max").clamp(1e-4)
+            x_scale = abs_max / fp8_amax
+            x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+
+            # Broadcast scales back to the original tensor shape
+            scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", 
t0=t0, t1=t1)
+        else:
+            # Handle column-major tiling
+            x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, 
s1=s1)
+            abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", 
"max").clamp(1e-4)
+            x_scale = abs_max / fp8_amax
+            x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+
+            # Permute scale axes before repeating to match layout
+            scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0")
+            scales_repeated = repeat(scales_permuted, "s1 s0 -> (s1 t0) (s0 
t1)", t0=t0, t1=t1)
+
+    elif ndim == 3:
+        s0, s1, s2 = scale_shape
+        t0, t1, t2 = tile_shape
+        if scale_major_mode == "K":
+            # Tile x and find the max absolute value in each tile
+            x_tiled = rearrange(
+                x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, 
s1=s1, s2=s2
+            )
+            abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", 
"max").clamp(1e-4)
+            x_scale = abs_max / fp8_amax
+            x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+
+            # Broadcast scales back to the original tensor shape
+            scales_repeated = repeat(
+                x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, 
t2=t2
+            )
+        else:
+            # Handle layout where the last two axes are swapped
+            x_tiled = rearrange(
+                x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, 
s1=s1, s2=s2
+            )
+            abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", 
"max").clamp(1e-4)
+            x_scale = abs_max / fp8_amax
+            x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
+            # Permute scale axes before repeating to match layout
+            scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1")
+            scales_repeated = repeat(
+                scales_permuted,
+                "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)",
+                t0=t0,
+                t1=t1,
+                t2=t2,
+            )
+    # 3. Final Quantization
+    # Divide the original tensor by the broadcasted scales
+    x_fp32 = x / (scales_repeated + 1e-8)
+
+    # Convert the result to the target FP8 format
+    x_fp8 = x_fp32.to(torch.float8_e4m3fn)
+
+    return x_fp8, x_scale
+
+
+def dequantize_fp8(x, x_scale, scale_major_mode):
+    from einops import rearrange
+
+    """
+    Quantizes a 2D or 3D tensor to FP8.
+
+    Args:
+        x (torch.Tensor): The 2D or 3D input tensor.
+        scale_shape (tuple): The shape of the scale tensor.
+        tile_shape (tuple): The shape of the tiles.
+        scale_major_mode (str): The tiling order, "K" for row-major like,
+                                or another value for column-major like.
+
+    Returns:
+        tuple: A tuple containing the quantized FP8 tensor and the
+               calculated float32 scales.
+    """
+    # 1. Assertions and Initial Setup
+    ndim = x.ndim
+    assert ndim == len(x_scale.shape)
+
+    # 2. Tiling and Scale Calculation
+    if ndim == 2:
+        if scale_major_mode == "K":
+            s0, s1 = x_scale.shape
+        else:
+            s1, s0 = x_scale.shape
+        x = rearrange(x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", 
s0=s0, s1=s1)
+        if scale_major_mode == "K":
+            x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1")
+        else:
+            x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1")
+        out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)")
+    elif ndim == 3:
+        if scale_major_mode == "K":
+            s0, s1, s2 = x_scale.shape
+        else:
+            s0, s2, s1 = x_scale.shape
+        x = rearrange(
+            x.to(torch.float32),
+            "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2",
+            s0=s0,
+            s1=s1,
+            s2=s2,
+        )
+        if scale_major_mode == "K":
+            x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1")
+        else:
+            x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1")
+        out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 
t2)")
+
+    return out
+
+
+###########################################
+########### Refernce generation ###########
+###########################################
+def compute_reference_grouped_gemm(
+    a_fp32: torch.Tensor,  # (total_m, k)
+    b_fp32: torch.Tensor,  # (batch_size, n, k)
+    m_indptr: torch.Tensor,
+    dtype_out: str,  # (total_m, n)
+):
+    """Compute reference result using PyTorch operations"""
+    """Compute reference result using original FP32 tensors"""
+
+    total_m, k = a_fp32.shape
+    batch_size, n, k2 = b_fp32.shape
+    assert k == k2
+
+    # Perform grouped GEMM computation directly on original FP32 data
+    results = []
+
+    for i in range(batch_size):
+        start_m = m_indptr[i].item()
+        end_m = m_indptr[i + 1].item()
+
+        # Extract group's portion of A
+        a_group = a_fp32[start_m:end_m, :]  # [m_sizes[i], k]
+        b_group = b_fp32[i]
+
+        # Multiply with shared B matrix
+        result_group = torch.mm(a_group, b_group.T)  # [m_sizes[i], n]
+        results.append(result_group)
+
+    result_fp32 = torch.cat(results, dim=0)
+
+    # Convert to output dtype
+    if dtype_out == "bfloat16":
+        result = result_fp32.to(torch.bfloat16)
+    elif dtype_out == "float16":
+        result = result_fp32.to(torch.float16)
+    else:
+        result = result_fp32
+
+    return result
+
+
+###########################################
+########### Test data generation ##########
+###########################################
+def generate_test_data(
+    m_sizes: list,
+    batch_size: int,
+    n: int,
+    k: int,
+    dtype_a: str,
+    dtype_b: str,
+    dtype_out: str,
+    scale_granularity_m: int,
+    scale_granularity_n: int,
+    scale_granularity_k: int,
+    scale_major_mode: str,
+    device: tvm.runtime.Device,
+):
+    """Generate test data for grouped GEMM operations"""
+    assert batch_size == len(
+        m_sizes
+    ), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})"
+
+    # print(f"Device object: {device}")
+    torch_device = torch.device(f"cuda:{device.index}")
+
+    cum_m = [0] + list(np.cumsum(m_sizes))
+    total_m = cum_m[-1]
+
+    # Generate input matrices A and B (where we assert of form fp8) random 
data in fp32 first, then convert
+    assert dtype_a == "float8_e4m3fn"
+    a_fp32 = torch.randn(total_m, k, device=torch_device, dtype=torch.float32)
+
+    assert dtype_b == "float8_e4m3fn"
+    b_fp32 = torch.randn(batch_size, n, k, device=torch_device, 
dtype=torch.float32) / math.sqrt(k)
+
+    if scale_major_mode == "K":  # K mode:
+        scale_a_shape = (total_m // scale_granularity_m, k // 
scale_granularity_k)
+        scale_b_shape = (batch_size, n // scale_granularity_n, k // 
scale_granularity_k)
+
+    else:  # MN mode
+        scale_a_shape = (k // scale_granularity_k, total_m // 
scale_granularity_m)
+        scale_b_shape = (batch_size, k // scale_granularity_k, n // 
scale_granularity_n)
+
+    tile_a_shape = (scale_granularity_m, scale_granularity_k)
+    tile_b_shape = (1, scale_granularity_n, scale_granularity_k)
+
+    # quantize A, B
+    a_quantized, scale_a = quantize_fp8(a_fp32, scale_a_shape, tile_a_shape, 
scale_major_mode)
+    b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape, 
scale_major_mode)
+
+    if dtype_a == "float8_e4m3fn":
+        a_tvm = tvm.runtime.tensor(
+            a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), 
device=device
+        )
+    else:
+        a_tvm = tvm.runtime.from_dlpack(a_quantized)
+
+    if dtype_b == "float8_e4m3fn":
+        b_tvm = tvm.runtime.tensor(
+            b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), 
device=device
+        )
+    else:
+        b_tvm = tvm.runtime.from_dlpack(b_quantized)
+
+    scale_a_tvm = tvm.runtime.from_dlpack(scale_a)
+    scale_b_tvm = tvm.runtime.from_dlpack(scale_b)
+
+    # Create m_indptr for grouped operation
+    m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32)
+    m_indptr_tvm = tvm.runtime.tensor(m_indptr.cpu().numpy(), device)
+
+    return {
+        "a": a_tvm,
+        "b": b_tvm,
+        "torch_a": a_fp32,
+        "torch_b": b_fp32,
+        "scale_a": scale_a_tvm,
+        "scale_b": scale_b_tvm,
+        "m_indptr": m_indptr_tvm,
+        "m_sizes": m_sizes,
+        "n": n,
+        "k": k,
+        "total_m": total_m,
+        "torch_scale_a": scale_a,
+        "torch_scale_b": scale_b,
+        "torch_m_indptr": m_indptr,
+    }
+
+
+###########################################
+############### Test driver ###############
+###########################################
[email protected](not has_flashinfer(), reason="FlashInfer not available")
[email protected](not has_cutlass(), reason="CUTLASS SM90+ not available")
[email protected](
+    "dtype_a,dtype_b,dtype_out",
+    [
+        ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"),
+        ("float8_e4m3fn", "float8_e4m3fn", "float16"),
+    ],
+)
[email protected](
+    "scale_granularity_m,scale_granularity_n,scale_granularity_k",
+    [
+        (1, 128, 128),  # Row-wise A, block-wise B
+    ],
+)
[email protected]("scale_major_mode", ["K", "MN"])
[email protected]("mma_sm", [1, 2])
[email protected](
+    "test_case",
+    [
+        {"batch_size": 4, "m_sizes": [128, 256, 192, 320], "n": 512, "k": 
1024},
+        {"batch_size": 2, "m_sizes": [64, 128], "n": 256, "k": 512},
+        {"batch_size": 3, "m_sizes": [256, 256, 128], "n": 768, "k": 768},
+        {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768},
+    ],
+)
+def test_grouped_gemm_correctness(
+    dtype_a,
+    dtype_b,
+    dtype_out,
+    scale_granularity_m,
+    scale_granularity_n,
+    scale_granularity_k,
+    scale_major_mode,
+    mma_sm,
+    test_case,
+):
+    """Test correctness of GroupedGemm operations"""
+    device = tvm.cuda(0)
+    target = tvm.target.Target.from_device(device)
+
+    def _load_module(name: str, static_modules):
+        """Helper function to load compiled modules."""
+        assert len(static_modules) > 0
+        if len(static_modules) == 1:
+            return static_modules[0]
+        static_mod = static_modules[0]
+        for mod in static_modules[1:]:
+            static_mod.import_module(mod)
+        temp = tvm.contrib.utils.tempdir()
+        mod_path = temp.relpath(f"{name}.so")
+        static_mod.export_library(mod_path)
+        return tvm.runtime.load_module(mod_path)
+
+    # Generate the module
+    modules = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(
+        dtype_a=dtype_a,
+        dtype_b=dtype_b,
+        dtype_out=dtype_out,
+        scale_granularity_m=scale_granularity_m,
+        scale_granularity_n=scale_granularity_n,
+        scale_granularity_k=scale_granularity_k,
+        scale_major_mode=scale_major_mode,
+        mma_sm=mma_sm,
+        target=target,
+        num_threads=4,
+    )
+
+    # Load the module
+    mod = _load_module("flashinfer_grouped_gemm", modules)
+    grouped_gemm_fn = mod["grouped_gemm_fp8_run"]
+
+    # Generate test data
+    test_data = generate_test_data(
+        batch_size=test_case["batch_size"],
+        m_sizes=test_case["m_sizes"],
+        n=test_case["n"],
+        k=test_case["k"],
+        dtype_a=dtype_a,
+        dtype_b=dtype_b,
+        dtype_out=dtype_out,
+        scale_granularity_m=scale_granularity_m,
+        scale_granularity_n=scale_granularity_n,
+        scale_granularity_k=scale_granularity_k,
+        scale_major_mode=scale_major_mode,
+        device=device,
+    )
+
+    # Prepare output buffer
+    output_shape = (test_data["total_m"], test_data["n"])
+    if dtype_out == "bfloat16":
+        output = tvm.runtime.empty(output_shape, dtype="bfloat16", 
device=device)
+    elif dtype_out == "float16":
+        output = tvm.runtime.empty(output_shape, dtype="float16", 
device=device)
+    else:
+        output = tvm.runtime.empty(output_shape, dtype="float32", 
device=device)
+
+    # Create workspace buffers (required by the interface)
+    int_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), 
dtype="int32", device=device)
+    float_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), 
dtype="float32", device=device)
+
+    grouped_gemm_fn(
+        int_workspace,  # int_workspace_buffer
+        float_workspace,  # float_workspace_buffer
+        test_data["a"],  # A
+        test_data["b"],  # B
+        test_data["scale_a"],  # SFA
+        test_data["scale_b"],  # SFB
+        output,  # D
+        test_data["m_indptr"],  # m_indptr
+        test_data["n"],  # n (scalar)
+        test_data["k"],  # k (scalar)
+        None,  # cuda_stream (use default stream)
+    )
+
+    # Compute reference result
+    reference = compute_reference_grouped_gemm(
+        test_data["torch_a"],
+        test_data["torch_b"],
+        test_data["torch_m_indptr"],
+        dtype_out,
+    )
+
+    # Convert TVM output to PyTorch for comparison
+    output_torch = torch.as_tensor(output, device=test_data["torch_a"].device)
+    output_torch
+
+    # Compare results with appropriate tolerance
+    if dtype_out == "bfloat16":
+        rtol, atol = 1e-2, 1e-2
+    elif dtype_out == "float16":
+        rtol, atol = 1e-3, 1e-3
+    else:
+        rtol, atol = 1e-4, 1e-4
+
+    # Check shapes match
+    assert (
+        output_torch.shape == reference.shape
+    ), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}"
+
+    diff = calc_diff(output_torch.cpu().double().numpy(), 
reference.cpu().double().numpy())
+    assert diff < 1e-3, f"diff too large {diff}"
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to