-
Notifications
You must be signed in to change notification settings - Fork 909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding-support-for-mamba2 #1009
base: main
Are you sure you want to change the base?
adding-support-for-mamba2 #1009
Conversation
…niz-Guelmez/mlx-examples into adding-support-for-mamba2
…t MambaMixer block pass)
…niz-Guelmez/mlx-examples into adding-support-for-mamba2
…niz-Guelmez/mlx-examples into adding-support-for-mamba2
Codestral Mamba and other models rely on the Mamba2 architecture. Hopefully we can get this soon. |
How is it going here? Still very slow? |
Unfortunately Yes, I did look into the transformers implementation and rewrote the slow working Mamba2Mixer class, I haven’t got time to continue working on it, but will continue in the weekend. |
… is generateing gibberish
…t still only one input Token and outputs gibberish
…s still a litle slow: 0.222 tokens-per-sec
@awni I finally got it to work!Inference:python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello" --max-tokens 22 --ignore-chat-templat
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 65948.18it/s]
==========
Prompt: hello
, I am a little girl, I am a little girl, I am a little girl, I am a
==========
Prompt: 1 tokens, 7.499 tokens-per-sec
Generation: 22 tokens, 28.258 tokens-per-sec
Peak memory: 0.454 GB python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello world" --max-tokens 22 --ignore-chat-templat
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 55043.36it/s]
==========
Prompt: hello world
hello world
hello world
hello world
hello world
hello world
hello world
hello world
==========
Prompt: 2 tokens, 5.552 tokens-per-sec
Generation: 22 tokens, 24.904 tokens-per-sec
Peak memory: 0.454 GB Trainingpython -m mlx_lm.lora \ (adding-support-for-mamba2|-1)
--model rokyang/mamba2-130m-hf \
--train \
--data /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/data_tyni \
--iters 5 \
--batch-size 1 \
--num-layers 1 \
--val-batches 1 \
--steps-per-report 1 \
--adapter-path /Users/gokdenizgulmez/Desktop/mamba2-pretrain \
--max-seq-length 12
Loading pretrained model
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 87381.33it/s]
Loading datasets
Training
Trainable parameters: 0.956% (1.233M/128.988M)
Starting training..., iters: 5
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1508 will be truncated to 12. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1250 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 1: Val loss 7.408, Val took 1.578s
Iter 1: Train loss 7.408, Learning Rate 1.000e-05, It/sec 0.405, Tokens/sec 4.450, Trained Tokens 11, Peak mem 2.173 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1692 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 2: Train loss 7.275, Learning Rate 1.000e-05, It/sec 2.110, Tokens/sec 23.212, Trained Tokens 22, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1397 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 3: Train loss 7.093, Learning Rate 1.000e-05, It/sec 2.694, Tokens/sec 29.637, Trained Tokens 33, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1238 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 4: Train loss 6.880, Learning Rate 1.000e-05, It/sec 2.803, Tokens/sec 30.829, Trained Tokens 44, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1265 will be truncated to 12. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 802 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 5: Val loss 6.641, Val took 0.175s
Iter 5: Train loss 6.641, Learning Rate 1.000e-05, It/sec 2.754, Tokens/sec 30.298, Trained Tokens 55, Peak mem 2.189 GB
Saved final weights to /Users/gokdenizgulmez/Desktop/mamba2-pretrain/adapters.safetensors. |
Very nice!! What's a good model to test with? The one you are using doesn't look like it generates high-quality responses. |
Mamba Codestral or one of the larger base Mamba2 models. |
I tried running codestral and it crashed with a weight size mismatch error:
Looks like the weight shape is not computed correctly for that model? This is what I ran for reference:
|
Ahh ok, yea I didn't try Codestral, the model I used is the safetensor convert from the OG states-space account called rokyang/mamba2-130m-hf, I'll look into the Codestral shape problem later this day. |
…ns from mamba2.py
No description provided.