Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add VLM support, refactor common LM code into MLXLMCommon. breaking A…
…PI changes (#151) * implement VLM - based on models from https://github.com/Blaizzy/mlx-vlm There are two new libraries: - `MLXVLM` contains vision language models that combine images and text prompts to produce text results, e.g. `describe this image` - `MLXLMCommon` contains the `LanguageModel` code that is shared between `MLXLLM` and `MLXVLM` The API between `LLM` and `VLM` is identical aside from the preparation of the `UserInput`. ```swift let parameters = GenerateParameters() // LLM prompt let input = UserInput(prompt: "tell me a story") // VLM prompt let input = UserInput(prompt: "describe the image", images: [.url(url)]) // inference is identical let result = try await modelContainer.perform { [generate, input] context in let input = try await context.processor.prepare(input: input) return try generate(input: input, parameters: parameters, context: context) { token in // print tokens as they are generated, stop early, etc. return .more } } ``` VLM example code is available in the `llm-tool` example: ``` ./mlx-run llm-tool eval --help OVERVIEW: evaluate prompt and images to generate text (VLM) USAGE: llm-tool eval <options> OPTIONS: --model <model> Name of the huggingface model or absolute path to directory -p, --prompt <prompt> The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt --resize <resize> Resize images to this size (width, height) --image <image> Paths or urls for input images ... ``` Probably no effect to code external to this repo: - the mlx-swift-examples.xcodeproj now references the local `Package.swift` to build the libraries - the example code now uses the naming matching external uses of mlx-swift-examples, e.g. `import LLM` -> `import MLXLLM` - the library directories are now renamed to match their target names, e.g. `LLM` -> `MLXLLM` Breaking: - some code will now need to import both `MLXLLM` and `MLXLMCommon` (particularly code that loads models) - `MLXLMCommon` contains the common API between LLM and VLM ```swift import MLXLLM import MLXLMCommon ``` - constants for models have moved from `ModelConfiguration` to `ModelRegistry` - this is `MLXLM.ModelRegistry` and there is also `MLXVLM.ModelRegistry` ```diff - let modelConfiguration = ModelConfiguration.phi3_5_4bit + let modelConfiguration = ModelRegistry.phi3_5_4bit ``` - the `loadModelContainer()` function is now `LLMModelFactory.shared.loadContainer()` - there is a new `VLMModelFactory` with identical methods for loading VLMs ```diff - let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration) - { + let modelContainer = try await LLMModelFactory.shared.loadContainer( + configuration: modelConfiguration + ) { ``` - `ModelContainer.perform` is now throwing (and in MLXLMCommon): ```diff - let result = await modelContainer.perform { model, tokenizer in - LLM.generate( + let result = try await modelContainer.perform { model, tokenizer in + try MLXLMCommon.generate( ``` - `ModelConfiguration` previously had a way to register new configurations. This is now on `LLMModelFactory` (and `VLMModelFactory` has the same): ```swift LLMModelFactory.shared.modelRegistry.register(configurations: [modelConfiguration]) ``` An example at the end shows all of these deprecations in context. **Prefer to use the `ModelContext.processor` to prepare prompts.** Previously users would pass in a bare `[Int]` of tokens, but in order to support more complex inputs (VLMs) the use of bare `[Int]` is deprecated and callers should use `UserInput` and `LMInput`. For example, previously callers might have done something like this: ```swift let messages = [["role": "user", "content": prompt]] let promptTokens = try await modelContainer.perform { _, tokenizer in try tokenizer.applyChatTemplate(messages: messages) } ``` Now that should be: ```swift let input = try await context.processor.prepare(input: .init(prompt: prompt)) ``` Which will initialize a `UserInput` from the prompt text and produce an `LMInput` that can be used to generate tokens. **This call to `generate()` is now deprecated:** ```swift public func generate( promptTokens: [Int], parameters: GenerateParameters, model: any LanguageModel, tokenizer: Tokenizer, extraEOSTokens: Set<String>? = nil, didGenerate: ([Int]) -> GenerateDisposition ) throws -> GenerateResult ``` This consumed the `[Int]` variety of tokens. Now this is preferred: ```swift public func generate( input: LMInput, parameters: GenerateParameters, context: ModelContext, didGenerate: ([Int]) -> GenerateDisposition ) throws -> GenerateResult ``` **This method on `ModelContainer` is now deprecated:** ```swift /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext") public func perform<R>(_ action: @sendable (any LanguageModel, Tokenizer) throws -> R) rethrows -> R ``` use this one instead (though the former still works): ```swift /// Perform an action on the ``ModelContext``. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. public func perform<R>(_ action: @sendable (ModelContext) async throws -> R) async rethrows -> R ``` Putting all of these deprecations together, previously you might have generated text like this: ```swift let messages = [["role": "user", "content": prompt]] let promptTokens = try await modelContainer.perform { _, tokenizer in try tokenizer.applyChatTemplate(messages: messages) } let result = await modelContainer.perform { model, tokenizer in LLM.generate( promptTokens: promptTokens, parameters: generateParameters, model: model, tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens ) { tokens in ... } } ``` now do this: ```swift let result = try await modelContainer.perform { context in let input = try await context.processor.prepare(input: .init(prompt: prompt)) return try MLXLMCommon.generate( input: input, parameters: generateParameters, context: context ) { tokens in ... } } ``` Co-authored-by: Awni Hannun <[email protected]>
- Loading branch information