Skip to content

Commit

Permalink
Add retry hf hub decorator (#35213)
Browse files Browse the repository at this point in the history
* Add retry torch decorator

* New approach

* Empty commit

* Empty commit

* Style

* Use logger.error

* Add a test

* Update src/transformers/testing_utils.py

Co-authored-by: Lucain <[email protected]>

* Fix err

* Update tests/utils/test_modeling_utils.py

---------

Co-authored-by: Lucain <[email protected]>
Co-authored-by: Yih-Dar <[email protected]>
  • Loading branch information
3 people authored Feb 25, 2025
1 parent 9ebfda3 commit 41925e4
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
47 changes: 46 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from unittest.mock import patch

import huggingface_hub.utils
import requests
import urllib3
from huggingface_hub import delete_repo
from packaging import version
Expand Down Expand Up @@ -200,6 +201,8 @@
IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False

logger = transformers_logging.get_logger(__name__)


def parse_flag_from_env(key, default=False):
try:
Expand Down Expand Up @@ -2497,7 +2500,49 @@ def wrapper(*args, **kwargs):
return test_func_ref(*args, **kwargs)

except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1

return test_func_ref(*args, **kwargs)

return wrapper

return decorator


def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
"""
To decorate tests that download from the Hub. They can fail due to a
variety of network issues such as timeouts, connection resets, etc.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*, defaults to 2):
If provided, will wait that number of seconds before retrying the test.
"""

def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1

while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
# We catch all exceptions related to network issues from requests
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ReadTimeout,
requests.exceptions.HTTPError,
requests.exceptions.RequestException,
) as err:
logger.error(
f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specied Hub repository."
)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
)
from transformers.testing_utils import (
CaptureLogger,
hub_retry,
is_flaky,
require_accelerate,
require_bitsandbytes,
Expand Down Expand Up @@ -214,6 +215,16 @@ class ModelTesterMixin:
_is_composite = False
model_split_percents = [0.5, 0.7, 0.9]

# Note: for all mixins that utilize the Hub in some way, we should ensure that
# they contain the `hub_retry` decorator in case of failures.
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for attr_name in dir(cls):
if attr_name.startswith("test_"):
attr = getattr(cls, attr_name)
if callable(attr):
setattr(cls, attr_name, hub_retry(attr))

@property
def all_generative_model_classes(self):
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())
Expand Down
13 changes: 13 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
LoggingLevel,
TemporaryHubRepo,
TestCasePlus,
hub_retry,
is_staging_test,
require_accelerate,
require_flax,
Expand Down Expand Up @@ -327,6 +328,18 @@ def tearDown(self):
torch.set_default_dtype(self.old_dtype)
super().tearDown()

def test_hub_retry(self):
@hub_retry(max_attempts=2)
def test_func():
# First attempt will fail with a connection error
if not hasattr(test_func, "attempt"):
test_func.attempt = 1
raise requests.exceptions.ConnectionError("Connection failed")
# Second attempt will succeed
return True

self.assertTrue(test_func())

@slow
def test_model_from_pretrained(self):
model_name = "google-bert/bert-base-uncased"
Expand Down

0 comments on commit 41925e4

Please sign in to comment.