diff --git a/modeling_phi.py b/modeling_phi.py index 4b51f34..3dd2ae7 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -302,6 +302,9 @@ class PhiAttention(nn.Module): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled + @torch.autocast("cpu", enabled=False) + @torch.autocast("cuda", enabled=False) def forward( self, hidden_states: torch.Tensor, @@ -359,10 +362,7 @@ class PhiAttention(nn.Module): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow - attn_weights = torch.matmul( - query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) - ) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError(