Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score and return only documents provided #114

Merged
merged 6 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions semantic_search/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from operator import itemgetter
from typing import List, Optional, Tuple, Union, cast

import logging
import faiss
import torch
from fastapi import FastAPI, Request
Expand All @@ -23,6 +22,7 @@
from pathlib import Path
from dotenv import load_dotenv
import os
from fastapi import HTTPException

dot_env_filepath = Path(__file__).absolute().parent.parent / ".env"
load_dotenv(dot_env_filepath)
Expand All @@ -34,7 +34,6 @@
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | {level} | <level>{message}</level>",
level=os.getenv("LOG_LEVEL", "DEBUG"),
)
from fastapi import HTTPException

app = FastAPI(
title="Scientific Semantic Search",
Expand Down Expand Up @@ -152,7 +151,7 @@ async def search(search: Search):
texts[i] = normalize_documents([str(id_)])
except HTTPException:
# Some bogus PMID - set text as empty string
logging.warn(f"Error encountered in normalize_documents: {id_}")
logger.warn(f"Error encountered in normalize_documents: {id_}")
texts[i] = ""

# We then embed the corresponding text and update the index
Expand All @@ -162,15 +161,27 @@ async def search(search: Search):
embeddings = encode(texts).cpu().numpy() # type: ignore
add_to_faiss_index(ids, embeddings, model.index)

# Can't search for more items than exist in the index
top_k = min(model.index.ntotal, search.top_k)
# Embed the query and perform the search
# Embed the query
query_embedding = encode(search.query.text).cpu().numpy() # type: ignore
num_indexed = model.index.ntotal
# Can't search for more items than exist in the index
top_k = min(num_indexed, search.top_k)

if search.use_docs:
top_k = num_indexed

# Perform the search
top_k_scores, top_k_indicies = model.index.search(query_embedding, top_k)

top_k_indicies = top_k_indicies.reshape(-1).tolist()
top_k_scores = top_k_scores.reshape(-1).tolist()

# Pick out results for the incoming ids in search.documents
if search.use_docs:
documents_positions = [top_k_indicies.index(id) for id in ids]
top_k_indicies = ids
top_k_scores = [top_k_scores[position] for position in documents_positions]

if int(search.query.uid) in top_k_indicies:
index = top_k_indicies.index(int(search.query.uid))
del top_k_indicies[index], top_k_scores[index]
Expand Down
1 change: 1 addition & 0 deletions semantic_search/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Search(BaseModel):
query: Document
documents: List[Document] = []
top_k: int = Field(10, gt=0, description="top_k must be greater than 0")
use_docs: bool = False
jvwong marked this conversation as resolved.
Show resolved Hide resolved

class Config:
schema_extra = {
Expand Down
32 changes: 32 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,35 @@ def dummy_request_with_test() -> Request:
# We don't actually test scores, so use a dummy value of -1
response = [{"uid": "30049242", "score": -1}, {"uid": "22936248", "score": -1}]
return json.dumps(request), response


@pytest.fixture(scope="module")
def followup_request_with_test() -> Request:
request = {
"query": {
"uid": "9813169",
"text": "TGF-beta signaling from the cell surface to the nucleus is mediated by the SMAD...",
},
"documents": [
{
"uid": "10320478",
"text": "Much is known about the three subfamilies of the TGFbeta superfamily in vertebrates...",
},
{
"uid": "10357889",
"text": "The transforming growth factor-beta (TGF-beta) superfamily encompasses a large...",
},
{
"uid": "15473904",
"text": "Members of TGFbeta superfamily are found to play important roles in many cellular...",
},
],
"use_docs": True,
}
# We don't actually test scores, so use a dummy value of -1
response = [
{"uid": "10320478", "score": -1},
{"uid": "10357889", "score": -1},
{"uid": "15473904", "score": -1},
]
return json.dumps(request), response
18 changes: 17 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List, Tuple


import numpy as np
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -48,3 +47,20 @@ def test_search_with_text(self, dummy_request_with_test: Request) -> None:
assert len(expected_uids) == len(actual_uids)
assert set(actual_uids) == set(expected_uids)
assert all(0 <= score <= 1 for score in actual_scores)

def test_restrict_search_to_documents(
self, dummy_request_with_test: Request, followup_request_with_test: Request
):
# Dope the index
dummy_request, _ = dummy_request_with_test
dummy_response = client.post("/search", dummy_request)
assert dummy_response.status_code == 200

# Do the search of interest
request, expected_response = followup_request_with_test
actual_response = client.post("/search", request)
assert actual_response.status_code == 200

actual_uids = [item["uid"] for item in actual_response.json()]
expected_uids = [item["uid"] for item in expected_response]
assert set(actual_uids) == set(expected_uids)