MasterJH5574 commented on code in PR #18336:
URL: https://github.com/apache/tvm/pull/18336#discussion_r2371041586
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -545,3 +598,184 @@ def fused_rope_longrope_scaling( # pylint:
disable=too-many-locals
if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope
+
+
+def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
+ theta: float,
+ scale: float,
+ head_dim: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ dtype: str,
+ rope_scaling: Dict[str, Any],
+ rotary_dim: Optional[int] = None,
+):
+ """Return the TIR function that computes Llama-style RoPE with q position
map.
+
+ Parameters
+ ----------
+ theta : float
+ The theta value, or "base" in RoPE, which controls the frequency.
+
+ scale : float
+ The RoPE scaling factor.
+
+ head_dim : int
+ The number of features on each head.
+
+ num_q_heads : int
+ The number of query heads.
+
+ num_kv_heads : int
+ The number of key/value heads. It differs from `num_q_heads` in
group-query attention.
+
+ dtype : str
+ The dtype of qkv data.
+
+ rope_scaling : Dict
+ The configuration of RoPE scaling.
+
+ rotary_dim : int
+ The number of dimensions in the embedding that RoPE is applied to. By
default, the
+ rotary_dim is the same as head_dim.
+ """
+ fused_heads = num_q_heads + num_kv_heads * 2
+ if rotary_dim is None:
+ rotary_dim = head_dim
+ scale = tir.const(scale, "float32")
+ is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+
+ def _rope( # pylint: disable=too-many-arguments
+ x: T.Buffer,
+ s: tir.Var,
+ h: tir.Var,
+ d: tir.Var,
+ pos: tir.Var,
+ ext_factors: Optional[T.Buffer] = None,
+ ):
+ kwargs = {}
+ if ext_factors:
+ kwargs["ext_factors"] = ext_factors
+ cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
+ pos * scale, d, rotary_dim, theta, "float32", **kwargs
+ )
+ cos = cos_freq * x[s, h, d].astype("float32")
+ if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ else:
+ # Data layout is different for llama4 vs llama3
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ expr = (cos + sin).astype(dtype)
+ for var, value in var_map.items():
+ expr = tir.Let(var, value, expr)
+ return expr
+
+ @T.prim_func(private=True)
+ def fused_rope( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ apply_rope: T.int64,
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int32()
+ position_map_elem_offset = T.int32()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ @T.prim_func
+ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int64()
+ position_map_elem_offset = T.int64()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ if is_longrope_scaling:
+ return fused_rope_longrope_scaling
+ return fused_rope
Review Comment:
Thanks. We'll get this PR in and clean up duplication as followup.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]