-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Redis sliding window ratelimiter (#6)
- Loading branch information
1 parent
f743f03
commit 8c9e41a
Showing
4 changed files
with
332 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package redis | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"math" | ||
"time" | ||
|
||
"github.com/aidenwallis/go-ratelimiting/redis/adapters" | ||
) | ||
|
||
// SlidingWindow provides an interface for the redis sliding window ratelimiter, compatible with SlidingWindowImpl | ||
// | ||
// The sliding window ratelimiter is a fixed size window that holds a set of timestamps. When a token is taken, the current time is added to the window. | ||
// The window is constantly cleaned, and evicting old tokens, which allows new ones to be added as the window discards old tokens. | ||
type SlidingWindow interface { | ||
// Use atomically attempts to use the sliding window. Sliding window ratelimiters always take 1 token at a time, as the key is inferred | ||
// from when it would expire in nanoseconds. | ||
Use(ctx context.Context, bucket *SlidingWindowOptions) (*UseSlidingWindowResponse, error) | ||
} | ||
|
||
var _ SlidingWindow = (*SlidingWindowImpl)(nil) | ||
|
||
// SlidingWindowImpl implements a sliding window ratelimiter for Redis using Lua. This struct is compatible with the SlidingWindow interface. | ||
// | ||
// Refer to the SlidingWindow interface for more information about this ratelimiter. | ||
type SlidingWindowImpl struct { | ||
// Adapter defines the Redis adapter | ||
Adapter adapters.Adapter | ||
|
||
// nowFunc is a private helper used to mock out time changes in unit testing | ||
// | ||
// if this is not defined, it falls back to time.Now() | ||
nowFunc func() time.Time | ||
} | ||
|
||
// SlidingWindowOptions defines the options available to a sliding window bucket. | ||
type SlidingWindowOptions struct { | ||
// Key defines the Redis key used for this sliding window ratelimiter | ||
Key string | ||
|
||
// MaximumCapacity defines the max size of the sliding window, no more tokens than this may be stored in the sliding | ||
// window at any time. | ||
MaximumCapacity int | ||
|
||
// Window defines the size of the sliding window, resolution is available up to nanoseconds. | ||
Window time.Duration | ||
} | ||
|
||
// NewSlidingWindow creates a new sliding window instance | ||
func NewSlidingWindow(adapter adapters.Adapter) *SlidingWindowImpl { | ||
return &SlidingWindowImpl{ | ||
Adapter: adapter, | ||
nowFunc: time.Now, | ||
} | ||
} | ||
|
||
func (r *SlidingWindowImpl) now() time.Time { | ||
if r.nowFunc == nil { | ||
return time.Now() | ||
} | ||
return r.nowFunc() | ||
} | ||
|
||
// UseSlidingWindowResponse defines the response parameters for SlidingWindow.Use() | ||
type UseSlidingWindowResponse struct { | ||
// Success defines whether the sliding window was successfully used | ||
Success bool | ||
|
||
// RemainingCapacity defines the remaining amount of capacity left in the bucket | ||
RemainingCapacity int | ||
} | ||
|
||
// Use atomically attempts to use the sliding window. | ||
func (r *SlidingWindowImpl) Use(ctx context.Context, bucket *SlidingWindowOptions) (*UseSlidingWindowResponse, error) { | ||
const script = ` | ||
local key = KEYS[1] | ||
local now = ARGV[1] | ||
local expiresAt = ARGV[2] | ||
local window = ARGV[3] | ||
local max = tonumber(ARGV[4]) | ||
redis.call("zremrangebyscore", key, "-inf", now) -- clear expired tokens | ||
local tokens = tonumber(redis.call("zcard", key)) | ||
if (tokens == nil) then | ||
tokens = 0 -- default tokens to 0 | ||
end | ||
local success = 0 | ||
if (tokens < max) then | ||
-- room available: add a token, bump ttl, and include newly added token in count | ||
redis.call("zadd", key, expiresAt, expiresAt) | ||
redis.call("expire", key, window) | ||
success = 1 | ||
tokens = tokens + 1 | ||
end | ||
return {success, tokens} | ||
` | ||
|
||
now := r.now() | ||
current := now.UnixNano() | ||
expiresAt := now.Add(bucket.Window).UnixNano() | ||
windowTTL := int(math.Ceil(bucket.Window.Seconds())) | ||
|
||
resp, err := r.Adapter.Eval(ctx, script, []string{bucket.Key}, []interface{}{ | ||
current, expiresAt, windowTTL, bucket.MaximumCapacity, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to query redis adapter: %w", err) | ||
} | ||
|
||
output, err := parseSlidingWindowResponse(resp) | ||
if err != nil { | ||
return nil, fmt.Errorf("parsing redis response: %w", err) | ||
} | ||
|
||
remaining := 0 | ||
if v := bucket.MaximumCapacity - output.tokens; v > remaining { | ||
remaining = v | ||
} | ||
|
||
return &UseSlidingWindowResponse{ | ||
Success: output.success, | ||
RemainingCapacity: remaining, | ||
}, nil | ||
} | ||
|
||
type slidingWindowOutput struct { | ||
success bool | ||
tokens int | ||
} | ||
|
||
func parseSlidingWindowResponse(v interface{}) (*slidingWindowOutput, error) { | ||
args, ok := v.([]interface{}) | ||
if !ok { | ||
return nil, fmt.Errorf("expected []interface{} but got %T", v) | ||
} | ||
|
||
if len(args) != 2 { | ||
return nil, fmt.Errorf("expected 2 args but got %d", len(args)) | ||
} | ||
|
||
argInts := make([]int64, len(args)) | ||
for i, argValue := range args { | ||
intValue, ok := argValue.(int64) | ||
if !ok { | ||
return nil, fmt.Errorf("expected int64 in arg[%d] but got %T", i, argValue) | ||
} | ||
|
||
argInts[i] = intValue | ||
} | ||
|
||
return &slidingWindowOutput{ | ||
success: argInts[0] == 1, | ||
tokens: int(argInts[1]), | ||
}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
package redis | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"time" | ||
|
||
"github.com/aidenwallis/go-ratelimiting/redis/adapters" | ||
goredisadapter "github.com/aidenwallis/go-ratelimiting/redis/adapters/go-redis" | ||
redigoadapter "github.com/aidenwallis/go-ratelimiting/redis/adapters/redigo" | ||
"github.com/alicebob/miniredis/v2" | ||
redigo "github.com/gomodule/redigo/redis" | ||
goredis "github.com/redis/go-redis/v9" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestSlidingWindow_Now(t *testing.T) { | ||
adapter := NewSlidingWindow(nil) | ||
adapter.nowFunc = nil | ||
assert.WithinDuration(t, adapter.now(), time.Now(), time.Minute) | ||
} | ||
|
||
func TestUseSlidingWindow(t *testing.T) { | ||
testCases := map[string]func(*miniredis.Miniredis) adapters.Adapter{ | ||
"go-redis": func(t *miniredis.Miniredis) adapters.Adapter { | ||
return goredisadapter.NewAdapter(goredis.NewClient(&goredis.Options{Addr: t.Addr()})) | ||
}, | ||
"redigo": func(t *miniredis.Miniredis) adapters.Adapter { | ||
conn, err := redigo.Dial("tcp", t.Addr()) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return redigoadapter.NewAdapter(conn) | ||
}, | ||
} | ||
|
||
for name, testCase := range testCases { | ||
testCase := testCase | ||
|
||
t.Run(name, func(t *testing.T) { | ||
ctx := context.Background() | ||
now := time.Now().UTC() | ||
limiter := NewSlidingWindow(testCase(miniredis.RunT(t))) | ||
limiter.nowFunc = func() time.Time { return now } | ||
|
||
{ | ||
resp, err := useSlidingWindow(ctx, limiter) | ||
assert.NoError(t, err) | ||
assert.True(t, resp.Success) | ||
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingCapacity) | ||
} | ||
|
||
// move forward 3 seconds | ||
limiter.nowFunc = func() time.Time { return now.Add(time.Second * 3) } | ||
|
||
{ | ||
resp, err := useSlidingWindow(ctx, limiter) | ||
assert.NoError(t, err) | ||
assert.True(t, resp.Success) | ||
assert.Equal(t, leakyBucketOptions().MaximumCapacity-2, resp.RemainingCapacity, "tokens shouldn't have expired yet") | ||
} | ||
|
||
// move forward 60 seconds | ||
limiter.nowFunc = func() time.Time { return now.Add(time.Second * 60) } | ||
|
||
{ | ||
resp, err := useSlidingWindow(ctx, limiter) | ||
assert.NoError(t, err) | ||
assert.True(t, resp.Success) | ||
assert.Equal(t, leakyBucketOptions().MaximumCapacity-2, resp.RemainingCapacity, "one token should've expired, so including this request, 2 should be used") | ||
} | ||
|
||
// move forward 120 seconds | ||
limiter.nowFunc = func() time.Time { return now.Add(time.Second * 120) } | ||
|
||
{ | ||
resp, err := useSlidingWindow(ctx, limiter) | ||
assert.NoError(t, err) | ||
assert.True(t, resp.Success) | ||
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingCapacity, "all tokens should've expired by now, so only this one is left") | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestUseSlidingWindow_Errors(t *testing.T) { | ||
testCases := map[string]struct { | ||
errorMessage string | ||
mockAdapter adapters.Adapter | ||
}{ | ||
"redis error": { | ||
errorMessage: "failed to query redis adapter: " + assert.AnError.Error(), | ||
mockAdapter: &mockAdapter{ | ||
returnError: assert.AnError, | ||
}, | ||
}, | ||
"parsing error": { | ||
errorMessage: "parsing redis response: expected []interface{} but got string", | ||
mockAdapter: &mockAdapter{ | ||
returnValue: "foo", | ||
}, | ||
}, | ||
} | ||
|
||
for name, testCase := range testCases { | ||
testCase := testCase | ||
|
||
t.Run(name, func(t *testing.T) { | ||
out, err := useSlidingWindow(context.Background(), NewSlidingWindow(testCase.mockAdapter)) | ||
assert.Nil(t, out) | ||
assert.EqualError(t, err, testCase.errorMessage) | ||
}) | ||
} | ||
} | ||
|
||
func TestParseSlidingWindowResponse_Errors(t *testing.T) { | ||
testCases := map[string]struct { | ||
errorMessage string | ||
in interface{} | ||
}{ | ||
"invalid type": { | ||
errorMessage: "expected []interface{} but got string", | ||
in: "foo", | ||
}, | ||
"invalid length": { | ||
errorMessage: "expected 2 args but got 3", | ||
in: []interface{}{1, 2, 3}, | ||
}, | ||
"invalid item type": { | ||
errorMessage: "expected int64 in arg[1] but got float64", | ||
in: []interface{}{int64(1), float64(2)}, | ||
}, | ||
} | ||
|
||
for name, testCase := range testCases { | ||
testCase := testCase | ||
|
||
t.Run(name, func(t *testing.T) { | ||
out, err := parseSlidingWindowResponse(testCase.in) | ||
assert.Nil(t, out) | ||
assert.EqualError(t, err, testCase.errorMessage) | ||
}) | ||
} | ||
} | ||
|
||
// slidingWindowOptions provides quick sane defaults for testing sliding windows | ||
func slidingWindowOptions() *SlidingWindowOptions { | ||
return &SlidingWindowOptions{ | ||
Key: "test-bucket", | ||
MaximumCapacity: 60, | ||
Window: time.Minute, | ||
} | ||
} | ||
|
||
// useSlidingWindow is a helper to test your sliding window with some predefined options | ||
func useSlidingWindow(ctx context.Context, limiter SlidingWindow) (*UseSlidingWindowResponse, error) { | ||
return limiter.Use(ctx, slidingWindowOptions()) | ||
} |