This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f97159504c [Relax] Operator and RoPE support for Llama4 (#18336)
f97159504c is described below
commit f97159504cef41513f77ee8e2cb8636365e4fb52
Author: Pranav Venkatram <[email protected]>
AuthorDate: Tue Sep 23 13:08:38 2025 -0400
[Relax] Operator and RoPE support for Llama4 (#18336)
Added LLama4 implementation, new rope implementation
---
python/tvm/relax/expr.py | 4 +
.../relax/frontend/nn/llm/position_embedding.py | 234 +++++++++++++++++++++
python/tvm/relax/frontend/nn/op.py | 86 ++++++++
tests/python/relax/test_frontend_nn_op.py | 12 +-
4 files changed, 335 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 1a7a5c224a..8dd4eff5c7 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -22,6 +22,7 @@ from typing import Any, Callable, Dict, List, Optional,
Union, Mapping
import numpy as _np # type: ignore
import tvm_ffi
+
import tvm.ir
import tvm.relax
from tvm import DataType
@@ -1153,6 +1154,9 @@ def const(
- bool maps to "bool"
- other using the same default rule as numpy.
"""
+ # Needed for bf16 and fp8 support (does not come with numpy)
+ import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel
+
if isinstance(value, (Number, (bool, list))):
value = _np.array(value, dtype=dtype)
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 1a1659b29e..6fda4b0bca 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int,
theta: float, dtype: st
return cos_freq, sin_freq, {freq_var: freq}
+def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
+ s: tir.Var,
+ d: tir.Var,
+ d_range: int,
+ theta: float,
+ dtype: str,
+ factor: float,
+ low_freq_factor: float,
+ high_freq_factor: float,
+ original_max_position_embeddings: float,
+):
+ """Compute the inverse frequency of RoPE for llama4 RoPE scaling."""
+ orig_freq = tir.const(1, "float32") / tir.power(
+ theta, 2 * (d // 2) / tir.const(d_range, "float32")
+ )
+ orig_freq_var = tir.Var("orig_freq", "float32")
+
+ llama4_inv_scaling_factor = 1.0 / factor
+
+ if high_freq_factor == low_freq_factor:
+ wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
+ threshold_wavelen = tir.const(original_max_position_embeddings /
low_freq_factor, "float32")
+
+ scaled_freq = tir.if_then_else(
+ wavelength > threshold_wavelen, orig_freq_var / factor,
orig_freq_var
+ )
+ smoothed_freq = s * scaled_freq
+
+ else:
+ # Original smooth interpolation logic
+ inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)
+
+ llama4_alpha = original_max_position_embeddings / (2 * math.pi) *
inv_diff_freq_factor
+ llama4_beta = low_freq_factor * inv_diff_freq_factor
+ smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var -
llama4_beta))
+ smoothed_freq = s * (
+ (1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor +
smooth * orig_freq_var
+ )
+
+ smoothed_freq_var = tir.Var("smoothed_freq", "float32")
+ cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
+ sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
+ return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq,
orig_freq_var: orig_freq}
+
+
def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
s: tir.Var,
d: tir.Var,
@@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) ->
Callable:
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
+ if rope_scaling["rope_type"] == "llama4":
+ return partial(
+ rope_freq_llama4,
+ factor=rope_scaling["factor"],
+ low_freq_factor=rope_scaling["low_freq_factor"],
+ high_freq_factor=rope_scaling["high_freq_factor"],
+
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
+ )
if rope_scaling["rope_type"] == "longrope":
return partial(
rope_freq_longrope,
@@ -545,3 +598,184 @@ def llama_rope_with_position_map( # pylint:
disable=too-many-arguments
if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope
+
+
+def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
+ theta: float,
+ scale: float,
+ head_dim: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ dtype: str,
+ rope_scaling: Dict[str, Any],
+ rotary_dim: Optional[int] = None,
+):
+ """Return the TIR function that computes Llama-style RoPE with q position
map.
+
+ Parameters
+ ----------
+ theta : float
+ The theta value, or "base" in RoPE, which controls the frequency.
+
+ scale : float
+ The RoPE scaling factor.
+
+ head_dim : int
+ The number of features on each head.
+
+ num_q_heads : int
+ The number of query heads.
+
+ num_kv_heads : int
+ The number of key/value heads. It differs from `num_q_heads` in
group-query attention.
+
+ dtype : str
+ The dtype of qkv data.
+
+ rope_scaling : Dict
+ The configuration of RoPE scaling.
+
+ rotary_dim : int
+ The number of dimensions in the embedding that RoPE is applied to. By
default, the
+ rotary_dim is the same as head_dim.
+ """
+ fused_heads = num_q_heads + num_kv_heads * 2
+ if rotary_dim is None:
+ rotary_dim = head_dim
+ scale = tir.const(scale, "float32")
+ is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+
+ def _rope( # pylint: disable=too-many-arguments
+ x: T.Buffer,
+ s: tir.Var,
+ h: tir.Var,
+ d: tir.Var,
+ pos: tir.Var,
+ ext_factors: Optional[T.Buffer] = None,
+ ):
+ kwargs = {}
+ if ext_factors:
+ kwargs["ext_factors"] = ext_factors
+ cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
+ pos * scale, d, rotary_dim, theta, "float32", **kwargs
+ )
+ cos = cos_freq * x[s, h, d].astype("float32")
+ if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ else:
+ # Data layout is different for llama4 vs llama3
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ expr = (cos + sin).astype(dtype)
+ for var, value in var_map.items():
+ expr = tir.Let(var, value, expr)
+ return expr
+
+ @T.prim_func(private=True)
+ def fused_rope( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ apply_rope: T.int64,
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int32()
+ position_map_elem_offset = T.int32()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ @T.prim_func
+ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int64()
+ position_map_elem_offset = T.int64()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ if is_longrope_scaling:
+ return fused_rope_longrope_scaling
+ return fused_rope
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 714ae94782..50d4772d8c 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1174,6 +1174,92 @@ def exp(x: Tensor, name: str = "exp") -> Tensor:
return wrap_nested(_op.exp(x._expr), name)
+def log(x: Tensor, name: str = "log") -> Tensor:
+ r"""Applies the natural logarithm function.
+
+ .. math::
+ \text{Log}(x) = \log(x)
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data to the operator.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ Note
+ ----
+ The input tensor is required to have float dtype
+ """
+ return wrap_nested(_op.log(x._expr), name)
+
+
+def floor(x: Tensor, name: str = "floor") -> Tensor:
+ r"""Computes the floor of the input tensor.
+
+ .. math::
+ \text{Floor}(x) = \floor(x)
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data to the operator.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Note
+ ----
+ The input tensor is required to have float dtype
+ """
+ return wrap_nested(_op.floor(x._expr), name)
+
+
+def arange(
+ start: int,
+ end: Optional[int] = None,
+ step: int = 1,
+ dtype: Optional[str] = "float32",
+ name: str = "arange",
+) -> Tensor:
+ r"""Construct a tensor with evenly spaced elements.
+
+ Parameters
+ ----------
+ start : int
+ The start of the interval.
+
+ end : Optional[int]
+ The end of the interval. If not given, it will be set to start,
+ and start will be set to 0.
+
+ step : int
+ The step size.
+
+ dtype : Optional[str]
+ The data type of the created tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.arange(start, end, step, dtype), name)
+
+
def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") ->
Tensor:
"""Permutes the dimensions of the input tensor.
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index e827f643b3..28c11f6dfa 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -384,6 +384,8 @@ def test_chunk():
def test_nn():
class Model(Module):
def test(self, x: Tensor, weight: Tensor, bias: Tensor):
+ log_out = op.log(x)
+ floor_out = op.floor(x)
relu_out = op.relu(x)
relu6_out = op.relu6(x)
silu_out = op.silu(x)
@@ -409,6 +411,8 @@ def test_nn():
) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 4})
with R.dataflow():
+ log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.log(x)
+ floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.floor(x)
relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x)
relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x)
silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
@@ -463,6 +467,8 @@ def test_create():
)
zeros_out = op.zeros([10, 10])
zeros_fp16_out = op.zeros([10, 10], dtype="float16")
+
+ arange_out = op.arange(0, 10, 1, "float32")
return x
# fmt: off
@@ -476,6 +482,7 @@ def test_create():
full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10,
10]), R.const(10, "float32"), dtype="float32")
zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10,
10]), dtype="float32")
zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10,
10]), dtype="float16")
+ arange: R.Tensor((10,), dtype="float32") = R.arange(T.int64(0),
T.int64(10), T.int64(1), dtype="float32")
gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tuple(R.Object)) = x, (_io,)
R.output(gv1)
return gv1
@@ -504,7 +511,10 @@ def test_timestep_embedding():
lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1,
axis=[1])
lv3: R.Tensor((5,), dtype="float32") = R.arange(
- R.prim_value(0), R.prim_value(5), R.prim_value(1),
dtype="float32"
+ R.prim_value(T.int64(0)),
+ R.prim_value(T.int64(5)),
+ R.prim_value(T.int64(1)),
+ dtype="float32",
)
lv4: R.Tensor((5,), dtype="float32") = R.multiply(
R.const(-9.2103404998779297, "float32"), lv3