This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b555a964f Adjusted Longrope embedding function to match Huggingface 
Implementation (#18422)
4b555a964f is described below

commit 4b555a964f39b519eac13d6350cab30f88466fb3
Author: Sidharth N. Babu <[email protected]>
AuthorDate: Wed Nov 12 13:31:53 2025 -0500

    Adjusted Longrope embedding function to match Huggingface Implementation 
(#18422)
    
    This updated implementation of longrope allows for the consideration
    of `long_factors` and `short_factors`, which are scaling dictionaries
    provided via HF configs for MSFT's Phi3+ models. In the HF canonical
    implementation of longrope, once the sequence length exceeds a certain
    pre-configured dimension, you must use a different set of `ext_factors`
    than you were previously. This patch enables this by packing both sets
    of scaling factors into one argument, and selecting which to use
    dynamically within the returned `prim_func`.
    
    The HF implementation of this can be found here:
    
https://github.com/huggingface/transformers/blob/7b325cd573e40bbb12951b8446176c96e8b1afaa/src/transformers/modeling_rope_utils.py#L521
    
    The link above points directly to the switching logic between long
    and short factors, which has been replicated in this PR.
---
 .../relax/frontend/nn/llm/position_embedding.py    | 107 +++++++++++++++------
 1 file changed, 75 insertions(+), 32 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py 
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 6fda4b0bca..35eeb4f5f3 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -464,6 +464,10 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
         rotary_dim = head_dim
     scale = tir.const(scale, "float32")
     is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+    if is_longrope_scaling and "original_max_position_embeddings" in 
rope_scaling:
+        original_max_position_embeddings = 
rope_scaling["original_max_position_embeddings"]
+    else:
+        original_max_position_embeddings = 0
 
     def _rope(  # pylint: disable=too-many-arguments
         x: T.Buffer,
@@ -546,7 +550,7 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
         var_q: T.handle,
         var_k: T.handle,
         var_v: T.handle,
-        ext_factors: T.Buffer((rotary_dim // 2,), "float32"),  # type: ignore
+        ext_factors: T.Buffer((rotary_dim,), "float32"),  # type: ignore
     ):
         T.func_attr(
             {
@@ -563,37 +567,76 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
         position_map = T.match_buffer(
             var_position_map, (seq_len,), "int32", 
elem_offset=position_map_elem_offset
         )
-        for iters in T.grid(seq_len, fused_heads, head_dim):
-            with T.block("llama_fused_rope"):
-                s, h, d = T.axis.remap("SSS", iters)
-                if h < num_q_heads:
-                    q[s, h, d] = T.if_then_else(
-                        d < rotary_dim,
-                        _rope(
-                            qkv,
-                            s,
-                            h,
-                            d,
-                            position_map[s],
-                            ext_factors if is_longrope_scaling else None,
-                        ),
-                        qkv[s, h, d],
-                    )
-                elif h < num_q_heads + num_kv_heads:
-                    k[s, h - num_q_heads, d] = T.if_then_else(
-                        d < rotary_dim,
-                        _rope(
-                            qkv,
-                            s,
-                            h,
-                            d,
-                            position_map[s],
-                            ext_factors if is_longrope_scaling else None,
-                        ),
-                        qkv[s, h, d],
-                    )
-                else:
-                    v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+        # long factors is the first half, short factors is the second half
+        long_factors = T.Buffer((rotary_dim // 2,), "float32", 
data=ext_factors.data)
+        short_factors = T.Buffer(
+            (rotary_dim // 2,), "float32", data=ext_factors.data, 
elem_offset=(rotary_dim // 2)
+        )
+
+        if seq_len > original_max_position_embeddings:
+            for iters in T.grid(seq_len, fused_heads, head_dim):
+                with T.block("llama_fused_rope"):
+                    s, h, d = T.axis.remap("SSS", iters)
+                    if h < num_q_heads:
+                        q[s, h, d] = T.if_then_else(
+                            d < rotary_dim,
+                            _rope(
+                                qkv,
+                                s,
+                                h,
+                                d,
+                                position_map[s],
+                                long_factors if is_longrope_scaling else None,
+                            ),
+                            qkv[s, h, d],
+                        )
+                    elif h < num_q_heads + num_kv_heads:
+                        k[s, h - num_q_heads, d] = T.if_then_else(
+                            d < rotary_dim,
+                            _rope(
+                                qkv,
+                                s,
+                                h,
+                                d,
+                                position_map[s],
+                                long_factors if is_longrope_scaling else None,
+                            ),
+                            qkv[s, h, d],
+                        )
+                    else:
+                        v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, 
d]
+        else:
+            for iters in T.grid(seq_len, fused_heads, head_dim):
+                with T.block("llama_fused_rope"):
+                    s, h, d = T.axis.remap("SSS", iters)
+                    if h < num_q_heads:
+                        q[s, h, d] = T.if_then_else(
+                            d < rotary_dim,
+                            _rope(
+                                qkv,
+                                s,
+                                h,
+                                d,
+                                position_map[s],
+                                short_factors if is_longrope_scaling else None,
+                            ),
+                            qkv[s, h, d],
+                        )
+                    elif h < num_q_heads + num_kv_heads:
+                        k[s, h - num_q_heads, d] = T.if_then_else(
+                            d < rotary_dim,
+                            _rope(
+                                qkv,
+                                s,
+                                h,
+                                d,
+                                position_map[s],
+                                short_factors if is_longrope_scaling else None,
+                            ),
+                            qkv[s, h, d],
+                        )
+                    else:
+                        v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, 
d]
 
     if is_longrope_scaling:
         return fused_rope_longrope_scaling

Reply via email to