Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: GPU support on apple silicon #196

Open
QuantitativeVirology opened this issue Nov 16, 2023 · 5 comments
Open

Question: GPU support on apple silicon #196

QuantitativeVirology opened this issue Nov 16, 2023 · 5 comments

Comments

@QuantitativeVirology
Copy link

I am trying to use the GPUs on my M2 Max MacBook Pro for ColabFold predictions. I installed jax metal 0.0.4. This gave me an incompatibility with haiku. After upgrading haiku to 0.0.9, I now get an error: "python3.10[7216:52122] -[MPSGraphExecutable initWithMLIRBytecode:executableDescriptor:]: unrecognized selector sent to instance 0x2b5871a90"
as well as "Could not predict FFAR2_HUMAN. Not Enough GPU memory? Caught an unknown exception!"

Has anyone got localcolabfold running using Apple silicon GPUs?

Thanks!

Jens

CSSB Hamburg

@philipptrepte
Copy link

Hi Jens,

I did just get it to work on a M3 Max MacBook Pro.
I followed the localcolabfold installation instructions for M1 chips but also installed jax following the Metal installation instructions with the only change to install it within the colabfold-conda environment:

conda activate /your_path_to/localcolabfold/colabfold-conda
python -m pip install -U pip
python -m pip install numpy wheel ml-dtypes==0.2.0
python -m pip install jax-metal

When running
python -c 'import jax; print(jax.numpy.arange(10))'
I get the following message:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-01-31 09:41:29.351393: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 128.00 GB
maxCacheSize: 48.00 GB

[0 1 2 3 4 5 6 7 8 9]

Similarly, when I run colabfold_batch input output from the activated colabfold-conda environment I get:

2024-01-31 09:44:33,393 Running colabfold 1.5.5 (a00ce1bcc477491d7693e3816d21ea3fc2cf40fd)

WARNING: You are welcome to use the default MSA server, however keep in mind that it's a
limited shared resource only capable of processing a few thousand MSAs per day. Please
submit jobs only from a single IP address. We reserve the right to limit access to the
server case-by-case when usage exceeds fair use. If you require more MSAs: You can
precompute all MSAs with colabfold_search or host your own API and pass it to --host-url

2024-01-31 09:44:33.410904: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 128.00 GB
maxCacheSize: 48.00 GB

2024-01-31 09:44:34,990 Running on GPU

Hope this will also work for you.

Best
Philipp

@QuantitativeVirology
Copy link
Author

Thanks Philipp!
Jens

@philipptrepte
Copy link

Unfortunately the resulting "proteins" do not look like a protein:
Screenshot 2024-02-01 at 06 23 20

In comparison when using localcolabfold with a Nvidia Tesla A100 GPU it looks fine
Screenshot 2024-02-01 at 06 19 14

@QuantitativeVirology
Copy link
Author

That's a new one! lol
Interesting that it does not bug out but gives weird coordinates instead. Anything in the log file?

@philipptrepte
Copy link

At least no errors, but pTMs and ipTMs of 0.99 or similar compared to pTMs/ipTMs of around 0.4 when using the Tesla A100. Also the runtime was 2023.9s vs 65.8s for 159 amino acids

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants