Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a fix for special added tokens #1163

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
76fac6e
Add a fix for special added tokens
mokeddembillel Dec 5, 2024
f5af472
[MLX LM] Fix f-string formatting in memory warning message (#1105)
skeetmtp Nov 13, 2024
6497a5d
Pass seed to sd img2img (#1114)
louen Nov 20, 2024
20e87fe
Fix format (#1115)
angeloskath Nov 21, 2024
54020c7
Tencent HunYuan MOE model (#1100)
awni Nov 23, 2024
ad9948b
Generation refactor: part 2 (#1099)
awni Nov 23, 2024
9960d47
Fix object property value in mlx_lm.server chat completions response …
kconner Nov 25, 2024
a0c11fe
Allow converting models from local directories (#1118)
remixer-dec Nov 25, 2024
47456f9
docs: update stream_generate return type annotation (#1121)
madroidmaq Nov 25, 2024
149fdcc
Put prompt processing in same stream (#1122)
awni Nov 25, 2024
5d841fd
Accept mx.array type for prompt argument for stream_generate (#1125)
neilmehta24 Nov 27, 2024
278884f
Add olmo2 (#1128)
awni Dec 2, 2024
7ee0a55
Fix bug in FluxSampler.timesteps method (#1131)
hehua2008 Dec 2, 2024
a0e7965
Allow loading from diffusers ckpt (#1117)
angeloskath Dec 2, 2024
a73de93
Fix data_iter in prepare_dataset from speechcommands example (#1113)
sakares Dec 3, 2024
e08c470
Allow prompt callback to `generate_step` (#1133)
awni Dec 4, 2024
7bb1298
Add mentions of MLX-my-repo. (#1129)
Vaibhavs10 Dec 4, 2024
e61847e
`mlx_lm.evaluate` (#1140)
barronalex Dec 8, 2024
1bf8129
Mixed Quantizations (#1132)
barronalex Dec 8, 2024
21f5f66
Fix flux training with batch size (#1135)
hehua2008 Dec 9, 2024
2932980
Fix final message at end of flux training (#1143)
petersibley Dec 9, 2024
5d561f1
Change Flux default max_shift to 1.15 to match the official one (#1137)
hehua2008 Dec 9, 2024
f1d730a
Adds EXAONE architecture. (#1145)
N8python Dec 9, 2024
64781fd
Support for multiple EOS tokens (#1141)
madroidmaq Dec 9, 2024
b83a730
Fix max_tokens (#1148)
barronalex Dec 10, 2024
6eb95c6
fix llava (#1149)
awni Dec 12, 2024
3f89574
Add finish_reason in GenerationResponse (#1153)
madroidmaq Dec 12, 2024
3858956
Replace unicode errors instead of raising exception (#1146)
angeloskath Dec 12, 2024
1026cc5
[mlx-lm] Use top p in server (#1144)
awni Dec 12, 2024
46fd8b7
* rebase with main
awni Dec 12, 2024
fb3d052
chore: update evaluate.py (#1159)
eltociear Dec 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flux/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,4 @@ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
tic = time.time()

save_adapters("final_adapters.safetensors", flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.")
print("Training successful.")
2 changes: 2 additions & 0 deletions flux/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def __init__(self, params: FluxParams):
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if k.startswith("model.diffusion_model."):
k = k[22:]
if k.endswith(".scale"):
k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
Expand Down
5 changes: 3 additions & 2 deletions flux/flux/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class FluxSampler:
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15):
self._base_shift = base_shift
self._max_shift = max_shift
self._schnell = "schnell" in name
Expand All @@ -25,7 +25,7 @@ def timesteps(
):
t = mx.linspace(start, stop, num_steps + 1)

if self._schnell:
if not self._schnell:
t = self._time_shift(image_sequence_length, t)

return t.tolist()
Expand All @@ -50,6 +50,7 @@ def add_noise(self, x, t, noise=None, key=None):
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
t = t.reshape([-1] + [1] * (x.ndim - 1))
return x * (1 - t) + t * noise

def step(self, pred, x_t, t, t_prev):
Expand Down
7 changes: 3 additions & 4 deletions llava/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def load_image(image_source):
def prepare_inputs(processor, image, prompt):
if isinstance(image, str):
image = load_image(image)
inputs = processor(prompt, image, return_tensors="np")
inputs = processor(image, prompt, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
return input_ids, pixel_values
return pixel_values, input_ids


def load_model(model_path, tokenizer_config={}):
Expand Down Expand Up @@ -126,8 +126,7 @@ def main():
processor, model = load_model(args.model, tokenizer_config)

prompt = codecs.decode(args.prompt, "unicode_escape")

input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)

print(prompt)
generated_text = generate_text(
Expand Down
26 changes: 8 additions & 18 deletions llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,31 +104,21 @@ def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
image_positions = mx.array(
np.where(input_ids[0] == image_token_index)[0], mx.uint32
)

if len(image_positions) != num_images:
if len(image_positions) != num_image_patches:
raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images})."
f" match the number of image patches ({num_image_patches})."
)

text_segments = []
start_idx = 0

for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1

image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]

# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
inputs_embeds[0, image_positions] = image_features
return inputs_embeds

def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
Expand Down
17 changes: 11 additions & 6 deletions llms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

response = generate(model, tokenizer, prompt=prompt, verbose=True)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
```

To see a description of all the arguments you can do:
Expand All @@ -77,7 +77,7 @@ to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.

You can convert models in the Python API with:
You can convert models using the Python API:

```python
from mlx_lm import convert
Expand All @@ -100,8 +100,9 @@ To see a description of all the arguments you can do:

#### Streaming

For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text, token, and log probabilities.
For streaming generation, use the `stream_generate` function. This yields
a generation response object.

For example,

```python
Expand All @@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(response.text, end="", flush=True)
print()
```

Expand Down Expand Up @@ -162,6 +163,10 @@ mlx_lm.convert \
--upload-repo mlx-community/my-4bit-mistral
```

Models can also be converted and quantized directly in the
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
Face Space.

### Long Prompts and Generations

`mlx-lm` has some tools to scale efficiently to long prompts and generations:
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ curl localhost:8080/v1/chat/completions \

- `system_fingerprint`: A unique identifier for the system.

- `object`: Any of "chat.completions", "chat.completions.chunk" (for
- `object`: Any of "chat.completion", "chat.completion.chunk" (for
streaming), or "text.completion".

- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.19.3"
__version__ = "0.20.4"
35 changes: 14 additions & 21 deletions llms/mlx_lm/cache_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mlx.core as mx

from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load, maybe_quantize_kv_cache
from .utils import generate_step, load

DEFAULT_QUANTIZED_KV_START = 5000

Expand Down Expand Up @@ -50,12 +50,6 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument(
"--max-kv-size",
type=int,
Expand Down Expand Up @@ -99,9 +93,6 @@ def main():
parser = setup_arg_parser()
args = parser.parse_args()

if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)

# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
Expand Down Expand Up @@ -144,26 +135,28 @@ def main():
y = mx.array(tokenizer.encode(prompt))

# Process the prompt
processed = 0
step_size = 512
start = time.time()
max_msg_len = 0
while y.size > 0:

model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
def callback(processed, total_tokens):
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
nonlocal max_msg_len
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)

maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
for _ in generate_step(
y,
model,
max_tokens=0,
prompt_cache=cache,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
prompt_progress_callback=callback,
):
pass

print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
Expand Down
12 changes: 6 additions & 6 deletions llms/mlx_lm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import mlx.core as mx

from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate

DEFAULT_TEMP = 0.0
Expand Down Expand Up @@ -74,16 +75,15 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response, *_ in stream_generate(
for response in stream_generate(
model,
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
max_tokens=args.max_tokens,
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print(response.text, flush=True, end="")
print()


Expand Down
Loading