diff --git a/instruct_pipeline.py b/instruct_pipeline.py index cd864b7..f8b2915 100644 --- a/instruct_pipeline.py +++ b/instruct_pipeline.py @@ -131,7 +131,7 @@ class InstructionTextGenerationPipeline(Pipeline): generated_sequence = self.model.generate( input_ids=input_ids.to(self.model.device), - attention_mask=attention_mask, + attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs, )