Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlx_lm): support batch input in generate() #948

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def min_p_sampling(
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.

temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token(s) selected based on the min-p criterion.
Shape: same as logits, but with the last dimension having size 1.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
Expand All @@ -39,14 +42,14 @@ def min_p_sampling(
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605

# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
probs = mx.softmax(logits / temperature, axis=-1)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
sorted_indices = mx.argsort(-logits)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)

# Top probability
top_probs = probs[..., sorted_indices[0]]
top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1)

# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
Expand All @@ -58,43 +61,54 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)

# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
return sorted_indices[sorted_token]
# Return sampled token(s)
sampled_indices = mx.random.categorical(mx.log(selected_probs))
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
)
return tokens.squeeze(-1)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def top_p_sampling(
logits: mx.array, top_p: float, temperature: float, axis: int = -1
) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.

Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
axis: The axis along which to apply top-p sampling.
Returns:
token selected based on the top-p criterion.
token(s) selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Apply temperature and compute softmax
probs = mx.softmax(logits / temperature, axis=axis)

# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
# Sort probs in descending order
sorted_indices = mx.argsort(-probs, axis=axis)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)

cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# Compute cumulative probabilities
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)

# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
)
# Create a mask for probs above the threshold
mask = cumulative_probs <= top_p

# Apply the mask to the sorted probabilities
masked_probs = sorted_probs * mask

sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
# Sample from the normalized probabilities
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)

# Gather the original token indices
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
)

return token
return tokens.squeeze(axis)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
Expand Down
7 changes: 5 additions & 2 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def handle_completion(
top_tokens = []
for (token, logprobs), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
Expand All @@ -420,6 +420,8 @@ def handle_completion(
),
range(self.max_tokens),
):
token = token.item()
logprobs = logprobs.squeeze(0)
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down Expand Up @@ -497,7 +499,7 @@ def handle_stream(

for (token, _), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
Expand All @@ -506,6 +508,7 @@ def handle_stream(
),
range(self.max_tokens),
):
token = token.item()
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down
Loading