Skip to content

Commit

Permalink
attempt 2 at preparing for strict concurrency (#90)
Browse files Browse the repository at this point in the history
* attempt 2 at preparing for strict concurrency

- see also #83
- this marks many things in Sendable (which I think we can take regardless)
- creates an actor container for models and tokenizers, which are not Sendable (though perhaps Tokenizers could be)
  • Loading branch information
davidkoski authored Aug 2, 2024
1 parent 885e520 commit fb5ee82
Show file tree
Hide file tree
Showing 29 changed files with 506 additions and 278 deletions.
86 changes: 41 additions & 45 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ struct ContentView: View {
}

@Observable
@MainActor
class LLMEvaluator {

@MainActor
var running = false

var output = ""
Expand All @@ -172,91 +172,87 @@ class LLMEvaluator {

enum LoadState {
case idle
case loaded(LLMModel, Tokenizers.Tokenizer)
case loaded(ModelContainer)
}

var loadState = LoadState.idle

/// load and return the model -- can be called multiple times, subsequent calls will
/// just return the loaded model
func load() async throws -> (LLMModel, Tokenizers.Tokenizer) {
func load() async throws -> ModelContainer {
switch loadState {
case .idle:
// limit the buffer cache
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)

let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
{
[modelConfiguration] progress in
DispatchQueue.main.sync {
Task { @MainActor in
self.modelInfo =
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
}
}
self.modelInfo =
"Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
loadState = .loaded(model, tokenizer)
return (model, tokenizer)
loadState = .loaded(modelContainer)
return modelContainer

case .loaded(let model, let tokenizer):
return (model, tokenizer)
case .loaded(let modelContainer):
return modelContainer
}
}

func generate(prompt: String) async {
let canGenerate = await MainActor.run {
if running {
return false
} else {
running = true
self.output = ""
return true
}
}
guard !running else { return }

guard canGenerate else { return }
running = true
self.output = ""

do {
let (model, tokenizer) = try await load()
let modelContainer = try await load()

// augment the prompt as needed
let prompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: prompt)

let promptTokens = await modelContainer.perform { _, tokenizer in
tokenizer.encode(text: prompt)
}

// each time you generate you will get something new
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))

let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
) { tokens in
// update the output -- this will make the view show the text as it generates
if tokens.count % displayEveryNTokens == 0 {
let text = tokenizer.decode(tokens: tokens)
await MainActor.run {
self.output = text
let result = await modelContainer.perform { model, tokenizer in
LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
) { tokens in
// update the output -- this will make the view show the text as it generates
if tokens.count % displayEveryNTokens == 0 {
let text = tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = text
}
}
}

if tokens.count >= maxTokens {
return .stop
} else {
return .more
if tokens.count >= maxTokens {
return .stop
} else {
return .more
}
}
}

// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
await MainActor.run {
if result.output != self.output {
self.output = result.output
}
running = false
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
if result.output != self.output {
self.output = result.output
}
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"

} catch {
await MainActor.run {
running = false
output = "Failed: \(error)"
}
output = "Failed: \(error)"
}

running = false
}
}
25 changes: 7 additions & 18 deletions Applications/LLMEval/ViewModels/DeviceStat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,22 @@ import LLM
import MLX

@Observable
class DeviceStat {
final class DeviceStat: @unchecked Sendable {

@MainActor
var gpuUsage = GPU.snapshot()
private var initialGPUSnapshot = GPU.snapshot()

private let initialGPUSnapshot = GPU.snapshot()
private var timer: Timer?

init() {
startTimer()
}

deinit {
stopTimer()
}

private func startTimer() {
timer?.invalidate()
timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in
self?.updateStats()
self?.updateGPUUsages()
}
}

private func stopTimer() {
deinit {
timer?.invalidate()
timer = nil
}

private func updateStats() {
updateGPUUsages()
}

private func updateGPUUsages() {
Expand Down
Loading

0 comments on commit fb5ee82

Please sign in to comment.