From cbe0ea59f34572e4f494ebb2ed910fbc40c742ab Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 25 Feb 2025 17:22:09 +0100 Subject: [PATCH 1/3] Security fix for `benchmark.yml` (#36402) security Co-authored-by: ydshieh --- .github/workflows/benchmark.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index bb5281778bf2..6b5555097c09 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -64,7 +64,7 @@ jobs: commit_id=$GITHUB_SHA fi commit_msg=$(git show -s --format=%s | cut -c1-70) - python3 benchmark/benchmarks_entrypoint.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg" + python3 benchmark/benchmarks_entrypoint.py "$BRANCH_NAME" "$commit_id" "$commit_msg" env: HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} # Enable this to see debug logs @@ -73,3 +73,4 @@ jobs: PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }} PGUSER: transformers_benchmarks PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }} + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} From 9ebfda3263fcdc4e05fd87fad1aadc8a08294608 Mon Sep 17 00:00:00 2001 From: "Chulhwa (Evan) Han" Date: Wed, 26 Feb 2025 04:31:24 +0900 Subject: [PATCH 2/3] Fixed VitDet for non-squre Images (#35969) * size tuple * delete original input_size * use zip * process the other case * Update src/transformers/models/vitdet/modeling_vitdet.py Co-authored-by: Pavel Iakubovskii * [VITDET] Test non-square image * [Fix] Make Quality * make fix style * Update src/transformers/models/vitdet/modeling_vitdet.py --------- Co-authored-by: Pavel Iakubovskii --- .../models/vitdet/modeling_vitdet.py | 8 +++++- tests/models/vitdet/test_modeling_vitdet.py | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index 9bd7ca2ff1c9..9585c295e18a 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -456,8 +456,14 @@ def __init__( super().__init__() dim = config.hidden_size - input_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + image_size = config.image_size + image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size) + + patch_size = config.patch_size + patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size) + + input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = VitDetAttention( config, input_size=input_size if window_size == 0 else (window_size, window_size) diff --git a/tests/models/vitdet/test_modeling_vitdet.py b/tests/models/vitdet/test_modeling_vitdet.py index 2c46b60f7e73..4b5ac0f3378c 100644 --- a/tests/models/vitdet/test_modeling_vitdet.py +++ b/tests/models/vitdet/test_modeling_vitdet.py @@ -290,6 +290,31 @@ def test_feed_forward_chunking(self): def test_model_from_pretrained(self): pass + def test_non_square_image(self): + non_square_image_size = (32, 40) + patch_size = (2, 2) + config = self.model_tester.get_config() + config.image_size = non_square_image_size + config.patch_size = patch_size + + model = VitDetModel(config=config) + model.to(torch_device) + model.eval() + + batch_size = self.model_tester.batch_size + # Create a dummy input tensor with non-square spatial dimensions. + pixel_values = floats_tensor( + [batch_size, config.num_channels, non_square_image_size[0], non_square_image_size[1]] + ) + + result = model(pixel_values) + + expected_height = non_square_image_size[0] / patch_size[0] + expected_width = non_square_image_size[1] / patch_size[1] + expected_shape = (batch_size, config.hidden_size, expected_height, expected_width) + + self.assertEqual(result.last_hidden_state.shape, expected_shape) + @require_torch class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin): From 41925e42135257361b7f02aa20e3bbdab3f7b923 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 25 Feb 2025 14:53:11 -0500 Subject: [PATCH 3/3] Add retry hf hub decorator (#35213) * Add retry torch decorator * New approach * Empty commit * Empty commit * Style * Use logger.error * Add a test * Update src/transformers/testing_utils.py Co-authored-by: Lucain * Fix err * Update tests/utils/test_modeling_utils.py --------- Co-authored-by: Lucain Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- src/transformers/testing_utils.py | 47 +++++++++++++++++++++++++++++- tests/test_modeling_common.py | 11 +++++++ tests/utils/test_modeling_utils.py | 13 +++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index bb0b3d3b2f86..17223278eb11 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -43,6 +43,7 @@ from unittest.mock import patch import huggingface_hub.utils +import requests import urllib3 from huggingface_hub import delete_repo from packaging import version @@ -200,6 +201,8 @@ IS_ROCM_SYSTEM = False IS_CUDA_SYSTEM = False +logger = transformers_logging.get_logger(__name__) + def parse_flag_from_env(key, default=False): try: @@ -2497,7 +2500,49 @@ def wrapper(*args, **kwargs): return test_func_ref(*args, **kwargs) except Exception as err: - print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.") + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + +def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2): + """ + To decorate tests that download from the Hub. They can fail due to a + variety of network issues such as timeouts, connection resets, etc. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*, defaults to 2): + If provided, will wait that number of seconds before retrying the test. + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + # We catch all exceptions related to network issues from requests + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ReadTimeout, + requests.exceptions.HTTPError, + requests.exceptions.RequestException, + ) as err: + logger.error( + f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specied Hub repository." + ) if wait_before_retry is not None: time.sleep(wait_before_retry) retry_count += 1 diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 80e7bd144714..8de8b0584c0d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -74,6 +74,7 @@ ) from transformers.testing_utils import ( CaptureLogger, + hub_retry, is_flaky, require_accelerate, require_bitsandbytes, @@ -214,6 +215,16 @@ class ModelTesterMixin: _is_composite = False model_split_percents = [0.5, 0.7, 0.9] + # Note: for all mixins that utilize the Hub in some way, we should ensure that + # they contain the `hub_retry` decorator in case of failures. + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, hub_retry(attr)) + @property def all_generative_model_classes(self): return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate()) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 72434a192226..7d8906fa5936 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -51,6 +51,7 @@ LoggingLevel, TemporaryHubRepo, TestCasePlus, + hub_retry, is_staging_test, require_accelerate, require_flax, @@ -327,6 +328,18 @@ def tearDown(self): torch.set_default_dtype(self.old_dtype) super().tearDown() + def test_hub_retry(self): + @hub_retry(max_attempts=2) + def test_func(): + # First attempt will fail with a connection error + if not hasattr(test_func, "attempt"): + test_func.attempt = 1 + raise requests.exceptions.ConnectionError("Connection failed") + # Second attempt will succeed + return True + + self.assertTrue(test_func()) + @slow def test_model_from_pretrained(self): model_name = "google-bert/bert-base-uncased"