diff --git a/modeling_phi.py b/modeling_phi.py index 9f59e97..4b51f34 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -506,7 +506,7 @@ class PhiFlashAttention2(PhiAttention): value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0 + query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()