Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle partially quantized models #76

Merged
merged 2 commits into from
May 28, 2024
Merged

handle partially quantized models #76

merged 2 commits into from
May 28, 2024

Conversation

davidkoski
Copy link
Collaborator

@davidkoski davidkoski commented May 20, 2024

Note: this isn't directly usable as it requires ml-explore/mlx-swift#73 (I was using the local checkout for development). This will need an update to the mlx version after 73 is merged.

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
@davidkoski davidkoski requested a review from awni May 20, 2024 23:44
@@ -125,6 +125,8 @@ struct ContentView: View {

}
.task {
self.prompt = llm.modelConfiguration.defaultPrompt
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the new default prompt (in case the model changes)

@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
logits = logits.asType(.float32)
}

let probs = softMax(logits / temp, axis: -1)
let probs = softmax(logits / temp, axis: -1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deprecation warning fix

if repetitionContext.shape[0] > parameters.repetitionContextSize {
repetitionContext = repetitionContext[1...]
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix for #71 -- this uses a direct port of the python code. Previously it looks like there were workarounds for the lack of full numpy array indexing, see take() above.

}
.map {
$0.last!
})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix for #74

@@ -12,4 +12,15 @@ public protocol LLMModel: Module {
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
)

/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match the mlx_lm implementation

quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
path, module in
weights["\(path).scales"] != nil
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix for loading partially quantized models -- if they have "scales" then the layer is quantized. We don't need to test the type of layer as the quantize() method will only convert layers that can be converted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this depends on ml-explore/mlx-swift#73


private func quantizeIfNeeded(
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice simplification!


public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer"
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified each of these models and added a defaultPrompt so I don't have to hunt around each time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful, thanks!

@@ -111,40 +130,53 @@ extension ModelConfiguration {
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}

public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
prompt in
"Instruct: \(prompt)\nOutput: "
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't giving good results (any more?) and isn't used on the python side.

@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The qwen2 model implementation has changed slightly

@davidkoski
Copy link
Collaborator Author

note: the CI tests are failing because this depends on the mlx-swift change that needs to merge first

@solume
Copy link

solume commented May 21, 2024

tested it with mlx-swift fork mlx-swift fork, this breaks for me with Libraries/LLM/Load.swift:62:13 Cannot find 'quantize' in scope
and
Libraries/LLM/Qwen2.swift:198:37 Value of type 'Embedding' has no member 'asLinear'

@davidkoski
Copy link
Collaborator Author

tested it with mlx-swift fork mlx-swift fork, this breaks for me with Libraries/LLM/Load.swift:62:13 Cannot find 'quantize' in scope and Libraries/LLM/Qwen2.swift:198:37 Value of type 'Embedding' has no member 'asLinear'

That sounds like the hookup with that other branch didn't work. See:

@solume
Copy link

solume commented May 25, 2024

was able to fix dependencies but getting runtime error:
libc++abi: terminating due to uncaught exception of type std::invalid_argument: [matmul] Last dimension of first input with shape (1,42,1280) must match second to last dimension of second input with shape (160,32000).

Comment on lines 208 to 210
if configuration.tieWordEmbeddings && weights["lm_head.weight"] == nil {
weights["lm_head.weight"] = nil
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks kind of odd. If it's equal to nil, why set it to nil? Is that a Swift thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you are right that is odd.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the original python:

        if self.args.tie_word_embeddings:
            weights.pop("lm_head.weight", None)

I think I can just remove the && ...

/// - didGenerate: visitor for the tokens as they are generated
public func generate(
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
configuration: ModelConfiguration,
Copy link
Member

@awni awni May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain about the way this model configuration gets passed around and what the intention for it is (the naming is a bit ambiguous). Somehow we don't have a need for it in Python, so I'm wondering why do you need it here? What should go in it vs in the tokenizer/model directly?

From what I can tell it's more like the default arguments for a given model (eos token / prompt). The prompt gets handled outside this function. So here it's just for the additional eos tokens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, just for the additional eos tokens. I can switch that to just pass in additional eos tokens (optional). If we need more parameters in the future I can rethink that slightly.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants