Update modeling_phi.py

This commit is contained in:
Gustavo de Rosa 2024-01-18 11:27:21 +00:00 committed by system
parent 1a4c7ae2ef
commit 85d00b03fe
No known key found for this signature in database
GPG Key ID: 6A528E38E0733467

@ -362,7 +362,10 @@ class PhiAttention(nn.Module):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
attn_weights = torch.matmul(
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(