Update instruct_pipeline.py

Send attention_mask to device
This commit is contained in:
Matthew Hayes 2023-04-25 21:21:53 +00:00 committed by huggingface-web
parent a7077365ca
commit d0aa7ea43d

@ -131,7 +131,7 @@ class InstructionTextGenerationPipeline(Pipeline):
generated_sequence = self.model.generate( generated_sequence = self.model.generate(
input_ids=input_ids.to(self.model.device), input_ids=input_ids.to(self.model.device),
attention_mask=attention_mask, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
**generate_kwargs, **generate_kwargs,
) )