2023-09-29 14:32:25 +00:00
# coding=utf-8
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Falcon model. """
2023-05-24 12:17:53 +00:00
import math
from typing import Optional , Tuple , Union
import torch
import torch . utils . checkpoint
from torch import nn
from torch . nn import BCEWithLogitsLoss , CrossEntropyLoss , LayerNorm , MSELoss
from torch . nn import functional as F
from transformers . modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions ,
CausalLMOutputWithCrossAttentions ,
QuestionAnsweringModelOutput ,
SequenceClassifierOutputWithPast ,
TokenClassifierOutput ,
)
from transformers . modeling_utils import PreTrainedModel
2023-09-29 14:32:25 +00:00
from transformers . utils import add_code_sample_docstrings , add_start_docstrings , add_start_docstrings_to_model_forward , logging
from . configuration_falcon import FalconConfig
2023-05-24 12:17:53 +00:00
logger = logging . get_logger ( __name__ )
2023-09-29 14:32:25 +00:00
FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
" tiiuae/falcon-40b " ,
" tiiuae/falcon-40b-instruct " ,
" tiiuae/falcon-7b " ,
" tiiuae/falcon-7b-instruct " ,
" tiiuae/falcon-rw-7b " ,
" tiiuae/falcon-rw-1b " ,
]
_CHECKPOINT_FOR_DOC = " Rocketknight1/falcon-rw-1b "
_CONFIG_FOR_DOC = " FalconConfig "
2023-05-24 12:17:53 +00:00
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
2023-09-29 14:32:25 +00:00
class FalconLinear ( nn . Linear ) :
2023-05-24 12:17:53 +00:00
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
2023-09-29 14:32:25 +00:00
hidden_states = input @ self . weight . T
2023-05-24 12:17:53 +00:00
if self . bias is None :
2023-09-29 14:32:25 +00:00
return hidden_states
return hidden_states + self . bias
2023-05-24 12:17:53 +00:00
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
def rotate_half ( x ) :
x1 , x2 = x [ . . . , : x . shape [ - 1 ] / / 2 ] , x [ . . . , x . shape [ - 1 ] / / 2 : ]
2023-09-29 14:32:25 +00:00
return torch . cat ( ( - x2 , x1 ) , dim = - 1 )
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
class FalconRotaryEmbedding ( nn . Module ) :
2023-05-24 12:17:53 +00:00
""" Implementation of RotaryEmbedding from GPT-NeoX.
2023-09-29 14:32:25 +00:00
This implementation is designed to operate on queries and keys that are compatible with ` [ batch_size ,
n_heads_per_partition , seq_len , head_dim ] ` ( e . g . MinGPTAttention format ) .
2023-05-24 12:17:53 +00:00
"""
2023-09-29 14:32:25 +00:00
def __init__ ( self , head_dim : int , base = 10000 ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( )
inv_freq = 1.0 / ( base * * ( torch . arange ( 0 , head_dim , 2 ) . float ( ) / head_dim ) )
self . register_buffer ( " inv_freq " , inv_freq , persistent = False )
self . head_dim = head_dim
2023-09-29 14:32:25 +00:00
self . seq_len_cached = - 1
2023-05-24 12:17:53 +00:00
self . cos_cached : torch . Tensor | None = None
self . sin_cached : torch . Tensor | None = None
2023-09-29 14:32:25 +00:00
def cos_sin ( self , seq_len : int , past_key_values_length : int , device = " cpu " , dtype = torch . bfloat16 ) - > torch . Tensor :
total_length = seq_len + past_key_values_length
if total_length > self . seq_len_cached :
self . seq_len_cached = total_length
t = torch . arange ( total_length , device = device , dtype = self . inv_freq . dtype )
2023-05-24 12:17:53 +00:00
freqs = torch . einsum ( " i,j->ij " , t , self . inv_freq )
emb = torch . cat ( ( freqs , freqs ) , dim = - 1 ) . to ( device )
if dtype in [ torch . float16 , torch . bfloat16 ] :
emb = emb . float ( )
self . cos_cached = emb . cos ( ) [ None , : , : ]
self . sin_cached = emb . sin ( ) [ None , : , : ]
self . cos_cached = self . cos_cached . type ( dtype )
self . sin_cached = self . sin_cached . type ( dtype )
2023-09-29 14:32:25 +00:00
return (
self . cos_cached [ : , past_key_values_length : seq_len + past_key_values_length ] ,
self . sin_cached [ : , past_key_values_length : seq_len + past_key_values_length ] ,
)
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
def forward ( self , query , key , past_key_values_length = 0 ) :
batch , seq_len , head_dim = query . shape
cos , sin = self . cos_sin ( seq_len , past_key_values_length , query . device , query . dtype )
return ( query * cos ) + ( rotate_half ( query ) * sin ) , ( key * cos ) + ( rotate_half ( key ) * sin )
2023-05-24 12:17:53 +00:00
def _make_causal_mask (
input_ids_shape : torch . Size , device : torch . device , past_key_values_length : int
) - > torch . BoolTensor :
2023-09-29 14:32:25 +00:00
"""
Make causal mask used for self - attention . This mask does not take the existing attention mask into account - it
just blocks tokens from attending forwards in the sequence . The output shape will be ` [ batch_size , 1 ,
target_length , target_length + past_key_values_length ] ` .
"""
2023-05-24 12:17:53 +00:00
batch_size , target_length = input_ids_shape
2023-09-29 14:32:25 +00:00
mask = torch . triu ( torch . ones ( ( target_length , target_length ) , dtype = torch . bool , device = device ) , diagonal = 1 )
# If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
# This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
# way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
past_mask = torch . zeros ( ( target_length , past_key_values_length ) , dtype = torch . bool , device = device )
mask = torch . cat ( [ past_mask , mask ] , dim = - 1 )
2023-05-24 12:17:53 +00:00
expanded_mask = mask [ None , None , : , : ] . expand ( batch_size , 1 , target_length , target_length + past_key_values_length )
return expanded_mask
2023-09-29 14:32:25 +00:00
def _expand_mask ( mask : torch . Tensor , past_key_values_length : int ) - > torch . BoolTensor :
"""
Expands attention_mask from ` [ batch_size , seq_length ] ` to ` [ batch_size , 1 , seq_length , seq_length + past_length ] ` .
"""
batch_size , total_length = mask . shape
seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
2023-05-24 12:17:53 +00:00
expanded_mask = ~ ( mask [ : , None , None , : ] . to ( torch . bool ) )
2023-09-29 14:32:25 +00:00
return expanded_mask . expand ( batch_size , 1 , seq_length , total_length )
2023-05-24 12:17:53 +00:00
def build_alibi_tensor ( attention_mask : torch . Tensor , num_heads : int , dtype : torch . dtype ) - > torch . Tensor :
batch_size , seq_length = attention_mask . shape
closest_power_of_2 = 2 * * math . floor ( math . log2 ( num_heads ) )
base = torch . tensor (
2 * * ( - ( 2 * * - ( math . log2 ( closest_power_of_2 ) - 3 ) ) ) , device = attention_mask . device , dtype = torch . float32
)
powers = torch . arange ( 1 , 1 + closest_power_of_2 , device = attention_mask . device , dtype = torch . int32 )
slopes = torch . pow ( base , powers )
if closest_power_of_2 != num_heads :
extra_base = torch . tensor (
2 * * ( - ( 2 * * - ( math . log2 ( 2 * closest_power_of_2 ) - 3 ) ) ) , device = attention_mask . device , dtype = torch . float32
)
num_remaining_heads = min ( closest_power_of_2 , num_heads - closest_power_of_2 )
extra_powers = torch . arange ( 1 , 1 + 2 * num_remaining_heads , 2 , device = attention_mask . device , dtype = torch . int32 )
slopes = torch . cat ( [ slopes , torch . pow ( extra_base , extra_powers ) ] , dim = 0 )
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ( ( attention_mask . cumsum ( dim = - 1 ) - 1 ) * attention_mask ) [ : , None , : ]
alibi = slopes [ . . . , None ] . bfloat16 ( ) * arange_tensor
return alibi . reshape ( batch_size * num_heads , 1 , seq_length ) . to ( dtype )
2023-09-29 14:32:25 +00:00
# Copied from transformers.models.bloom.modeling_bloom.dropout_add
2023-05-24 12:17:53 +00:00
def dropout_add ( x : torch . Tensor , residual : torch . Tensor , prob : float , training : bool ) - > torch . Tensor :
2023-09-29 14:32:25 +00:00
"""
Dropout add function
Args :
x ( ` torch . tensor ` , * required * ) :
input tensor
residual ( ` torch . tensor ` , * required * ) :
residual tensor
prob ( ` float ` , * required * ) :
dropout probability
training ( ` bool ` , * required * ) :
training mode
"""
2023-05-24 12:17:53 +00:00
out = F . dropout ( x , p = prob , training = training )
out = residual + out
return out
2023-09-29 14:32:25 +00:00
class FalconAttention ( nn . Module ) :
def __init__ ( self , config : FalconConfig ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
2023-09-29 14:32:25 +00:00
self . num_heads = config . num_attention_heads
2023-05-24 12:17:53 +00:00
self . head_dim = self . hidden_size / / self . num_heads
self . split_size = self . hidden_size
self . hidden_dropout = config . hidden_dropout
if self . head_dim * self . num_heads != self . hidden_size :
raise ValueError (
f " `hidden_size` must be divisible by num_heads (got `hidden_size`: { self . hidden_size } and `num_heads`: "
f " { self . num_heads } ). "
)
2023-09-29 14:32:25 +00:00
self . maybe_rotary = FalconRotaryEmbedding ( config . head_dim ) if config . rotary else lambda q , k , t : ( q , k )
2023-05-24 12:17:53 +00:00
# Layer-wise attention scaling
self . inv_norm_factor = 1.0 / math . sqrt ( self . head_dim )
self . beta = self . inv_norm_factor
2023-09-29 14:32:25 +00:00
if config . new_decoder_architecture :
qkv_out_dim = ( config . num_kv_heads * 2 + config . num_attention_heads ) * self . head_dim
elif config . multi_query :
qkv_out_dim = self . hidden_size + 2 * self . head_dim
else :
qkv_out_dim = 3 * self . hidden_size
self . query_key_value = FalconLinear ( self . hidden_size , qkv_out_dim , bias = config . bias )
self . new_decoder_architecture = config . new_decoder_architecture
self . multi_query = config . multi_query
self . dense = FalconLinear ( self . hidden_size , self . hidden_size , bias = config . bias )
2023-05-24 12:17:53 +00:00
self . attention_dropout = nn . Dropout ( config . attention_dropout )
2023-09-29 14:32:25 +00:00
self . num_kv_heads = config . num_kv_heads if ( self . new_decoder_architecture or not self . multi_query ) else 1
2023-05-24 12:17:53 +00:00
def _split_heads ( self , fused_qkv : torch . Tensor ) - > Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
"""
2023-09-29 14:32:25 +00:00
Split the last dimension into ( num_heads , head_dim ) , results share same memory storage as ` fused_qkv `
2023-05-24 12:17:53 +00:00
Args :
fused_qkv ( ` torch . tensor ` , * required * ) : [ batch_size , seq_length , num_heads * 3 * head_dim ]
Returns :
2023-09-29 14:32:25 +00:00
query : [ batch_size , seq_length , num_heads , head_dim ] key : [ batch_size , seq_length , num_heads , head_dim ]
2023-05-24 12:17:53 +00:00
value : [ batch_size , seq_length , num_heads , head_dim ]
"""
2023-09-29 14:32:25 +00:00
if self . new_decoder_architecture :
batch , seq_len , _ = fused_qkv . shape
qkv = fused_qkv . view ( batch , seq_len , - 1 , self . num_heads / / self . num_kv_heads + 2 , self . head_dim )
query = qkv [ : , : , : , : - 2 ]
key = qkv [ : , : , : , [ - 2 ] ]
value = qkv [ : , : , : , [ - 1 ] ]
key = torch . broadcast_to ( key , query . shape )
value = torch . broadcast_to ( value , query . shape )
query , key , value = [ x . flatten ( 2 , 3 ) for x in ( query , key , value ) ]
return query , key , value
elif not self . multi_query :
batch_size , seq_length , three_times_hidden_size = fused_qkv . shape
fused_qkv = fused_qkv . view ( batch_size , seq_length , self . num_heads , 3 , self . head_dim )
return fused_qkv [ . . . , 0 , : ] , fused_qkv [ . . . , 1 , : ] , fused_qkv [ . . . , 2 , : ]
else :
batch_size , seq_length , three_times_hidden_size = fused_qkv . shape
fused_qkv = fused_qkv . view ( batch_size , seq_length , self . num_heads + 2 , self . head_dim )
return fused_qkv [ . . . , : - 2 , : ] , fused_qkv [ . . . , [ - 2 ] , : ] , fused_qkv [ . . . , [ - 1 ] , : ]
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
2023-05-24 12:17:53 +00:00
def _merge_heads ( self , x : torch . Tensor ) - > torch . Tensor :
"""
Merge heads together over the last dimenstion
Args :
2023-09-29 14:32:25 +00:00
x ( ` torch . tensor ` , * required * ) : [ batch_size * num_heads , seq_length , head_dim ]
2023-05-24 12:17:53 +00:00
Returns :
torch . tensor : [ batch_size , seq_length , num_heads * head_dim ]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads , seq_length , _ = x . shape
batch_size = batch_size_and_num_heads / / self . num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x . view ( batch_size , self . num_heads , seq_length , self . head_dim )
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x . permute ( 0 , 2 , 1 , 3 )
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x . reshape ( batch_size , seq_length , self . num_heads * self . head_dim )
def forward (
self ,
hidden_states : torch . Tensor ,
2023-09-29 14:32:25 +00:00
alibi : Optional [ torch . Tensor ] ,
2023-05-24 12:17:53 +00:00
attention_mask : torch . Tensor ,
layer_past : Optional [ Tuple [ torch . Tensor , torch . Tensor ] ] = None ,
head_mask : Optional [ torch . Tensor ] = None ,
use_cache : bool = False ,
output_attentions : bool = False ,
) :
fused_qkv = self . query_key_value ( hidden_states ) # [batch_size, seq_length, 3 x hidden_size]
2023-09-29 14:32:25 +00:00
num_kv_heads = self . num_heads if self . new_decoder_architecture else self . num_kv_heads
2023-05-24 12:17:53 +00:00
# 3 x [batch_size, seq_length, num_heads, head_dim]
( query_layer , key_layer , value_layer ) = self . _split_heads ( fused_qkv )
2023-09-29 14:32:25 +00:00
batch_size , query_length , _ , _ = query_layer . shape
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
query_layer = query_layer . transpose ( 1 , 2 ) . reshape ( batch_size * self . num_heads , query_length , self . head_dim )
2023-05-24 12:17:53 +00:00
key_layer = key_layer . transpose ( 1 , 2 ) . reshape (
2023-09-29 14:32:25 +00:00
batch_size * num_kv_heads ,
query_length ,
2023-05-24 12:17:53 +00:00
self . head_dim ,
)
2023-09-29 14:32:25 +00:00
value_layer = value_layer . transpose ( 1 , 2 ) . reshape ( batch_size * num_kv_heads , query_length , self . head_dim )
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
past_kv_length = 0 if layer_past is None else layer_past [ 0 ] . shape [ 1 ]
query_layer , key_layer = self . maybe_rotary ( query_layer , key_layer , past_kv_length )
2023-05-24 12:17:53 +00:00
if layer_past is not None :
past_key , past_value = layer_past
# concatenate along seq_length dimension:
2023-09-29 14:32:25 +00:00
# - key: [batch_size * self.num_heads, kv_length, head_dim]
2023-05-24 12:17:53 +00:00
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch . cat ( ( past_key , key_layer ) , dim = 1 )
value_layer = torch . cat ( ( past_value , value_layer ) , dim = 1 )
_ , kv_length , _ = key_layer . shape
2023-09-29 14:32:25 +00:00
if use_cache :
2023-05-24 12:17:53 +00:00
present = ( key_layer , value_layer )
else :
present = None
2023-09-29 14:32:25 +00:00
attention_mask_float = ( attention_mask * 1.0 ) . masked_fill ( attention_mask , float ( " -1e9 " ) ) . to ( query_layer . dtype )
query_layer_ = query_layer . reshape ( batch_size , self . num_heads , - 1 , self . head_dim )
key_layer_ = key_layer . reshape ( batch_size , num_kv_heads , - 1 , self . head_dim )
value_layer_ = value_layer . reshape ( batch_size , num_kv_heads , - 1 , self . head_dim )
2023-05-24 12:17:53 +00:00
if alibi is None :
2023-09-29 14:32:25 +00:00
if output_attentions :
# F.scaled_dot_product_attention doesn't return the attention weights, so we have
# to do it by hand if we want them
attention_scores = query_layer_ @ key_layer_ . transpose ( - 1 , - 2 )
attention_scores / = math . sqrt ( self . head_dim )
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
attention_scores = F . softmax (
attention_scores + attention_mask_float , dim = - 1 , dtype = hidden_states . dtype
)
attn_output = attention_scores @ value_layer_
else :
attn_output = F . scaled_dot_product_attention (
query_layer_ , key_layer_ , value_layer_ , attention_mask_float , 0.0 , is_causal = False
)
attention_scores = None
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
attn_output = attn_output . view ( batch_size , self . num_heads , query_length , self . head_dim )
attn_output = attn_output . permute ( 0 , 2 , 1 , 3 )
attn_output = attn_output . reshape ( batch_size , query_length , self . num_heads * self . head_dim )
2023-05-24 12:17:53 +00:00
output_tensor = self . dense ( attn_output )
2023-09-29 14:32:25 +00:00
if output_attentions :
return output_tensor , present , attention_scores
else :
return output_tensor , present
2023-05-24 12:17:53 +00:00
else :
2023-09-29 14:32:25 +00:00
matmul_result = query_layer_ @ key_layer_ . transpose ( - 1 , - 2 )
2023-05-24 12:17:53 +00:00
# change view to [batch_size, num_heads, q_length, kv_length]
2023-09-29 14:32:25 +00:00
attention_scores = matmul_result . view ( batch_size , self . num_heads , query_length , kv_length )
2023-05-24 12:17:53 +00:00
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores . dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch . float16 or input_dtype == torch . bfloat16 :
attention_scores = attention_scores . to ( torch . float32 )
2023-09-29 14:32:25 +00:00
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi . view ( batch_size , self . num_heads , 1 , - 1 )
attention_logits * = self . inv_norm_factor
attention_probs = F . softmax ( attention_logits + attention_mask_float , dim = - 1 , dtype = hidden_states . dtype )
2023-05-24 12:17:53 +00:00
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self . attention_dropout ( attention_probs )
if head_mask is not None :
attention_probs = attention_probs * head_mask
2023-09-29 14:32:25 +00:00
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs . view ( batch_size , self . num_heads , query_length , kv_length )
2023-05-24 12:17:53 +00:00
# matmul: [batch_size * num_heads, q_length, head_dim]
2023-09-29 14:32:25 +00:00
context_layer = ( attention_probs_reshaped @ value_layer_ ) . flatten ( 0 , 1 )
2023-05-24 12:17:53 +00:00
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = self . _merge_heads ( context_layer )
output_tensor = self . dense ( context_layer )
if output_attentions :
2023-09-29 14:32:25 +00:00
return output_tensor , present , attention_probs
else :
return output_tensor , present
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
class FalconMLP ( nn . Module ) :
def __init__ ( self , config : FalconConfig ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( )
hidden_size = config . hidden_size
2023-09-29 14:32:25 +00:00
self . dense_h_to_4h = FalconLinear ( hidden_size , 4 * hidden_size , bias = config . bias )
2023-05-24 12:17:53 +00:00
self . act = nn . GELU ( )
2023-09-29 14:32:25 +00:00
self . dense_4h_to_h = FalconLinear ( 4 * hidden_size , hidden_size , bias = config . bias )
2023-05-24 12:17:53 +00:00
self . hidden_dropout = config . hidden_dropout
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
x = self . act ( self . dense_h_to_4h ( x ) )
x = self . dense_4h_to_h ( x )
return x
2023-09-29 14:32:25 +00:00
class FalconDecoderLayer ( nn . Module ) :
def __init__ ( self , config : FalconConfig ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( )
hidden_size = config . hidden_size
2023-09-29 14:32:25 +00:00
self . num_heads = config . num_attention_heads
self . self_attention = FalconAttention ( config )
self . mlp = FalconMLP ( config )
2023-05-24 12:17:53 +00:00
self . hidden_dropout = config . hidden_dropout
2023-07-13 13:52:11 +00:00
self . config = config
2023-07-12 21:33:10 +00:00
2023-09-29 14:32:25 +00:00
if config . new_decoder_architecture :
# The layer norm before self-attention
self . ln_attn = LayerNorm ( hidden_size , eps = config . layer_norm_epsilon )
# The layer norm before the MLP
self . ln_mlp = LayerNorm ( hidden_size , eps = config . layer_norm_epsilon )
else :
self . input_layernorm = LayerNorm ( hidden_size , eps = config . layer_norm_epsilon )
if not config . parallel_attn :
self . post_attention_layernorm = LayerNorm ( hidden_size , eps = config . layer_norm_epsilon )
2023-05-24 12:17:53 +00:00
def forward (
self ,
hidden_states : torch . Tensor ,
2023-09-29 14:32:25 +00:00
alibi : Optional [ torch . Tensor ] ,
2023-05-24 12:17:53 +00:00
attention_mask : torch . Tensor ,
layer_past : Optional [ Tuple [ torch . Tensor , torch . Tensor ] ] = None ,
head_mask : Optional [ torch . Tensor ] = None ,
use_cache : bool = False ,
output_attentions : bool = False ,
) :
2023-07-13 13:52:11 +00:00
residual = hidden_states
2023-07-12 21:33:10 +00:00
2023-09-29 14:32:25 +00:00
if self . config . new_decoder_architecture :
attention_layernorm_out = self . ln_attn ( hidden_states )
mlp_layernorm_out = self . ln_mlp ( hidden_states )
else :
attention_layernorm_out = self . input_layernorm ( hidden_states )
2023-05-24 12:17:53 +00:00
# Self attention.
attn_outputs = self . self_attention (
2023-09-29 14:32:25 +00:00
attention_layernorm_out ,
2023-05-24 12:17:53 +00:00
layer_past = layer_past ,
attention_mask = attention_mask ,
alibi = alibi ,
head_mask = head_mask ,
use_cache = use_cache ,
output_attentions = output_attentions ,
)
attention_output = attn_outputs [ 0 ]
2023-09-29 14:32:25 +00:00
if not self . config . new_decoder_architecture :
if self . config . parallel_attn :
mlp_layernorm_out = attention_layernorm_out
else :
residual = dropout_add (
attention_output , residual , self . config . attention_dropout , training = self . training
)
mlp_layernorm_out = self . post_attention_layernorm ( residual )
2023-05-24 12:17:53 +00:00
outputs = attn_outputs [ 1 : ]
# MLP.
2023-09-29 14:32:25 +00:00
mlp_output = self . mlp ( mlp_layernorm_out )
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
if self . config . new_decoder_architecture or self . config . parallel_attn :
mlp_output + = attention_output
output = dropout_add ( mlp_output , residual , self . config . hidden_dropout , training = self . training )
2023-05-24 12:17:53 +00:00
if use_cache :
outputs = ( output , ) + outputs
else :
outputs = ( output , ) + outputs [ 1 : ]
return outputs # hidden_states, present, attentions
2023-09-29 14:32:25 +00:00
FALCON_START_DOCSTRING = r """
This model inherits from [ ` PreTrainedModel ` ] . Check the superclass documentation for the generic methods the
library implements for all its model ( such as downloading or saving , resizing the input embeddings etc . )
This model is also a PyTorch [ torch . nn . Module ] ( https : / / pytorch . org / docs / stable / nn . html #torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior .
Parameters :
config ( [ ` FalconConfig ` ] ) : Model configuration class with all the parameters of the model .
Initializing with a config file does not load the weights associated with the model , only the
configuration . Check out the [ ` ~ PreTrainedModel . from_pretrained ` ] method to load the model weights .
"""
FALCON_INPUTS_DOCSTRING = r """
Args :
input_ids ( ` torch . LongTensor ` of shape ` ( batch_size , input_ids_length ) ` ) :
` input_ids_length ` = ` sequence_length ` if ` past_key_values ` is ` None ` else ` past_key_values [ 0 ] [ 0 ] . shape [ 2 ] `
( ` sequence_length ` of input past key value states ) . Indices of input sequence tokens in the vocabulary .
If ` past_key_values ` is used , only ` input_ids ` that do not have their past calculated should be passed as
` input_ids ` .
Indices can be obtained using [ ` AutoTokenizer ` ] . See [ ` PreTrainedTokenizer . encode ` ] and
[ ` PreTrainedTokenizer . __call__ ` ] for details .
[ What are input IDs ? ] ( . . / glossary #input-ids)
past_key_values ( ` Tuple [ Tuple [ torch . Tensor ] ] ` of length ` config . num_hidden_layers ` ) :
Contains precomputed hidden - states ( key and values in the attention blocks ) as computed by the model ( see
` past_key_values ` output below ) . Can be used to speed up sequential decoding . The ` input_ids ` which have
their past given to this model should not be passed as ` input_ids ` as they have already been computed .
Each element of ` past_key_values ` is a tuple ( past_key , past_value ) :
- past_key : [ batch_size * num_heads , head_dim , kv_length ]
- past_value : [ batch_size * num_heads , kv_length , head_dim ]
attention_mask ( ` torch . FloatTensor ` of shape ` ( batch_size , sequence_length ) ` , * optional * ) :
Mask to avoid performing attention on padding token indices . Mask values selected in ` [ 0 , 1 ] ` :
- 1 for tokens that are * * not masked * * ,
- 0 for tokens that are * * masked * * .
[ What are attention masks ? ] ( . . / glossary #attention-mask)
head_mask ( ` torch . FloatTensor ` of shape ` ( num_heads , ) ` or ` ( num_layers , num_heads ) ` , * optional * ) :
Mask to nullify selected heads of the self - attention modules . Mask values selected in ` [ 0 , 1 ] ` :
- 1 indicates the head is * * not masked * * ,
- 0 indicates the head is * * masked * * .
inputs_embeds ( ` torch . FloatTensor ` of shape ` ( batch_size , sequence_length , hidden_size ) ` , * optional * ) :
Optionally , instead of passing ` input_ids ` you can choose to directly pass an embedded representation . This
is useful if you want more control over how to convert ` input_ids ` indices into associated vectors than the
model ' s internal embedding lookup matrix.
If ` past_key_values ` is used , optionally only the last ` inputs_embeds ` have to be input ( see
` past_key_values ` ) .
use_cache ( ` bool ` , * optional * ) :
If set to ` True ` , ` past_key_values ` key value states are returned and can be used to speed up decoding ( see
` past_key_values ` ) .
output_attentions ( ` bool ` , * optional * ) :
Whether or not to return the attentions tensors of all attention layers . See ` attentions ` under returned
tensors for more detail .
output_hidden_states ( ` bool ` , * optional * ) :
Whether or not to return the hidden states of all layers . See ` hidden_states ` under returned tensors for
more detail .
return_dict ( ` bool ` , * optional * ) :
Whether or not to return a [ ` ~ file_utils . ModelOutput ` ] instead of a plain tuple .
"""
class FalconPreTrainedModel ( PreTrainedModel ) :
2023-05-24 12:17:53 +00:00
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models .
"""
2023-09-29 14:32:25 +00:00
config_class = FalconConfig
2023-05-24 12:17:53 +00:00
base_model_prefix = " transformer "
supports_gradient_checkpointing = True
2023-09-29 14:32:25 +00:00
_no_split_modules = [ " FalconDecoderLayer " ]
2023-05-24 12:17:53 +00:00
def __init__ ( self , * inputs , * * kwargs ) :
super ( ) . __init__ ( * inputs , * * kwargs )
def _init_weights ( self , module : nn . Module ) :
""" Initialize the weights. """
2023-09-29 14:32:25 +00:00
if isinstance ( module , nn . Linear ) or isinstance ( module , FalconLinear ) :
2023-05-24 12:17:53 +00:00
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module . weight . data . normal_ ( mean = 0.0 , std = self . config . initializer_range )
if module . bias is not None :
module . bias . data . zero_ ( )
elif isinstance ( module , nn . Embedding ) :
module . weight . data . normal_ ( mean = 0.0 , std = self . config . initializer_range )
if module . padding_idx is not None :
module . weight . data [ module . padding_idx ] . zero_ ( )
elif isinstance ( module , LayerNorm ) :
module . bias . data . zero_ ( )
module . weight . data . fill_ ( 1.0 )
2023-09-29 14:32:25 +00:00
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
2023-05-24 12:17:53 +00:00
def _set_gradient_checkpointing ( self , module : nn . Module , value : bool = False ) :
2023-09-29 14:32:25 +00:00
if isinstance ( module , FalconModel ) :
2023-05-24 12:17:53 +00:00
module . gradient_checkpointing = value
@staticmethod
2023-09-29 14:32:25 +00:00
def _convert_cache_to_standard_format (
2023-05-24 12:17:53 +00:00
past_key_value : Tuple [ Tuple [ torch . Tensor , torch . Tensor ] ] , batch_size : int
) - > Tuple [ Tuple [ torch . Tensor , torch . Tensor ] ] :
"""
Standardizes the format of the cache so as to match most implementations , i . e . to tuple ( tuple ( [ batch_size ,
num_heads , . . . ] ) )
"""
2023-09-29 14:32:25 +00:00
batch_size_times_num_heads , kv_length , head_dim = past_key_value [ 0 ] [ 0 ] . shape
# [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
# Note that don't want to use self.num_attention_heads because the number of heads may vary depending
# on whether we use multi_query attention.
2023-05-24 12:17:53 +00:00
num_heads = batch_size_times_num_heads / / batch_size
return tuple (
(
2023-09-29 14:32:25 +00:00
layer_past [ 0 ] . view ( batch_size , num_heads , kv_length , head_dim ) ,
layer_past [ 1 ] . view ( batch_size , num_heads , kv_length , head_dim ) ,
2023-05-24 12:17:53 +00:00
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_rw_cache (
past_key_value : Tuple [ Tuple [ torch . Tensor , torch . Tensor ] ]
) - > Tuple [ Tuple [ torch . Tensor , torch . Tensor ] ] :
2023-09-29 14:32:25 +00:00
batch_size , num_heads , kv_length , head_dim = past_key_value [ 0 ] [ 0 ] . shape
2023-05-24 12:17:53 +00:00
batch_size_times_num_heads = batch_size * num_heads
2023-09-29 14:32:25 +00:00
# [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
2023-05-24 12:17:53 +00:00
return tuple (
(
2023-09-29 14:32:25 +00:00
layer_past [ 0 ] . view ( batch_size_times_num_heads , kv_length , head_dim ) ,
layer_past [ 1 ] . view ( batch_size_times_num_heads , kv_length , head_dim ) ,
2023-05-24 12:17:53 +00:00
)
for layer_past in past_key_value
)
2023-09-29 14:32:25 +00:00
@add_start_docstrings (
" The bare Falcon Model transformer outputting raw hidden-states without any specific head on top. " ,
FALCON_START_DOCSTRING ,
)
class FalconModel ( FalconPreTrainedModel ) :
def __init__ ( self , config : FalconConfig ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( config )
self . embed_dim = config . hidden_size
2023-09-29 14:32:25 +00:00
self . num_heads = config . num_attention_heads
self . use_alibi = config . alibi
2023-05-24 12:17:53 +00:00
# Embedding + LN Embedding
self . word_embeddings = nn . Embedding ( config . vocab_size , self . embed_dim )
# Transformer blocks
2023-09-29 14:32:25 +00:00
self . h = nn . ModuleList ( [ FalconDecoderLayer ( config ) for _ in range ( config . num_hidden_layers ) ] )
2023-05-24 12:17:53 +00:00
# Final Layer Norm
self . ln_f = LayerNorm ( self . embed_dim , eps = config . layer_norm_epsilon )
self . gradient_checkpointing = False
# Initialize weights and apply final processing
self . post_init ( )
def get_input_embeddings ( self ) :
return self . word_embeddings
2023-09-29 14:32:25 +00:00
@staticmethod
2023-05-24 12:17:53 +00:00
def _prepare_attn_mask (
2023-09-29 14:32:25 +00:00
attention_mask : torch . Tensor , input_shape : Tuple [ int , int ] , past_key_values_length : int
2023-05-24 12:17:53 +00:00
) - > torch . BoolTensor :
2023-09-29 14:32:25 +00:00
# Create a causal mask
# The attention mask we receive as input should cover the whole extended sequence, including any past
# cache, so its shape should be [batch_size, seq_length + past_key_values_length]
# The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
if input_shape [ 1 ] + past_key_values_length != attention_mask . shape [ 1 ] :
raise ValueError (
" Attention mask shape should be (batch_size, seq_length + past_key_values_length) "
f " but is { attention_mask . shape } with input_ids shape { input_shape } and past length "
f " { past_key_values_length } . "
)
2023-05-24 12:17:53 +00:00
combined_attention_mask = None
device = attention_mask . device
2023-09-29 14:32:25 +00:00
_ , seq_length = input_shape
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
if seq_length > 1 :
2023-05-24 12:17:53 +00:00
combined_attention_mask = _make_causal_mask (
input_shape , device = device , past_key_values_length = past_key_values_length
)
2023-09-29 14:32:25 +00:00
# [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
expanded_attn_mask = _expand_mask ( attention_mask , past_key_values_length = past_key_values_length )
2023-05-24 12:17:53 +00:00
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings ( self , new_embeddings : torch . Tensor ) :
self . word_embeddings = new_embeddings
2023-09-29 14:32:25 +00:00
@add_start_docstrings_to_model_forward ( FALCON_INPUTS_DOCSTRING )
@add_code_sample_docstrings (
checkpoint = _CHECKPOINT_FOR_DOC ,
output_type = BaseModelOutputWithPastAndCrossAttentions ,
config_class = _CONFIG_FOR_DOC ,
)
2023-05-24 12:17:53 +00:00
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
head_mask : Optional [ torch . LongTensor ] = None ,
inputs_embeds : Optional [ torch . LongTensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple [ torch . Tensor , . . . ] , BaseModelOutputWithPastAndCrossAttentions ] :
output_attentions = output_attentions if output_attentions is not None else self . config . output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self . config . output_hidden_states
)
use_cache = use_cache if use_cache is not None else self . config . use_cache
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
if input_ids is not None and inputs_embeds is not None :
raise ValueError ( " You cannot specify both input_ids and inputs_embeds at the same time " )
elif input_ids is not None :
batch_size , seq_length = input_ids . shape
elif inputs_embeds is not None :
batch_size , seq_length , _ = inputs_embeds . shape
else :
raise ValueError ( " You have to specify either input_ids or inputs_embeds " )
if past_key_values is None :
past_key_values = tuple ( [ None ] * len ( self . h ) )
2023-09-29 14:32:25 +00:00
else :
past_key_values = self . _convert_to_rw_cache ( past_key_values )
2023-05-24 12:17:53 +00:00
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
2023-09-29 14:32:25 +00:00
head_mask = self . get_head_mask ( head_mask , self . config . num_hidden_layers )
2023-05-24 12:17:53 +00:00
if inputs_embeds is None :
inputs_embeds = self . word_embeddings ( input_ids )
hidden_states = inputs_embeds
presents = ( ) if use_cache else None
all_self_attentions = ( ) if output_attentions else None
all_hidden_states = ( ) if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values [ 0 ] is not None :
2023-09-29 14:32:25 +00:00
past_key_values_length = past_key_values [ 0 ] [ 0 ] . shape [ 1 ] # 1 because RW-cache, not standard format
2023-05-24 12:17:53 +00:00
if attention_mask is None :
2023-09-29 14:32:25 +00:00
attention_mask = torch . ones ( ( batch_size , seq_length + past_key_values_length ) , device = hidden_states . device )
2023-05-24 12:17:53 +00:00
else :
attention_mask = attention_mask . to ( hidden_states . device )
2023-09-29 14:32:25 +00:00
if self . use_alibi :
2023-05-24 12:17:53 +00:00
alibi = build_alibi_tensor ( attention_mask , self . num_heads , dtype = hidden_states . dtype )
else :
alibi = None
causal_mask = self . _prepare_attn_mask (
attention_mask ,
input_shape = ( batch_size , seq_length ) ,
past_key_values_length = past_key_values_length ,
)
for i , ( block , layer_past ) in enumerate ( zip ( self . h , past_key_values ) ) :
if output_hidden_states :
all_hidden_states = all_hidden_states + ( hidden_states , )
if self . gradient_checkpointing and self . training :
if use_cache :
logger . warning (
" `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... "
)
use_cache = False
def create_custom_forward ( module ) :
def custom_forward ( * inputs ) :
# None for past_key_value
return module ( * inputs , use_cache = use_cache , output_attentions = output_attentions )
return custom_forward
outputs = torch . utils . checkpoint . checkpoint (
create_custom_forward ( block ) ,
hidden_states ,
alibi ,
causal_mask ,
head_mask [ i ] ,
)
else :
outputs = block (
hidden_states ,
layer_past = layer_past ,
attention_mask = causal_mask ,
head_mask = head_mask [ i ] ,
use_cache = use_cache ,
output_attentions = output_attentions ,
alibi = alibi ,
)
hidden_states = outputs [ 0 ]
if use_cache is True :
presents = presents + ( outputs [ 1 ] , )
if output_attentions :
all_self_attentions = all_self_attentions + ( outputs [ 2 if use_cache else 1 ] , )
# Add last hidden state
hidden_states = self . ln_f ( hidden_states )
if output_hidden_states :
all_hidden_states = all_hidden_states + ( hidden_states , )
2023-09-29 14:32:25 +00:00
if presents is not None :
presents = self . _convert_cache_to_standard_format ( presents , batch_size )
2023-05-24 12:17:53 +00:00
if not return_dict :
return tuple ( v for v in [ hidden_states , presents , all_hidden_states , all_self_attentions ] if v is not None )
return BaseModelOutputWithPastAndCrossAttentions (
last_hidden_state = hidden_states ,
past_key_values = presents ,
hidden_states = all_hidden_states ,
attentions = all_self_attentions ,
)
2023-09-29 14:32:25 +00:00
@add_start_docstrings (
" The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). " ,
FALCON_START_DOCSTRING ,
)
class FalconForCausalLM ( FalconPreTrainedModel ) :
_tied_weights_keys = [ " lm_head.weight " ]
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
def __init__ ( self , config : FalconConfig ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( config )
2023-09-29 14:32:25 +00:00
self . transformer = FalconModel ( config )
2023-05-24 12:17:53 +00:00
self . lm_head = nn . Linear ( config . hidden_size , config . vocab_size , bias = False )
# Initialize weights and apply final processing
self . post_init ( )
def get_output_embeddings ( self ) :
return self . lm_head
def set_output_embeddings ( self , new_embeddings : torch . Tensor ) :
self . lm_head = new_embeddings
def prepare_inputs_for_generation (
self ,
input_ids : torch . LongTensor ,
2023-09-29 14:32:25 +00:00
past_key_values : Optional [ torch . Tensor ] = None ,
2023-05-24 12:17:53 +00:00
attention_mask : Optional [ torch . Tensor ] = None ,
* * kwargs ,
) - > dict :
2023-09-29 14:32:25 +00:00
if past_key_values is not None :
input_ids = input_ids [ : , - 1 : ]
2023-05-24 12:17:53 +00:00
return {
" input_ids " : input_ids ,
2023-09-29 14:32:25 +00:00
" past_key_values " : past_key_values ,
2023-05-24 12:17:53 +00:00
" use_cache " : kwargs . get ( " use_cache " ) ,
" attention_mask " : attention_mask ,
}
2023-09-29 14:32:25 +00:00
@add_start_docstrings_to_model_forward ( FALCON_INPUTS_DOCSTRING )
@add_code_sample_docstrings (
checkpoint = _CHECKPOINT_FOR_DOC ,
output_type = CausalLMOutputWithCrossAttentions ,
config_class = _CONFIG_FOR_DOC ,
)
2023-05-24 12:17:53 +00:00
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
head_mask : Optional [ torch . Tensor ] = None ,
inputs_embeds : Optional [ torch . Tensor ] = None ,
labels : Optional [ torch . Tensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple [ torch . Tensor ] , CausalLMOutputWithCrossAttentions ] :
r """
labels ( ` torch . LongTensor ` of shape ` ( batch_size , sequence_length ) ` , * optional * ) :
Labels for language modeling . Note that the labels * * are shifted * * inside the model , i . e . you can set
` labels = input_ids ` Indices are selected in ` [ - 100 , 0 , . . . , config . vocab_size ] ` All labels set to ` - 100 `
are ignored ( masked ) , the loss is only computed for labels in ` [ 0 , . . . , config . vocab_size ] `
"""
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
transformer_outputs = self . transformer (
input_ids ,
past_key_values = past_key_values ,
attention_mask = attention_mask ,
head_mask = head_mask ,
inputs_embeds = inputs_embeds ,
use_cache = use_cache ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
hidden_states = transformer_outputs [ 0 ]
lm_logits = self . lm_head ( hidden_states )
loss = None
if labels is not None :
# Shift so that tokens < n predict n
shift_logits = lm_logits [ . . . , : - 1 , : ] . contiguous ( )
shift_labels = labels [ . . . , 1 : ] . contiguous ( )
batch_size , seq_length , vocab_size = shift_logits . shape
# Flatten the tokens
loss_fct = CrossEntropyLoss ( )
loss = loss_fct (
shift_logits . view ( batch_size * seq_length , vocab_size ) , shift_labels . view ( batch_size * seq_length )
)
if not return_dict :
output = ( lm_logits , ) + transformer_outputs [ 1 : ]
return ( ( loss , ) + output ) if loss is not None else output
return CausalLMOutputWithCrossAttentions (
loss = loss ,
logits = lm_logits ,
past_key_values = transformer_outputs . past_key_values ,
hidden_states = transformer_outputs . hidden_states ,
attentions = transformer_outputs . attentions ,
)
def _reorder_cache (
self , past : Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] , beam_idx : torch . LongTensor
) - > Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] :
"""
This function is used to re - order the ` past_key_values ` cache if [ ` ~ PreTrainedModel . beam_search ` ] or
[ ` ~ PreTrainedModel . beam_sample ` ] is called . This is required to match ` past_key_values ` with the correct
beam_idx at every generation step .
Output shares the same memory storage as ` past ` .
"""
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state . device : beam_idx . to ( past_state . device ) for layer_past in past for past_state in layer_past
}
reordered_past = tuple (
(
layer_past [ 0 ] . index_select ( 0 , device_to_beam_idx [ layer_past [ 0 ] . device ] ) ,
layer_past [ 1 ] . index_select ( 0 , device_to_beam_idx [ layer_past [ 0 ] . device ] ) ,
)
2023-09-29 14:32:25 +00:00
for layer_past in past
2023-05-24 12:17:53 +00:00
)
2023-09-29 14:32:25 +00:00
return reordered_past
2023-05-24 12:17:53 +00:00
2023-09-29 14:32:25 +00:00
@add_start_docstrings (
"""
The Falcon Model transformer with a sequence classification head on top ( linear layer ) .
[ ` FalconForSequenceClassification ` ] uses the last token in order to do the classification , as other causal models
( e . g . GPT - 1 ) do .
Since it does classification on the last token , it requires to know the position of the last token . If a
` pad_token_id ` is defined in the configuration , it finds the last token that is not a padding token in each row . If
no ` pad_token_id ` is defined , it simply takes the last value in each row of the batch . Since it cannot guess the
padding tokens when ` inputs_embeds ` are passed instead of ` input_ids ` , it does the same ( take the last value in
each row of the batch ) .
""" ,
FALCON_START_DOCSTRING ,
)
class FalconForSequenceClassification ( FalconPreTrainedModel ) :
def __init__ ( self , config : FalconConfig ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( config )
self . num_labels = config . num_labels
2023-09-29 14:32:25 +00:00
self . transformer = FalconModel ( config )
2023-05-24 12:17:53 +00:00
self . score = nn . Linear ( config . hidden_size , config . num_labels , bias = False )
# Initialize weights and apply final processing
self . post_init ( )
2023-09-29 14:32:25 +00:00
@add_start_docstrings_to_model_forward ( FALCON_INPUTS_DOCSTRING )
@add_code_sample_docstrings (
checkpoint = _CHECKPOINT_FOR_DOC ,
output_type = SequenceClassifierOutputWithPast ,
config_class = _CONFIG_FOR_DOC ,
)
2023-05-24 12:17:53 +00:00
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
head_mask : Optional [ torch . Tensor ] = None ,
inputs_embeds : Optional [ torch . Tensor ] = None ,
labels : Optional [ torch . Tensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple [ torch . Tensor ] , SequenceClassifierOutputWithPast ] :
r """
labels ( ` torch . LongTensor ` of shape ` ( batch_size , ) ` , * optional * ) :
Labels for computing the sequence classification / regression loss . Indices should be in ` [ 0 , . . . ,
config . num_labels - 1 ] ` . If ` config . num_labels == 1 ` a regression loss is computed ( Mean - Square loss ) , If
` config . num_labels > 1 ` a classification loss is computed ( Cross - Entropy ) .
"""
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
transformer_outputs = self . transformer (
input_ids ,
past_key_values = past_key_values ,
attention_mask = attention_mask ,
head_mask = head_mask ,
inputs_embeds = inputs_embeds ,
use_cache = use_cache ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
hidden_states = transformer_outputs [ 0 ]
logits = self . score ( hidden_states )
if input_ids is not None :
batch_size = input_ids . shape [ 0 ]
else :
batch_size = inputs_embeds . shape [ 0 ]
if self . config . pad_token_id is None and batch_size != 1 :
raise ValueError ( " Cannot handle batch sizes > 1 if no padding token is defined. " )
if self . config . pad_token_id is None :
sequence_lengths = - 1
else :
if input_ids is not None :
sequence_lengths = torch . ne ( input_ids , self . config . pad_token_id ) . sum ( dim = - 1 ) - 1
else :
sequence_lengths = - 1
logger . warning (
f " { self . __class__ . __name__ } will not detect padding tokens in `inputs_embeds`. Results may be "
" unexpected if using padding tokens in conjunction with `inputs_embeds.` "
)
pooled_logits = logits [ torch . arange ( batch_size , device = logits . device ) , sequence_lengths ]
loss = None
if labels is not None :
if self . config . problem_type is None :
if self . num_labels == 1 :
self . config . problem_type = " regression "
elif self . num_labels > 1 and ( labels . dtype == torch . long or labels . dtype == torch . int ) :
self . config . problem_type = " single_label_classification "
else :
self . config . problem_type = " multi_label_classification "
if self . config . problem_type == " regression " :
loss_fct = MSELoss ( )
if self . num_labels == 1 :
loss = loss_fct ( pooled_logits . squeeze ( ) , labels . squeeze ( ) )
else :
loss = loss_fct ( pooled_logits , labels )
elif self . config . problem_type == " single_label_classification " :
loss_fct = CrossEntropyLoss ( )
loss = loss_fct ( pooled_logits , labels )
elif self . config . problem_type == " multi_label_classification " :
loss_fct = BCEWithLogitsLoss ( )
loss = loss_fct ( pooled_logits , labels )
if not return_dict :
output = ( pooled_logits , ) + transformer_outputs [ 1 : ]
return ( ( loss , ) + output ) if loss is not None else output
return SequenceClassifierOutputWithPast (
loss = loss ,
logits = pooled_logits ,
past_key_values = transformer_outputs . past_key_values ,
hidden_states = transformer_outputs . hidden_states ,
attentions = transformer_outputs . attentions ,
)
2023-09-29 14:32:25 +00:00
@add_start_docstrings (
"""
Falcon Model with a token classification head on top ( a linear layer on top of the hidden - states output ) e . g . for
Named - Entity - Recognition ( NER ) tasks .
""" ,
FALCON_START_DOCSTRING ,
)
class FalconForTokenClassification ( FalconPreTrainedModel ) :
def __init__ ( self , config : FalconConfig ) :
2023-05-24 12:17:53 +00:00
super ( ) . __init__ ( config )
self . num_labels = config . num_labels
2023-09-29 14:32:25 +00:00
self . transformer = FalconModel ( config )
if getattr ( config , " classifier_dropout " , None ) is not None :
2023-05-24 12:17:53 +00:00
classifier_dropout = config . classifier_dropout
2023-09-29 14:32:25 +00:00
elif getattr ( config , " hidden_dropout " , None ) is not None :
2023-05-24 12:17:53 +00:00
classifier_dropout = config . hidden_dropout
else :
classifier_dropout = 0.1
self . dropout = nn . Dropout ( classifier_dropout )
self . classifier = nn . Linear ( config . hidden_size , config . num_labels )
# Initialize weights and apply final processing
self . post_init ( )
2023-09-29 14:32:25 +00:00
@add_start_docstrings_to_model_forward ( FALCON_INPUTS_DOCSTRING )
@add_code_sample_docstrings (
checkpoint = _CHECKPOINT_FOR_DOC ,
output_type = TokenClassifierOutput ,
config_class = _CONFIG_FOR_DOC ,
)
2023-05-24 12:17:53 +00:00
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
head_mask : Optional [ torch . Tensor ] = None ,
inputs_embeds : Optional [ torch . Tensor ] = None ,
labels : Optional [ torch . Tensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple [ torch . Tensor ] , TokenClassifierOutput ] :
r """
labels ( ` torch . LongTensor ` of shape ` ( batch_size , ) ` , * optional * ) :
Labels for computing the sequence classification / regression loss . Indices should be in ` [ 0 , . . . ,
config . num_labels - 1 ] ` . If ` config . num_labels == 1 ` a regression loss is computed ( Mean - Square loss ) , If
` config . num_labels > 1 ` a classification loss is computed ( Cross - Entropy ) .
"""
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
transformer_outputs = self . transformer (
input_ids ,
past_key_values = past_key_values ,
attention_mask = attention_mask ,
head_mask = head_mask ,
inputs_embeds = inputs_embeds ,
use_cache = use_cache ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
hidden_states = transformer_outputs [ 0 ]
hidden_states = self . dropout ( hidden_states )
logits = self . classifier ( hidden_states )
loss = None
if labels is not None :
batch_size , seq_length = labels . shape
loss_fct = CrossEntropyLoss ( )
2023-09-29 14:32:25 +00:00
loss = loss_fct (
logits . view ( batch_size * seq_length , self . num_labels ) , labels . view ( batch_size * seq_length )
)
2023-05-24 12:17:53 +00:00
if not return_dict :
output = ( logits , ) + transformer_outputs [ 2 : ]
return ( ( loss , ) + output ) if loss is not None else output
return TokenClassifierOutput (
loss = loss ,
logits = logits ,
hidden_states = transformer_outputs . hidden_states ,
attentions = transformer_outputs . attentions ,
)
2023-09-29 14:32:25 +00:00
@add_start_docstrings (
"""
The Falcon Model transformer with a span classification head on top for extractive question - answering tasks like
SQuAD ( a linear layers on top of the hidden - states output to compute ` span start logits ` and ` span end logits ` ) .
""" ,
FALCON_START_DOCSTRING ,
)
class FalconForQuestionAnswering ( FalconPreTrainedModel ) :
2023-05-24 12:17:53 +00:00
def __init__ ( self , config ) :
super ( ) . __init__ ( config )
2023-09-29 14:32:25 +00:00
self . transformer = FalconModel ( config )
2023-05-24 12:17:53 +00:00
self . qa_outputs = nn . Linear ( config . hidden_size , 2 )
# Initialize weights and apply final processing
self . post_init ( )
2023-09-29 14:32:25 +00:00
@add_start_docstrings_to_model_forward ( FALCON_INPUTS_DOCSTRING )
2023-05-24 12:17:53 +00:00
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
attention_mask : Optional [ torch . FloatTensor ] = None ,
head_mask : Optional [ torch . FloatTensor ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
start_positions : Optional [ torch . LongTensor ] = None ,
end_positions : Optional [ torch . LongTensor ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple , QuestionAnsweringModelOutput ] :
r """
start_positions ( ` torch . LongTensor ` of shape ` ( batch_size , ) ` , * optional * ) :
Labels for position ( index ) of the start of the labelled span for computing the token classification loss .
Positions are clamped to the length of the sequence ( ` sequence_length ` ) . Position outside of the sequence
are not taken into account for computing the loss .
end_positions ( ` torch . LongTensor ` of shape ` ( batch_size , ) ` , * optional * ) :
Labels for position ( index ) of the end of the labelled span for computing the token classification loss .
Positions are clamped to the length of the sequence ( ` sequence_length ` ) . Position outside of the sequence
are not taken into account for computing the loss .
"""
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
outputs = self . transformer (
input_ids ,
attention_mask = attention_mask ,
head_mask = head_mask ,
inputs_embeds = inputs_embeds ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
sequence_output = outputs [ 0 ]
logits = self . qa_outputs ( sequence_output )
start_logits , end_logits = logits . split ( 1 , dim = - 1 )
start_logits = start_logits . squeeze ( - 1 ) . contiguous ( )
end_logits = end_logits . squeeze ( - 1 ) . contiguous ( )
total_loss = None
if start_positions is not None and end_positions is not None :
# If we are on multi-GPU, split add a dimension
if len ( start_positions . size ( ) ) > 1 :
start_positions = start_positions . squeeze ( - 1 )
if len ( end_positions . size ( ) ) > 1 :
end_positions = end_positions . squeeze ( - 1 )
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits . size ( 1 )
start_positions = start_positions . clamp ( 0 , ignored_index )
end_positions = end_positions . clamp ( 0 , ignored_index )
loss_fct = CrossEntropyLoss ( ignore_index = ignored_index )
start_loss = loss_fct ( start_logits , start_positions )
end_loss = loss_fct ( end_logits , end_positions )
total_loss = ( start_loss + end_loss ) / 2
if not return_dict :
output = ( start_logits , end_logits ) + outputs [ 2 : ]
return ( ( total_loss , ) + output ) if total_loss is not None else output
return QuestionAnsweringModelOutput (
loss = total_loss ,
start_logits = start_logits ,
end_logits = end_logits ,
hidden_states = outputs . hidden_states ,
attentions = outputs . attentions ,
)