diff --git a/modelling_RW.py b/modelling_RW.py index 77cb0c1..f0c38a9 100644 --- a/modelling_RW.py +++ b/modelling_RW.py @@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module): def forward(self, q, k): batch, seq_len, head_dim = q.shape - cos, sin = self.cos_sin(seq_len, q.device) + cos, sin = self.cos_sin(seq_len, q.device, q.dtype) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)