Skip to content

Commit

Permalink
Merge pull request #78 from michaelfeil/default-device
Browse files Browse the repository at this point in the history
rename device name
  • Loading branch information
michaelfeil authored Jan 30, 2024
2 parents 74b1cbd + 8cdda8a commit 0ff0200
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, model_name_or_path, **kwargs):
self.model, logger, disable=self._target_device.type == "mps"
)

if self._target_device.type == "cuda" and not os.environ.get(
if self.device.type == "cuda" and not os.environ.get(
"INFINITY_DISABLE_HALF", ""
):
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, model_name_or_path, trust_remote_code=True, **kwargs):
super().__init__(
model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
device = self._target_device
self.to(device)
self.to(self.device)
# make a copy of the tokenizer,
# to be able to could the tokens in another thread
# without corrupting the original.
Expand All @@ -59,10 +58,12 @@ def __init__(self, model_name_or_path, trust_remote_code=True, **kwargs):
self.eval()

fm.auto_model = to_bettertransformer(
fm.auto_model, logger, disable=device.type == "mps"
fm.auto_model, logger, disable=self.device.type == "mps"
)

if device.type == "cuda" and not os.environ.get("INFINITY_DISABLE_HALF", ""):
if self.device.type == "cuda" and not os.environ.get(
"INFINITY_DISABLE_HALF", ""
):
logger.info(
"Switching to half() precision (cuda: fp16). "
"Disable by the setting the env var `INFINITY_DISABLE_HALF`"
Expand All @@ -80,8 +81,7 @@ def encode_core(self, features: Dict[str, Tensor]) -> Tensor:
"""

with torch.inference_mode():
device = self._target_device
features = util.batch_to_device(features, device)
features = util.batch_to_device(features, self.device)
out_features = self.forward(features)["sentence_embedding"]

return out_features
Expand Down

0 comments on commit 0ff0200

Please sign in to comment.