Skip to content

Commit

Permalink
Merge branch 'main' into optimisations
Browse files Browse the repository at this point in the history
  • Loading branch information
eleanorTurintech authored Feb 25, 2025
2 parents 8958c02 + 41925e4 commit 59d5692
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
commit_id=$GITHUB_SHA
fi
commit_msg=$(git show -s --format=%s | cut -c1-70)
python3 benchmark/benchmarks_entrypoint.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg"
python3 benchmark/benchmarks_entrypoint.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
env:
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
# Enable this to see debug logs
Expand All @@ -73,3 +73,4 @@ jobs:
PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }}
PGUSER: transformers_benchmarks
PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }}
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
8 changes: 7 additions & 1 deletion src/transformers/models/vitdet/modeling_vitdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,14 @@ def __init__(
super().__init__()

dim = config.hidden_size
input_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)

image_size = config.image_size
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)

patch_size = config.patch_size
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)

input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = VitDetAttention(
config, input_size=input_size if window_size == 0 else (window_size, window_size)
Expand Down
47 changes: 46 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from unittest.mock import patch

import huggingface_hub.utils
import requests
import urllib3
from huggingface_hub import delete_repo
from packaging import version
Expand Down Expand Up @@ -200,6 +201,8 @@
IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False

logger = transformers_logging.get_logger(__name__)


def parse_flag_from_env(key, default=False):
try:
Expand Down Expand Up @@ -2497,7 +2500,49 @@ def wrapper(*args, **kwargs):
return test_func_ref(*args, **kwargs)

except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1

return test_func_ref(*args, **kwargs)

return wrapper

return decorator


def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
"""
To decorate tests that download from the Hub. They can fail due to a
variety of network issues such as timeouts, connection resets, etc.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*, defaults to 2):
If provided, will wait that number of seconds before retrying the test.
"""

def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1

while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
# We catch all exceptions related to network issues from requests
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ReadTimeout,
requests.exceptions.HTTPError,
requests.exceptions.RequestException,
) as err:
logger.error(
f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specied Hub repository."
)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
Expand Down
25 changes: 25 additions & 0 deletions tests/models/vitdet/test_modeling_vitdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,31 @@ def test_feed_forward_chunking(self):
def test_model_from_pretrained(self):
pass

def test_non_square_image(self):
non_square_image_size = (32, 40)
patch_size = (2, 2)
config = self.model_tester.get_config()
config.image_size = non_square_image_size
config.patch_size = patch_size

model = VitDetModel(config=config)
model.to(torch_device)
model.eval()

batch_size = self.model_tester.batch_size
# Create a dummy input tensor with non-square spatial dimensions.
pixel_values = floats_tensor(
[batch_size, config.num_channels, non_square_image_size[0], non_square_image_size[1]]
)

result = model(pixel_values)

expected_height = non_square_image_size[0] / patch_size[0]
expected_width = non_square_image_size[1] / patch_size[1]
expected_shape = (batch_size, config.hidden_size, expected_height, expected_width)

self.assertEqual(result.last_hidden_state.shape, expected_shape)


@require_torch
class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
)
from transformers.testing_utils import (
CaptureLogger,
hub_retry,
is_flaky,
require_accelerate,
require_bitsandbytes,
Expand Down Expand Up @@ -214,6 +215,16 @@ class ModelTesterMixin:
_is_composite = False
model_split_percents = [0.5, 0.7, 0.9]

# Note: for all mixins that utilize the Hub in some way, we should ensure that
# they contain the `hub_retry` decorator in case of failures.
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for attr_name in dir(cls):
if attr_name.startswith("test_"):
attr = getattr(cls, attr_name)
if callable(attr):
setattr(cls, attr_name, hub_retry(attr))

@property
def all_generative_model_classes(self):
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())
Expand Down
13 changes: 13 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
LoggingLevel,
TemporaryHubRepo,
TestCasePlus,
hub_retry,
is_staging_test,
require_accelerate,
require_flax,
Expand Down Expand Up @@ -327,6 +328,18 @@ def tearDown(self):
torch.set_default_dtype(self.old_dtype)
super().tearDown()

def test_hub_retry(self):
@hub_retry(max_attempts=2)
def test_func():
# First attempt will fail with a connection error
if not hasattr(test_func, "attempt"):
test_func.attempt = 1
raise requests.exceptions.ConnectionError("Connection failed")
# Second attempt will succeed
return True

self.assertTrue(test_func())

@slow
def test_model_from_pretrained(self):
model_name = "google-bert/bert-base-uncased"
Expand Down

0 comments on commit 59d5692

Please sign in to comment.