Skip to content

Commit

Permalink
feat: support batch input in generate()
Browse files Browse the repository at this point in the history
The `prompt` argument can now be either a `str` or `list[str]`.

The change to `generate()` is backwards-compatible.

The changes to `generate_step()`, `top_p_sampling()`, and
`min_p_sampling()` are backwards-incompatible in order to unify shapes;
this could be changed by adding a few if-statements, if preferred.
  • Loading branch information
llllvvuu committed Aug 21, 2024
1 parent 0164d20 commit 12c6066
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 89 deletions.
64 changes: 39 additions & 25 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def min_p_sampling(
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.
temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token(s) selected based on the min-p criterion.
Shape: same as logits, but with the last dimension having size 1.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
Expand All @@ -39,14 +42,14 @@ def min_p_sampling(
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605

# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
probs = mx.softmax(logits / temperature, axis=-1)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
sorted_indices = mx.argsort(-logits)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)

# Top probability
top_probs = probs[..., sorted_indices[0]]
top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1)

# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
Expand All @@ -58,43 +61,54 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)

# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
return sorted_indices[sorted_token]
# Return sampled token(s)
sampled_indices = mx.random.categorical(mx.log(selected_probs))
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
)
return tokens.squeeze(-1)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def top_p_sampling(
logits: mx.array, top_p: float, temperature: float, axis: int = -1
) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
axis: The axis along which to apply top-p sampling.
Returns:
token selected based on the top-p criterion.
token(s) selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Apply temperature and compute softmax
probs = mx.softmax(logits / temperature, axis=axis)

# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
# Sort probs in descending order
sorted_indices = mx.argsort(-probs, axis=axis)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)

cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# Compute cumulative probabilities
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)

# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
)
# Create a mask for probs above the threshold
mask = cumulative_probs <= top_p

# Apply the mask to the sorted probabilities
masked_probs = sorted_probs * mask

sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
# Sample from the normalized probabilities
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)

# Gather the original token indices
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
)

return token
return tokens.squeeze(axis)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
Expand Down
7 changes: 5 additions & 2 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def handle_completion(
top_tokens = []
for (token, logprobs), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
Expand All @@ -420,6 +420,8 @@ def handle_completion(
),
range(self.max_tokens),
):
token = token.item()
logprobs = logprobs.squeeze()
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down Expand Up @@ -497,7 +499,7 @@ def handle_stream(

for (token, _), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
Expand All @@ -506,6 +508,7 @@ def handle_stream(
),
range(self.max_tokens),
):
token = token.item()
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down
Loading

0 comments on commit 12c6066

Please sign in to comment.