Speculative Decoding doesn't work yet with Whisper-v3 (#23)
- Speculative Decoding doesn't work yet with Whisper-v3 (3264a14fc680cf148949b17860f045ce2cf81edb)
This commit is contained in:
parent
cf8f9cff50
commit
77b8369da9
51
README.md
51
README.md
@ -258,57 +258,6 @@ result = pipe(sample, return_timestamps=True, generate_kwargs={"language": "fren
|
|||||||
print(result["chunks"])
|
print(result["chunks"])
|
||||||
```
|
```
|
||||||
|
|
||||||
## Speculative Decoding
|
|
||||||
|
|
||||||
Whisper `tiny` can be used as an assistant model to Whisper for speculative decoding. Speculative decoding mathematically
|
|
||||||
ensures the exact same outputs as Whisper are obtained while being 2 times faster. This makes it the perfect drop-in
|
|
||||||
replacement for existing Whisper pipelines, since the same outputs are guaranteed.
|
|
||||||
|
|
||||||
In the following code-snippet, we load the assistant Distil-Whisper model standalone to the main Whisper pipeline. We then
|
|
||||||
specify it as the "assistant model" for generation:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
|
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
||||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
||||||
|
|
||||||
assistant_model_id = "openai/whisper-tiny"
|
|
||||||
|
|
||||||
assistant_model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
|
||||||
)
|
|
||||||
assistant_model.to(device)
|
|
||||||
|
|
||||||
model_id = "openai/whisper-large-v3"
|
|
||||||
|
|
||||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
||||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(model_id)
|
|
||||||
|
|
||||||
pipe = pipeline(
|
|
||||||
"automatic-speech-recognition",
|
|
||||||
model=model,
|
|
||||||
tokenizer=processor.tokenizer,
|
|
||||||
feature_extractor=processor.feature_extractor,
|
|
||||||
max_new_tokens=128,
|
|
||||||
generate_kwargs={"assistant_model": assistant_model},
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
sample = dataset[0]["audio"]
|
|
||||||
|
|
||||||
result = pipe(sample)
|
|
||||||
print(result["text"])
|
|
||||||
```
|
|
||||||
|
|
||||||
## Additional Speed & Memory Improvements
|
## Additional Speed & Memory Improvements
|
||||||
|
|
||||||
You can apply additional speed and memory improvements to Whisper-large-v3 which we cover in the following.
|
You can apply additional speed and memory improvements to Whisper-large-v3 which we cover in the following.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user