Skip to content

Commit

Permalink
Support BERT style multilingual native embeddings (#22)
Browse files Browse the repository at this point in the history
* Add support for BERT style native contextual embeddings

* Update swift.yml

* Add flag for building on macos13

* Fix typo bracket

* Reduce macos version on ci

* Use macos13 not latest
  • Loading branch information
ZachNagengast authored Oct 15, 2023
1 parent 81e176d commit 18b52c6
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 5 deletions.
8 changes: 7 additions & 1 deletion Examples/BasicExample/BasicExample/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ struct ContentView: View {
@State private var querySentence: String = ""
@State private var similarityResults: [SearchResult] = []
@State private var similarityIndex: SimilarityIndex?
@State private var similarityIndexComparison: SimilarityIndex?

func loadIndex() async {
var model: any EmbeddingsProtocol = MiniLMEmbeddings()
#if canImport(NaturalLanguage.NLContextualEmbedding)
embeddingModel = NativeContextualEmbeddings(language: .english)
#else

similarityIndex = await SimilarityIndex(
model: MiniLMEmbeddings(),
model: model,
metric: CosineSimilarity()
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ struct ChatWithFilesExampleSwiftUIView: View {
embeddingModel = MultiQAMiniLMEmbeddings()
currentTokenizer = BertTokenizer()
case .native:
#if canImport(NaturalLanguage.NLContextualEmbedding)
embeddingModel = NativeContextualEmbeddings()
#else
embeddingModel = NativeEmbeddings()
#endif
currentTokenizer = NativeTokenizer()
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//
// NativeContextualEmbeddings.swift
//
//
// Created by Zach Nagengast on 10/11/23.
//

import Foundation
import NaturalLanguage
import CoreML

#if canImport(NaturalLanguage.NLContextualEmbedding)
@available(macOS 14.0, iOS 17.0, *)
public class NativeContextualEmbeddings: EmbeddingsProtocol {
public let model: ModelActor
public let tokenizer: any TokenizerProtocol

// Initialize with a language
public init(language: NLLanguage = .english) {
self.tokenizer = NativeTokenizer()
guard let nativeModel = NLContextualEmbedding(language: language) else {
fatalError("Failed to load the Core ML model.")
}
Self.initializeModel(nativeModel)
self.model = ModelActor(model: nativeModel)
}

// Initialize with a script
public init(script: NLScript) {
self.tokenizer = NativeTokenizer()
guard let nativeModel = NLContextualEmbedding(script: script) else {
fatalError("Failed to load the Core ML model.")
}
Self.initializeModel(nativeModel)
self.model = ModelActor(model: nativeModel)
}

// Common model initialization logic
private static func initializeModel(_ nativeModel: NLContextualEmbedding) {
if !nativeModel.hasAvailableAssets {
nativeModel.requestAssets { _, _ in }
}
try! nativeModel.load()
}

// MARK: - Dense Embeddings

public actor ModelActor {
private let model: NLContextualEmbedding

init(model: NLContextualEmbedding) {
self.model = model
}

func vector(for sentence: String) -> [Float]? {
// Obtain embedding result for the given sentence
// Shape is [1, embedding.sequenceLength, model.dimension]
let embedding = try! model.embeddingResult(for: sentence, language: nil)

// Initialize an array to store the total embedding values and set the count
var meanPooledEmbeddings: [Float] = Array(repeating: 0.0, count: model.dimension)
let sequenceLength = embedding.sequenceLength

// Mean pooling: Loop through each token vector in the embedding and sum the values
embedding.enumerateTokenVectors(in: sentence.startIndex ..< sentence.endIndex) { (embedding, _) -> Bool in
for tokenEmbeddingIndex in 0 ..< embedding.count {
meanPooledEmbeddings[tokenEmbeddingIndex] += Float(embedding[tokenEmbeddingIndex])
}
return true
}

// Mean pooling: Get the average embedding from totals
if sequenceLength > 0 {
for index in 0 ..< sequenceLength {
meanPooledEmbeddings[index] /= Float(sequenceLength)
}
}

// Return the mean-pooled vector
return meanPooledEmbeddings
}
}

public func encode(sentence: String) async -> [Float]? {
return await model.vector(for: sentence)
}
}
#endif

Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@ import NaturalLanguage

@available(macOS 11.0, iOS 15.0, *)
public class NativeEmbeddings: EmbeddingsProtocol {
public let model = ModelActor()
public let model: ModelActor
public let tokenizer: any TokenizerProtocol

public init() {
public init(language: NLLanguage = .english) {
self.tokenizer = NativeTokenizer()
self.model = ModelActor(language: language)
}

// MARK: - Dense Embeddings

public actor ModelActor {
private let model: NLEmbedding

init() {
guard let nativeModel = NLEmbedding.sentenceEmbedding(for: .english) else {
init(language: NLLanguage) {
guard let nativeModel = NLEmbedding.sentenceEmbedding(for: language) else {
fatalError("Failed to load the Core ML model.")
}
model = nativeModel
Expand Down
6 changes: 6 additions & 0 deletions Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import CoreML

@available(macOS 13.0, iOS 16.0, *)
class SimilaritySearchKitTests: XCTestCase {

override func setUp() {
executionTimeAllowance = 60
continueAfterFailure = true
}

func testSavingJsonIndex() async {
let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())

Expand Down

0 comments on commit 18b52c6

Please sign in to comment.