Update instruct_pipeline.py
Send attention_mask to device
This commit is contained in:
parent
a7077365ca
commit
d0aa7ea43d
@ -131,7 +131,7 @@ class InstructionTextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
generated_sequence = self.model.generate(
|
generated_sequence = self.model.generate(
|
||||||
input_ids=input_ids.to(self.model.device),
|
input_ids=input_ids.to(self.model.device),
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user