From f4f5a0ab9ae2c8f120685fd1f2a0a2b5ea3f772e Mon Sep 17 00:00:00 2001 From: Lukas Lalinsky Date: Sun, 25 Feb 2024 16:35:22 +0100 Subject: [PATCH] Backport search API changes from v2 --- src/index/index.cpp | 4 ++-- src/index/index_reader.cpp | 42 +++++++++++++++++---------------- src/index/index_reader.h | 3 +-- src/index/index_reader_test.cpp | 24 ++++++++----------- src/index/search_result.h | 8 +------ src/index/segment_searcher.cpp | 23 +++++++++--------- src/index/segment_searcher.h | 8 +------ src/server/http.cpp | 10 +++----- src/server/protocol.cpp | 16 ++++++------- src/server/session.cpp | 13 +++++----- src/server/session.h | 6 ++--- src/server/session_test.cpp | 10 ++++---- 12 files changed, 73 insertions(+), 94 deletions(-) diff --git a/src/index/index.cpp b/src/index/index.cpp index 40f90f49..1ef15f21 100644 --- a/src/index/index.cpp +++ b/src/index/index.cpp @@ -187,7 +187,7 @@ void Index::applyUpdates(const OpBatch &batch) { } -std::vector Index::search(const std::vector &terms, int64_t timeoutInMSecs) { +std::vector Index::search(const std::vector &hashes, int64_t timeoutInMSecs) { auto reader = openReader(); - return reader->search(terms.data(), terms.size(), timeoutInMSecs); + return reader->search(hashes, timeoutInMSecs); } diff --git a/src/index/index_reader.cpp b/src/index/index_reader.cpp index 30428e41..bfb960df 100644 --- a/src/index/index_reader.cpp +++ b/src/index/index_reader.cpp @@ -38,31 +38,33 @@ SegmentDataReader* IndexReader::segmentDataReader(const SegmentInfo& segment) return new SegmentDataReader(m_dir->openFile(segment.dataFileName()), BLOCK_SIZE); } -void IndexReader::search(const uint32_t* fingerprint, size_t length, Collector* collector, int64_t timeoutInMSecs) +std::vector IndexReader::search(const std::vector &hashesIn, int64_t timeoutInMSecs) { auto deadline = timeoutInMSecs > 0 ? (QDateTime::currentMSecsSinceEpoch() + timeoutInMSecs) : 0; - std::vector fp(fingerprint, fingerprint + length); - std::sort(fp.begin(), fp.end()); - const SegmentInfoList& segments = m_info.segments(); - for (int i = 0; i < segments.size(); i++) { - if (deadline > 0) { - if (QDateTime::currentMSecsSinceEpoch() > deadline) { - throw TimeoutExceeded(); - } - } - const SegmentInfo& s = segments.at(i); - SegmentSearcher searcher(s.index(), segmentDataReader(s), s.lastKey()); - searcher.search(fp.data(), fp.size(), collector); + + std::vector hashes(hashesIn); + std::sort(hashes.begin(), hashes.end()); + + std::unordered_map hits; + + const SegmentInfoList& segments = m_info.segments(); + for (auto segment : segments) { + if (deadline > 0) { + if (QDateTime::currentMSecsSinceEpoch() > deadline) { + throw TimeoutExceeded(); + } } -} + SegmentSearcher searcher(segment.index(), segmentDataReader(segment), segment.lastKey()); + searcher.search(hashes, hits); + } -std::vector IndexReader::search(const uint32_t* fingerprint, size_t length, int64_t timeoutInMSecs) -{ - TopHitsCollector collector(1000); - search(fingerprint, length, &collector, timeoutInMSecs); std::vector results; - for (const auto result : collector.topResults()) { - results.emplace_back(result.id(), result.score()); + results.reserve(hits.size()); + for (const auto &hit : hits) { + results.emplace_back(hit.first, hit.second); } + + sortSearchResults(results); + return results; } diff --git a/src/index/index_reader.h b/src/index/index_reader.h index f86a7016..92c350a2 100644 --- a/src/index/index_reader.h +++ b/src/index/index_reader.h @@ -29,8 +29,7 @@ class IndexReader return m_index; } - void search(const uint32_t *fingerprint, size_t length, Collector *collector, int64_t timeoutInMSecs = 0); - std::vector search(const uint32_t *fingerprint, size_t length, int64_t timeoutInMSecs = 0); + std::vector search(const std::vector &hashes, int64_t timeoutInMSecs = 0); SegmentDataReader* segmentDataReader(const SegmentInfo& segment); diff --git a/src/index/index_reader_test.cpp b/src/index/index_reader_test.cpp index 730f716b..d1d89ee9 100644 --- a/src/index/index_reader_test.cpp +++ b/src/index/index_reader_test.cpp @@ -18,29 +18,25 @@ TEST(IndexReaderTest, Search) DirectorySharedPtr dir(new RAMDirectory()); IndexSharedPtr index(new Index(dir, true)); - uint32_t fp1[] = { 7, 9, 12 }; - auto fp1len = 3; - - uint32_t fp2[] = { 7, 9, 11 }; - auto fp2len = 3; + std::vector fp1 = { 7, 9, 12 }; + std::vector fp2 = { 7, 9, 11 }; { auto writer = index->openWriter(); - writer->addDocument(1, fp1, fp1len); + writer->addDocument(1, fp1.data(), fp1.size()); writer->commit(); - writer->addDocument(2, fp2, fp2len); + writer->addDocument(2, fp2.data(), fp2.size()); writer->commit(); } { IndexReader reader(index); - TopHitsCollector collector(100); - reader.search(fp1, fp1len, &collector); - ASSERT_EQ(2, collector.topResults().size()); - ASSERT_EQ(1, collector.topResults().at(0).id()); - ASSERT_EQ(3, collector.topResults().at(0).score()); - ASSERT_EQ(2, collector.topResults().at(1).id()); - ASSERT_EQ(2, collector.topResults().at(1).score()); + auto results = reader.search(fp1); + ASSERT_EQ(2, results.size()); + ASSERT_EQ(1, results.at(0).docId()); + ASSERT_EQ(3, results.at(0).score()); + ASSERT_EQ(2, results.at(1).docId()); + ASSERT_EQ(2, results.at(1).score()); } } diff --git a/src/index/search_result.h b/src/index/search_result.h index 18258d57..77cd1142 100644 --- a/src/index/search_result.h +++ b/src/index/search_result.h @@ -29,13 +29,7 @@ class SearchResult { inline void sortSearchResults(std::vector &results) { std::sort(results.begin(), results.end(), [](const SearchResult &a, const SearchResult &b) { - if (a.score() > b.score()) { - return true; - } else if (a.score() < b.score()) { - return false; - } else { - return a.docId() < b.docId(); - } + return a.score() > b.score() || (a.score() == b.score() && a.docId() < b.docId()); }); } diff --git a/src/index/segment_searcher.cpp b/src/index/segment_searcher.cpp index fd7d2c7f..ebe87f94 100644 --- a/src/index/segment_searcher.cpp +++ b/src/index/segment_searcher.cpp @@ -2,7 +2,6 @@ // Distributed under the MIT license, see the LICENSE file for details. #include -#include "collector.h" #include "segment_data_reader.h" #include "segment_searcher.h" @@ -17,17 +16,17 @@ SegmentSearcher::~SegmentSearcher() { } -void SegmentSearcher::search(uint32_t *fingerprint, size_t length, Collector *collector) +void SegmentSearcher::search(const std::vector &hashes, std::unordered_map &hits) { size_t i = 0, block = 0, lastBlock = SIZE_MAX; - while (i < length) { + while (i < hashes.size()) { if (block > lastBlock || lastBlock == SIZE_MAX) { size_t localFirstBlock, localLastBlock; - if (fingerprint[i] > m_lastKey) { + if (hashes[i] > m_lastKey) { // All following items are larger than the last segment's key. return; } - if (m_index->search(fingerprint[i], &localFirstBlock, &localLastBlock)) { + if (m_index->search(hashes[i], &localFirstBlock, &localLastBlock)) { if (block > localLastBlock) { // We already searched this block and the fingerprint item was not found. i++; @@ -48,19 +47,20 @@ void SegmentSearcher::search(uint32_t *fingerprint, size_t length, Collector *co std::unique_ptr blockData(m_dataReader->readBlock(block, firstKey)); while (blockData->next()) { uint32_t key = blockData->key(); - if (key >= fingerprint[i]) { - while (key > fingerprint[i]) { + if (key >= hashes[i]) { + while (key > hashes[i]) { i++; - if (i >= length) { + if (i >= hashes.size()) { return; } - else if (lastKey < fingerprint[i]) { + else if (lastKey < hashes[i]) { // There are no longer any items in this block that we could match. goto nextBlock; } } - if (key == fingerprint[i]) { - collector->collect(blockData->value()); + if (key == hashes[i]) { + auto docId = blockData->value(); + hits[docId]++; } } } @@ -68,4 +68,3 @@ void SegmentSearcher::search(uint32_t *fingerprint, size_t length, Collector *co block++; } } - diff --git a/src/index/segment_searcher.h b/src/index/segment_searcher.h index 6ae17548..d7cd3f54 100644 --- a/src/index/segment_searcher.h +++ b/src/index/segment_searcher.h @@ -10,7 +10,6 @@ namespace Acoustid { class SegmentDataReader; -class Collector; class SegmentSearcher { @@ -18,12 +17,7 @@ class SegmentSearcher SegmentSearcher(SegmentIndexSharedPtr index, SegmentDataReader *dataReader, uint32_t lastKey = UINT32_MAX); virtual ~SegmentSearcher(); - /** - * Search for the fingerprint in one segment. - * - * The fingerprint must be sorted. - */ - void search(uint32_t *fingerprint, size_t length, Collector *collector); + void search(const std::vector &hashes, std::unordered_map &hits); private: SegmentIndexSharedPtr m_index; diff --git a/src/server/http.cpp b/src/server/http.cpp index b9cb0c61..560711b9 100644 --- a/src/server/http.cpp +++ b/src/server/http.cpp @@ -204,17 +204,13 @@ static HttpResponse handleSearchRequest(const HttpRequest &request, const QShare limit = 100; } - auto collector = QSharedPointer::create(limit); - { - auto reader = index->openReader(); - reader->search(query.data(), query.size(), collector.data()); - } - auto results = collector->topResults(); + auto results = index->search(query); + filterSearchResults(results, limit); QJsonArray resultsJson; for (auto &result : results) { resultsJson.append(QJsonObject{ - {"id", qint64(result.id())}, + {"id", qint64(result.docId())}, {"score", result.score()}, }); } diff --git a/src/server/protocol.cpp b/src/server/protocol.cpp index 9d8e4e65..ee098828 100644 --- a/src/server/protocol.cpp +++ b/src/server/protocol.cpp @@ -4,9 +4,9 @@ namespace Acoustid { namespace Server { -QVector parseFingerprint(const QString &input) { - QStringList inputParts = input.split(','); - QVector output; +std::vector parseFingerprint(const QString &input) { + QStringList inputParts = input.split(','); + std::vector output; output.reserve(inputParts.size()); for (int i = 0; i < inputParts.size(); i++) { bool ok; @@ -14,9 +14,9 @@ QVector parseFingerprint(const QString &input) { if (!ok) { throw HandlerException("invalid fingerprint"); } - output.append(value); + output.push_back(value); } - if (output.isEmpty()) { + if (output.empty()) { throw HandlerException("empty fingerprint"); } return output; @@ -92,9 +92,9 @@ ScopedHandlerFunc buildHandler(const QString &command, const QStringList &args) auto results = session->search(hashes); QStringList output; output.reserve(results.size()); - for (int i = 0; i < results.size(); i++) { - output.append(QString("%1:%2").arg(results[i].id()).arg(results[i].score())); - } + for (auto result : results) { + output.append(QString("%1:%2").arg(result.docId()).arg(result.score())); + } QString outputString = output.join(" "); session->clearTraceId(); return outputString; diff --git a/src/server/session.cpp b/src/server/session.cpp index 348ecf47..a8c266b9 100644 --- a/src/server/session.cpp +++ b/src/server/session.cpp @@ -5,7 +5,6 @@ #include "errors.h" #include "index/index.h" #include "index/index_writer.h" -#include "index/top_hits_collector.h" using namespace Acoustid; using namespace Acoustid::Server; @@ -99,7 +98,7 @@ void Session::setAttribute(const QString &name, const QString &value) { m_indexWriter->setAttribute(name, value); } -void Session::insert(uint32_t id, const QVector &hashes) { +void Session::insert(uint32_t id, const std::vector &hashes) { QMutexLocker locker(&m_mutex); if (m_indexWriter.isNull()) { throw NotInTransactionException(); @@ -107,16 +106,16 @@ void Session::insert(uint32_t id, const QVector &hashes) { m_indexWriter->addDocument(id, hashes.data(), hashes.size()); } -QList Session::search(const QVector &hashes) { +std::vector Session::search(const std::vector &hashes) { QMutexLocker locker(&m_mutex); - TopHitsCollector collector(m_maxResults, m_topScorePercent); + std::vector results; try { - auto reader = m_index->openReader(); - reader->search(hashes.data(), hashes.size(), &collector, m_timeout); + results = m_index->search(hashes, m_timeout); } catch (TimeoutExceeded &ex) { throw HandlerException("timeout exceeded"); } - return collector.topResults(); + filterSearchResults(results, m_maxResults, m_topScorePercent); + return results; } QString Session::getTraceId() { diff --git a/src/server/session.h b/src/server/session.h index ee4f3808..f356c6e0 100644 --- a/src/server/session.h +++ b/src/server/session.h @@ -6,7 +6,7 @@ #include #include -#include "index/top_hits_collector.h" +#include "index/search_result.h" namespace Acoustid { @@ -28,8 +28,8 @@ class Session void rollback(); void optimize(); void cleanup(); - void insert(uint32_t id, const QVector &hashes); - QList search(const QVector &hashes); + void insert(uint32_t id, const std::vector &hashes); + std::vector search(const std::vector &hashes); QString getAttribute(const QString &name); void setAttribute(const QString &name, const QString &value); diff --git a/src/server/session_test.cpp b/src/server/session_test.cpp index 2a889a72..e2f65f3c 100644 --- a/src/server/session_test.cpp +++ b/src/server/session_test.cpp @@ -52,18 +52,18 @@ TEST(SessionTest, InsertAndSearch) { auto results = session->search({ 1, 2, 3 }); ASSERT_EQ(2, results.size()); - ASSERT_EQ(1, results[0].id()); + ASSERT_EQ(1, results[0].docId()); ASSERT_EQ(3, results[0].score()); - ASSERT_EQ(2, results[1].id()); + ASSERT_EQ(2, results[1].docId()); ASSERT_EQ(1, results[1].score()); } { auto results = session->search({ 1, 200, 300 }); ASSERT_EQ(2, results.size()); - ASSERT_EQ(2, results[0].id()); + ASSERT_EQ(2, results[0].docId()); ASSERT_EQ(3, results[0].score()); - ASSERT_EQ(1, results[1].id()); + ASSERT_EQ(1, results[1].docId()); ASSERT_EQ(1, results[1].score()); } @@ -72,7 +72,7 @@ TEST(SessionTest, InsertAndSearch) { auto results = session->search({ 1, 2, 3 }); ASSERT_EQ(1, results.size()); - ASSERT_EQ(1, results[0].id()); + ASSERT_EQ(1, results[0].docId()); ASSERT_EQ(3, results[0].score()); } }