[WIP] Add a speculative decoding generator #1155
Open
+197
−31
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Benchmarks on M3 Max:
Baseline:
mlx_lm.generate --model mlx-community/Qwen2.5-32B-Instruct-4bit --prompt "Write a quick sort in C++" -m 256
Prompt: 36 tokens, 86.936 tokens-per-sec
Generation: 256 tokens, 19.680 tokens-per-sec
Peak memory: 18.573 GB
With speculative decoding:
Prompt: 36 tokens, 87.853 tokens-per-sec
Generation: 256 tokens, 35.738 tokens-per-sec
Peak memory: 19.112 GB
The outputs are identical.
A note on the implementation.. it seemed simpler to start to have a separate
speculative_generate_step
rather than try to merge everything. I might refactor a bit so they can use more functionality. I'm also not sold on wiring this throughstream_generate
. Could start by having it be a standalone thing that either builds on top of MLX LM or more standalone.. let me know thoughts if any..