Skip to content

Commit

Permalink
wrap: torch._inductor settings
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Dec 2, 2024
1 parent 71b6457 commit cda8bbd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
11 changes: 6 additions & 5 deletions libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import subprocess
from typing import Union

from functools import cache
import numpy as np

from infinity_emb._optional_imports import CHECK_OPTIMUM_NEURON, CHECK_TORCH
Expand All @@ -30,6 +30,7 @@
]


@cache
def get_nc_count() -> Union[int, None]:
"""Returns the number of neuron cores on the current instance."""
try:
Expand All @@ -45,7 +46,7 @@ def get_nc_count() -> Union[int, None]:
return None


def pad_up_to_size(desired_max_bs, input_ids):
def pad_up_to_size(desired_max_bs: int, input_ids: "torch.Tensor") -> "torch.Tensor":
"""input_ids a 2D array with batch_size on dim=0
makes sure the func runs with self.batch_size
Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(self, *, engine_args: EngineArgs):
)
self.batch_size = self.model.neuron_config.input_shapes["batch_size"]

def encode_pre(self, sentences: list[str]) -> dict[str, np.ndarray]:
def encode_pre(self, sentences: list[str]) -> dict[str, "torch.Tensor"]:
input_dict = self.tokenizer(
sentences,
max_length=self.config.max_position_embeddings,
Expand All @@ -127,7 +128,7 @@ def encode_pre(self, sentences: list[str]) -> dict[str, np.ndarray]:
)
return input_dict

def encode_core(self, input_dict: dict[str, np.ndarray]) -> dict:
def encode_core(self, input_dict: dict[str, "torch.Tensor"]) -> dict:
"""requires constant batch size, which is a bit of extra work"""
for key, tensor in input_dict.items():
actual_bsize = tensor.shape[0]
Expand All @@ -140,7 +141,7 @@ def encode_core(self, input_dict: dict[str, np.ndarray]) -> dict:
}

@quant_embedding_decorator()
def encode_post(self, embedding: dict) -> EmbeddingReturnType:
def encode_post(self, embedding: dict[str, "torch.Tensor"]) -> EmbeddingReturnType:
embedding = self.pooling( # type: ignore
embedding["token_embeddings"].numpy(), embedding["attention_mask"].numpy()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ class SentenceTransformer: # type: ignore[no-redef]
import torch._inductor.config

# torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
try:
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
except Exception:
pass


class SentenceTransformerPatched(SentenceTransformer, BaseEmbedder):
Expand Down

0 comments on commit cda8bbd

Please sign in to comment.