-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
handle partially quantized models #76
Conversation
@@ -125,6 +125,8 @@ struct ContentView: View { | |||
|
|||
} | |||
.task { | |||
self.prompt = llm.modelConfiguration.defaultPrompt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the new default prompt (in case the model changes)
@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra | |||
logits = logits.asType(.float32) | |||
} | |||
|
|||
let probs = softMax(logits / temp, axis: -1) | |||
let probs = softmax(logits / temp, axis: -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deprecation warning fix
if repetitionContext.shape[0] > parameters.repetitionContextSize { | ||
repetitionContext = repetitionContext[1...] | ||
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix for #71 -- this uses a direct port of the python code. Previously it looks like there were workarounds for the lack of full numpy array indexing, see take()
above.
} | ||
.map { | ||
$0.last! | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix for #74
@@ -12,4 +12,15 @@ public protocol LLMModel: Module { | |||
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( | |||
MLXArray, [(MLXArray, MLXArray)] | |||
) | |||
|
|||
/// Optionally preprocess the weights and modify / remove values as needed. | |||
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To match the mlx_lm
implementation
quantizeIfNeeded(model: model, weights: weights, quantization: quantization) | ||
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) { | ||
path, module in | ||
weights["\(path).scales"] != nil |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the fix for loading partially quantized models -- if they have "scales" then the layer is quantized. We don't need to test the type of layer as the quantize()
method will only convert layers that can be converted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this depends on ml-explore/mlx-swift#73
|
||
private func quantizeIfNeeded( | ||
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization | ||
) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice simplification!
|
||
public static let codeLlama13b4bit = ModelConfiguration( | ||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", | ||
overrideTokenizer: "PreTrainedTokenizer" | ||
overrideTokenizer: "PreTrainedTokenizer", | ||
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I verified each of these models and added a defaultPrompt so I don't have to hunt around each time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful, thanks!
@@ -111,40 +130,53 @@ extension ModelConfiguration { | |||
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>" | |||
} | |||
|
|||
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { | |||
prompt in | |||
"Instruct: \(prompt)\nOutput: " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wasn't giving good results (any more?) and isn't used on the python side.
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable { | |||
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false | |||
self.ropeScaling = try container.decodeIfPresent( | |||
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling) | |||
self.tieWordEmbeddings = | |||
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The qwen2 model implementation has changed slightly
note: the CI tests are failing because this depends on the mlx-swift change that needs to merge first |
tested it with mlx-swift fork mlx-swift fork, this breaks for me with Libraries/LLM/Load.swift:62:13 Cannot find 'quantize' in scope |
That sounds like the hookup with that other branch didn't work. See: |
was able to fix dependencies but getting runtime error: |
Libraries/LLM/Qwen2.swift
Outdated
if configuration.tieWordEmbeddings && weights["lm_head.weight"] == nil { | ||
weights["lm_head.weight"] = nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks kind of odd. If it's equal to nil
, why set it to nil
? Is that a Swift thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, you are right that is odd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the original python:
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
I think I can just remove the && ...
Libraries/LLM/Evaluate.swift
Outdated
/// - didGenerate: visitor for the tokens as they are generated | ||
public func generate( | ||
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer, | ||
configuration: ModelConfiguration, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not certain about the way this model configuration gets passed around and what the intention for it is (the naming is a bit ambiguous). Somehow we don't have a need for it in Python, so I'm wondering why do you need it here? What should go in it vs in the tokenizer/model directly?
From what I can tell it's more like the default arguments for a given model (eos token / prompt). The prompt gets handled outside this function. So here it's just for the additional eos tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, just for the additional eos tokens. I can switch that to just pass in additional eos tokens (optional). If we need more parameters in the future I can rethink that slightly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! LGTM
generate
silently fails on OpenELM 3B after latest commit #69 Phi-3 mini stop token not recognized #74Note: this isn't directly usable as it requires ml-explore/mlx-swift#73 (I was using the local checkout for development). This will need an update to the mlx version after 73 is merged.