diff --git a/Examples/BasicExample/BasicExample/ContentView.swift b/Examples/BasicExample/BasicExample/ContentView.swift index 7feeae5..12e51cb 100644 --- a/Examples/BasicExample/BasicExample/ContentView.swift +++ b/Examples/BasicExample/BasicExample/ContentView.swift @@ -16,10 +16,16 @@ struct ContentView: View { @State private var querySentence: String = "" @State private var similarityResults: [SearchResult] = [] @State private var similarityIndex: SimilarityIndex? + @State private var similarityIndexComparison: SimilarityIndex? func loadIndex() async { + var model: any EmbeddingsProtocol = MiniLMEmbeddings() + #if canImport(NaturalLanguage.NLContextualEmbedding) + embeddingModel = NativeContextualEmbeddings(language: .english) + #else + similarityIndex = await SimilarityIndex( - model: MiniLMEmbeddings(), + model: model, metric: CosineSimilarity() ) } diff --git a/Examples/ChatWithFilesExample/ChatWithFilesExample/ChatWithFilesExampleSwiftUIView.swift b/Examples/ChatWithFilesExample/ChatWithFilesExample/ChatWithFilesExampleSwiftUIView.swift index 55f344a..b41694d 100644 --- a/Examples/ChatWithFilesExample/ChatWithFilesExample/ChatWithFilesExampleSwiftUIView.swift +++ b/Examples/ChatWithFilesExample/ChatWithFilesExample/ChatWithFilesExampleSwiftUIView.swift @@ -272,7 +272,11 @@ struct ChatWithFilesExampleSwiftUIView: View { embeddingModel = MultiQAMiniLMEmbeddings() currentTokenizer = BertTokenizer() case .native: + #if canImport(NaturalLanguage.NLContextualEmbedding) + embeddingModel = NativeContextualEmbeddings() + #else embeddingModel = NativeEmbeddings() + #endif currentTokenizer = NativeTokenizer() } diff --git a/Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeContextualEmbeddings.swift b/Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeContextualEmbeddings.swift new file mode 100644 index 0000000..e6cdab8 --- /dev/null +++ b/Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeContextualEmbeddings.swift @@ -0,0 +1,89 @@ +// +// NativeContextualEmbeddings.swift +// +// +// Created by Zach Nagengast on 10/11/23. +// + +import Foundation +import NaturalLanguage +import CoreML + +#if canImport(NaturalLanguage.NLContextualEmbedding) +@available(macOS 14.0, iOS 17.0, *) +public class NativeContextualEmbeddings: EmbeddingsProtocol { + public let model: ModelActor + public let tokenizer: any TokenizerProtocol + + // Initialize with a language + public init(language: NLLanguage = .english) { + self.tokenizer = NativeTokenizer() + guard let nativeModel = NLContextualEmbedding(language: language) else { + fatalError("Failed to load the Core ML model.") + } + Self.initializeModel(nativeModel) + self.model = ModelActor(model: nativeModel) + } + + // Initialize with a script + public init(script: NLScript) { + self.tokenizer = NativeTokenizer() + guard let nativeModel = NLContextualEmbedding(script: script) else { + fatalError("Failed to load the Core ML model.") + } + Self.initializeModel(nativeModel) + self.model = ModelActor(model: nativeModel) + } + + // Common model initialization logic + private static func initializeModel(_ nativeModel: NLContextualEmbedding) { + if !nativeModel.hasAvailableAssets { + nativeModel.requestAssets { _, _ in } + } + try! nativeModel.load() + } + + // MARK: - Dense Embeddings + + public actor ModelActor { + private let model: NLContextualEmbedding + + init(model: NLContextualEmbedding) { + self.model = model + } + + func vector(for sentence: String) -> [Float]? { + // Obtain embedding result for the given sentence + // Shape is [1, embedding.sequenceLength, model.dimension] + let embedding = try! model.embeddingResult(for: sentence, language: nil) + + // Initialize an array to store the total embedding values and set the count + var meanPooledEmbeddings: [Float] = Array(repeating: 0.0, count: model.dimension) + let sequenceLength = embedding.sequenceLength + + // Mean pooling: Loop through each token vector in the embedding and sum the values + embedding.enumerateTokenVectors(in: sentence.startIndex ..< sentence.endIndex) { (embedding, _) -> Bool in + for tokenEmbeddingIndex in 0 ..< embedding.count { + meanPooledEmbeddings[tokenEmbeddingIndex] += Float(embedding[tokenEmbeddingIndex]) + } + return true + } + + // Mean pooling: Get the average embedding from totals + if sequenceLength > 0 { + for index in 0 ..< sequenceLength { + meanPooledEmbeddings[index] /= Float(sequenceLength) + } + } + + // Return the mean-pooled vector + return meanPooledEmbeddings + } + } + + public func encode(sentence: String) async -> [Float]? { + return await model.vector(for: sentence) + } +} +#endif + diff --git a/Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift b/Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift index 94a4f1e..7b8ce65 100644 --- a/Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift +++ b/Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift @@ -10,11 +10,12 @@ import NaturalLanguage @available(macOS 11.0, iOS 15.0, *) public class NativeEmbeddings: EmbeddingsProtocol { - public let model = ModelActor() + public let model: ModelActor public let tokenizer: any TokenizerProtocol - public init() { + public init(language: NLLanguage = .english) { self.tokenizer = NativeTokenizer() + self.model = ModelActor(language: language) } // MARK: - Dense Embeddings @@ -22,8 +23,8 @@ public class NativeEmbeddings: EmbeddingsProtocol { public actor ModelActor { private let model: NLEmbedding - init() { - guard let nativeModel = NLEmbedding.sentenceEmbedding(for: .english) else { + init(language: NLLanguage) { + guard let nativeModel = NLEmbedding.sentenceEmbedding(for: language) else { fatalError("Failed to load the Core ML model.") } model = nativeModel diff --git a/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift b/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift index 220c31a..c908040 100644 --- a/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift +++ b/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift @@ -14,6 +14,12 @@ import CoreML @available(macOS 13.0, iOS 16.0, *) class SimilaritySearchKitTests: XCTestCase { + + override func setUp() { + executionTimeAllowance = 60 + continueAfterFailure = true + } + func testSavingJsonIndex() async { let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())