diff --git a/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py b/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py index f8b9851b..7b557287 100644 --- a/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py +++ b/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py @@ -5,7 +5,7 @@ try: from huggingface_hub import HfApi, HfFolder # type: ignore - from huggingface_hub.constants import HF_HUB_CACHE # type: ignore + from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE # type: ignore from optimum.onnxruntime import ORTOptimizer # type: ignore from optimum.onnxruntime.configuration import OptimizationConfig # type: ignore except ImportError: @@ -42,7 +42,7 @@ def optimize_model( path_folder = ( Path(model_name_or_path) if Path(model_name_or_path).exists() - else Path(HF_HUB_CACHE) / "infinity_onnx" / model_name_or_path + else Path(HUGGINGFACE_HUB_CACHE) / "infinity_onnx" / model_name_or_path ) files_optimized = list(path_folder.glob("**/*optimized.onnx")) if files_optimized and not execution_provider == "TensorrtExecutionProvider": @@ -93,19 +93,27 @@ def optimize_model( def get_onnx_files( - model_id: str, + model_name_or_path: str, revision: str, use_auth_token: Union[bool, str] = True, prefer_quantized=False, ) -> Path: """gets the onnx files from the repo""" - if isinstance(use_auth_token, bool): - token = HfFolder().get_token() + if not Path(model_name_or_path).exists(): + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = list( + map( + Path, + HfApi().list_repo_files( + model_name_or_path, revision=revision, token=token + ), + ) + ) else: - token = use_auth_token - repo_files = map( - Path, HfApi().list_repo_files(model_id, revision=revision, token=token) - ) + repo_files = list(Path(model_name_or_path).glob("**/*")) pattern = "**.onnx" onnx_files = [p for p in repo_files if p.match(pattern)] @@ -121,4 +129,4 @@ def get_onnx_files( elif len(onnx_files) == 1: return onnx_files[0] else: - raise ValueError(f"No onnx files found for {model_id}") + raise ValueError(f"No onnx files found for {model_name_or_path}")