Skip to content

Commit

Permalink
add VLM support, refactor common LM code into MLXLMCommon. breaking A…
Browse files Browse the repository at this point in the history
…PI changes (#151)

* implement VLM

- based on models from https://github.com/Blaizzy/mlx-vlm

There are two new libraries:

- `MLXVLM` contains vision language models that combine images and text prompts to produce text results, e.g. `describe this image`
- `MLXLMCommon` contains the `LanguageModel` code that is shared between `MLXLLM` and `MLXVLM`

The API between `LLM` and `VLM` is identical aside from the preparation of the `UserInput`.

```swift
let parameters = GenerateParameters()

// LLM prompt
let input = UserInput(prompt: "tell me a story")

// VLM prompt
let input = UserInput(prompt: "describe the image", images: [.url(url)])

// inference is identical
let result = try await modelContainer.perform { [generate, input] context in
    let input = try await context.processor.prepare(input: input)
    return try generate(input: input, parameters: parameters, context: context) { token in
        // print tokens as they are generated, stop early, etc.
        return .more
    }
}
```

VLM example code is available in the `llm-tool` example:

```
./mlx-run llm-tool eval --help
OVERVIEW: evaluate prompt and images to generate text (VLM)

USAGE: llm-tool eval <options>

OPTIONS:
  --model <model>         Name of the huggingface model or absolute path to directory
  -p, --prompt <prompt>   The message to be processed by the model.  Use @path,@path to load from files, e.g. @/tmp/prompt.txt
  --resize <resize>       Resize images to this size (width, height)
  --image <image>         Paths or urls for input images
...
```

Probably no effect to code external to this repo:

- the mlx-swift-examples.xcodeproj now references the local `Package.swift` to build the libraries
- the example code now uses the naming matching external uses of mlx-swift-examples, e.g. `import LLM` -> `import MLXLLM`
- the library directories are now renamed to match their target names, e.g. `LLM` -> `MLXLLM`

Breaking:

- some code will now need to import both `MLXLLM` and `MLXLMCommon` (particularly code that loads models)
- `MLXLMCommon` contains the common API between LLM and VLM

```swift
import MLXLLM
import MLXLMCommon
```

- constants for models have moved from `ModelConfiguration` to `ModelRegistry`
- this is `MLXLM.ModelRegistry` and there is also `MLXVLM.ModelRegistry`

```diff
-    let modelConfiguration = ModelConfiguration.phi3_5_4bit
+    let modelConfiguration = ModelRegistry.phi3_5_4bit
```

- the `loadModelContainer()` function is now `LLMModelFactory.shared.loadContainer()`
- there is a new `VLMModelFactory` with identical methods for loading VLMs

```diff
-     let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
-    {
+     let modelContainer = try await LLMModelFactory.shared.loadContainer(
+          configuration: modelConfiguration
+    ) {
```

- `ModelContainer.perform` is now throwing (and in MLXLMCommon):

```diff
-     let result = await modelContainer.perform { model, tokenizer in
-          LLM.generate(
+     let result = try await modelContainer.perform { model, tokenizer in
+          try MLXLMCommon.generate(
```

- `ModelConfiguration` previously had a way to register new configurations.  This is now on `LLMModelFactory` (and `VLMModelFactory` has the same):

```swift
LLMModelFactory.shared.modelRegistry.register(configurations: [modelConfiguration])
```

An example at the end shows all of these deprecations in context.

**Prefer to use the `ModelContext.processor` to prepare prompts.**  Previously users would pass in a bare `[Int]` of tokens, but in order to support more complex inputs (VLMs) the use of bare `[Int]` is deprecated and callers should use `UserInput` and `LMInput`.

For example, previously callers might have done something like this:

```swift
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
    try tokenizer.applyChatTemplate(messages: messages)
}
```

Now that should be:

```swift
let input = try await context.processor.prepare(input: .init(prompt: prompt))
```

Which will initialize a `UserInput` from the prompt text and produce an `LMInput` that can be used to generate tokens.

**This call to `generate()` is now deprecated:**

```swift
public func generate(
    promptTokens: [Int], parameters: GenerateParameters, model: any LanguageModel,
    tokenizer: Tokenizer,
    extraEOSTokens: Set<String>? = nil,
    didGenerate: ([Int]) -> GenerateDisposition
) throws -> GenerateResult
```

This consumed the `[Int]` variety of tokens.  Now this is preferred:

```swift
public func generate(
    input: LMInput, parameters: GenerateParameters, context: ModelContext,
    didGenerate: ([Int]) -> GenerateDisposition
) throws -> GenerateResult
```

**This method on `ModelContainer` is now deprecated:**

```swift
    /// Perform an action on the model and/or tokenizer.  Callers _must_ eval any `MLXArray` before returning as
    /// `MLXArray` is not `Sendable`.
    @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext")
    public func perform<R>(_ action: @sendable (any LanguageModel, Tokenizer) throws -> R) rethrows
        -> R
```

use this one instead (though the former still works):

```swift
    /// Perform an action on the ``ModelContext``.  Callers _must_ eval any `MLXArray` before returning as
    /// `MLXArray` is not `Sendable`.
    public func perform<R>(_ action: @sendable (ModelContext) async throws -> R) async rethrows -> R
```

Putting all of these deprecations together, previously you might have generated text like this:

```swift
            let messages = [["role": "user", "content": prompt]]
            let promptTokens = try await modelContainer.perform { _, tokenizer in
                try tokenizer.applyChatTemplate(messages: messages)
            }

            let result = await modelContainer.perform { model, tokenizer in
                LLM.generate(
                    promptTokens: promptTokens, parameters: generateParameters, model: model,
                    tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
                ) { tokens in ... }
            }
```

now do this:

```swift
            let result = try await modelContainer.perform { context in
                let input = try await context.processor.prepare(input: .init(prompt: prompt))
                return try MLXLMCommon.generate(
                    input: input, parameters: generateParameters, context: context
                ) { tokens in ... }
            }
```

Co-authored-by: Awni Hannun <[email protected]>
  • Loading branch information
davidkoski and awni authored Dec 10, 2024
1 parent 318044f commit 6ef303b
Show file tree
Hide file tree
Showing 65 changed files with 5,152 additions and 2,628 deletions.
7 changes: 4 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

mac_build_and_test:
macos:
xcode: 15.2.0
xcode: 16.0.0
resource_class: macos.m1.medium.gen1
steps:
- checkout
Expand All @@ -35,8 +35,9 @@ jobs:
xcrun --show-sdk-build-version
swift --version
find . -name Package.resolved -exec rm {} \;
xcodebuild -skipPackagePluginValidation -scheme llm-tool
xcodebuild -skipPackagePluginValidation -scheme mnist-tool
xcodebuild -scheme llm-tool
xcodebuild -scheme image-tool
xcodebuild -scheme mnist-tool
workflows:
build_and_test:
Expand Down
30 changes: 13 additions & 17 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright © 2024 Apple Inc.

import LLM
import MLX
import MLXLLM
import MLXLMCommon
import MLXRandom
import MarkdownUI
import Metal
Expand Down Expand Up @@ -159,7 +160,7 @@ class LLMEvaluator {

/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
/// more devices.
let modelConfiguration = ModelConfiguration.phi3_5_4bit
let modelConfiguration = ModelRegistry.phi3_5_4bit

/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6)
Expand All @@ -185,17 +186,17 @@ class LLMEvaluator {
// limit the buffer cache
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)

let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
{
let modelContainer = try await LLMModelFactory.shared.loadContainer(
configuration: modelConfiguration
) {
[modelConfiguration] progress in
Task { @MainActor in
self.modelInfo =
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
}
}
let numParams = await modelContainer.perform {
[] model, _ in
return model.numParameters()
let numParams = await modelContainer.perform { context in
context.model.numParameters()
}

self.modelInfo =
Expand All @@ -217,22 +218,17 @@ class LLMEvaluator {
do {
let modelContainer = try await load()

let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// each time you generate you will get something new
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))

let result = await modelContainer.perform { model, tokenizer in
LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
let result = try await modelContainer.perform { context in
let input = try await context.processor.prepare(input: .init(prompt: prompt))
return try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
) { tokens in
// update the output -- this will make the view show the text as it generates
if tokens.count % displayEveryNTokens == 0 {
let text = tokenizer.decode(tokens: tokens)
let text = context.tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = text
}
Expand Down
2 changes: 1 addition & 1 deletion Applications/LLMEval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The example application uses Phi2 model by default, see [ContentView.swift](Cont
let modelConfiguration = ModelConfiguration.phi4bit
```

There are some pre-configured models in [LLM/Models.swift](../../Libraries/LLM/Models.swift#L62)
There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78)
and you can load any weights from Hugging Face where there
is a model architecture defined and you have enough
memory.
Expand Down
2 changes: 1 addition & 1 deletion Applications/LLMEval/ViewModels/DeviceStat.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Foundation
import LLM
import MLX
import MLXLLM

@Observable
final class DeviceStat: @unchecked Sendable {
Expand Down
57 changes: 27 additions & 30 deletions Applications/LoRATrainingExample/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright © 2024 Apple Inc.

import LLM
import MLX
import MLXLLM
import MLXLMCommon
import MLXNN
import MLXOptimizers
import MLXRandom
Expand Down Expand Up @@ -122,7 +123,7 @@ class LoRAEvaluator {

var output = ""

private let modelConfiguration = ModelConfiguration.mistral7B4bit
private let modelConfiguration = ModelRegistry.mistral7B4bit
private var model: ModelState = .idle

private let loraLayers = 4
Expand All @@ -141,8 +142,9 @@ class LoRAEvaluator {
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
}

let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
{
let modelContainer = try await LLMModelFactory.shared.loadContainer(
configuration: modelConfiguration
) {
progress in
Task { @MainActor in
self.progress = .init(
Expand All @@ -160,7 +162,7 @@ class LoRAEvaluator {

private func loadLoRAData(name: String) throws -> [String]? {
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
return try LLM.loadLoRAData(url: url)
return try MLXLLM.loadLoRAData(url: url)
}
return nil
}
Expand Down Expand Up @@ -196,9 +198,9 @@ class LoRAEvaluator {
let modelContainer = try await loadModel()

// apply LoRA adapters and train
await modelContainer.perform { model, _ in
await modelContainer.perform { context in
LoRATrain.convert(
model: model, layers: loraLayers(model: model))
model: context.model, layers: loraLayers(model: context.model))
}

let train = try loadLoRAData(name: "train")
Expand All @@ -208,11 +210,11 @@ class LoRAEvaluator {
return
}

try await modelContainer.perform { model, tokenizer in
try await modelContainer.perform { context in
let optimizer = Adam(learningRate: learningRate)
try LoRATrain.train(
model: model, train: train, validate: valid, optimizer: optimizer,
tokenizer: tokenizer,
model: context.model, train: train, validate: valid, optimizer: optimizer,
tokenizer: context.tokenizer,
parameters: parameters
) { progress in
Task { @MainActor in
Expand Down Expand Up @@ -240,9 +242,10 @@ class LoRAEvaluator {
return
}

let loss = await modelContainer.perform { model, tokenizer in
let loss = await modelContainer.perform { context in
LoRATrain.evaluate(
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
model: context.model, dataset: test,
tokenizer: context.tokenizer, batchSize: 1, batchCount: 0)
}

self.progress = nil
Expand All @@ -269,26 +272,20 @@ class LoRAEvaluator {

let modelContainer = try await loadModel()

let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// evaluate
let result = await modelContainer.perform { model, tokenizer in
LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens,
didGenerate: { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = fullOutput
}
let result = try await modelContainer.perform { context in
let input = try await context.processor.prepare(input: .init(prompt: prompt))
return try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
) { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = context.tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = fullOutput
}
return tokens.count >= maxTokens ? .stop : .more
})
}
return tokens.count >= maxTokens ? .stop : .more
}
}

self.output = result.output
Expand Down
2 changes: 1 addition & 1 deletion Applications/MNISTTrainer/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright © 2024 Apple Inc.

import MLX
import MLXMNIST
import MLXNN
import MLXOptimizers
import MLXRandom
import MNIST
import SwiftUI

struct TrainingView: View {
Expand Down
2 changes: 1 addition & 1 deletion Applications/MNISTTrainer/PredictionView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//

import MLX
import MLXMNIST
import MLXNN
import MNIST
import SwiftUI

struct Canvas: View {
Expand Down
Loading

0 comments on commit 6ef303b

Please sign in to comment.