Update modelling_RW.py

This commit is contained in:
Falcon LLM TII UAE 2023-05-30 06:12:12 +00:00 committed by huggingface-web
parent 6e61c89591
commit e49c179d8d

@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module):
def forward(self, q, k):
batch, seq_len, head_dim = q.shape
cos, sin = self.cos_sin(seq_len, q.device)
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)