From 049a86a585fe3e5e533e7426104f63996c9e1a1a Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sun, 31 Mar 2024 18:55:34 -0700 Subject: [PATCH] update defered moving to cpu --- .../infinity_emb/transformer/acceleration.py | 28 +++++++++++++++---- .../transformer/crossencoder/torch.py | 12 ++++---- .../embedder/sentence_transformer.py | 12 ++++---- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 972e67bb..4ee97495 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -1,12 +1,20 @@ import os +from typing import TYPE_CHECKING from infinity_emb._optional_imports import CHECK_OPTIMUM if CHECK_OPTIMUM.is_available: - from optimum.bettertransformer import BetterTransformer # type: ignore + from optimum.bettertransformer import ( # type: ignore[import-untyped] + BetterTransformer, + ) +if TYPE_CHECKING: + from logging import Logger -def to_bettertransformer(model, logger): + from transformers import PreTrainedModel # type: ignore[import-untyped] + + +def to_bettertransformer(model: "PreTrainedModel", logger: "Logger"): if os.environ.get("INFINITY_DISABLE_OPTIMUM", False): logger.info( "No optimizations via Huggingface optimum," @@ -18,8 +26,16 @@ def to_bettertransformer(model, logger): try: model = BetterTransformer.transform(model) except Exception as ex: - logger.exception( - f"BetterTransformer is not available for model. {ex}." - " Continue without bettertransformer modeling code." - ) + # if level is debug then show the exception + if logger.level <= 10: + logger.exception( + f"BetterTransformer is not available for model: {model.__class__} {ex}." + " Continue without bettertransformer modeling code." + ) + else: + logger.warning( + f"BetterTransformer is not available for model: {model.__class__}" + " Continue without bettertransformer modeling code." + ) + return model diff --git a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py index dbb72b14..e1b80605 100644 --- a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py +++ b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +from typing import TYPE_CHECKING from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH from infinity_emb.args import EngineArgs @@ -11,13 +12,14 @@ if CHECK_TORCH.is_available and CHECK_SENTENCE_TRANSFORMERS.is_available: import torch from sentence_transformers import CrossEncoder # type: ignore[import-untyped] - from torch import Tensor else: class CrossEncoder: # type: ignore[no-redef] pass - Tensor = None # type: ignore + +if TYPE_CHECKING: + from torch import Tensor from infinity_emb.transformer.acceleration import to_bettertransformer @@ -74,7 +76,7 @@ def encode_pre(self, input_tuples: list[tuple[str, str]]): ) return tokenized - def encode_core(self, features: dict[str, Tensor]): + def encode_core(self, features: dict[str, "Tensor"]): """ Computes sentence embeddings """ @@ -82,10 +84,10 @@ def encode_core(self, features: dict[str, Tensor]): features = {k: v.to(self.model.device) for k, v in features.items()} out_features = self.model(**features, return_dict=True)["logits"] - return out_features + return out_features.detach().cpu() def encode_post(self, out_features) -> list[float]: - return out_features.detach().cpu().flatten() + return out_features.flatten() def tokenize_lengths(self, sentences: list[str]) -> list[int]: tks = self._infinity_tokenizer.batch_encode_plus( diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py index 83155204..721688aa 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py @@ -82,27 +82,27 @@ def __init__(self, *, engine_args=EngineArgs): logger.info("using torch.compile()") fm.auto_model = torch.compile(fm.auto_model, dynamic=True) - def encode_pre(self, sentences) -> Mapping[str, Tensor]: + def encode_pre(self, sentences) -> Mapping[str, "Tensor"]: features = self.tokenize(sentences) return features - def encode_core(self, features: Mapping[str, Tensor]) -> Tensor: + def encode_core(self, features: Mapping[str, "Tensor"]) -> "Tensor": """ Computes sentence embeddings """ with torch.no_grad(): features = util.batch_to_device(features, self.device) - out_features = self.forward(features)["sentence_embedding"] + out_features: "Tensor" = self.forward(features)["sentence_embedding"] - return out_features + return out_features.detach().cpu() def encode_post( - self, out_features: Tensor, normalize_embeddings: bool = True + self, out_features: "Tensor", normalize_embeddings: bool = True ) -> EmbeddingReturnType: with torch.inference_mode(): - embeddings: Tensor = out_features.detach().cpu().to(torch.float32) + embeddings: "Tensor" = out_features.to(torch.float32) if normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)