-
Notifications
You must be signed in to change notification settings - Fork 909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(mlx_lm): support batch input in generate()
#948
base: main
Are you sure you want to change the base?
Conversation
generate()
generate()
7332759
to
332a713
Compare
332a713
to
12c6066
Compare
12c6066
to
ef92993
Compare
Kind of interesting: for quantized models, the throughput is doesn't go up a lot between small bs (bs=1,2,3,4), but then it starts to go up a lot at higher bs, which is the opposite of what I expected intuitively. For unquantized models the throughput does goes up between small bs. I observe the same on @willccbb's original repo. |
The `prompt` argument can now be either a `str` or `list[str]`. The change to `generate()` is backwards-compatible. The changes to `generate_step()`, `top_p_sampling()`, and `min_p_sampling()` are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred.
5105b31
to
2caa832
Compare
llms/mlx_lm/utils.py
Outdated
prompt_tokens = mx.array(tokenizer.encode(prompt)) | ||
detokenizer = tokenizer.detokenizer | ||
if is_batch: | ||
tokenizer._tokenizer.padding_side = "left" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that we left pad shorter prompts here which makes sense. But one thing that I'm wondering is how this is handled in the causal models if at all? Shouldn't the causal mask take into account the padding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't handle it, generation seems OK without it but indeed to be correct I should consume the tokenizer._tokenizer(prompt, padding=True)["attention_mask"]
. To do this I would need to update our model APIs to have attention_mask
as an input similar to how transformers has model.generate
taking attention_mask
. Probably this is involves hitting every file in models/
. Should mostly be copy/paste though. I can look into it.
I think it makes sense to minimize the complexity to the Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:
Let me know what you think about the above. |
Makes sense to me, will implement.
Yes, this sounds straightforward enough.
I'll do a bit of thinking if there's an easy way to handle this, otherwise I'll remove that parameter in Will update when these changes are ready! |
@llllvvuu are you coming back to this? |
hey @awni , sorry for the delay, I'd been job hunting this month. I should be able to get back to this in ~a week |
No worries, just checking. I'll follow up in a week or so. |
bea0c4b
to
8fb82fe
Compare
308ad24
to
9ee726c
Compare
Just realised the attention mask has been mentioned in this PR, which is the reason I raised this issue #1044 |
The
prompt
argument can now be either astr
orlist[str]
.This is based on @willccbb's implementation at https://github.com/willccbb/mlx_parallm; I noticed that it aligned with the KVCache upgrades in #911.
The change to
generate()
is backwards-compatible.The changes to
generate_step()
,top_p_sampling()
, andmin_p_sampling()
are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred.