Update modeling_phi.py
This commit is contained in:
parent
cb2f453360
commit
eb8bbd1d37
@ -47,9 +47,11 @@ from transformers.utils import (
|
|||||||
from .configuration_phi import PhiConfig
|
from .configuration_phi import PhiConfig
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
try:
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user