From 12c6066c977a6a951cbd811cacc35f9cf81ecfbc Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Wed, 21 Aug 2024 00:46:19 +0800 Subject: [PATCH] feat: support batch input in `generate()` The `prompt` argument can now be either a `str` or `list[str]`. The change to `generate()` is backwards-compatible. The changes to `generate_step()`, `top_p_sampling()`, and `min_p_sampling()` are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred. --- llms/mlx_lm/sample_utils.py | 64 +++++++++------ llms/mlx_lm/server.py | 7 +- llms/mlx_lm/utils.py | 157 ++++++++++++++++++++++-------------- 3 files changed, 139 insertions(+), 89 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fac..1b403b393 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -26,7 +26,10 @@ def min_p_sampling( 0.99-0.8 range. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered. Default: ``1``. - + temperature: Temperature parameter for softmax distribution reshaping. + Returns: + token(s) selected based on the min-p criterion. + Shape: same as logits, but with the last dimension having size 1. """ if not (0 <= min_p <= 1.0): raise ValueError( @@ -39,14 +42,14 @@ def min_p_sampling( # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + probs = mx.softmax(logits / temperature, axis=-1) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logits) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) # Top probability - top_probs = probs[..., sorted_indices[0]] + top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1) # Calculate the min_p threshold scaled_min_p = min_p * top_probs @@ -58,13 +61,18 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) - # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) - return sorted_indices[sorted_token] + # Return sampled token(s) + sampled_indices = mx.random.categorical(mx.log(selected_probs)) + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1 + ) + return tokens.squeeze(-1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: +def top_p_sampling( + logits: mx.array, top_p: float, temperature: float, axis: int = -1 +) -> mx.array: """ Apply top-p (nucleus) sampling to logits. @@ -72,29 +80,35 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. temperature: Temperature parameter for softmax distribution reshaping. + axis: The axis along which to apply top-p sampling. Returns: - token selected based on the top-p criterion. + token(s) selected based on the top-p criterion. """ - # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits * (1 / temperature), axis=-1) + # Apply temperature and compute softmax + probs = mx.softmax(logits / temperature, axis=axis) - # sort probs in ascending order - sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices.squeeze(0)] + # Sort probs in descending order + sorted_indices = mx.argsort(-probs, axis=axis) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis) - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + # Compute cumulative probabilities + cumulative_probs = mx.cumsum(sorted_probs, axis=axis) - # select tokens with cumulative probs below threshold - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - 0, - ) + # Create a mask for probs above the threshold + mask = cumulative_probs <= top_p + + # Apply the mask to the sorted probabilities + masked_probs = sorted_probs * mask - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] + # Sample from the normalized probabilities + sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis) + + # Gather the original token indices + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis + ) - return token + return tokens.squeeze(axis) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 79ac18361..28dd20495 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -410,7 +410,7 @@ def handle_completion( top_tokens = [] for (token, logprobs), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -420,6 +420,8 @@ def handle_completion( ), range(self.max_tokens), ): + token = token.item() + logprobs = logprobs.squeeze() detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) @@ -497,7 +499,7 @@ def handle_stream( for (token, _), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -506,6 +508,7 @@ def handle_stream( ), range(self.max_tokens), ): + token = token.item() detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 441967667..a13b68d44 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -9,7 +9,7 @@ import time from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union import mlx.core as mx import mlx.nn as nn @@ -117,17 +117,17 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f logits (mx.array): Logits with repetition penalty applied to generated tokens. """ if len(generated_tokens) > 0: - indices = mx.array([token for token in generated_tokens]) - selected_logits = logits[:, indices] + indices = generated_tokens + selected_logits = mx.take_along_axis(logits, indices, axis=-1) selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty ) - logits[:, indices] = selected_logits + logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits return logits def generate_step( - prompt: mx.array, + prompts: mx.array, model: nn.Module, temp: float = 0.0, repetition_penalty: Optional[float] = None, @@ -143,7 +143,7 @@ def generate_step( A generator producing token ids based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. + prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. @@ -164,27 +164,33 @@ def generate_step( Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + one token and a vector of log probabilities per prompt. + Shapes: ``(bs, 1), (bs, vocab_size)``. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: + if prompts.ndim != 2: + raise ValueError( + f"Shape of prompts should be (bs, seq_len), got {prompts.shape}" + ) + + def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: if logit_bias: indices = mx.array(list(logit_bias.keys())) values = mx.array(list(logit_bias.values())) logits[:, indices] += values - logprobs = logits - mx.logsumexp(logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) if temp == 0: - token = mx.argmax(logits, axis=-1) + tokens = mx.argmax(logits, axis=-1) else: if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) + tokens = top_p_sampling(logits, top_p, temp) elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) + tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) else: - token = categorical_sampling(logits, temp) + tokens = categorical_sampling(logits, temp) - return token, logprobs + return mx.expand_dims(tokens, axis=-1), logprobs if repetition_penalty and ( repetition_penalty < 0 or not isinstance(repetition_penalty, float) @@ -193,7 +199,7 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) - y = prompt + y = prompts if hasattr(model, "make_cache"): cache = model.make_cache() else: @@ -210,14 +216,14 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: else: cache = [KVCache(model.head_dim, n) for n in kv_heads] - repetition_context = prompt.tolist() + repetition_context = prompts if repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] + repetition_context = repetition_context[:, -repetition_context_size:] def _step(y): nonlocal repetition_context - logits = model(y[None], cache=cache) + logits = model(y, cache=cache) logits = logits[:, -1, :] if repetition_penalty: @@ -225,27 +231,27 @@ def _step(y): logits, repetition_context, repetition_penalty ) y, logprobs = sample(logits) - repetition_context.append(y.item()) + repetition_context = mx.concatenate([repetition_context, y], axis=-1) else: y, logprobs = sample(logits) if repetition_context_size: - if len(repetition_context) > repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] - return y, logprobs.squeeze(0) + if repetition_context.shape[1] > repetition_context_size: + repetition_context = repetition_context[:, -repetition_context_size:] + return y, logprobs - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=cache) + while y.shape[1] > prefill_step_size: + model(y[:, :prefill_step_size], cache=cache) mx.eval([c.state for c in cache]) - y = y[prefill_step_size:] + y = y[:, prefill_step_size:] y, logprobs = _step(y) - mx.async_eval(y) while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y) - yield y.item(), logprobs + mx.eval(y) + yield y, logprobs y, logprobs = next_y, next_logprobs @@ -277,9 +283,10 @@ def stream_generate( detokenizer.reset() for (token, _), n in zip( - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens[None], model, **kwargs), range(max_tokens), ): + token = token.item() if token == tokenizer.eos_token_id: break detokenizer.add_token(token) @@ -294,19 +301,19 @@ def stream_generate( def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[str]], max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Union[str, List[str]]: """ Generate a complete response from the model. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (str): The string prompt. + prompts (str): The string prompt(s). max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. @@ -315,56 +322,82 @@ def generate( kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. """ + is_batch = isinstance(prompt, list) if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - if verbose: - print("=" * 10) - print("Prompt:", prompt) - - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer + if is_batch: + tokenizer._tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer._tokenizer.pad_token = tokenizer.eos_token + tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id + prompt_tokens = mx.array( + tokenizer._tokenizer(prompt, padding=True)["input_ids"] + ) + output_toks = [] + else: + prompt_tokens = mx.array(tokenizer.encode(prompt))[None] + detokenizer = tokenizer.detokenizer + detokenizer.reset() + if verbose: + print("=" * 10) + print("Prompt:", prompt) tic = time.perf_counter() - detokenizer.reset() - for (token, logprobs), n in zip( + for (tokens, logprobs), n in zip( generate_step(prompt_tokens, model, **kwargs), range(max_tokens), ): if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() - if token == tokenizer.eos_token_id: + if (tokens == tokenizer.eos_token_id).all(): break - detokenizer.add_token(token) - - if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() + if is_batch: + output_toks.append(tokens) + else: + token = tokens.item() + logprobs = logprobs.squeeze() + detokenizer.add_token(token) + if verbose: + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + else: + print(detokenizer.last_segment, end="", flush=True) + + if is_batch: + output_toks = mx.concatenate(output_toks, axis=1) + token_count = output_toks.size + response = [ + response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0] + for response in tokenizer.batch_decode(output_toks.tolist()) + ] + else: + token_count = n + detokenizer.finalize() + response = detokenizer.text if verbose: gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: + if token_count <= 0: print("No tokens generated for this prompt") - return + if is_batch: + for p, resp in zip(prompt, response): + print("=" * 10) + print("Prompt:", p) + print(resp) + else: + print(detokenizer.last_segment, flush=True) prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") + gen_tps = token_count / gen_time + print("=" * 10) + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") - return detokenizer.text + return response def load_config(model_path: Path) -> dict: