Skip to content

Commit

Permalink
Merge pull request #187 from michaelfeil/fix-typing
Browse files Browse the repository at this point in the history
update defered moving to cpu & type hints improvement
  • Loading branch information
michaelfeil authored Apr 1, 2024
2 parents a8e0e29 + 049a86a commit 197ede8
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
28 changes: 22 additions & 6 deletions libs/infinity_emb/infinity_emb/transformer/acceleration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
from typing import TYPE_CHECKING

from infinity_emb._optional_imports import CHECK_OPTIMUM

if CHECK_OPTIMUM.is_available:
from optimum.bettertransformer import BetterTransformer # type: ignore
from optimum.bettertransformer import ( # type: ignore[import-untyped]
BetterTransformer,
)

if TYPE_CHECKING:
from logging import Logger

def to_bettertransformer(model, logger):
from transformers import PreTrainedModel # type: ignore[import-untyped]


def to_bettertransformer(model: "PreTrainedModel", logger: "Logger"):
if os.environ.get("INFINITY_DISABLE_OPTIMUM", False):
logger.info(
"No optimizations via Huggingface optimum,"
Expand All @@ -18,8 +26,16 @@ def to_bettertransformer(model, logger):
try:
model = BetterTransformer.transform(model)
except Exception as ex:
logger.exception(
f"BetterTransformer is not available for model. {ex}."
" Continue without bettertransformer modeling code."
)
# if level is debug then show the exception
if logger.level <= 10:
logger.exception(
f"BetterTransformer is not available for model: {model.__class__} {ex}."
" Continue without bettertransformer modeling code."
)
else:
logger.warning(
f"BetterTransformer is not available for model: {model.__class__}"
" Continue without bettertransformer modeling code."
)

return model
12 changes: 7 additions & 5 deletions libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING

from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH
from infinity_emb.args import EngineArgs
Expand All @@ -11,13 +12,14 @@
if CHECK_TORCH.is_available and CHECK_SENTENCE_TRANSFORMERS.is_available:
import torch
from sentence_transformers import CrossEncoder # type: ignore[import-untyped]
from torch import Tensor
else:

class CrossEncoder: # type: ignore[no-redef]
pass

Tensor = None # type: ignore

if TYPE_CHECKING:
from torch import Tensor


from infinity_emb.transformer.acceleration import to_bettertransformer
Expand Down Expand Up @@ -74,18 +76,18 @@ def encode_pre(self, input_tuples: list[tuple[str, str]]):
)
return tokenized

def encode_core(self, features: dict[str, Tensor]):
def encode_core(self, features: dict[str, "Tensor"]):
"""
Computes sentence embeddings
"""
with torch.no_grad():
features = {k: v.to(self.model.device) for k, v in features.items()}
out_features = self.model(**features, return_dict=True)["logits"]

return out_features
return out_features.detach().cpu()

def encode_post(self, out_features) -> list[float]:
return out_features.detach().cpu().flatten()
return out_features.flatten()

def tokenize_lengths(self, sentences: list[str]) -> list[int]:
tks = self._infinity_tokenizer.batch_encode_plus(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,27 @@ def __init__(self, *, engine_args=EngineArgs):
logger.info("using torch.compile()")
fm.auto_model = torch.compile(fm.auto_model, dynamic=True)

def encode_pre(self, sentences) -> Mapping[str, Tensor]:
def encode_pre(self, sentences) -> Mapping[str, "Tensor"]:
features = self.tokenize(sentences)

return features

def encode_core(self, features: Mapping[str, Tensor]) -> Tensor:
def encode_core(self, features: Mapping[str, "Tensor"]) -> "Tensor":
"""
Computes sentence embeddings
"""

with torch.no_grad():
features = util.batch_to_device(features, self.device)
out_features = self.forward(features)["sentence_embedding"]
out_features: "Tensor" = self.forward(features)["sentence_embedding"]

return out_features
return out_features.detach().cpu()

def encode_post(
self, out_features: Tensor, normalize_embeddings: bool = True
self, out_features: "Tensor", normalize_embeddings: bool = True
) -> EmbeddingReturnType:
with torch.inference_mode():
embeddings: Tensor = out_features.detach().cpu().to(torch.float32)
embeddings: "Tensor" = out_features.to(torch.float32)
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

Expand Down

0 comments on commit 197ede8

Please sign in to comment.