Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed inference example #890

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Distributed inference example #890

wants to merge 4 commits into from

Conversation

angeloskath
Copy link
Member

Simply distributed inference on top of ml-explore/mlx#1270 . Again a draft PR so we can iterate on the design. This communication will be very latency bound (probably impractical) so no need to be particularly excited yet.

@Blaizzy
Copy link
Contributor

Blaizzy commented Jul 15, 2024

Thanks @angeloskath!

This is very timely, as I was looking for such an example for a couple days.

@mzbac
Copy link
Contributor

mzbac commented Jul 16, 2024

Amazing, time to buy a second m2 ultra:p

@mzbac
Copy link
Contributor

mzbac commented Jul 18, 2024

@angeloskath, please correct me if I am wrong. By looking at the implementation, it seems like we are sharding vertically. For o_proj, we have to wait for all nodes to complete the forward pass before moving on to the next layer. This would create a bottleneck as the slowest node would slow down the entire process. Would it be better to shard by layers instead?

Edit:
I think I understand it now. It makes sense to shard the model across the same hardware using a fast connection that maximizes parallelization. This should be a good fit for the MOE. For dense models, maybe we have to do something similar to exo, shard it over layers and make the inference sequential.

@Blaizzy
Copy link
Contributor

Blaizzy commented Jul 18, 2024

I think you might be correct here @mzbac!

Tho I would like to also benchmark @angeloskath approach.

I have been researching this topic for weeks to support it on FastMLX. And according to the paper I read and Accelerate docs, layer group sharding is the best approach for distributed inference and training.

But requires every single node / machine to have quick access to model weights/shard on device.

@angeloskath
Copy link
Member Author

They are just different approaches really. Pipelining gives perfect scaling in throughput but not latency.

This means that if you are running evaluations or simply running batch generations, then it is perfect. But it will still take the same amount of time to see the first output. Basically for a single generation and assuming the model fits on one device it doesn't provide any speedup. Another way to say it is that the tokens per second per client are not sped up. The aggregate ones scale pretty much perfectly though.

The approach in this PR is called model parallelism or tensor parallelism. The goal is to reduce latency as well as throughput. However it depends heavily on the latency of the interconnect. So given ethernet this will probably not achieve speedups (we are looking into it). Indeed, we need to communicate 2 * num_layers times to produce a single output. On the other hand with pipelining we only communicate num_shards times per output. So if the communication latency is large pipelining may be a better approach.

@mzbac
Copy link
Contributor

mzbac commented Jul 19, 2024

They are just different approaches really. Pipelining gives perfect scaling in throughput but not latency.

This means that if you are running evaluations or simply running batch generations, then it is perfect. But it will still take the same amount of time to see the first output. Basically for a single generation and assuming the model fits on one device it doesn't provide any speedup. Another way to say it is that the tokens per second per client are not sped up. The aggregate ones scale pretty much perfectly though.

The approach in this PR is called model parallelism or tensor parallelism. The goal is to reduce latency as well as throughput. However it depends heavily on the latency of the interconnect. So given ethernet this will probably not achieve speedups (we are looking into it). Indeed, we need to communicate 2 * num_layers times to produce a single output. On the other hand with pipelining we only communicate num_shards times per output. So if the communication latency is large pipelining may be a better approach.

@angeloskath, thank you for the detailed explanation. I may try to get another M2 Ultra and test it via the Thunderbolt 4 connection :)

@fblissjr
Copy link

They are just different approaches really. Pipelining gives perfect scaling in throughput but not latency.

This means that if you are running evaluations or simply running batch generations, then it is perfect. But it will still take the same amount of time to see the first output. Basically for a single generation and assuming the model fits on one device it doesn't provide any speedup. Another way to say it is that the tokens per second per client are not sped up. The aggregate ones scale pretty much perfectly though.

The approach in this PR is called model parallelism or tensor parallelism. The goal is to reduce latency as well as throughput. However it depends heavily on the latency of the interconnect. So given ethernet this will probably not achieve speedups (we are looking into it). Indeed, we need to communicate 2 * num_layers times to produce a single output. On the other hand with pipelining we only communicate num_shards times per output. So if the communication latency is large pipelining may be a better approach.

IMO this is exactly what we need in the long run.

In the short term, the hype is around the 400B llama - but that will fade eventually. Latency optimization is what I think fits with the overall MLX ethos.

@mzbac
Copy link
Contributor

mzbac commented Jul 30, 2024

I tried clustering one M2 Ultra 192GB with another M2 Ultra 128GB, splitting the weights to 160GB and 67GB (not tensor parallelism) for llama3 405b. I got around 0.3 t/s, but I expected it to be closer to 1 or 2 t/s. I'm not sure if this is related to mlx or some system-level issue.

ps:
I tried to run sudo sysctl iogpu.disable_wired_collector=1 but I got the error sysctl: unknown oid 'iogpu.disable_wired_collector'. Maybe that could be a potential issue.

