Skip to content

Commit

Permalink
Support multiple rate strategy for one route. (long2ice#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
long2ice committed Mar 30, 2021
1 parent 3828de6 commit dd385d6
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 0.1

### 0.1.3

- Support multiple rate strategy for one route. (#3)

### 0.1.2

- Use milliseconds instead of seconds as default unit of expiration.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def default_identifier(request: Request):
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0]
return request.client.host
return request.client.host + ":" + request.scope["path"]
```

### callback
Expand Down
11 changes: 11 additions & 0 deletions examples/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,16 @@ async def index():
return {"msg": "Hello World"}


@app.get(
"/multiple",
dependencies=[
Depends(RateLimiter(times=1, seconds=5)),
Depends(RateLimiter(times=2, seconds=15)),
],
)
async def multiple():
return {"msg": "Hello World"}


if __name__ == "__main__":
uvicorn.run("main:app", debug=True, reload=True)
2 changes: 1 addition & 1 deletion fastapi_limiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async def default_identifier(request: Request):
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0]
return request.client.host
return request.client.host + ":" + request.scope["path"]


async def default_callback(request: Request, response: Response, pexpire: int):
Expand Down
9 changes: 8 additions & 1 deletion fastapi_limiter/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,19 @@ def __init__(
async def __call__(self, request: Request, response: Response):
if not FastAPILimiter.redis:
raise Exception("You must call FastAPILimiter.init in startup event of fastapi!")
index = 0
for route in request.app.routes:
if route.path == request.scope["path"]:
for idx, dependency in enumerate(route.dependencies):
if self is dependency.dependency:
index = idx
break
# moved here because constructor run before app startup
identifier = self.identifier or FastAPILimiter.identifier
callback = self.callback or FastAPILimiter.callback
redis = FastAPILimiter.redis
rate_key = await identifier(request)
key = FastAPILimiter.prefix + ":" + rate_key
key = f"{FastAPILimiter.prefix}:{rate_key}:{index}"
tr = redis.multi_exec()
tr.incrby(key, 1)
tr.pttl(key)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages = [
]
readme = "README.md"
repository = "https://github.com/long2ice/fastapi-limiter.git"
version = "0.1.2"
version = "0.1.3"

[tool.poetry.dependencies]
aioredis = "*"
Expand Down
22 changes: 20 additions & 2 deletions tests/test_depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def test_limiter():
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"msg": "Hello World"}

client.get("/")

Expand All @@ -19,4 +18,23 @@ def test_limiter():

response = client.get("/")
assert response.status_code == 200
assert response.json() == {"msg": "Hello World"}


def test_limiter_multiple():
with TestClient(app) as client:
response = client.get("/multiple")
assert response.status_code == 200

response = client.get("/multiple")
assert response.status_code == 429
sleep(5)

response = client.get("/multiple")
assert response.status_code == 200

response = client.get("/multiple")
assert response.status_code == 429
sleep(10)

response = client.get("/multiple")
assert response.status_code == 200

0 comments on commit dd385d6

Please sign in to comment.