This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f97159504c [Relax] Operator and RoPE support for Llama4 (#18336)
f97159504c is described below

commit f97159504cef41513f77ee8e2cb8636365e4fb52
Author: Pranav Venkatram <[email protected]>
AuthorDate: Tue Sep 23 13:08:38 2025 -0400

    [Relax] Operator and RoPE support for Llama4 (#18336)
    
    Added LLama4 implementation, new rope implementation
---
 python/tvm/relax/expr.py                           |   4 +
 .../relax/frontend/nn/llm/position_embedding.py    | 234 +++++++++++++++++++++
 python/tvm/relax/frontend/nn/op.py                 |  86 ++++++++
 tests/python/relax/test_frontend_nn_op.py          |  12 +-
 4 files changed, 335 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 1a7a5c224a..8dd4eff5c7 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -22,6 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, 
Union, Mapping
 import numpy as _np  # type: ignore
 
 import tvm_ffi
+
 import tvm.ir
 import tvm.relax
 from tvm import DataType
@@ -1153,6 +1154,9 @@ def const(
     - bool maps to "bool"
     - other using the same default rule as numpy.
     """
+    # Needed for bf16 and fp8 support (does not come with numpy)
+    import ml_dtypes  # pylint: disable=unused-import,import-outside-toplevel
+
     if isinstance(value, (Number, (bool, list))):
         value = _np.array(value, dtype=dtype)
 
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py 
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 1a1659b29e..6fda4b0bca 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, 
theta: float, dtype: st
     return cos_freq, sin_freq, {freq_var: freq}
 
 
+def rope_freq_llama4(  # pylint: disable=too-many-arguments,too-many-locals
+    s: tir.Var,
+    d: tir.Var,
+    d_range: int,
+    theta: float,
+    dtype: str,
+    factor: float,
+    low_freq_factor: float,
+    high_freq_factor: float,
+    original_max_position_embeddings: float,
+):
+    """Compute the inverse frequency of RoPE for llama4 RoPE scaling."""
+    orig_freq = tir.const(1, "float32") / tir.power(
+        theta, 2 * (d // 2) / tir.const(d_range, "float32")
+    )
+    orig_freq_var = tir.Var("orig_freq", "float32")
+
+    llama4_inv_scaling_factor = 1.0 / factor
+
+    if high_freq_factor == low_freq_factor:
+        wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
+        threshold_wavelen = tir.const(original_max_position_embeddings / 
low_freq_factor, "float32")
+
+        scaled_freq = tir.if_then_else(
+            wavelength > threshold_wavelen, orig_freq_var / factor, 
orig_freq_var
+        )
+        smoothed_freq = s * scaled_freq
+
+    else:
+        # Original smooth interpolation logic
+        inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)
+
+        llama4_alpha = original_max_position_embeddings / (2 * math.pi) * 
inv_diff_freq_factor
+        llama4_beta = low_freq_factor * inv_diff_freq_factor
+        smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var - 
llama4_beta))
+        smoothed_freq = s * (
+            (1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor + 
smooth * orig_freq_var
+        )
+
+    smoothed_freq_var = tir.Var("smoothed_freq", "float32")
+    cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
+    sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
+    return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, 
orig_freq_var: orig_freq}
+
+
 def rope_freq_llama3(  # pylint: disable=too-many-arguments,too-many-locals
     s: tir.Var,
     d: tir.Var,
@@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> 
Callable:
             high_freq_factor=rope_scaling["high_freq_factor"],
             
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
         )
+    if rope_scaling["rope_type"] == "llama4":
+        return partial(
+            rope_freq_llama4,
+            factor=rope_scaling["factor"],
+            low_freq_factor=rope_scaling["low_freq_factor"],
+            high_freq_factor=rope_scaling["high_freq_factor"],
+            
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
+        )
     if rope_scaling["rope_type"] == "longrope":
         return partial(
             rope_freq_longrope,
@@ -545,3 +598,184 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
     if is_longrope_scaling:
         return fused_rope_longrope_scaling
     return fused_rope
+
+
+def llama4_rope_with_position_map(  # pylint: disable=too-many-arguments
+    theta: float,
+    scale: float,
+    head_dim: int,
+    num_q_heads: int,
+    num_kv_heads: int,
+    dtype: str,
+    rope_scaling: Dict[str, Any],
+    rotary_dim: Optional[int] = None,
+):
+    """Return the TIR function that computes Llama-style RoPE with q position 
map.
+
+    Parameters
+    ----------
+    theta : float
+        The theta value, or "base" in RoPE, which controls the frequency.
+
+    scale : float
+        The RoPE scaling factor.
+
+    head_dim : int
+        The number of features on each head.
+
+    num_q_heads : int
+        The number of query heads.
+
+    num_kv_heads : int
+        The number of key/value heads. It differs from `num_q_heads` in 
group-query attention.
+
+    dtype : str
+        The dtype of qkv data.
+
+    rope_scaling : Dict
+        The configuration of RoPE scaling.
+
+    rotary_dim : int
+        The number of dimensions in the embedding that RoPE is applied to. By 
default, the
+        rotary_dim is the same as head_dim.
+    """
+    fused_heads = num_q_heads + num_kv_heads * 2
+    if rotary_dim is None:
+        rotary_dim = head_dim
+    scale = tir.const(scale, "float32")
+    is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+
+    def _rope(  # pylint: disable=too-many-arguments
+        x: T.Buffer,
+        s: tir.Var,
+        h: tir.Var,
+        d: tir.Var,
+        pos: tir.Var,
+        ext_factors: Optional[T.Buffer] = None,
+    ):
+        kwargs = {}
+        if ext_factors:
+            kwargs["ext_factors"] = ext_factors
+        cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
+            pos * scale, d, rotary_dim, theta, "float32", **kwargs
+        )
+        cos = cos_freq * x[s, h, d].astype("float32")
+        if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
+            sin = sin_freq * tir.if_then_else(
+                d % 2 == 0,
+                -x[s, h, d + 1],
+                x[s, h, d - 1],
+            ).astype("float32")
+        else:
+            # Data layout is different for llama4 vs llama3
+            sin = sin_freq * tir.if_then_else(
+                d % 2 == 0,
+                -x[s, h, d + 1],
+                x[s, h, d - 1],
+            ).astype("float32")
+        expr = (cos + sin).astype(dtype)
+        for var, value in var_map.items():
+            expr = tir.Let(var, value, expr)
+        return expr
+
+    @T.prim_func(private=True)
+    def fused_rope(  # pylint: disable=too-many-locals
+        var_qkv: T.handle,
+        var_position_map: T.handle,
+        var_q: T.handle,
+        var_k: T.handle,
+        var_v: T.handle,
+        apply_rope: T.int64,
+    ):
+        T.func_attr(
+            {
+                "op_pattern": 8,  # 2 means injective, 8 means opaque
+                "tir.noalias": True,
+            }
+        )
+        seq_len = T.int32()
+        position_map_elem_offset = T.int32()
+        qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+        q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+        k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+        v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+        position_map = T.match_buffer(
+            var_position_map, (seq_len,), "int32", 
elem_offset=position_map_elem_offset
+        )
+        for iters in T.grid(seq_len, fused_heads, head_dim):
+            with T.block("llama_fused_rope"):
+                s, h, d = T.axis.remap("SSS", iters)
+                if h < num_q_heads:
+                    q[s, h, d] = T.if_then_else(
+                        apply_rope > 0 and d < rotary_dim,
+                        _rope(qkv, s, h, d, position_map[s]),
+                        qkv[s, h, d],
+                    )
+                elif h < num_q_heads + num_kv_heads:
+                    k[s, h - num_q_heads, d] = T.if_then_else(
+                        apply_rope > 0 and d < rotary_dim,
+                        _rope(qkv, s, h, d, position_map[s]),
+                        qkv[s, h, d],
+                    )
+                else:
+                    v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+    @T.prim_func
+    def fused_rope_longrope_scaling(  # pylint: disable=too-many-locals
+        var_qkv: T.handle,
+        var_position_map: T.handle,
+        var_q: T.handle,
+        var_k: T.handle,
+        var_v: T.handle,
+        ext_factors: T.Buffer((rotary_dim // 2,), "float32"),  # type: ignore
+    ):
+        T.func_attr(
+            {
+                "op_pattern": 8,  # 2 means injective, 8 means opaque
+                "tir.noalias": True,
+            }
+        )
+        seq_len = T.int64()
+        position_map_elem_offset = T.int64()
+        qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+        q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+        k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+        v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+        position_map = T.match_buffer(
+            var_position_map, (seq_len,), "int32", 
elem_offset=position_map_elem_offset
+        )
+        for iters in T.grid(seq_len, fused_heads, head_dim):
+            with T.block("llama_fused_rope"):
+                s, h, d = T.axis.remap("SSS", iters)
+                if h < num_q_heads:
+                    q[s, h, d] = T.if_then_else(
+                        d < rotary_dim,
+                        _rope(
+                            qkv,
+                            s,
+                            h,
+                            d,
+                            position_map[s],
+                            ext_factors if is_longrope_scaling else None,
+                        ),
+                        qkv[s, h, d],
+                    )
+                elif h < num_q_heads + num_kv_heads:
+                    k[s, h - num_q_heads, d] = T.if_then_else(
+                        d < rotary_dim,
+                        _rope(
+                            qkv,
+                            s,
+                            h,
+                            d,
+                            position_map[s],
+                            ext_factors if is_longrope_scaling else None,
+                        ),
+                        qkv[s, h, d],
+                    )
+                else:
+                    v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+    if is_longrope_scaling:
+        return fused_rope_longrope_scaling
+    return fused_rope
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 714ae94782..50d4772d8c 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1174,6 +1174,92 @@ def exp(x: Tensor, name: str = "exp") -> Tensor:
     return wrap_nested(_op.exp(x._expr), name)
 
 
+def log(x: Tensor, name: str = "log") -> Tensor:
+    r"""Applies the natural logarithm function.
+
+    .. math::
+        \text{Log}(x) = \log(x)
+
+    Parameters
+    ----------
+    x : Tensor
+        The input data to the operator.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    Note
+    ----
+    The input tensor is required to have float dtype
+    """
+    return wrap_nested(_op.log(x._expr), name)
+
+
+def floor(x: Tensor, name: str = "floor") -> Tensor:
+    r"""Computes the floor of the input tensor.
+
+    .. math::
+        \text{Floor}(x) = \floor(x)
+
+    Parameters
+    ----------
+    x : Tensor
+        The input data to the operator.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+
+    Note
+    ----
+    The input tensor is required to have float dtype
+    """
+    return wrap_nested(_op.floor(x._expr), name)
+
+
+def arange(
+    start: int,
+    end: Optional[int] = None,
+    step: int = 1,
+    dtype: Optional[str] = "float32",
+    name: str = "arange",
+) -> Tensor:
+    r"""Construct a tensor with evenly spaced elements.
+
+    Parameters
+    ----------
+    start : int
+        The start of the interval.
+
+    end : Optional[int]
+        The end of the interval. If not given, it will be set to start,
+        and start will be set to 0.
+
+    step : int
+        The step size.
+
+    dtype : Optional[str]
+        The data type of the created tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.arange(start, end, step, dtype), name)
+
+
 def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> 
Tensor:
     """Permutes the dimensions of the input tensor.
 
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index e827f643b3..28c11f6dfa 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -384,6 +384,8 @@ def test_chunk():
 def test_nn():
     class Model(Module):
         def test(self, x: Tensor, weight: Tensor, bias: Tensor):
+            log_out = op.log(x)
+            floor_out = op.floor(x)
             relu_out = op.relu(x)
             relu6_out = op.relu6(x)
             silu_out = op.silu(x)
@@ -409,6 +411,8 @@ def test_nn():
     ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
         R.func_attr({"num_input": 4})
         with R.dataflow():
+            log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.log(x)
+            floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.floor(x)
             relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x)
             relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x)
             silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
@@ -463,6 +467,8 @@ def test_create():
             )
             zeros_out = op.zeros([10, 10])
             zeros_fp16_out = op.zeros([10, 10], dtype="float16")
+
+            arange_out = op.arange(0, 10, 1, "float32")
             return x
 
     # fmt: off
@@ -476,6 +482,7 @@ def test_create():
             full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 
10]), R.const(10, "float32"), dtype="float32")
             zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10, 
10]), dtype="float32")
             zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10, 
10]), dtype="float16")
+            arange: R.Tensor((10,), dtype="float32") = R.arange(T.int64(0), 
T.int64(10), T.int64(1), dtype="float32")
             gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), 
R.Tuple(R.Object)) = x, (_io,)
             R.output(gv1)
         return gv1
@@ -504,7 +511,10 @@ def test_timestep_embedding():
             lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
             lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, 
axis=[1])
             lv3: R.Tensor((5,), dtype="float32") = R.arange(
-                R.prim_value(0), R.prim_value(5), R.prim_value(1), 
dtype="float32"
+                R.prim_value(T.int64(0)),
+                R.prim_value(T.int64(5)),
+                R.prim_value(T.int64(1)),
+                dtype="float32",
             )
             lv4: R.Tensor((5,), dtype="float32") = R.multiply(
                 R.const(-9.2103404998779297, "float32"), lv3

Reply via email to