Skip to content

Commit

Permalink
Add ability to inspect redis ratelimiters (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
aidenwallis authored Aug 18, 2023
1 parent 8c9e41a commit 736baa7
Show file tree
Hide file tree
Showing 6 changed files with 414 additions and 65 deletions.
155 changes: 120 additions & 35 deletions redis/leaky_bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
//
// See: https://en.wikipedia.org/wiki/Leaky_bucket
type LeakyBucket interface {
// Inspect atomically inspects the leaky bucket and returns the capacity available. It does not take any tokens.
Inspect(ctx context.Context, bucket *LeakyBucketOptions) (*InspectLeakyBucketResponse, error)

// Use atomically attempts to use the leaky bucket. Use takeAmount to set how many tokens should be attempted to be removed
// from the bucket: they are atomic, either all tokens are taken, or the ratelimit is unsuccessful.
Use(ctx context.Context, bucket *LeakyBucketOptions, takeAmount int) (*UseLeakyBucketResponse, error)
Expand Down Expand Up @@ -50,18 +53,6 @@ type LeakyBucketOptions struct {
WindowSeconds int
}

// UseLeakyBucketResponse defines the response parameters for LeakyBucket.Use()
type UseLeakyBucketResponse struct {
// Success is true when we were successfully able to take tokens from the bucket.
Success bool

// RemainingTokens defines hwo many tokens are left in the bucket
RemainingTokens int

// ResetAt is the time at which the bucket will be fully refilled
ResetAt time.Time
}

// LeakyBucketImpl implements a leaky bucket ratelimiter in Redis with Lua. This struct is compatible with the LeakyBucket interface
//
// See the LeakyBucket interface for more information about leaky bucket ratelimiters.
Expand All @@ -70,6 +61,8 @@ type LeakyBucketImpl struct {
Adapter adapters.Adapter

// nowFunc is a private helper used to mock out time changes in unit testing
//
// if this is not defined, it falls back to time.Now()
nowFunc func() time.Time
}

Expand All @@ -88,6 +81,80 @@ func (r *LeakyBucketImpl) now() time.Time {
return r.nowFunc()
}

// InspectLeakyBucketResponse defines the response parameters for LeakyBucket.Inspect()
type InspectLeakyBucketResponse struct {
// RemainingTokens defines hwo many tokens are left in the bucket
RemainingTokens int

// ResetAt is the time at which the bucket will be fully refilled
ResetAt time.Time
}

// Inspect atomically inspects the leaky bucket and returns the capacity available. It does not take any tokens.
func (r *LeakyBucketImpl) Inspect(ctx context.Context, bucket *LeakyBucketOptions) (*InspectLeakyBucketResponse, error) {
const script = `
local tokensKey = KEYS[1]
local lastFillKey = KEYS[2]
local capacity = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local tokens = tonumber(redis.call("get", tokensKey))
local lastFilled = tonumber(redis.call("get", lastFillKey))
if (tokens == nil) then
tokens = 0 -- default empty buckets to 0
end
if (tokens > capacity) then
tokens = capacity -- shrink buckets if the capacity is reduced
end
if (lastFilled == nil) then
lastFilled = 0
end
if (tokens < capacity) then
local tokensToFill = math.floor((now - lastFilled) * rate)
if (tokensToFill > 0) then
tokens = math.min(capacity, tokens + tokensToFill)
lastFilled = now
end
end
return {tokens, lastFilled}
`
refillRate := getRefillRate(bucket.MaximumCapacity, bucket.WindowSeconds)
now := r.now().UTC().Unix()

resp, err := r.Adapter.Eval(ctx, script, []string{tokensKey(bucket.KeyPrefix), lastFillKey(bucket.KeyPrefix)}, []interface{}{bucket.MaximumCapacity, refillRate, now})
if err != nil {
return nil, fmt.Errorf("failed to query redis adapter: %w", err)
}

output, err := parseInspectLeakyBucketResponse(resp)
if err != nil {
return nil, fmt.Errorf("parsing redis response: %w", err)
}

return &InspectLeakyBucketResponse{
RemainingTokens: output.remaining,
ResetAt: calculateLeakyBucketFillTime(output.lastFilled, output.remaining, bucket.MaximumCapacity, bucket.WindowSeconds),
}, nil
}

// UseLeakyBucketResponse defines the response parameters for LeakyBucket.Use()
type UseLeakyBucketResponse struct {
// Success is true when we were successfully able to take tokens from the bucket.
Success bool

// RemainingTokens defines hwo many tokens are left in the bucket
RemainingTokens int

// ResetAt is the time at which the bucket will be fully refilled
ResetAt time.Time
}

// Use atomically attempts to use the leaky bucket. Use takeAmount to set how many tokens should be attempted to be removed
// from the bucket: they are atomic, either all tokens are taken, or the ratelimit is unsuccessful.
func (r *LeakyBucketImpl) Use(ctx context.Context, bucket *LeakyBucketOptions, takeAmount int) (*UseLeakyBucketResponse, error) {
Expand Down Expand Up @@ -139,15 +206,14 @@ return {success, tokens, lastFilled}
refillRate := getRefillRate(bucket.MaximumCapacity, bucket.WindowSeconds)
now := r.now().UTC().Unix()

tokensKey := bucket.KeyPrefix + "::tokens"
lastFillKey := bucket.KeyPrefix + "::last_fill"

resp, err := r.Adapter.Eval(ctx, script, []string{tokensKey, lastFillKey}, []interface{}{bucket.MaximumCapacity, refillRate, now, takeAmount, bucket.WindowSeconds})
resp, err := r.Adapter.Eval(ctx, script, []string{tokensKey(bucket.KeyPrefix), lastFillKey(bucket.KeyPrefix)}, []interface{}{
bucket.MaximumCapacity, refillRate, now, takeAmount, bucket.WindowSeconds,
})
if err != nil {
return nil, fmt.Errorf("failed to query redis adapter: %w", err)
}

output, err := parseLeakyBucketResponse(resp)
output, err := parseUseLeakyBucketResponse(resp)
if err != nil {
return nil, fmt.Errorf("parsing redis response: %w", err)
}
Expand All @@ -159,6 +225,14 @@ return {success, tokens, lastFilled}
}, nil
}

func tokensKey(prefix string) string {
return prefix + "::tokens"
}

func lastFillKey(prefix string) string {
return prefix + "::last_fill"
}

func calculateLeakyBucketFillTime(lastFillUnix, currentTokens, maxCapacity, windowSeconds int) time.Time {
resetAt := lastFillUnix // if delta is 0 (thus, all tokens are filled), then the bucket is already reset
if delta := maxCapacity - currentTokens; delta > 0 {
Expand All @@ -182,35 +256,46 @@ func getRefillRate(maxCapacity, windowSeconds int) float64 {
return float64(maxCapacity) / float64(windowSeconds)
}

type leakyBucketOutput struct {
type useLeakyBucketOutput struct {
success bool
remaining int
lastFilled int
}

func parseLeakyBucketResponse(v interface{}) (*leakyBucketOutput, error) {
args, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("expected []interface{} but got %T", v)
func parseUseLeakyBucketResponse(v interface{}) (*useLeakyBucketOutput, error) {
ints, err := parseRedisInt64Slice(v)
if err != nil {
return nil, err
}

if len(args) != 3 {
return nil, fmt.Errorf("expected 3 args but got %d", len(args))
if len(ints) != 3 {
return nil, fmt.Errorf("expected 3 args but got %d", len(ints))
}

argInts := make([]int64, len(args))
for i, argValue := range args {
intValue, ok := argValue.(int64)
if !ok {
return nil, fmt.Errorf("expected int64 in arg[%d] but got %T", i, argValue)
}
return &useLeakyBucketOutput{
success: ints[0] == 1,
remaining: int(ints[1]),
lastFilled: int(ints[2]),
}, nil
}

type inspectLeakyBucketOutput struct {
remaining int
lastFilled int
}

func parseInspectLeakyBucketResponse(v interface{}) (*inspectLeakyBucketOutput, error) {
ints, err := parseRedisInt64Slice(v)
if err != nil {
return nil, err
}

argInts[i] = intValue
if len(ints) != 2 {
return nil, fmt.Errorf("expected 2 args but got %d", len(ints))
}

return &leakyBucketOutput{
success: argInts[0] == 1,
remaining: int(argInts[1]),
lastFilled: int(argInts[2]),
return &inspectLeakyBucketOutput{
remaining: int(ints[0]),
lastFilled: int(ints[1]),
}, nil
}
113 changes: 105 additions & 8 deletions redis/leaky_bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,84 @@ import (
"github.com/stretchr/testify/assert"
)

func TestUseLeakyBucket(t *testing.T) {
t.Parallel()
func TestInspectLeakyBucket(t *testing.T) {
testCases := map[string]func(*miniredis.Miniredis) adapters.Adapter{
"go-redis": func(t *miniredis.Miniredis) adapters.Adapter {
return goredisadapter.NewAdapter(goredis.NewClient(&goredis.Options{Addr: t.Addr()}))
},
"redigo": func(t *miniredis.Miniredis) adapters.Adapter {
conn, err := redigo.Dial("tcp", t.Addr())
if err != nil {
panic(err)
}
return redigoadapter.NewAdapter(conn)
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
ctx := context.Background()
now := time.Now().UTC()
limiter := NewLeakyBucket(testCase(miniredis.RunT(t)))
limiter.nowFunc = func() time.Time { return now }

{
resp, err := limiter.Inspect(ctx, leakyBucketOptions())
assert.NoError(t, err)
assert.Equal(t, leakyBucketOptions().MaximumCapacity, resp.RemainingTokens)
assert.Equal(t, now.Unix(), resp.ResetAt.Unix())
}

{
resp, err := useLeakyBucket(ctx, limiter)
assert.NoError(t, err)
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingTokens)
assert.Equal(t, now.Add(time.Second*1).Unix(), resp.ResetAt.Unix())
}

{
resp, err := limiter.Inspect(ctx, leakyBucketOptions())
assert.NoError(t, err)
assert.Equal(t, leakyBucketOptions().MaximumCapacity-1, resp.RemainingTokens)
assert.Equal(t, now.Add(time.Second*1).Unix(), resp.ResetAt.Unix())
}
})
}
}

func TestInspectLeakyBucket_Errors(t *testing.T) {
testCases := map[string]struct {
errorMessage string
mockAdapter adapters.Adapter
}{
"redis error": {
errorMessage: "failed to query redis adapter: " + assert.AnError.Error(),
mockAdapter: &mockAdapter{
returnError: assert.AnError,
},
},
"parsing error": {
errorMessage: "parsing redis response: expected []interface{} but got string",
mockAdapter: &mockAdapter{
returnValue: "foo",
},
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
out, err := NewLeakyBucket(testCase.mockAdapter).Inspect(context.Background(), leakyBucketOptions())
assert.Nil(t, out)
assert.EqualError(t, err, testCase.errorMessage)
})
}
}

func TestUseLeakyBucket(t *testing.T) {
testCases := map[string]func(*miniredis.Miniredis) adapters.Adapter{
"go-redis": func(t *miniredis.Miniredis) adapters.Adapter {
return goredisadapter.NewAdapter(goredis.NewClient(&goredis.Options{Addr: t.Addr()}))
Expand Down Expand Up @@ -108,7 +183,7 @@ func TestRefillRate(t *testing.T) {
assert.EqualValues(t, 5, getRefillRate(300, 60))
}

func TestParseLeakyBucketResponse_Errors(t *testing.T) {
func TestParseUseLeakyBucketResponse_Errors(t *testing.T) {
testCases := map[string]struct {
errorMessage string
in interface{}
Expand All @@ -119,19 +194,41 @@ func TestParseLeakyBucketResponse_Errors(t *testing.T) {
},
"invalid length": {
errorMessage: "expected 3 args but got 2",
in: []interface{}{1, 2},
in: []interface{}{int64(1), int64(2)},
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
out, err := parseUseLeakyBucketResponse(testCase.in)
assert.Nil(t, out)
assert.EqualError(t, err, testCase.errorMessage)
})
}
}

func TestParseInspectLeakyBucketResponse_Errors(t *testing.T) {
testCases := map[string]struct {
errorMessage string
in interface{}
}{
"invalid type": {
errorMessage: "expected []interface{} but got string",
in: "foo",
},
"invalid item type": {
errorMessage: "expected int64 in arg[1] but got float64",
in: []interface{}{int64(1), float64(2), "three"},
"invalid length": {
errorMessage: "expected 2 args but got 3",
in: []interface{}{int64(1), int64(2), int64(3)},
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
out, err := parseLeakyBucketResponse(testCase.in)
out, err := parseInspectLeakyBucketResponse(testCase.in)
assert.Nil(t, out)
assert.EqualError(t, err, testCase.errorMessage)
})
Expand Down
22 changes: 22 additions & 0 deletions redis/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package redis

import "fmt"

func parseRedisInt64Slice(v interface{}) ([]int64, error) {
args, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("expected []interface{} but got %T", v)
}

out := make([]int64, len(args))
for i, arg := range args {
value, ok := arg.(int64)
if !ok {
return nil, fmt.Errorf("expected int64 in args[%d] but got %T", i, arg)
}

out[i] = value
}

return out, nil
}
Loading

0 comments on commit 736baa7

Please sign in to comment.