Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further reduce the HTTP calls to huggingface.co #13107

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 79 additions & 56 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import json
import os
import time
from functools import cache
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Type, Union
from typing import Any, Callable, Dict, Literal, Optional, Type, Union

import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
try_to_load_from_cache)
from huggingface_hub import hf_hub_download
from huggingface_hub import list_repo_files as hf_list_repo_files
from huggingface_hub import try_to_load_from_cache
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, LocalEntryNotFoundError,
RepositoryNotFoundError,
Expand Down Expand Up @@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
MISTRAL = "mistral"


def with_retry(func: Callable[[], Any],
log_msg: str,
max_retries: int = 2,
retry_delay: int = 2):
for attempt in range(max_retries):
try:
return func()
except Exception as e:
if attempt == max_retries - 1:
logger.error("%s: %s", log_msg, e)
raise
logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1,
max_retries)
time.sleep(retry_delay)
retry_delay *= 2


# @cache doesn't cache exceptions
@cache
def list_repo_files(
repo_id: str,
*,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Union[str, bool, None] = None,
) -> list[str]:

def lookup_files():
try:
return hf_list_repo_files(repo_id,
revision=revision,
repo_type=repo_type,
token=token)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode,
# all we know is that we don't have this
# file cached.
return []

return with_retry(lookup_files, "Error retrieving file list")


def file_exists(
repo_id: str,
file_name: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
) -> bool:

file_list = list_repo_files(repo_id,
repo_type=repo_type,
revision=revision,
token=token)
return file_name in file_list


# In offline mode the result can be a false negative
def file_or_path_exists(model: Union[str, Path], config_name: str,
revision: Optional[str]) -> bool:
if Path(model).exists():
Expand All @@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
# hf_hub. This will fail in offline mode.

# Call HF to check if the file exists
# 2 retries and exponential backoff
max_retries = 2
retry_delay = 2
for attempt in range(max_retries):
try:
return file_exists(model,
config_name,
revision=revision,
token=HF_TOKEN)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode,
# all we know is that we don't have this
# file cached.
return False
except Exception as e:
logger.error(
"Error checking file existence: %s, retrying %d of %d", e,
attempt + 1, max_retries)
if attempt == max_retries - 1:
logger.error("Error checking file existence: %s", e)
raise
time.sleep(retry_delay)
retry_delay *= 2
continue
return False
return file_exists(str(model),
config_name,
revision=revision,
token=HF_TOKEN)


def patch_rope_scaling(config: PretrainedConfig) -> None:
Expand Down Expand Up @@ -208,32 +248,7 @@ def get_config(
revision=revision):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().

# Call HF to check if the file exists
# 2 retries and exponential backoff
max_retries = 2
retry_delay = 2
for attempt in range(max_retries):
try:
file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=HF_TOKEN)
except Exception as e:
logger.error(
"Error checking file existence: %s, retrying %d of %d",
e, attempt + 1, max_retries)
if attempt == max_retries:
logger.error("Error checking file existence: %s", e)
raise e
time.sleep(retry_delay)
retry_delay *= 2

raise ValueError(f"No supported config format found in {model}")
raise ValueError(f"No supported config format found in {model}.")

if config_format == ConfigFormat.HF:
config_dict, _ = PretrainedConfig.get_config_dict(
Expand Down Expand Up @@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
file_name=file_name,
revision=revision)

if file_path is None and file_or_path_exists(
model=model, config_name=file_name, revision=revision):
if file_path is None:
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except huggingface_hub.errors.OfflineModeIsEnabled:
return None
except (RepositoryNotFoundError, RevisionNotFoundError,
EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download", e)
Expand All @@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
return None


@cache
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
"""
This function gets the pooling and normalize
Expand Down Expand Up @@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
if modules_dict is None:
return None

logger.info("Found sentence-transformers modules configuration.")

pooling = next((item for item in modules_dict
if item["type"] == "sentence_transformers.models.Pooling"),
None)
Expand All @@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
if pooling_type_name is not None:
pooling_type_name = get_pooling_config_name(pooling_type_name)

logger.info("Found pooling configuration.")
return {"pooling_type": pooling_type_name, "normalize": normalize}

return None
Expand Down Expand Up @@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
return None


@cache
def get_sentence_transformer_tokenizer_config(model: str,
revision: Optional[str] = 'main'
):
Expand Down Expand Up @@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
if not encoder_dict:
return None

logger.info("Found sentence-transformers tokenize configuration.")

if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
return encoder_dict
return None
Expand Down