You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
defapply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos=cos.unsqueeze(unsqueeze_dim)
sin=sin.unsqueeze(unsqueeze_dim)
q_embed= (q*cos) + (rotate_half(q) *sin)
k_embed= (k*cos) + (rotate_half(k) *sin)
returnq_embed, k_embed
The rotary embedding for q and k is done exactly the same, but should they be different so that when they do multiply together, we end up with a relative position as the difference between them. The current implementation results in a relative position as a sum of the positions of q and k, so I think we should put a negative sign in front of k's angle, thus a correct implementation for k may look like k_embed = (k * cos) + (rotate_half(k) * (-sin))
Expected behavior
The angle between q and k should differ by a sign
The text was updated successfully, but these errors were encountered:
System Info
main branch
Who can help?
@ArthurZucker @younesbelkada @bojone
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
transformers/src/transformers/models/llama/modeling_llama.py
Lines 160 to 184 in 481a957
The rotary embedding for q and k is done exactly the same, but should they be different so that when they do multiply together, we end up with a relative position as the difference between them. The current implementation results in a relative position as a sum of the positions of q and k, so I think we should put a negative sign in front of k's angle, thus a correct implementation for k may look like
k_embed = (k * cos) + (rotate_half(k) * (-sin))
Expected behavior
The angle between q and k should differ by a sign
The text was updated successfully, but these errors were encountered: