This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e3f5ac1c6b [Relax] Correct YaRN RoPE frequency scaling formula to 
align with the original paper (#18576)
e3f5ac1c6b is described below

commit e3f5ac1c6bccebc4bf1c35c9a1d81cf4c0a1740d
Author: Yeongjae Jang <[email protected]>
AuthorDate: Thu Jan 1 05:49:16 2026 +0900

    [Relax] Correct YaRN RoPE frequency scaling formula to align with the 
original paper (#18576)
    
    ## Summary
    Fixed frequency calculations for RoPE (YaRN) scaling and correct range
    finding.
    
    ## Description
    Greetings:
    
    This PR corrects the mathematical formulation of the
    [YaRN](https://arxiv.org/abs/2309.00071) RoPE scaling.
    I have verified that this change eliminates the discrepancy observed
    when comparing against PyTorch baseline (an implementation of
    `gpt-oss`).
    
    ### in `yarn_find_correction_range()`
    #### `low`, `high`
    Removed `tir.floor` and `tir.ceil` operations in
    `yarn_find_correction_dim()`.
    In YaRN paper, there is no floor or ceil function within calculations of
    those values.
    In `gpt-oss`, the implementation uses floating-point values for these
    thresholds to ensure smooth interpolation in the ramp function.
    Rounding them caused quantization errors in the ramp mask.
    
    ### in `rope_freq_yarn()`
    #### `freq_inter`
    Currently, the implementation calculates the inverse frequency as:
    ```
    freq_inter = tir.const(1, "float32") / tir.power(
        scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
    )
    ```
    
    This implies `scale` is also affected by the exponent, leading to
    non-uniform scaling across dimensions.
    
    According to the YaRN method (and an implementation of `gpt-oss`), the
    scaling factor should be applied linearly:
    ```
    exponent = d * 2 % d_range / tir.const(d_range, "float32")
    freq_power = tir.power(theta, exponent)
    freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
    ```
    
    #### `d_range`
    The `yarn_find_correction_range()` function was incorrectly using the
    current dimension index `d` to calculate thresholds.
    This caused the ramp boundaries to shift dynamically per dimension.
    It has been corrected to use the total dimension size (`d_range`) to
    ensure consistent frequency thresholds.
    
    Before:
    ```
    yarn_find_correction_range(..., d, ...)
    ```
    
    After:
    ```
    yarn_find_correction_range(..., d_range, ...)
    ```
    
    Thank you very much for reading.
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 python/tvm/relax/frontend/nn/llm/position_embedding.py | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py 
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 35eeb4f5f3..60808a6b35 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -197,8 +197,8 @@ def yarn_find_correction_range(
     max_position_embeddings: int,
 ):
     """Find the correction range based on the number of rotations"""
-    low = tir.floor(yarn_find_correction_dim(low_rot, d, theta, 
max_position_embeddings))
-    high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta, 
max_position_embeddings))
+    low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
+    high = yarn_find_correction_dim(high_rot, d, theta, 
max_position_embeddings)
     return tir.max(low, 0), tir.min(high, d - 1)
 
 
@@ -214,16 +214,14 @@ def rope_freq_yarn(
     beta_slow: int,
 ):  # pylint: disable=too-many-arguments, too-many-locals
     """Compute the inverse frequency of RoPE for yarn RoPE scaling."""
-    freq_extra = tir.const(1, "float32") / tir.power(
-        theta, d * 2 % d_range / tir.const(d_range, "float32")
-    )
 
-    freq_inter = tir.const(1, "float32") / tir.power(
-        scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
-    )
+    exponent = d * 2 % d_range / tir.const(d_range, "float32")
+    freq_power = tir.power(theta, exponent)
+    freq_extra = tir.const(1, "float32") / freq_power
+    freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
 
     low, high = yarn_find_correction_range(
-        beta_fast, beta_slow, d, theta, original_max_position_embeddings
+        beta_fast, beta_slow, d_range, theta, original_max_position_embeddings
     )
     high = tir.if_then_else(low == high, high + 0.001, high)
     inv_freq_mask = tir.const(1, "float32") - tir.max(

Reply via email to