Skip to content

Commit

Permalink
update torch: move to torch only if needed.
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Oct 12, 2023
1 parent d4c4251 commit 53a58c9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
self._queue_prio = CustomPrioQueue()
self._result_store = ResultKVStore()
self._feature_queue: queue.Queue = queue.Queue(4)
self._postprocess_queue: queue.Queue = queue.Queue(5)
self._postprocess_queue: queue.Queue = queue.Queue(4)
self.max_batch_size = max_batch_size
self.model = model
self.max_queue_wait = max_queue_wait
Expand Down Expand Up @@ -266,7 +266,7 @@ async def _postprocess_batch(self):
except queue.Empty:
# 7 ms, assuming this is below
# 3-50ms for inference on avg.
await asyncio.sleep(7e-3)
await asyncio.sleep(5e-3)
continue
embed, batch = post_batch
embeddings = self.model.encode_post(embed).tolist()
Expand Down
20 changes: 10 additions & 10 deletions libs/infinity_emb/infinity_emb/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, *args, **kwargs):
self._infinity_tokenizer = copy.deepcopy(self._first_module().tokenizer)

def encode_pre(self, sentences) -> Dict[str, Tensor]:

features = self.tokenize(sentences)

return features
Expand All @@ -85,23 +86,22 @@ def encode_core(self, features: Dict[str, Tensor]) -> Tensor:
"""
Computes sentence embeddings
"""
device = self._target_device
features = util.batch_to_device(features, device)
# move forward

with torch.no_grad():
out_features = self.forward(features)

with torch.inference_mode():
device = self._target_device
features = util.batch_to_device(features, device)
out_features = self.forward(features)["sentence_embedding"]

return out_features["sentence_embedding"].detach().cpu()
return out_features

def encode_post(
self, out_features: Tensor, normalize_embeddings: bool = True
) -> NpEmbeddingType:
with torch.no_grad():
embeddings = out_features
with torch.inference_mode():
embeddings = out_features.detach().cpu()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
embeddings_out: np.ndarray = embeddings.cpu().numpy()
embeddings_out: np.ndarray = embeddings.numpy()

return embeddings_out

Expand Down
10 changes: 5 additions & 5 deletions libs/infinity_emb/tests/script_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def remote(json_data: bytes):
remote_resp = [d["embedding"] for d in remote(json_d).json()["data"]]
np.testing.assert_almost_equal(local_resp, remote_resp, 6)

print("Measuring latency via SentenceTransformers")
latency_st = timeit.timeit("local(sample)", number=10, globals=locals())
print("SentenceTransformers latency: ", latency_st)
model = None
# print("Measuring latency via SentenceTransformers")
# latency_st = timeit.timeit("local(sample)", number=10, globals=locals())
# print("SentenceTransformers latency: ", latency_st)
# model = None

print("Measuring latency via requests")
latency_request = timeit.timeit("remote(json_d)", number=10, globals=locals())
print(f"Request latency: {latency_request}")

assert latency_st * 1.1 > latency_request
# assert latency_st * 1.1 > latency_request


if __name__ == "__main__":
Expand Down

0 comments on commit 53a58c9

Please sign in to comment.