Update instruct_pipeline.py

This commit is contained in:
Matthew Hayes 2023-04-17 19:28:39 +00:00 committed by huggingface-web
parent 1c11fae95c
commit 758a161dda

@ -1,9 +1,15 @@
import logging import logging
import re import re
from typing import List
import numpy as np import numpy as np
from transformers import Pipeline, PreTrainedTokenizer from transformers import Pipeline, PreTrainedTokenizer
from transformers.utils import is_tf_available
if is_tf_available():
import tensorflow as tf
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
INSTRUCTION_KEY = "### Instruction:" INSTRUCTION_KEY = "### Instruction:"
@ -55,9 +61,22 @@ class InstructionTextGenerationPipeline(Pipeline):
def __init__( def __init__(
self, *args, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs self, *args, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs
): ):
super().__init__(*args, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs) """Initialize the pipeline
def _sanitize_parameters(self, return_instruction_text=False, **generate_kwargs): Args:
do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128.
top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with
probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
Defaults to 0.
"""
super().__init__(*args, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k,
**kwargs)
def _sanitize_parameters(self,
return_full_text: bool = None,
**generate_kwargs):
preprocess_params = {} preprocess_params = {}
# newer versions of the tokenizer configure the response key as a special token. newer versions still may # newer versions of the tokenizer configure the response key as a special token. newer versions still may
@ -81,10 +100,12 @@ class InstructionTextGenerationPipeline(Pipeline):
forward_params = generate_kwargs forward_params = generate_kwargs
postprocess_params = { postprocess_params = {
"response_key_token_id": response_key_token_id, "response_key_token_id": response_key_token_id,
"end_key_token_id": end_key_token_id, "end_key_token_id": end_key_token_id
"return_instruction_text": return_instruction_text,
} }
if return_full_text is not None:
postprocess_params["return_full_text"] = return_full_text
return preprocess_params, forward_params, postprocess_params return preprocess_params, forward_params, postprocess_params
def preprocess(self, instruction_text, **generate_kwargs): def preprocess(self, instruction_text, **generate_kwargs):
@ -100,66 +121,92 @@ class InstructionTextGenerationPipeline(Pipeline):
def _forward(self, model_inputs, **generate_kwargs): def _forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"] input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None) attention_mask = model_inputs.get("attention_mask", None)
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
generated_sequence = self.model.generate( generated_sequence = self.model.generate(
input_ids=input_ids.to(self.model.device), input_ids=input_ids.to(self.model.device),
attention_mask=attention_mask, attention_mask=attention_mask,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
**generate_kwargs, **generate_kwargs,
)[0].cpu() )
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf":
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
instruction_text = model_inputs.pop("instruction_text") instruction_text = model_inputs.pop("instruction_text")
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text} return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_instruction_text): def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_full_text: bool = False):
sequence = model_outputs["generated_sequence"]
generated_sequence = model_outputs["generated_sequence"][0]
instruction_text = model_outputs["instruction_text"] instruction_text = model_outputs["instruction_text"]
# The response will be set to this variable if we can identify it. generated_sequence: List[List[int]] = generated_sequence.numpy().tolist()
decoded = None records = []
for sequence in generated_sequence:
# If we have token IDs for the response and end, then we can find the tokens and only decode between them. # The response will be set to this variable if we can identify it.
if response_key_token_id and end_key_token_id: decoded = None
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
# prompt, we should definitely find it. We will return the tokens found after this token.
response_pos = None
response_positions = np.where(sequence == response_key_token_id)[0]
if len(response_positions) == 0:
logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
else:
response_pos = response_positions[0]
if response_pos: # If we have token IDs for the response and end, then we can find the tokens and only decode between them.
# Next find where "### End" is located. The model has been trained to end its responses with this if response_key_token_id and end_key_token_id:
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find # Find where "### Response:" is first found in the generated tokens. Considering this is part of the
# this token, as the response could be truncated. If we don't find it then just return everything # prompt, we should definitely find it. We will return the tokens found after this token.
# to the end. Note that even though we set eos_token_id, we still see the this token at the end. try:
end_pos = None response_pos = sequence.index(response_key_token_id)
end_positions = np.where(sequence == end_key_token_id)[0] except ValueError:
if len(end_positions) > 0: logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
end_pos = end_positions[0] response_pos = None
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip() if response_pos:
else: # Next find where "### End" is located. The model has been trained to end its responses with this
# Otherwise we'll decode everything and use a regex to find the response and end. # sequence (or actually, the token ID it maps to, since it is a special token). We may not find
# this token, as the response could be truncated. If we don't find it then just return everything
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
try:
end_pos = sequence.index(end_key_token_id)
except ValueError:
end_pos = None
fully_decoded = self.tokenizer.decode(sequence) decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
# The response appears after "### Response:". The model has been trained to append "### End" at the if not decoded:
# end. # Otherwise we'll decode everything and use a regex to find the response and end.
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
fully_decoded = self.tokenizer.decode(sequence)
# The response appears after "### Response:". The model has been trained to append "### End" at the
# end.
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
if m:
decoded = m.group(1).strip()
else:
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
# return everything after "### Response:".
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
if m: if m:
decoded = m.group(1).strip() decoded = m.group(1).strip()
else: else:
logger.warn(f"Failed to find response in:\n{fully_decoded}") # The model might not generate the "### End" sequence before reaching the max tokens. In this case,
# return everything after "### Response:".
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
if m:
decoded = m.group(1).strip()
else:
logger.warn(f"Failed to find response in:\n{fully_decoded}")
if return_instruction_text: # If the full text is requested, then append the decoded text to the original instruction.
return {"instruction_text": instruction_text, "generated_text": decoded} # This technically isn't the full text, as we format the instruction in the prompt the model has been
# trained on, but to the client it will appear to be the full text.
if return_full_text:
decoded = f"{instruction_text}\n{decoded}"
return decoded rec = {"generated_text": decoded}
records.append(rec)
return records