diff --git a/Sources/SimilaritySearchKit/Core/Embeddings/EmbeddingProtocols.swift b/Sources/SimilaritySearchKit/Core/Embeddings/EmbeddingProtocols.swift index 6a6abc6..b883b4c 100644 --- a/Sources/SimilaritySearchKit/Core/Embeddings/EmbeddingProtocols.swift +++ b/Sources/SimilaritySearchKit/Core/Embeddings/EmbeddingProtocols.swift @@ -38,9 +38,9 @@ public protocol DistanceMetricProtocol { /// Find the nearest neighbors given a query embedding vector and a list of embeddings vectors. /// /// - Parameters: - /// - queryEmbedding: A `[Float]` array representing the query embedding vector. - /// - itemEmbeddings: A `[[Float]]` array representing the list of embeddings vectors to search within. - /// - resultsCount: An Int representing the number of nearest neighbors to return. + /// - queryEmbedding: A `[Float]` array representing the query embedding vector. + /// - itemEmbeddings: A `[[Float]]` array representing the list of embeddings vectors to search within. + /// - resultsCount: An Int representing the number of nearest neighbors to return. /// /// - Returns: A `[(Float, Int)]` array, where each tuple contains the similarity score and the index of the corresponding item in `neighborEmbeddings`. The array is sorted by decreasing similarity ranking. func findNearest(for queryEmbedding: [Float], in neighborEmbeddings: [[Float]], resultsCount: Int) -> [(Float, Int)] diff --git a/Sources/SimilaritySearchKit/Core/Embeddings/Metrics/DistanceMetrics.swift b/Sources/SimilaritySearchKit/Core/Embeddings/Metrics/DistanceMetrics.swift index 4cd1ecf..b752aab 100644 --- a/Sources/SimilaritySearchKit/Core/Embeddings/Metrics/DistanceMetrics.swift +++ b/Sources/SimilaritySearchKit/Core/Embeddings/Metrics/DistanceMetrics.swift @@ -76,37 +76,49 @@ public struct EuclideanDistance: DistanceMetricProtocol { /// Helper function to sort scores and return the top K scores with their indices. /// /// - Parameters: -/// - scores: An array of Float values representing scores. -/// - topK: The number of top scores to return. +/// - scores: An array of Float values representing scores. +/// - topK: The number of top scores to return. +/// /// - Returns: An array of tuples containing the top K scores and their corresponding indices. public func sortedScores(scores: [Float], topK: Int) -> [(Float, Int)] { // Combine indices & scores let indexedScores = scores.enumerated().map { index, score in (score, index) } // Sort by decreasing score - let sortedIndexedScores = indexedScores.sorted { $0.0 > $1.0 } + func compare(a: (Float, Int), b: (Float, Int)) throws -> Bool { + return a.0 > b.0 + } // Take top k neighbors - let results = Array(sortedIndexedScores.prefix(topK)) - - return results + do { + return try indexedScores.topK(topK, by: compare) + } catch { + print("There has been an error comparing elements in sortedScores") + return [] + } } /// Helper function to sort distances and return the top K distances with their indices. /// /// - Parameters: -/// - distances: An array of Float values representing distances. -/// - topK: The number of top distances to return. +/// - distances: An array of Float values representing distances. +/// - topK: The number of top distances to return. +/// /// - Returns: An array of tuples containing the top K distances and their corresponding indices. public func sortedDistances(distances: [Float], topK: Int) -> [(Float, Int)] { // Combine indices & distances let indexedDistances = distances.enumerated().map { index, score in (score, index) } // Sort by increasing distance - let sortedIndexedDistances = indexedDistances.sorted { $0.0 < $1.0 } + func compare(a: (Float, Int), b: (Float, Int)) throws -> Bool { + return a.0 < b.0 + } // Take top k neighbors - let results = Array(sortedIndexedDistances.prefix(topK)) - - return results + do { + return try indexedDistances.topK(topK, by: compare) + } catch { + print("There has been an error comparing elements in sortedDistances") + return [] + } } diff --git a/Sources/SimilaritySearchKit/Core/Embeddings/Metrics/TopK.swift b/Sources/SimilaritySearchKit/Core/Embeddings/Metrics/TopK.swift new file mode 100644 index 0000000..a64cb7f --- /dev/null +++ b/Sources/SimilaritySearchKit/Core/Embeddings/Metrics/TopK.swift @@ -0,0 +1,60 @@ +// +// TopK.swift +// +// +// Created by Bernhard Eisvogel on 31.10.23. +// + +import Foundation + +public extension Collection { + /// Helper function to sort distances and return the top K distances with their indices. + /// + /// The `by` parameter accepts a function of the following form: + /// ```swift + /// (Element, Element) throws -> Bool + /// ``` + /// + /// Adapted from [Stackoverflow](https://stackoverflow.com/questions/65746299/how-do-you-find-the-top-3-maximum-values-in-a-swift-dictionary) + /// + /// - Parameters: + /// - count: the number of top distances to return. + /// - by: comparison function + /// + /// - Returns: ordered array containing the top K distances + /// + /// - Note: TopK and the standard swift implementations switch elements with equal value differently + func topK(_ count: Int, by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows -> [Self.Element] { + assert(count >= 0, + """ + Cannot prefix with a negative amount of elements! + """) + + guard count > 0 else { + return [] + } + + let prefixCount = Swift.min(count, self.count) + + guard prefixCount < self.count / 10 else { + return try Array(sorted(by: areInIncreasingOrder).prefix(prefixCount)) + } + + var result = try self.prefix(prefixCount).sorted(by: areInIncreasingOrder) + + for e in self.dropFirst(prefixCount) { + if let last = result.last, try areInIncreasingOrder(last, e) { + continue + } + let insertionIndex = try result.partition { try areInIncreasingOrder(e, $0) } + let isLastElement = insertionIndex == result.endIndex + result.removeLast() + if isLastElement { + result.append(e) + } else { + result.insert(e, at: insertionIndex) + } + } + return result + } +} diff --git a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift index 41094b2..6e85b4d 100644 --- a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift +++ b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift @@ -324,6 +324,24 @@ extension SimilarityIndex { } public func loadIndex(fromDirectory path: URL? = nil, name: String? = nil) throws -> [IndexItem]? { + if let indexPath = try getIndexPath(fromDirectory: path, name: name) { + let loadedIndexItems = try vectorStore.loadIndex(from: indexPath) + addItems(loadedIndexItems) + print("Loaded \(indexItems.count) index items from \(indexPath.absoluteString)") + return loadedIndexItems + } + + return nil + } + + /// This function returns the default location where the data from the loadIndex/saveIndex functions gets stored + /// gets stored. + /// - Parameters: + /// - fromDirectory: optional directory path where the file postfix is added to + /// - name: optional name + /// + /// - Returns: an optional URL + public func getIndexPath(fromDirectory path: URL? = nil, name: String? = nil) throws -> URL? { let indexName = name ?? self.indexName let basePath: URL @@ -333,15 +351,7 @@ extension SimilarityIndex { // Default local path basePath = try getDefaultStoragePath() } - - if let vectorStorePath = vectorStore.listIndexes(at: basePath).first(where: { $0.lastPathComponent.contains(indexName) }) { - let loadedIndexItems = try vectorStore.loadIndex(from: vectorStorePath) - addItems(loadedIndexItems) - print("Loaded \(indexItems.count) index items from \(vectorStorePath.absoluteString)") - return loadedIndexItems - } - - return nil + return vectorStore.listIndexes(at: basePath).first(where: { $0.lastPathComponent.contains(indexName) }) } private func getDefaultStoragePath() throws -> URL { diff --git a/Tests/SimilaritySearchKitTests/DistanceTest.swift b/Tests/SimilaritySearchKitTests/DistanceTest.swift new file mode 100644 index 0000000..721eadc --- /dev/null +++ b/Tests/SimilaritySearchKitTests/DistanceTest.swift @@ -0,0 +1,64 @@ +// +// DistanceTest.swift +// +// +// Created by Bernhard Eisvogel on 31.10.23. +// + +@testable import SimilaritySearchKit +import XCTest + +func randomString(_ length: Int) -> String { + let letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789아오우" + return String((0.. Bool { + return a Bool { + return a.hashValue