Update modeling_phi.py
This commit is contained in:
parent
accfee56d8
commit
1a4c7ae2ef
@ -302,6 +302,9 @@ class PhiAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
|
# Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
|
||||||
|
@torch.autocast("cpu", enabled=False)
|
||||||
|
@torch.autocast("cuda", enabled=False)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -359,10 +362,7 @@ class PhiAttention(nn.Module):
|
|||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
attn_weights = torch.matmul(
|
|
||||||
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
|
||||||
) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user