From d0aa7ea43dc3548a8b499cc88605e150c4a2704d Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Tue, 25 Apr 2023 21:21:53 +0000 Subject: [PATCH] Update instruct_pipeline.py Send attention_mask to device --- instruct_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/instruct_pipeline.py b/instruct_pipeline.py index cd864b7..f8b2915 100644 --- a/instruct_pipeline.py +++ b/instruct_pipeline.py @@ -131,7 +131,7 @@ class InstructionTextGenerationPipeline(Pipeline): generated_sequence = self.model.generate( input_ids=input_ids.to(self.model.device), - attention_mask=attention_mask, + attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs, )