From 6ddba7b88af4aae1f7cc91ff125b390170b88652 Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Fri, 3 Jan 2025 16:40:24 +0000 Subject: [PATCH] update torch compile --- .../transformer/crossencoder/test_torch_crossencoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/infinity_emb/tests/unit_test/transformer/crossencoder/test_torch_crossencoder.py b/libs/infinity_emb/tests/unit_test/transformer/crossencoder/test_torch_crossencoder.py index 9d58041e..4d6d329c 100644 --- a/libs/infinity_emb/tests/unit_test/transformer/crossencoder/test_torch_crossencoder.py +++ b/libs/infinity_emb/tests/unit_test/transformer/crossencoder/test_torch_crossencoder.py @@ -11,7 +11,7 @@ device = Device.cpu if torch.backends.mps.is_available() else Device.auto -SHOULD_TORCH_COMPILE = sys.platform == "linux" and sys.version_info < (3, 12) +SHOULD_TORCH_COMPILE = sys.platform == "linux" and sys.version_info < (3, 12) and torch.cuda.is_available() def test_crossencoder(): @@ -43,7 +43,7 @@ def test_crossencoder(): def test_patched_crossencoder_vs_sentence_transformers(): model = CrossEncoderPatched( engine_args=EngineArgs( - model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1", compile=True, device=device + model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1", compile=SHOULD_TORCH_COMPILE, device=device ) ) model_unpatched = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")