This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4747a92827 [Python][Relax] Fix YaRN correction dim calculation (#18661)
4747a92827 is described below

commit 4747a9282728614bf751f9c96b2c91a41a3abb1d
Author: Akaash Parthasarathy <[email protected]>
AuthorDate: Fri Jan 30 16:07:40 2026 -0500

    [Python][Relax] Fix YaRN correction dim calculation (#18661)
    
    Precompute ```inv_theta_log_scale```
---
 python/tvm/relax/frontend/nn/llm/kv_cache.py       | 19 ++++++++++++++
 .../relax/frontend/nn/llm/position_embedding.py    | 30 +++++++++++++++-------
 2 files changed, 40 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 6b6029630d..d2df939c3d 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -297,6 +297,23 @@ class PagedKVCache(Object):  # pylint: 
disable=too-few-public-methods
     # pylint: enable=protected-access
 
 
+def _prepare_yarn_rope_scaling(
+    rope_scaling: Optional[Dict[str, Any]],
+    rope_theta: Optional[float],
+) -> Optional[Dict[str, Any]]:
+    """Ensure Yarn-specific scaling configs include the theta metadata."""
+    if rope_scaling is None:
+        return None
+    if rope_scaling.get("rope_type") != "yarn":
+        return rope_scaling
+
+    rope_scaling_updated = dict(rope_scaling)
+    if "inv_theta_log_scale" not in rope_scaling_updated and rope_theta is not 
None:
+        theta_value = float(rope_theta)
+        rope_scaling_updated["inv_theta_log_scale"] = 1.0 / (2 * 
math.log(theta_value))
+    return rope_scaling_updated
+
+
 class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
     """Paged KV cache using FlashInfer (CUDA) kernels."""
 
@@ -372,6 +389,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             Whether to enable disaggregation in the KV cache.
         """
         assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support 
inline mode."
+        rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)
 
         attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else 
attn_kind
         if attn_kind_single == "mha_sliding":
@@ -561,6 +579,7 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
         target : Target
             The target to build the model to.
         """
+        rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)
         attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else 
attn_kind
         if attn_kind_single == "mha_sliding":
             attn_kind_single = "mha"
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py 
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index b90b4bfecd..ee2a356299 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -19,7 +19,7 @@
 
 import math
 from functools import partial
-from typing import Any, Callable, Dict, Optional, Tuple
+from typing import Any, Callable, Dict, Optional, Tuple, Union
 
 from tvm import tir
 from tvm.relax.frontend.nn import Tensor, op
@@ -180,12 +180,12 @@ def rope_freq_longrope(  # pylint: 
disable=too-many-arguments
 def yarn_find_correction_dim(
     num_rotations: int,
     d: tir.Var,
-    theta: float,
     max_position_embeddings: int,
+    inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
 ):
     """Inverse dim formula to find dim based on number of rotations"""
-    return (d * math.log(max_position_embeddings / (num_rotations * 2 * 
math.pi))) / (
-        2 * math.log(theta)
+    return (
+        d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) 
* inv_theta_log_scale
     )
 
 
@@ -193,12 +193,16 @@ def yarn_find_correction_range(
     low_rot: int,
     high_rot: int,
     d: tir.Var,
-    theta: float,
     max_position_embeddings: int,
+    inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
 ):
     """Find the correction range based on the number of rotations"""
-    low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
-    high = yarn_find_correction_dim(high_rot, d, theta, 
max_position_embeddings)
+    low = yarn_find_correction_dim(
+        low_rot, d, max_position_embeddings, 
inv_theta_log_scale=inv_theta_log_scale
+    )
+    high = yarn_find_correction_dim(
+        high_rot, d, max_position_embeddings, 
inv_theta_log_scale=inv_theta_log_scale
+    )
     return tir.max(low, 0), tir.min(high, d - 1)
 
 
@@ -206,12 +210,13 @@ def rope_freq_yarn(
     s: tir.Var,
     d: tir.Var,
     d_range: int,
-    theta: float,
+    theta: Union[float, tir.PrimExpr],
     dtype: str,
     original_max_position_embeddings: int,
     scaling_factor: float,
     beta_fast: int,
     beta_slow: int,
+    inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
 ):  # pylint: disable=too-many-arguments, too-many-locals
     """Compute the inverse frequency of RoPE for yarn RoPE scaling."""
 
@@ -221,7 +226,11 @@ def rope_freq_yarn(
     freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
 
     low, high = yarn_find_correction_range(
-        beta_fast, beta_slow, d_range, theta, original_max_position_embeddings
+        beta_fast,
+        beta_slow,
+        d_range,
+        original_max_position_embeddings,
+        inv_theta_log_scale=inv_theta_log_scale,
     )
     high = tir.if_then_else(low == high, high + 0.001, high)
     inv_freq_mask = tir.const(1, "float32") - tir.max(
@@ -266,12 +275,15 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) 
-> Callable:
             
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
         )
     if rope_scaling["rope_type"] == "yarn":
+        inv_theta_log_scale = rope_scaling.get("inv_theta_log_scale")
+        assert inv_theta_log_scale is not None, "inv_theta_log_scale must be 
precomputed for YaRN"
         return partial(
             rope_freq_yarn,
             
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
             scaling_factor=rope_scaling["factor"],
             beta_fast=rope_scaling["beta_fast"],
             beta_slow=rope_scaling["beta_slow"],
+            inv_theta_log_scale=inv_theta_log_scale,
         )
     raise ValueError(f'Unsupported RoPE scaling type: 
{rope_scaling["rope_type"]}')
 

Reply via email to