Skip to content

Commit

Permalink
update torch compile
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Jan 3, 2025
1 parent 9544138 commit 6ddba7b
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

device = Device.cpu if torch.backends.mps.is_available() else Device.auto

SHOULD_TORCH_COMPILE = sys.platform == "linux" and sys.version_info < (3, 12)
SHOULD_TORCH_COMPILE = sys.platform == "linux" and sys.version_info < (3, 12) and torch.cuda.is_available()


def test_crossencoder():
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_crossencoder():
def test_patched_crossencoder_vs_sentence_transformers():
model = CrossEncoderPatched(
engine_args=EngineArgs(
model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1", compile=True, device=device
model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1", compile=SHOULD_TORCH_COMPILE, device=device
)
)
model_unpatched = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
Expand Down

0 comments on commit 6ddba7b

Please sign in to comment.