Skip to content

Commit

Permalink
Backport search API changes from v2
Browse files Browse the repository at this point in the history
  • Loading branch information
lalinsky committed Feb 25, 2024
1 parent b85d96d commit f4f5a0a
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 94 deletions.
4 changes: 2 additions & 2 deletions src/index/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void Index::applyUpdates(const OpBatch &batch) {

}

std::vector<SearchResult> Index::search(const std::vector<uint32_t> &terms, int64_t timeoutInMSecs) {
std::vector<SearchResult> Index::search(const std::vector<uint32_t> &hashes, int64_t timeoutInMSecs) {
auto reader = openReader();
return reader->search(terms.data(), terms.size(), timeoutInMSecs);
return reader->search(hashes, timeoutInMSecs);
}
42 changes: 22 additions & 20 deletions src/index/index_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,33 @@ SegmentDataReader* IndexReader::segmentDataReader(const SegmentInfo& segment)
return new SegmentDataReader(m_dir->openFile(segment.dataFileName()), BLOCK_SIZE);
}

void IndexReader::search(const uint32_t* fingerprint, size_t length, Collector* collector, int64_t timeoutInMSecs)
std::vector<SearchResult> IndexReader::search(const std::vector<uint32_t> &hashesIn, int64_t timeoutInMSecs)
{
auto deadline = timeoutInMSecs > 0 ? (QDateTime::currentMSecsSinceEpoch() + timeoutInMSecs) : 0;
std::vector<uint32_t> fp(fingerprint, fingerprint + length);
std::sort(fp.begin(), fp.end());
const SegmentInfoList& segments = m_info.segments();
for (int i = 0; i < segments.size(); i++) {
if (deadline > 0) {
if (QDateTime::currentMSecsSinceEpoch() > deadline) {
throw TimeoutExceeded();
}
}
const SegmentInfo& s = segments.at(i);
SegmentSearcher searcher(s.index(), segmentDataReader(s), s.lastKey());
searcher.search(fp.data(), fp.size(), collector);

std::vector<uint32_t> hashes(hashesIn);
std::sort(hashes.begin(), hashes.end());

std::unordered_map<uint32_t, int> hits;

const SegmentInfoList& segments = m_info.segments();
for (auto segment : segments) {
if (deadline > 0) {
if (QDateTime::currentMSecsSinceEpoch() > deadline) {
throw TimeoutExceeded();
}
}
}
SegmentSearcher searcher(segment.index(), segmentDataReader(segment), segment.lastKey());
searcher.search(hashes, hits);
}

std::vector<SearchResult> IndexReader::search(const uint32_t* fingerprint, size_t length, int64_t timeoutInMSecs)
{
TopHitsCollector collector(1000);
search(fingerprint, length, &collector, timeoutInMSecs);
std::vector<SearchResult> results;
for (const auto result : collector.topResults()) {
results.emplace_back(result.id(), result.score());
results.reserve(hits.size());
for (const auto &hit : hits) {
results.emplace_back(hit.first, hit.second);
}

sortSearchResults(results);

return results;
}
3 changes: 1 addition & 2 deletions src/index/index_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class IndexReader
return m_index;
}

void search(const uint32_t *fingerprint, size_t length, Collector *collector, int64_t timeoutInMSecs = 0);
std::vector<SearchResult> search(const uint32_t *fingerprint, size_t length, int64_t timeoutInMSecs = 0);
std::vector<SearchResult> search(const std::vector<uint32_t> &hashes, int64_t timeoutInMSecs = 0);

SegmentDataReader* segmentDataReader(const SegmentInfo& segment);

Expand Down
24 changes: 10 additions & 14 deletions src/index/index_reader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,25 @@ TEST(IndexReaderTest, Search)
DirectorySharedPtr dir(new RAMDirectory());
IndexSharedPtr index(new Index(dir, true));

uint32_t fp1[] = { 7, 9, 12 };
auto fp1len = 3;

uint32_t fp2[] = { 7, 9, 11 };
auto fp2len = 3;
std::vector<uint32_t> fp1 = { 7, 9, 12 };
std::vector<uint32_t> fp2 = { 7, 9, 11 };

{
auto writer = index->openWriter();
writer->addDocument(1, fp1, fp1len);
writer->addDocument(1, fp1.data(), fp1.size());
writer->commit();
writer->addDocument(2, fp2, fp2len);
writer->addDocument(2, fp2.data(), fp2.size());
writer->commit();
}

{
IndexReader reader(index);
TopHitsCollector collector(100);
reader.search(fp1, fp1len, &collector);
ASSERT_EQ(2, collector.topResults().size());
ASSERT_EQ(1, collector.topResults().at(0).id());
ASSERT_EQ(3, collector.topResults().at(0).score());
ASSERT_EQ(2, collector.topResults().at(1).id());
ASSERT_EQ(2, collector.topResults().at(1).score());
auto results = reader.search(fp1);
ASSERT_EQ(2, results.size());
ASSERT_EQ(1, results.at(0).docId());
ASSERT_EQ(3, results.at(0).score());
ASSERT_EQ(2, results.at(1).docId());
ASSERT_EQ(2, results.at(1).score());
}
}

