Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Enhance MoE benchmarking & tuning script #4921

Merged
merged 17 commits into from
Jun 4, 2024
Merged

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented May 20, 2024

This PR is to enhance the MoE tuning & benchmarking script which is a bit hacky at the moment. Also, the PR enables using multiple GPUs for benchmarking via Ray.

@pcmoritz pcmoritz self-assigned this May 20, 2024
@WoosukKwon
Copy link
Collaborator Author

@pcmoritz The PR is not ready now. I will ping you once it's ready.

@pcmoritz
Copy link
Collaborator

Sounds great, thank you :)

@WoosukKwon WoosukKwon marked this pull request as ready for review June 3, 2024 18:22
@WoosukKwon
Copy link
Collaborator Author

@pcmoritz This PR is ready now. Sorry for the delay.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Jun 3, 2024

One small gotcha I was running into while trying this out is that currently fp8 can't be benchmarked with an FP16 checkpoint, e.g.

python benchmark_moe.py --dtype fp8

errors out since mistralai/Mixtral-8x7B-Instruct-v0.1 is FP16. I think what we should do here is

diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 6796ea401..3f3005e20 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -46,6 +46,8 @@ def benchmark_config(
         w2_scale = torch.randn(num_experts, dtype=torch.float32)
         a1_scale = torch.randn(1, dtype=torch.float32)
         a2_scale = torch.randn(1, dtype=torch.float32)
+        w1 = w1.to(torch.float8_e4m3fn)
+        w2 = w2.to(torch.float8_e4m3fn)
 
     input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
 

since FP8 checkpoints are not widely available yet and also for vLLM FP8 we support running FP16 checkpoints in FP8 :)

@WoosukKwon
Copy link
Collaborator Author

@pcmoritz I addressed your comments. PTAL.

@WoosukKwon WoosukKwon requested a review from pcmoritz June 4, 2024 02:47
Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've been using the new script to do some tuning for FP8 and it works like a charm, thanks a lot for improving it -- I'll open a PR with the new configs shortly after I have tested the configs!

Btw, in order to get progress bars, I've been using this modification:

from ray.experimental.tqdm_ray import tqdm

and then where we iterate over the configs:

for config in tqdm(search_space):

This will make sure to print progress bars without messing up stdout and it works like this: https://docs.ray.io/en/latest/ray-observability/user-guides/configure-logging.html#distributed-progress-bars-tqdm

Feel free to add it (don't worry it is currently in the experimental namespace -- I think it is one of the APIs that should be stabilized and I'll look into that).

@WoosukKwon
Copy link
Collaborator Author

@pcmoritz Ray tqdm is really cool! I actually wanted to have exactly the same feature. Happy to add that!

@WoosukKwon WoosukKwon merged commit 3a434b0 into main Jun 4, 2024
18 of 23 checks passed
@WoosukKwon WoosukKwon deleted the bench-moe branch June 4, 2024 03:07
blinkbear pushed a commit to blinkbear/vllm that referenced this pull request Jun 6, 2024
pcmoritz pushed a commit that referenced this pull request Jun 13, 2024
Tune Qwen2-57B-A14B configs based on #4921

Throughput Performance
command: python benchmarks/benchmark_throughput.py --model=Qwen/Qwen2-57B-A14B-Instruct --input-len 1000 --output-len 50 -tp 2

A100 GPU

benchmark	no config	w/ PR
tp=2	10.53 requests/s, 11058.17 tokens/s	12.47 requests/s, 13088.57 tokens/s
tp=4	17.77 requests/s, 18662.95 tokens/s	20.20 requests/s, 21212.32 tokens/s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants