Skip to content

Commit

Permalink
Add Redis sliding window ratelimiter (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
aidenwallis authored Aug 18, 2023
1 parent f743f03 commit 8c9e41a
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 3 deletions.
13 changes: 12 additions & 1 deletion redis/leaky_bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ import (
)

// LeakyBucket defines an interface compatible with LeakyBucketImpl
//
// Leaky buckets have the advantage of being able to burst up to the max tokens you define, and then slowly leak out tokens at a constant rate. This makes
// it a good fit for situations where you want caller buckets to slowly fill if they decide to burst your service, whereas a sliding window ratelimiter will
// free all tokens at once.
//
// Leaky buckets slowly fill your window over time, and will not fill above the size of the window. For example, if you allow 10 tokens per a window of 1 second,
// your bucket fills at a fixed rate of 100ms.
//
// See: https://en.wikipedia.org/wiki/Leaky_bucket
type LeakyBucket interface {
// Use atomically attempts to use the leaky bucket. Use takeAmount to set how many tokens should be attempted to be removed
// from the bucket: they are atomic, either all tokens are taken, or the ratelimit is unsuccessful.
Expand Down Expand Up @@ -53,7 +62,9 @@ type UseLeakyBucketResponse struct {
ResetAt time.Time
}

// LeakyBucketImpl implements a leaky bucket ratelimiter, this struct is compatible with the LeakyBucket interface
// LeakyBucketImpl implements a leaky bucket ratelimiter in Redis with Lua. This struct is compatible with the LeakyBucket interface
//
// See the LeakyBucket interface for more information about leaky bucket ratelimiters.
type LeakyBucketImpl struct {
// Adapter defines the Redis adapter
Adapter adapters.Adapter
Expand Down
4 changes: 2 additions & 2 deletions redis/leaky_bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestUseLeakyBucket(t *testing.T) {
t.Parallel()

testCases := map[string]func(*miniredis.Miniredis) adapters.Adapter{
"goredis": func(t *miniredis.Miniredis) adapters.Adapter {
"go-redis": func(t *miniredis.Miniredis) adapters.Adapter {
return goredisadapter.NewAdapter(goredis.NewClient(&goredis.Options{Addr: t.Addr()}))
},
"redigo": func(t *miniredis.Miniredis) adapters.Adapter {
Expand Down Expand Up @@ -138,7 +138,7 @@ func TestParseLeakyBucketResponse_Errors(t *testing.T) {
}
}

// leakyBucketOptions provides quick sane defaults for testing leaky bucket options
// leakyBucketOptions provides quick sane defaults for testing leaky buckets
func leakyBucketOptions() *LeakyBucketOptions {
return &LeakyBucketOptions{
KeyPrefix: "test-bucket",
Expand Down
160 changes: 160 additions & 0 deletions redis/sliding_window.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package redis

import (
"context"
"fmt"
"math"
"time"

"github.com/aidenwallis/go-ratelimiting/redis/adapters"
)

// SlidingWindow provides an interface for the redis sliding window ratelimiter, compatible with SlidingWindowImpl
//
// The sliding window ratelimiter is a fixed size window that holds a set of timestamps. When a token is taken, the current time is added to the window.
// The window is constantly cleaned, and evicting old tokens, which allows new ones to be added as the window discards old tokens.
type SlidingWindow interface {
// Use atomically attempts to use the sliding window. Sliding window ratelimiters always take 1 token at a time, as the key is inferred
// from when it would expire in nanoseconds.
Use(ctx context.Context, bucket *SlidingWindowOptions) (*UseSlidingWindowResponse, error)
}

var _ SlidingWindow = (*SlidingWindowImpl)(nil)

// SlidingWindowImpl implements a sliding window ratelimiter for Redis using Lua. This struct is compatible with the SlidingWindow interface.
//
// Refer to the SlidingWindow interface for more information about this ratelimiter.
type SlidingWindowImpl struct {
// Adapter defines the Redis adapter
Adapter adapters.Adapter

// nowFunc is a private helper used to mock out time changes in unit testing
//
// if this is not defined, it falls back to time.Now()
nowFunc func() time.Time
}

// SlidingWindowOptions defines the options available to a sliding window bucket.
type SlidingWindowOptions struct {
// Key defines the Redis key used for this sliding window ratelimiter
Key string

// MaximumCapacity defines the max size of the sliding window, no more tokens than this may be stored in the sliding
// window at any time.
MaximumCapacity int

// Window defines the size of the sliding window, resolution is available up to nanoseconds.
Window time.Duration
}

// NewSlidingWindow creates a new sliding window instance
func NewSlidingWindow(adapter adapters.Adapter) *SlidingWindowImpl {
return &SlidingWindowImpl{
Adapter: adapter,
nowFunc: time.Now,
}
}

func (r *SlidingWindowImpl) now() time.Time {
if r.nowFunc == nil {
return time.Now()
}
return r.nowFunc()
}

// UseSlidingWindowResponse defines the response parameters for SlidingWindow.Use()
type UseSlidingWindowResponse struct {
// Success defines whether the sliding window was successfully used
Success bool

// RemainingCapacity defines the remaining amount of capacity left in the bucket
RemainingCapacity int
}

// Use atomically attempts to use the sliding window.
func (r *SlidingWindowImpl) Use(ctx context.Context, bucket *SlidingWindowOptions) (*UseSlidingWindowResponse, error) {
const script = `
local key = KEYS[1]
local now = ARGV[1]
local expiresAt = ARGV[2]
local window = ARGV[3]
local max = tonumber(ARGV[4])
redis.call("zremrangebyscore", key, "-inf", now) -- clear expired tokens
local tokens = tonumber(redis.call("zcard", key))
if (tokens == nil) then
tokens = 0 -- default tokens to 0
end
local success = 0
if (tokens < max) then
-- room available: add a token, bump ttl, and include newly added token in count
redis.call("zadd", key, expiresAt, expiresAt)
redis.call("expire", key, window)
success = 1
tokens = tokens + 1
end
return {success, tokens}
`

now := r.now()
current := now.UnixNano()
expiresAt := now.Add(bucket.Window).UnixNano()
windowTTL := int(math.Ceil(bucket.Window.Seconds()))

resp, err := r.Adapter.Eval(ctx, script, []string{bucket.Key}, []interface{}{
current, expiresAt, windowTTL, bucket.MaximumCapacity,
})
if err != nil {
return nil, fmt.Errorf("failed to query redis adapter: %w", err)
}

output, err := parseSlidingWindowResponse(resp)
if err != nil {
return nil, fmt.Errorf("parsing redis response: %w", err)
}

remaining := 0
if v := bucket.MaximumCapacity - output.tokens; v > remaining {
remaining = v
}

return &UseSlidingWindowResponse{
Success: output.success,
RemainingCapacity: remaining,
}, nil
}

type slidingWindowOutput struct {
success bool
tokens int
}

func parseSlidingWindowResponse(v interface{}) (*slidingWindowOutput, error) {
args, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("expected []interface{} but got %T", v)
}

if len(args) != 2 {
return nil, fmt.Errorf("expected 2 args but got %d", len(args))
}

argInts := make([]int64, len(args))
for i, argValue := range args {
intValue, ok := argValue.(int64)
if !ok {
return nil, fmt.Errorf("expected int64 in arg[%d] but got %T", i, argValue)
}

argInts[i] = intValue
}

return &slidingWindowOutput{
success: argInts[0] == 1,
tokens: int(argInts[1]),
}, nil
}
158 changes: 158 additions & 0 deletions redis/sliding_window_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package redis

import (
"context"
"testing"
"time"

"github.com/aidenwallis/go-ratelimiting/redis/adapters"
goredisadapter "github.com/aidenwallis/go-ratelimiting/redis/adapters/go-redis"
redigoadapter "github.com/aidenwallis/go-ratelimiting/redis/adapters/redigo"
"github.com/alicebob/miniredis/v2"
redigo "github.com/gomodule/redigo/redis"
goredis "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
)

func TestSlidingWindow_Now(t *testing.T) {
adapter := NewSlidingWindow(nil)
adapter.nowFunc = nil
assert.WithinDuration(t, adapter.now(), time.Now(), time.Minute)
}

func TestUseSlidingWindow(t *testing.T) {
testCases := map[string]func(*miniredis.Miniredis) adapters.Adapter{
"go-redis": func(t *miniredis.Miniredis) adapters.Adapter {
return goredisadapter.NewAdapter(goredis.NewClient(&goredis.Options{Addr: t.Addr()}))
},
"redigo": func(t *miniredis.Miniredis) adapters.Adapter {
conn, err := redigo.Dial("tcp", t.Addr())
if err != nil {
panic(err)
}
return redigoadapter.NewAdapter(conn)
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
limiter := NewSlidingWindow(testCase(miniredis.RunT(t)))
limiter.nowFunc = func() time.Time { return now }

{
resp, err := useSlidingWindow(ctx, limiter)
assert.NoError(t, err)
assert.True(t, resp.Success)
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingCapacity)
}

// move forward 3 seconds
limiter.nowFunc = func() time.Time { return now.Add(time.Second * 3) }

{
resp, err := useSlidingWindow(ctx, limiter)
assert.NoError(t, err)
assert.True(t, resp.Success)
assert.Equal(t, leakyBucketOptions().MaximumCapacity-2, resp.RemainingCapacity, "tokens shouldn't have expired yet")
}

// move forward 60 seconds
limiter.nowFunc = func() time.Time { return now.Add(time.Second * 60) }

{
resp, err := useSlidingWindow(ctx, limiter)
assert.NoError(t, err)
assert.True(t, resp.Success)
assert.Equal(t, leakyBucketOptions().MaximumCapacity-2, resp.RemainingCapacity, "one token should've expired, so including this request, 2 should be used")
}

// move forward 120 seconds
limiter.nowFunc = func() time.Time { return now.Add(time.Second * 120) }

{
resp, err := useSlidingWindow(ctx, limiter)
assert.NoError(t, err)
assert.True(t, resp.Success)
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingCapacity, "all tokens should've expired by now, so only this one is left")
}
})
}
}

func TestUseSlidingWindow_Errors(t *testing.T) {
testCases := map[string]struct {
errorMessage string
mockAdapter adapters.Adapter
}{
"redis error": {
errorMessage: "failed to query redis adapter: " + assert.AnError.Error(),
mockAdapter: &mockAdapter{
returnError: assert.AnError,
},
},
"parsing error": {
errorMessage: "parsing redis response: expected []interface{} but got string",
mockAdapter: &mockAdapter{
returnValue: "foo",
},
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
out, err := useSlidingWindow(context.Background(), NewSlidingWindow(testCase.mockAdapter))
assert.Nil(t, out)
assert.EqualError(t, err, testCase.errorMessage)
})
}
}

func TestParseSlidingWindowResponse_Errors(t *testing.T) {
testCases := map[string]struct {
errorMessage string
in interface{}
}{
"invalid type": {
errorMessage: "expected []interface{} but got string",
in: "foo",
},
"invalid length": {
errorMessage: "expected 2 args but got 3",
in: []interface{}{1, 2, 3},
},
"invalid item type": {
errorMessage: "expected int64 in arg[1] but got float64",
in: []interface{}{int64(1), float64(2)},
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
out, err := parseSlidingWindowResponse(testCase.in)
assert.Nil(t, out)
assert.EqualError(t, err, testCase.errorMessage)
})
}
}

// slidingWindowOptions provides quick sane defaults for testing sliding windows
func slidingWindowOptions() *SlidingWindowOptions {
return &SlidingWindowOptions{
Key: "test-bucket",
MaximumCapacity: 60,
Window: time.Minute,
}
}

// useSlidingWindow is a helper to test your sliding window with some predefined options
func useSlidingWindow(ctx context.Context, limiter SlidingWindow) (*UseSlidingWindowResponse, error) {
return limiter.Use(ctx, slidingWindowOptions())
}

0 comments on commit 8c9e41a

Please sign in to comment.