8 changes: 1 addition & 7 deletions src/index/search_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@ class SearchResult {
inline void sortSearchResults(std::vector<SearchResult> &results)
{
std::sort(results.begin(), results.end(), [](const SearchResult &a, const SearchResult &b) {
if (a.score() > b.score()) {
return true;
} else if (a.score() < b.score()) {
return false;
} else {
return a.docId() < b.docId();
}
return a.score() > b.score() || (a.score() == b.score() && a.docId() < b.docId());
});
}

Expand Down
23 changes: 11 additions & 12 deletions src/index/segment_searcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Distributed under the MIT license, see the LICENSE file for details.

#include <algorithm>
#include "collector.h"
#include "segment_data_reader.h"
#include "segment_searcher.h"

Expand All @@ -17,17 +16,17 @@ SegmentSearcher::~SegmentSearcher()
{
}

void SegmentSearcher::search(uint32_t *fingerprint, size_t length, Collector *collector)
void SegmentSearcher::search(const std::vector<uint32_t> &hashes, std::unordered_map<uint32_t, int> &hits)
{
size_t i = 0, block = 0, lastBlock = SIZE_MAX;
while (i < length) {
while (i < hashes.size()) {
if (block > lastBlock || lastBlock == SIZE_MAX) {
size_t localFirstBlock, localLastBlock;
if (fingerprint[i] > m_lastKey) {
if (hashes[i] > m_lastKey) {
// All following items are larger than the last segment's key.
return;
}
if (m_index->search(fingerprint[i], &localFirstBlock, &localLastBlock)) {
if (m_index->search(hashes[i], &localFirstBlock, &localLastBlock)) {
if (block > localLastBlock) {
// We already searched this block and the fingerprint item was not found.
i++;
Expand All @@ -48,24 +47,24 @@ void SegmentSearcher::search(uint32_t *fingerprint, size_t length, Collector *co
std::unique_ptr<BlockDataIterator> blockData(m_dataReader->readBlock(block, firstKey));
while (blockData->next()) {
uint32_t key = blockData->key();
if (key >= fingerprint[i]) {
while (key > fingerprint[i]) {
if (key >= hashes[i]) {
while (key > hashes[i]) {
i++;
if (i >= length) {
if (i >= hashes.size()) {
return;
}
else if (lastKey < fingerprint[i]) {
else if (lastKey < hashes[i]) {
// There are no longer any items in this block that we could match.
goto nextBlock;
}
}
if (key == fingerprint[i]) {
collector->collect(blockData->value());
if (key == hashes[i]) {
auto docId = blockData->value();
hits[docId]++;
}
}
}
nextBlock:
block++;
}
}

8 changes: 1 addition & 7 deletions src/index/segment_searcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,14 @@
namespace Acoustid {

class SegmentDataReader;
class Collector;

class SegmentSearcher
{
public:
SegmentSearcher(SegmentIndexSharedPtr index, SegmentDataReader *dataReader, uint32_t lastKey = UINT32_MAX);
virtual ~SegmentSearcher();

/**
* Search for the fingerprint in one segment.
*
* The fingerprint must be sorted.
*/
void search(uint32_t *fingerprint, size_t length, Collector *collector);
void search(const std::vector<uint32_t> &hashes, std::unordered_map<uint32_t, int> &hits);

private:
SegmentIndexSharedPtr m_index;
Expand Down
10 changes: 3 additions & 7 deletions src/server/http.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,13 @@ static HttpResponse handleSearchRequest(const HttpRequest &request, const QShare
limit = 100;
}

auto collector = QSharedPointer<TopHitsCollector>::create(limit);
{
auto reader = index->openReader();
reader->search(query.data(), query.size(), collector.data());
}
auto results = collector->topResults();
auto results = index->search(query);
filterSearchResults(results, limit);

QJsonArray resultsJson;
for (auto &result : results) {
resultsJson.append(QJsonObject{
{"id", qint64(result.id())},
{"id", qint64(result.docId())},
{"score", result.score()},
});
}
Expand Down
16 changes: 8 additions & 8 deletions src/server/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@

namespace Acoustid { namespace Server {

QVector<uint32_t> parseFingerprint(const QString &input) {
QStringList inputParts = input.split(',');
QVector<uint32_t> output;
std::vector<uint32_t> parseFingerprint(const QString &input) {
QStringList inputParts = input.split(',');
std::vector<uint32_t> output;
output.reserve(inputParts.size());
for (int i = 0; i < inputParts.size(); i++) {
bool ok;
auto value = inputParts.at(i).toInt(&ok);
if (!ok) {
throw HandlerException("invalid fingerprint");
}
output.append(value);
output.push_back(value);
}
if (output.isEmpty()) {
if (output.empty()) {
throw HandlerException("empty fingerprint");
}
return output;
Expand Down Expand Up @@ -92,9 +92,9 @@ ScopedHandlerFunc buildHandler(const QString &command, const QStringList &args)
auto results = session->search(hashes);
QStringList output;
output.reserve(results.size());
for (int i = 0; i < results.size(); i++) {
output.append(QString("%1:%2").arg(results[i].id()).arg(results[i].score()));
}
for (auto result : results) {
output.append(QString("%1:%2").arg(result.docId()).arg(result.score()));
}
QString outputString = output.join(" ");
session->clearTraceId();
return outputString;
Expand Down
13 changes: 6 additions & 7 deletions src/server/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "errors.h"
#include "index/index.h"
#include "index/index_writer.h"
#include "index/top_hits_collector.h"

using namespace Acoustid;
using namespace Acoustid::Server;
Expand Down Expand Up @@ -99,24 +98,24 @@ void Session::setAttribute(const QString &name, const QString &value) {
m_indexWriter->setAttribute(name, value);
}

void Session::insert(uint32_t id, const QVector<uint32_t> &hashes) {
void Session::insert(uint32_t id, const std::vector<uint32_t> &hashes) {
QMutexLocker locker(&m_mutex);
if (m_indexWriter.isNull()) {
throw NotInTransactionException();
}
m_indexWriter->addDocument(id, hashes.data(), hashes.size());
}

QList<Result> Session::search(const QVector<uint32_t> &hashes) {
std::vector<SearchResult> Session::search(const std::vector<uint32_t> &hashes) {
QMutexLocker locker(&m_mutex);
TopHitsCollector collector(m_maxResults, m_topScorePercent);
std::vector<SearchResult> results;
try {
auto reader = m_index->openReader();
reader->search(hashes.data(), hashes.size(), &collector, m_timeout);
results = m_index->search(hashes, m_timeout);
} catch (TimeoutExceeded &ex) {
throw HandlerException("timeout exceeded");
}
return collector.topResults();
filterSearchResults(results, m_maxResults, m_topScorePercent);
return results;
}

QString Session::getTraceId() {
Expand Down
6 changes: 3 additions & 3 deletions src/server/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <QMutex>
#include <QSharedPointer>
#include "index/top_hits_collector.h"
#include "index/search_result.h"

namespace Acoustid {

Expand All @@ -28,8 +28,8 @@ class Session
void rollback();
void optimize();
void cleanup();
void insert(uint32_t id, const QVector<uint32_t> &hashes);
QList<Result> search(const QVector<uint32_t> &hashes);
void insert(uint32_t id, const std::vector<uint32_t> &hashes);
std::vector<SearchResult> search(const std::vector<uint32_t> &hashes);

QString getAttribute(const QString &name);
void setAttribute(const QString &name, const QString &value);
Expand Down
10 changes: 5 additions & 5 deletions src/server/session_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,18 @@ TEST(SessionTest, InsertAndSearch)
{
auto results = session->search({ 1, 2, 3 });
ASSERT_EQ(2, results.size());
ASSERT_EQ(1, results[0].id());
ASSERT_EQ(1, results[0].docId());
ASSERT_EQ(3, results[0].score());
ASSERT_EQ(2, results[1].id());
ASSERT_EQ(2, results[1].docId());
ASSERT_EQ(1, results[1].score());
}

{
auto results = session->search({ 1, 200, 300 });
ASSERT_EQ(2, results.size());
ASSERT_EQ(2, results[0].id());
ASSERT_EQ(2, results[0].docId());
ASSERT_EQ(3, results[0].score());
ASSERT_EQ(1, results[1].id());
ASSERT_EQ(1, results[1].docId());
ASSERT_EQ(1, results[1].score());
}

Expand All @@ -72,7 +72,7 @@ TEST(SessionTest, InsertAndSearch)
{
auto results = session->search({ 1, 2, 3 });
ASSERT_EQ(1, results.size());
ASSERT_EQ(1, results[0].id());
ASSERT_EQ(1, results[0].docId());
ASSERT_EQ(3, results[0].score());
}
}

0 comments on commit f4f5a0a

Please sign in to comment.