Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt Lookup Decoding - merged under Speculative example #237

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions llms/speculative_decoding/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.npz
21 changes: 19 additions & 2 deletions llms/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ python convert.py --model t5-small
You can run with the default arguments:

```
python main.py
python speculative.py
```

To see a full list of options use:
```
python main.py --help
python speculative.py --help
```

### Notes
Expand All @@ -64,3 +64,20 @@ draft tokens at the expense of more large model evaluations.
Decoding](https://arxiv.org/abs/2211.17192)
[^2]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).

## Prompt Lookup Decoding
When speculative decoding works, it significantly accelerates inference. However, selecting an appropriate draft model can be challenging. Prompt lookup decoding[^3] modifies speculative decoding by substituting the draft model with a straightforward sliding window search across the prompt. This alteration eliminates the need for a draft model while offering comparable speed enhancements, particularly when applied to the right task. Prompt lookup decoding excels in *input-grounded* tasks like summarization, document Q/A, and code editing, where there's substantial overlap between input and output.

## Run
[Setup](#setup) is the same as for Speculative Decoding. You can the run with default arguments:
LeonEricsson marked this conversation as resolved.
Show resolved Hide resolved

```
python prompt_lookup.py
```

To see a full list of options use:
```
python prompt_lookup --help
```

[^3] Check out the [original implementation](https://github.com/apoorvumang/prompt-lookup-decoding).
183 changes: 174 additions & 9 deletions llms/speculative_decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
model: Model,
draft_model: Model,
tokenizer: str,
color: bool,
num_draft: int = 5,
delta: float = 0.0,
):
Expand All @@ -54,6 +55,7 @@ def __init__(
self.draft_model = draft_model
self.num_draft = num_draft
self.delta = delta
self.color = color

def _generate(
self,
Expand Down Expand Up @@ -91,6 +93,7 @@ def generate(
print()
self.model.reset_cache()

# Accept / Reject criteria (see Section 2.3 https://arxiv.org/pdf/2211.17192.pdf)
LeonEricsson marked this conversation as resolved.
Show resolved Hide resolved
def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
# accept_toks = mx.argmax(model_logits, axis=-1) == draft_tokens
model_probs = mx.take_along_axis(
Expand Down Expand Up @@ -120,9 +123,9 @@ def sample(logits):
tokens = mx.array([self.tokenizer.decoder_start_id])

n_steps = 0
ntoks = 0
n_generated = 0
n_accepted = 0
n_draft = 0
n_drafted = 0

outputs = []
skip = 0
Expand All @@ -133,7 +136,7 @@ def sample(logits):
draft_tokens = []
draft_probs = []
for _, (t, p) in zip(
range(ntoks, min(ntoks + self.num_draft, max_tokens)),
range(n_generated, min(n_generated + self.num_draft, max_tokens)),
self._generate(draft_inputs, draft_memory, draft=True),
):
draft_tokens.append(t)
Expand Down Expand Up @@ -163,7 +166,7 @@ def sample(logits):
)

n_accepted += num_to_accept
n_draft += draft_tokens.size
n_drafted += draft_tokens.size

# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
Expand All @@ -172,17 +175,35 @@ def sample(logits):

n_steps += 1

truncated = False
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
truncated = True
break
outputs.append(t)
ntoks += 1
n_generated += 1

str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)

if self.color and not truncated:
model_token = len(self.tokenizer.decode(outputs[-1]))
print(
"\033[34m" + str_output[skip:-model_token] + "\033[30m",
end="",
)
print(str_output[-model_token:], end="", flush=True)
elif self.color and truncated:
if truncated:
print(
"\033[34m" + str_output[skip:] + "\033[30m",
end="",
)
else:
print(str_output[skip:], end="", flush=True)

skip = len(str_output)

if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
inputs = draft_inputs[-1:]
Expand All @@ -192,4 +213,148 @@ def sample(logits):

self.model.reset_cache()
self.draft_model.reset_cache()
return {"n_accepted": n_accepted, "n_draft": n_draft, "n_steps": n_steps}
return {"n_accepted": n_accepted, "n_draft": n_drafted, "n_steps": n_steps}


########################################################

class PromptLookupDecoder:
def __init__(
self,
model: Model,
tokenizer: str,
n_draft: int,
ngram_max: int,
ngram_min: int,
temp: float,
seed: int,
color: bool,
):
self.model = model
self.tokenizer = Tokenizer(tokenizer)
self.n_draft = n_draft
self.ngram_max = ngram_max
self.ngram_min = ngram_min
self.temp = temp
self.seed = seed
self.color = color

@staticmethod
def window_compare(start_idx, input_ids, ngram):
return input_ids[mx.arange(ngram.size) + start_idx] == ngram

def generate_draft(self, input_ids):
for ngram_size in range(self.ngram_max, self.ngram_min - 1, -1):
ngram = input_ids[-ngram_size:]

start_indices = mx.arange(0, input_ids.size - self.ngram_max)
matches = self.vmap_compare(start_indices, input_ids, ngram)

# check for full `ngram` matches
matches = matches.all(axis=1)
# get idx of first match; 0 if no match
idx_match = matches.argmax()

# double check idx
if matches[idx_match]:
start_idx = idx_match.item() + ngram_size
end_idx = start_idx + self.n_draft
return input_ids[start_idx:end_idx]

return mx.array([], dtype=mx.uint32)

def prompt_lookup(
self,
prompt: str,
max_tokens: int,
):
def sample(logits):
if self.temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / self.temp))

# used in draft generation
self.vmap_compare = mx.vmap(self.window_compare, in_axes=(0, None, None))

prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)[None]
memory = self.model.encode(prompt)

