Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add a speculative decoding generator #1155

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

awni
Copy link
Member

@awni awni commented Dec 14, 2024

Benchmarks on M3 Max:

Baseline:

mlx_lm.generate --model mlx-community/Qwen2.5-32B-Instruct-4bit --prompt "Write a quick sort in C++" -m 256

Prompt: 36 tokens, 86.936 tokens-per-sec
Generation: 256 tokens, 19.680 tokens-per-sec
Peak memory: 18.573 GB

With speculative decoding:

mlx_lm.generate --model mlx-community/Qwen2.5-32B-Instruct-4bit --prompt "Write a quick sort in C++" -m 256 --draft-model mlx-community/Qwen2.5-0.5B-Instruct-8bit --num-draft-tokens 4

Prompt: 36 tokens, 87.853 tokens-per-sec
Generation: 256 tokens, 35.738 tokens-per-sec
Peak memory: 19.112 GB

The outputs are identical.

A note on the implementation.. it seemed simpler to start to have a separate speculative_generate_step rather than try to merge everything. I might refactor a bit so they can use more functionality. I'm also not sold on wiring this through stream_generate. Could start by having it be a standalone thing that either builds on top of MLX LM or more standalone.. let me know thoughts if any..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant