Update modelling_RW.py
This commit is contained in:
parent
6e61c89591
commit
e49c179d8d
@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def forward(self, q, k):
|
||||
batch, seq_len, head_dim = q.shape
|
||||
cos, sin = self.cos_sin(seq_len, q.device)
|
||||
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
||||
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user