Skip to content

Commit

Permalink
♻️ Update Ratelimiter
Browse files Browse the repository at this point in the history
  • Loading branch information
BalconyJH committed Jan 10, 2025
1 parent b330a95 commit 6c7b6cc
Show file tree
Hide file tree
Showing 5 changed files with 521 additions and 254 deletions.
163 changes: 100 additions & 63 deletions aioarxiv/utils/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from collections import deque
from dataclasses import dataclass
import functools
from typing import Any, Callable, ClassVar, Optional, TypeVar, cast

from aioarxiv.config import default_config
from typing import Any, Callable, TypeVar, cast

from .log import logger

Expand All @@ -13,12 +11,12 @@

@dataclass
class RateLimitState:
"""速率限制状态
"""Data class to represent the current rate limit state.
Attributes:
remaining: 剩余可用请求数
reset_at: 下次重置时间
window_start: 当前窗口开始时间
Args:
remaining (int): The number of remaining allowed requests in this window.
reset_at (float): The timestamp (epoch) when the rate-limit window resets.
window_start (float): The timestamp (epoch) when the current window starts.
"""

remaining: int
Expand All @@ -27,99 +25,138 @@ class RateLimitState:


class RateLimiter:
"""速率限制装饰器, 使用类级别共享状态实现请求限流
使用类级别变量确保所有实例共享同一个限流状态。支持并发控制和时间窗口限流。
Attributes:
timestamps: 请求时间戳队列
_lock: 用于保护共享状态的锁
_semaphore: 控制并发请求数的信号量
_calls: 时间窗口内允许的最大请求数
_period: 时间窗口大小(秒)
"""A rate limiter that restricts the number of requests within a given time window
and controls concurrent requests with a semaphore.
This class uses instance-level locks and semaphores, ensuring each RateLimiter
instance can manage its own concurrency and request timestamps without
interfering with other instances or event loops.
Args:
calls (int): The maximum number of allowed requests in one time window.
period (float): The time window length in seconds.
Usage:
```python
# Create a rate limiter allowing 3 calls every 10 seconds
limiter = RateLimiter(calls=3, period=10.0)
# Decorate your async function
@limiter.limit()
async def fetch_data(url: str) -> str:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.text()
# Use in concurrent operations
urls = ["http://example.com/1", "http://example.com/2", "http://example.com/3"]
results = await asyncio.gather(*(fetch_data(url) for url in urls))
# Use with custom error handling
async def fetch_with_retry():
try:
return await fetch_data("http://example.com")
except Exception as e:
logger.error(f"Failed to fetch: {e}")
return None
```
"""

timestamps: ClassVar[deque[float]] = deque()
_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
_semaphore: ClassVar[asyncio.Semaphore] = asyncio.Semaphore(1)
_calls: ClassVar[int] = 0
_period: ClassVar[float] = 0.0
def __init__(self, calls: int, period: float):
"""Initialize the RateLimiter.
Args:
calls (int): The maximum number of allowed requests in one time window.
period (float): The time window length in seconds.
"""
self._calls: int = calls
self._period: float = period
self._timestamps: deque[float] = deque(maxlen=calls)
self._lock: asyncio.Lock = asyncio.Lock()
self._semaphore: asyncio.Semaphore = asyncio.Semaphore(calls)

@classmethod
def _clean_expired(cls, now: float) -> None:
"""清理过期的时间戳
def _clean_expired(self, now: float) -> None:
"""Remove timestamps that are outside the current rate limit window.
Args:
now: 当前时间戳
now (float): The current time (epoch).
"""
cutoff = now - cls._period
while cls.timestamps and cls.timestamps[0] <= cutoff:
cls.timestamps.popleft()
cutoff = now - self._period
while self._timestamps and self._timestamps[0] <= cutoff:
self._timestamps.popleft()

@classmethod
async def _wait_if_needed(cls, now: float) -> float:
"""根据需要等待并返回新的当前时间
async def _wait_if_needed(self, now: float) -> float:
"""If rate limit is reached, block until a request slot is free.
Args:
now: 当前时间戳
now (float): The current time (epoch).
Returns:
float: 等待后的新时间戳
float: Updated current time (after waiting, if needed).
"""
if len(cls.timestamps) >= cls._calls:
wait_time = cls.timestamps[0] + cls._period - now
if len(self._timestamps) >= self._calls:
wait_time = self._timestamps[0] + self._period - now
if wait_time > 0:
logger.debug(
f"Rate limit reached, waiting for {wait_time:.2f}s",
extra={
"wait_time": f"{wait_time:.2f}s",
"current_calls": len(cls.timestamps),
"max_calls": cls._calls,
"period": cls._period,
"current_calls": len(self._timestamps),
"max_calls": self._calls,
"period": self._period,
},
)
await asyncio.sleep(wait_time)
# After sleeping, the "now" might have changed
return asyncio.get_event_loop().time()
return now

@classmethod
def limit(
cls,
calls: Optional[int] = None,
period: Optional[float] = None,
) -> Callable[[T], T]:
"""创建速率限制装饰器
def limit(self) -> Callable[[T], T]:
"""Decorator for rate limiting an async function.
Args:
calls: 时间窗口内最大请求数, 默认使用配置值
period: 时间窗口大小(秒), 默认使用配置值
This decorator ensures that:
1. Concurrency does not exceed `calls`
2. Requests are limited to `calls` within `period` seconds
Returns:
装饰器函数
Callable[[T], T]: A decorator that wraps the original async function.
Usage:
```python
limiter = RateLimiter(calls=5, period=1.0)
@limiter.limit()
async def api_call(endpoint: str) -> dict:
async with aiohttp.ClientSession() as session:
async with session.get(endpoint) as response:
return await response.json()
# The decorated function will automatically respect rate limits
result = await api_call("https://api.example.com/data")
```
"""
cls._calls = calls or default_config.rate_limit_calls
cls._period = period or default_config.rate_limit_period
cls.timestamps = deque(maxlen=cls._calls)
cls._semaphore = asyncio.Semaphore(cls._calls)

def decorator(func: T) -> T:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async with cls._semaphore, cls._lock:
# Acquire semaphore and lock to limit concurrency and ensure timestamp
# checks are atomic
async with self._semaphore, self._lock:
now = asyncio.get_event_loop().time()
cls._clean_expired(now)
now = await cls._wait_if_needed(now)
cls.timestamps.append(now)
self._clean_expired(now)
now = await self._wait_if_needed(now)
self._timestamps.append(now)

logger.debug(
"request rate limit",
"Request rate limit invoked",
extra={
"current_calls": len(cls.timestamps),
"max_calls": cls._calls,
"remaining": cls._calls - len(cls.timestamps),
"current_calls": len(self._timestamps),
"max_calls": self._calls,
"remaining": self._calls - len(self._timestamps),
},
)

# Proceed with the original function call
return await func(*args, **kwargs)

return cast(T, wrapper)
Expand Down
Loading

0 comments on commit 6c7b6cc

Please sign in to comment.