A JAX transform for LoRA-fying functions #15840
Replies: 4 comments 16 replies
-
This is really cool, thanks for sharing! It's a great idea, and an amazing package name too 😁 I looked through the source code – the jaxpr interpreter approach looks really solid. I have some very minor comments about how I'd structure things to make it more easily extensible (e.g. keep all the Do you see this as kind of a one-off experiment, or something that you're hoping to put more time & development effort into? If it's the latter, I'd be happy to do a deeper code review if you'd find it helpful. Thanks again for sharing! |
Beta Was this translation helpful? Give feedback.
-
I've just want to give you a different question. I also implemented this features, but the former is simple to implement and without modifying original code, but I am suspicious about its efficiency. Based on the above multiplication analysis, when in_features * out_features < batch_size * (in_features + out_features), the former implementation is more efficient than latter even though the analysis is simple. Let in_features = out_features = 4096 as usually billion-scale transformers take, when 2048 < batch_size, the former will operate better.. Also, for the addition analysis, batch size has to be more than 4096. I just want to know this kind of analysis works in practice and want to share your experience if you implement both approaches. Thank you |
Beta Was this translation helpful? Give feedback.
-
Hey @davisyoshida, I was in the early stages of working on a library with exactly the same goal in mind 😅 But I am a Jax noob so your implementation is leagues ahead of mine (and the library name is much cooler!) so I think I'll move on to do something else. Super cool stuff and I'll study the code closely! Would you say this library is ready for general use? Or are there still a few sharp edges? The API shown in the README looks pretty solid. |
Beta Was this translation helpful? Give feedback.
-
Hi @davisyoshida , Thanks for sharing the In the LORA paper, the author mentioned:
I read through your implementation, but I could not figure out how you avoid storing optimizer states for the frozen parameters. I think your implementation is focusing on transforming the original model. However, I still needed help understanding how the optimization part works in your examples. |
Beta Was this translation helpful? Give feedback.
-
I wrote a transformation to automate using LoRA for JAX models: Lorax (I didn't only do this because of the naming opportunity)
LoRA basically replaces products like
Wx
with(W + BA)x
whereA
andB
are skinny, allowing you to save memory by not updatingW
. Lorax also supports some convs and gathers in addition to the basic matmul. Anytime there's an op which Lorax doesn't know how to handle it will raise a warning and just directly calculate the value ofW + BA
.Minimal example:
After, you can train by differentiating a loss function w.r.t.
trainable_params
only.I tested with my personal haiku models while writing this, and have an example using it with a HuggingFace Flax model as well.
I'm generally open to any feedback, since I definitely felt like I was fumbling around a bit getting this working.
Beta Was this translation helpful? Give feedback.
All reactions