@Blaizzy
Copy link
Contributor

Blaizzy commented Jul 30, 2024

Was this over WiFi or thunderbolt 4 @mzbac ?

@mzbac
Copy link
Contributor

mzbac commented Jul 30, 2024

Was this over WiFi or thunderbolt 4 @mzbac ?

TB4, I did run some tests and I feel there may be a memory issue when the memory consumption reaches a certain limit by mlx causes the token per second to slow down to 0.x. I am not exactly sure what the issue is, but sharding across deepseek coder v2 4bit was working fine (60+ vram and up to 1xx ram cache).

@awni
Copy link
Member

awni commented Jul 30, 2024

Which OS are you on? A couple things that might help:

  1. Restart the machine(s)
  2. Upgrade to Sonoma (OS 15.0)
  3. Set some sysctls:
sudo sysctl iogpu.wired_limit_mb=200000
sudo sysctl iogpu.disable_wired_collector=1

The disable_wired_collector is OS 15.0+. With that combinations I was able to get DeepSeek Coder v2 large (236B params) running pretty fast on a single M2 Ultra.

@awni
Copy link
Member

awni commented Jul 30, 2024

one M2 Ultra 192GB with another M2 Ultra 128GB, splitting the weights to 160GB and 67GB

Maybe putting more on the 128GB machine will help also. Like 140 and 87 or something.

@mzbac
Copy link
Contributor

mzbac commented Jul 30, 2024

Which OS are you on? A couple things that might help:

  1. Restart the machine(s)
  2. Upgrade to Sonoma (OS 15.0)
  3. Set some sysctls:
sudo sysctl iogpu.wired_limit_mb=200000
sudo sysctl iogpu.disable_wired_collector=1

The disable_wired_collector is OS 15.0+. With that combinations I was able to get DeepSeek Coder v2 large (236B params) running pretty fast on a single M2 Ultra.

@awni Thanks for the pointers. I will try to upgrade macOS, currently, it's on version 14.5.

@mzbac
Copy link
Contributor

mzbac commented Jul 31, 2024

Just to share the update, upgrading to macOs 15.0 helped solve the memory issue, and now I am able to run 405B 4-bit around 3.4 t/s - not bad at all.

https://www.youtube.com/watch?v=_9vP7CS3TI4

@awni
Copy link
Member

awni commented Jul 31, 2024

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only faster from here 💪

@mzbac
Copy link
Contributor

mzbac commented Jul 31, 2024

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only uphill from here 💪

I added a bit more weight to the 128GB machine as you suggested in my layer sharding configuration:
Shard server (128gb machine): mlx-sharding-server --model Meta-Llama-3.1-405B-Instruct-4bit-mlx -s 70 -e 126
API server (192gb machine): mlx-sharding-api --model mlx_sharding/Meta-Llama-3.1-405B-Instruct-4bit-mlx -sl 0 -el 70 -s <tb4 ip>:49112 --host 0.0.0.0

@DamascusGit
Copy link

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only uphill from here 💪

I added a bit more weight to the 128GB machine as you suggested in my layer sharding configuration: Shard server (128gb machine): mlx-sharding-server --model Meta-Llama-3.1-405B-Instruct-4bit-mlx -s 70 -e 126 API server (192gb machine): mlx-sharding-api --model mlx_sharding/Meta-Llama-3.1-405B-Instruct-4bit-mlx -sl 0 -el 70 -s <tb4 ip>:49112 --host 0.0.0.0

any update to speed since? got my hands on two 192gbs and getting ready to run some tests over the weekend

@mzbac
Copy link
Contributor

mzbac commented Aug 17, 2024

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only uphill from here 💪

I added a bit more weight to the 128GB machine as you suggested in my layer sharding configuration: Shard server (128gb machine): mlx-sharding-server --model Meta-Llama-3.1-405B-Instruct-4bit-mlx -s 70 -e 126 API server (192gb machine): mlx-sharding-api --model mlx_sharding/Meta-Llama-3.1-405B-Instruct-4bit-mlx -sl 0 -el 70 -s <tb4 ip>:49112 --host 0.0.0.0

any update to speed since? got my hands on two 192gbs and getting ready to run some tests over the weekend

nothing in the mlx-sharding part. I am still waiting for MLX to support pipeline parallelism in MPI. Once that is supported, there may be some performance improvements compared to using gRPC.

@angeloskath angeloskath force-pushed the distributed-layers branch 5 times, most recently from 0f40077 to 9d7e80b Compare November 5, 2024 21:28
@Blaizzy
Copy link
Contributor

Blaizzy commented Nov 5, 2024

LFG 🚀🔥

@angeloskath angeloskath force-pushed the distributed-layers branch 9 times, most recently from a14db45 to 8e3d9f3 Compare November 6, 2024 01:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants