Skip to content

Commit

Permalink
Merge pull request #114 from PathwayCommons/iss61_restrict-to-documents
Browse files Browse the repository at this point in the history
Score and return only documents provided
  • Loading branch information
jvwong authored Dec 14, 2021
2 parents ba99400 + 332c1ea commit f5bec52
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 8 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ The return value is a JSON representation of the `top_k` most similar documents

If `"text"` is not provided, we assume `"uid"`s are valid PMIDs and fetch the title and abstract text before embedding, indexing and searching.

- Notes on optional parameters
- `top_k`: A positive integer (default is `10`) that limits the search results to this many of the most similar neighbours (articles)
- `docs_only`: A boolean (default is `False`) that instructs the service to return scores for the provided `documents`. If true, `top_k` is disregarded.

### Running via Docker

#### Setup
Expand Down
25 changes: 18 additions & 7 deletions semantic_search/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from operator import itemgetter
from typing import List, Optional, Tuple, Union, cast

import logging
import faiss
import torch
from fastapi import FastAPI, Request
Expand All @@ -23,6 +22,7 @@
from pathlib import Path
from dotenv import load_dotenv
import os
from fastapi import HTTPException

dot_env_filepath = Path(__file__).absolute().parent.parent / ".env"
load_dotenv(dot_env_filepath)
Expand All @@ -34,7 +34,6 @@
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | {level} | <level>{message}</level>",
level=os.getenv("LOG_LEVEL", "DEBUG"),
)
from fastapi import HTTPException

app = FastAPI(
title="Scientific Semantic Search",
Expand Down Expand Up @@ -133,7 +132,7 @@ def index(request: Request):
@app.post("/search", tags=["Search"], response_model=List[TopMatch])
async def search(search: Search):
"""Returns the `top_k` most similar documents to `query` from the provided list of `documents`
and the index.
and the index. When docs_only is True, returns all `documents` provided, and disregards `top_k`.
"""
ids = [int(doc.uid) for doc in search.documents]
texts = [document.text for document in search.documents]
Expand All @@ -152,7 +151,7 @@ async def search(search: Search):
texts[i] = normalize_documents([str(id_)])
except HTTPException:
# Some bogus PMID - set text as empty string
logging.warn(f"Error encountered in normalize_documents: {id_}")
logger.warn(f"Error encountered in normalize_documents: {id_}")
texts[i] = ""

# We then embed the corresponding text and update the index
Expand All @@ -162,15 +161,27 @@ async def search(search: Search):
embeddings = encode(texts).cpu().numpy() # type: ignore
add_to_faiss_index(ids, embeddings, model.index)

# Can't search for more items than exist in the index
top_k = min(model.index.ntotal, search.top_k)
# Embed the query and perform the search
# Embed the query
query_embedding = encode(search.query.text).cpu().numpy() # type: ignore
num_indexed = model.index.ntotal
# Can't search for more items than exist in the index
top_k = min(num_indexed, search.top_k)

if search.docs_only:
top_k = num_indexed

# Perform the search
top_k_scores, top_k_indicies = model.index.search(query_embedding, top_k)

top_k_indicies = top_k_indicies.reshape(-1).tolist()
top_k_scores = top_k_scores.reshape(-1).tolist()

# Pick out results for the incoming ids in search.documents
if search.docs_only:
documents_positions = [top_k_indicies.index(id) for id in ids]
top_k_indicies = ids
top_k_scores = [top_k_scores[position] for position in documents_positions]

if int(search.query.uid) in top_k_indicies:
index = top_k_indicies.index(int(search.query.uid))
del top_k_indicies[index], top_k_scores[index]
Expand Down
1 change: 1 addition & 0 deletions semantic_search/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Search(BaseModel):
query: Document
documents: List[Document] = []
top_k: int = Field(10, gt=0, description="top_k must be greater than 0")
docs_only: bool = False

class Config:
schema_extra = {
Expand Down
32 changes: 32 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,35 @@ def dummy_request_with_test() -> Request:
# We don't actually test scores, so use a dummy value of -1
response = [{"uid": "30049242", "score": -1}, {"uid": "22936248", "score": -1}]
return json.dumps(request), response


@pytest.fixture(scope="module")
def followup_request_with_test() -> Request:
request = {
"query": {
"uid": "9813169",
"text": "TGF-beta signaling from the cell surface to the nucleus is mediated by the SMAD...",
},
"documents": [
{
"uid": "10320478",
"text": "Much is known about the three subfamilies of the TGFbeta superfamily in vertebrates...",
},
{
"uid": "10357889",
"text": "The transforming growth factor-beta (TGF-beta) superfamily encompasses a large...",
},
{
"uid": "15473904",
"text": "Members of TGFbeta superfamily are found to play important roles in many cellular...",
},
],
"docs_only": True,
}
# We don't actually test scores, so use a dummy value of -1
response = [
{"uid": "10320478", "score": -1},
{"uid": "10357889", "score": -1},
{"uid": "15473904", "score": -1},
]
return json.dumps(request), response
18 changes: 17 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List, Tuple


import numpy as np
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -48,3 +47,20 @@ def test_search_with_text(self, dummy_request_with_test: Request) -> None:
assert len(expected_uids) == len(actual_uids)
assert set(actual_uids) == set(expected_uids)
assert all(0 <= score <= 1 for score in actual_scores)

def test_restrict_search_to_documents(
self, dummy_request_with_test: Request, followup_request_with_test: Request
):
# Dope the index
dummy_request, _ = dummy_request_with_test
dummy_response = client.post("/search", dummy_request)
assert dummy_response.status_code == 200

# Do the search of interest
request, expected_response = followup_request_with_test
actual_response = client.post("/search", request)
assert actual_response.status_code == 200

actual_uids = [item["uid"] for item in actual_response.json()]
expected_uids = [item["uid"] for item in expected_response]
assert set(actual_uids) == set(expected_uids)

0 comments on commit f5bec52

Please sign in to comment.