history = prompt.squeeze(0)[
:-1
] # remove eos token from prompt lookup search space

n_steps = 0
n_generated = 0
n_accepted = 0
n_drafted = 0

outputs = []
skip = 0
inputs = mx.array([self.tokenizer.decoder_start_id])
while True:
# For each decoding step: generate n_draft tokens by searching the prompt
draft_tokens = self.generate_draft(history)

# Verify draft tokens with the last verified token
verify_tokens = mx.concatenate([inputs, draft_tokens])
logits = self.model.decode(verify_tokens[None], memory)

# Only keep samples that match the draft:
# draft tokens aren't sampled - hence no accept / reject critera
sampled = sample(logits).squeeze(0)
equal_toks = sampled[:-1] == draft_tokens
num_to_accept = (equal_toks.tolist() + [False]).index(False)
new_tokens = sampled[
: max(1, num_to_accept + 1)
] # accepted draft tokens + next token from main model

n_accepted += num_to_accept
n_drafted += draft_tokens.size

# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.model.truncate_cache(n - new_tokens.size + 1)

n_steps += 1

truncated = False
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
truncated = True
break
outputs.append(t)
n_generated += 1

str_output = self.tokenizer.decode(outputs)

if self.color and not truncated:
model_token = len(self.tokenizer.decode(outputs[-1]))
print(
"\033[34m" + str_output[skip:-model_token] + "\033[30m",
end="",
)
print(str_output[-model_token:], end="", flush=True)
elif self.color and truncated:
if truncated:
print(
"\033[34m" + str_output[skip:] + "\033[30m",
end="",
)
else:
print(str_output[skip:], end="", flush=True)

skip = len(str_output)

if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break

history = mx.concatenate([history, new_tokens])
inputs = history[-1:]

print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
print()

self.model.reset_cache()

return {"n_accepted": n_accepted, "n_draft": n_drafted, "n_steps": n_steps}
12 changes: 11 additions & 1 deletion llms/speculative_decoding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import AutoTokenizer, T5Config
from transformers import T5Config


def _relative_position_bucket(
Expand Down Expand Up @@ -339,3 +339,13 @@ def __call__(
decoder_inputs: mx.array,
):
return self.decode(decoder_inputs, self.encode(inputs))[0]


def load_model(model_name: str):
config = T5Config.from_pretrained(model_name)
model = Model(config)
weights = mx.load(f"{model_name}.npz")
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model
88 changes: 88 additions & 0 deletions llms/speculative_decoding/prompt_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
import time

import mlx.core as mx
from decoder import PromptLookupDecoder
from model import load_model


def main(args):
mx.random.seed(args.seed)

lookup_decoder = PromptLookupDecoder(
model=load_model(args.model_name),
tokenizer=args.model_name,
n_draft=args.n_draft,
ngram_max=args.ngram_max,
ngram_min=args.ngram_min,
temp=args.temp,
seed=args.seed,
color=args.color,
)

tic = time.time()
print(args.prompt)

stats = lookup_decoder.prompt_lookup(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")

toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prompt Lookup Decoding")

parser.add_argument(
"--n-draft",
type=int,
default=10,
help="Number of draft tokens to generate upon prompt lookup match",
)
parser.add_argument(
"--model-name",
help="Name of the model.",
default="t5-base",
)

parser.add_argument(
"--prompt",
help="The prompt processed by the model.",
default="Repeat the following sentence five times in a row: 'The quick brown fox jumped over the fence.'",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--ngram-max",
type=int,
default=3,
help="Maximum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--ngram-min",
type=int,
default=1,
help="Minimum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"--color", type=bool, default=False, help="Color the accepted draft tokens"
LeonEricsson marked this conversation as resolved.
Show resolved Hide resolved
)

args = parser.parse_args()

main(args)
Loading