diff --git a/modeling_phi.py b/modeling_phi.py index f593327..9f59e97 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -47,9 +47,11 @@ from transformers.utils import ( from .configuration_phi import PhiConfig -if is_flash_attn_2_available(): +try: from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + pass logger = logging.get_logger(__name__)