gemini-code-assist[bot] commented on code in PR #18336:
URL: https://github.com/apache/tvm/pull/18336#discussion_r2370979502
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -545,3 +598,184 @@ def fused_rope_longrope_scaling( # pylint:
disable=too-many-locals
if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope
+
+
+def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
+ theta: float,
+ scale: float,
+ head_dim: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ dtype: str,
+ rope_scaling: Dict[str, Any],
+ rotary_dim: Optional[int] = None,
+):
+ """Return the TIR function that computes Llama-style RoPE with q position
map.
+
+ Parameters
+ ----------
+ theta : float
+ The theta value, or "base" in RoPE, which controls the frequency.
+
+ scale : float
+ The RoPE scaling factor.
+
+ head_dim : int
+ The number of features on each head.
+
+ num_q_heads : int
+ The number of query heads.
+
+ num_kv_heads : int
+ The number of key/value heads. It differs from `num_q_heads` in
group-query attention.
+
+ dtype : str
+ The dtype of qkv data.
+
+ rope_scaling : Dict
+ The configuration of RoPE scaling.
+
+ rotary_dim : int
+ The number of dimensions in the embedding that RoPE is applied to. By
default, the
+ rotary_dim is the same as head_dim.
+ """
+ fused_heads = num_q_heads + num_kv_heads * 2
+ if rotary_dim is None:
+ rotary_dim = head_dim
+ scale = tir.const(scale, "float32")
+ is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+
+ def _rope( # pylint: disable=too-many-arguments
+ x: T.Buffer,
+ s: tir.Var,
+ h: tir.Var,
+ d: tir.Var,
+ pos: tir.Var,
+ ext_factors: Optional[T.Buffer] = None,
+ ):
+ kwargs = {}
+ if ext_factors:
+ kwargs["ext_factors"] = ext_factors
+ cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
+ pos * scale, d, rotary_dim, theta, "float32", **kwargs
+ )
+ cos = cos_freq * x[s, h, d].astype("float32")
+ if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ else:
+ # Data layout is different for llama4 vs llama3
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ expr = (cos + sin).astype(dtype)
+ for var, value in var_map.items():
+ expr = tir.Let(var, value, expr)
+ return expr
+
+ @T.prim_func(private=True)
+ def fused_rope( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ apply_rope: T.int64,
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int32()
+ position_map_elem_offset = T.int32()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ @T.prim_func
+ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int64()
+ position_map_elem_offset = T.int64()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ if is_longrope_scaling:
+ return fused_rope_longrope_scaling
+ return fused_rope
Review Comment:

The new function `llama4_rope_with_position_map` is almost an identical copy
of the existing `llama_rope_with_position_map` function. This introduces
significant code duplication, which will make future maintenance more difficult.
The main difference lies in the `_rope` helper function, specifically in the
calculation of `sin`. The logic for `llama4` is identical to the existing logic
for `gptj`.
To avoid duplication, I recommend refactoring this by merging
`llama4_rope_with_position_map` into `llama_rope_with_position_map`. You can
handle the `llama4` case within the `_rope` function of
`llama_rope_with_position_map`, similar to how `gptj` is currently handled. For
example:
```python
def _rope(...):
# ...
cos = cos_freq * x[s, h, d].astype("float32")
if "rope_type" in rope_scaling and rope_scaling["rope_type"] in ("gptj",
"llama4"):
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[s, h, d + 1],
x[s, h, d - 1],
).astype("float32")
else:
sin = sin_freq * tir.if_then_else(
d < rotary_dim // 2,
-x[s, h, d + rotary_dim // 2],
x[s, h, d - rotary_dim // 2],
).astype("float32")
# ...
```
This would eliminate the need for the entire `llama4_rope_with_position_map`
function.
Additionally, there's a type inconsistency for the `apply_rope` parameter
between the two functions (`T.int32` vs. `T.int64`). This should be unified,
perhaps to `T.int64` for consistency with `fused_rope_longrope_scaling`.
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int,
theta: float, dtype: st
return cos_freq, sin_freq, {freq_var: freq}
+def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
+ s: tir.Var,
+ d: tir.Var,
+ d_range: int,
+ theta: float,
+ dtype: str,
+ factor: float,
+ low_freq_factor: float,
+ high_freq_factor: float,
+ original_max_position_embeddings: float,
+):
+ """Compute the inverse frequency of RoPE for llama3 RoPE scaling."""
+ orig_freq = tir.const(1, "float32") / tir.power(
+ theta, 2 * (d // 2) / tir.const(d_range, "float32")
+ )
+ orig_freq_var = tir.Var("orig_freq", "float32")
+
+ llama3_inv_scaling_factor = 1.0 / factor
+
+ if high_freq_factor == low_freq_factor:
+ wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
+ threshold_wavelen = tir.const(original_max_position_embeddings /
low_freq_factor, "float32")
+
+ scaled_freq = tir.if_then_else(
+ wavelength > threshold_wavelen, orig_freq_var / factor,
orig_freq_var
+ )
+ smoothed_freq = s * scaled_freq
+
+ else:
+ # Original smooth interpolation logic
+ inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)
+
+ llama3_alpha = original_max_position_embeddings / (2 * math.pi) *
inv_diff_freq_factor
+ llama3_beta = low_freq_factor * inv_diff_freq_factor
+ smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var -
llama3_beta))
+ smoothed_freq = s * (
+ (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor +
smooth * orig_freq_var
+ )
Review Comment:

The docstring and some variable names in `rope_freq_llama4` appear to be
copied from the Llama3 implementation. To improve code clarity and
maintainability, they should be updated to refer to Llama4.
Specifically:
- The docstring on line 89 mentions "llama3 RoPE scaling".
- Variable names like `llama3_inv_scaling_factor` (line 95), `llama3_alpha`
(line 110), and `llama3_beta` (line 111) should be updated.
```python
"""Compute the inverse frequency of RoPE for llama4 RoPE scaling."""
orig_freq = tir.const(1, "float32") / tir.power(
theta, 2 * (d // 2) / tir.const(d_range, "float32")
)
orig_freq_var = tir.Var("orig_freq", "float32")
llama4_inv_scaling_factor = 1.0 / factor
if high_freq_factor == low_freq_factor:
wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
threshold_wavelen = tir.const(original_max_position_embeddings /
low_freq_factor, "float32")
scaled_freq = tir.if_then_else(
wavelength > threshold_wavelen, orig_freq_var / factor,
orig_freq_var
)
smoothed_freq = s * scaled_freq
else:
# Original smooth interpolation logic
inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)
llama4_alpha = original_max_position_embeddings / (2 * math.pi) *
inv_diff_freq_factor
llama4_beta = low_freq_factor * inv_diff_freq_factor
smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var -
llama4_beta))
smoothed_freq = s * (
(1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor +
smooth * orig_freq_var
)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]