Skip to content

Commit

Permalink
Merge pull request #49 from michaelfeil/fix-fp16-crossencoder
Browse files Browse the repository at this point in the history
fix-fp16 not available for crossencoder
  • Loading branch information
michaelfeil authored Dec 20, 2023
2 parents 7e6f6c3 + 577c3d1 commit ac251fe
Show file tree
Hide file tree
Showing 6 changed files with 522 additions and 439 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ class Module:
class CrossEncoderPatched(CrossEncoder, BaseCrossEncoder):
"""CrossEncoder with .encode_core() and no microbatching"""

def __init__(self, *args, **kwargs):
def __init__(self, model_name_or_path, **kwargs):
if not TORCH_AVAILABLE:
raise ImportError(
"torch is not installed."
" `pip install infinity-emb[torch]` "
"or pip install infinity-emb[torch,optimum]`"
)
super().__init__(*args, **kwargs)
super().__init__(model_name_or_path, **kwargs)

# make a copy of the tokenizer,
# to be able to could the tokens in another thread
Expand All @@ -48,9 +48,10 @@ def __init__(self, *args, **kwargs):
self._infinity_tokenizer = copy.deepcopy(self.tokenizer)
self.model.eval()

self.model = to_bettertransformer(self.model, logger)

if self.model.device == "cuda" and not os.environ.get(
self.model = to_bettertransformer(self.model, logger)

if self._target_device.type == "cuda" and not os.environ.get(
"INFINITY_DISABLE_HALF", ""
):
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs) -> None:


class Fastembed(DefaultEmbedding, BaseEmbedder):
def __init__(self, *args, **kwargs):
def __init__(self, model_name_or_path, **kwargs):
if not FASTEMBED_AVAILABLE:
raise ImportError(
"fastembed is not installed." "`pip install infinity-emb[fastembed]`"
Expand All @@ -32,7 +32,7 @@ def __init__(self, *args, **kwargs):
kwargs["cache_dir"] = infinity_cache_dir()
if kwargs.pop("device", None) != "cpu":
providers = ["CUDAExecutionProvider"] + providers
super(DefaultEmbedding, self).__init__(*args, **kwargs)
super(DefaultEmbedding, self).__init__(model_name_or_path, **kwargs)
self._infinity_tokenizer = copy.deepcopy(self.model.tokenizer)
self.model.model.set_providers(providers)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ class Module: # type: ignore
class SentenceTransformerPatched(SentenceTransformer, BaseEmbedder):
"""SentenceTransformer with .encode_core() and no microbatching"""

def __init__(self, *args, **kwargs):
def __init__(self, model_name_or_path, **kwargs):
if not TORCH_AVAILABLE:
raise ImportError(
"torch is not installed."
" `pip install infinity-emb[torch]` "
"or pip install infinity-emb[torch,optimum]`"
)
super().__init__(*args, **kwargs)
super().__init__(model_name_or_path, **kwargs)
device = self._target_device
self.to(device)
# make a copy of the tokenizer,
Expand Down
Loading

0 comments on commit ac251fe

Please sign in to comment.