Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding-support-for-mamba2 #1009

Open
wants to merge 49 commits into
base: main
Choose a base branch
from

Conversation

Goekdeniz-Guelmez
Copy link
Contributor

No description provided.

@Goekdeniz-Guelmez Goekdeniz-Guelmez changed the title Create mamba2.py adding-support-for-mamba2 Oct 2, 2024
@hg0428
Copy link

hg0428 commented Oct 22, 2024

Codestral Mamba and other models rely on the Mamba2 architecture. Hopefully we can get this soon.

@awni
Copy link
Member

awni commented Nov 4, 2024

How is it going here? Still very slow?

@Goekdeniz-Guelmez
Copy link
Contributor Author

How is it going here? Still very slow?

Unfortunately Yes, I did look into the transformers implementation and rewrote the slow working Mamba2Mixer class, I haven’t got time to continue working on it, but will continue in the weekend.

@Goekdeniz-Guelmez
Copy link
Contributor Author

@awni I finally got it to work!

Inference:

python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello" --max-tokens 22 --ignore-chat-templat
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 65948.18it/s]
==========
Prompt: hello
, I am a little girl, I am a little girl, I am a little girl, I am a
==========
Prompt: 1 tokens, 7.499 tokens-per-sec
Generation: 22 tokens, 28.258 tokens-per-sec
Peak memory: 0.454 GB
python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello world" --max-tokens 22 --ignore-chat-templat
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 55043.36it/s]
==========
Prompt: hello world


hello world
hello world
hello world
hello world
hello world
hello world
hello world
==========
Prompt: 2 tokens, 5.552 tokens-per-sec
Generation: 22 tokens, 24.904 tokens-per-sec
Peak memory: 0.454 GB

Training

python -m mlx_lm.lora \                                                           (adding-support-for-mamba2|-1)
    --model rokyang/mamba2-130m-hf \
    --train \
    --data /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/data_tyni \
    --iters 5 \
    --batch-size 1 \
    --num-layers 1 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/gokdenizgulmez/Desktop/mamba2-pretrain \
    --max-seq-length 12
Loading pretrained model
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 87381.33it/s]
Loading datasets
Training
Trainable parameters: 0.956% (1.233M/128.988M)
Starting training..., iters: 5
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1508 will be truncated to 12. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1250 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 1: Val loss 7.408, Val took 1.578s
Iter 1: Train loss 7.408, Learning Rate 1.000e-05, It/sec 0.405, Tokens/sec 4.450, Trained Tokens 11, Peak mem 2.173 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1692 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 2: Train loss 7.275, Learning Rate 1.000e-05, It/sec 2.110, Tokens/sec 23.212, Trained Tokens 22, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1397 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 3: Train loss 7.093, Learning Rate 1.000e-05, It/sec 2.694, Tokens/sec 29.637, Trained Tokens 33, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1238 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 4: Train loss 6.880, Learning Rate 1.000e-05, It/sec 2.803, Tokens/sec 30.829, Trained Tokens 44, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1265 will be truncated to 12. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 802 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 5: Val loss 6.641, Val took 0.175s
Iter 5: Train loss 6.641, Learning Rate 1.000e-05, It/sec 2.754, Tokens/sec 30.298, Trained Tokens 55, Peak mem 2.189 GB
Saved final weights to /Users/gokdenizgulmez/Desktop/mamba2-pretrain/adapters.safetensors.

@Goekdeniz-Guelmez Goekdeniz-Guelmez marked this pull request as ready for review November 21, 2024 21:34
@awni
Copy link
Member

awni commented Nov 23, 2024

Very nice!! What's a good model to test with? The one you are using doesn't look like it generates high-quality responses.

@hg0428
Copy link

hg0428 commented Nov 23, 2024

Very nice!! What's a good model to test with? The one you are using doesn't look like it generates high-quality responses.

Mamba Codestral or one of the larger base Mamba2 models.

@awni
Copy link
Member

awni commented Nov 23, 2024

I tried running codestral and it crashed with a weight size mismatch error:

ValueError: Expected shape (16768, 4096) but received shape (18560, 4096) for parameter backbone.layers.0.mixer.in_proj.weight

Looks like the weight shape is not computed correctly for that model?

This is what I ran for reference:

mlx_lm.generate --model mistralai/Mamba-Codestral-7B-v0.1 --prompt "Write a quick sort in c++" -m 128

@Goekdeniz-Guelmez
Copy link
Contributor Author

Ahh ok, yea I didn't try Codestral, the model I used is the safetensor convert from the OG states-space account called rokyang/mamba2-130m-hf, I'll look into the Codestral shape problem later this day.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants