From e49c179d8d3194b84fb4d0908737b8e7537528ea Mon Sep 17 00:00:00 2001 From: Falcon LLM TII UAE Date: Tue, 30 May 2023 06:12:12 +0000 Subject: [PATCH] Update modelling_RW.py --- modelling_RW.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelling_RW.py b/modelling_RW.py index 77cb0c1..f0c38a9 100644 --- a/modelling_RW.py +++ b/modelling_RW.py @@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module): def forward(self, q, k): batch, seq_len, head_dim = q.shape - cos, sin = self.cos_sin(seq_len, q.device) + cos, sin = self.cos_sin(seq_len, q.device, q.dtype) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)