Skip to content

Commit

Permalink
add batching to sqs Transport publishing
Browse files Browse the repository at this point in the history
  • Loading branch information
jsteenb2 committed Jun 24, 2018
1 parent a373b72 commit 3a016f2
Show file tree
Hide file tree
Showing 3 changed files with 4,225 additions and 17 deletions.
91 changes: 78 additions & 13 deletions queues/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
package sqs

import (
"fmt"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -14,7 +16,9 @@ import (

// Transport is a vice.Transport for Amazon's SQS
type Transport struct {
wg sync.WaitGroup
batchSize int
batchInterval time.Duration
wg *sync.WaitGroup

sm sync.Mutex
sendChans map[string]chan []byte
Expand All @@ -34,14 +38,25 @@ type Transport struct {
// Credentials are automatically sourced using the AWS SDK credential chain,
// for more info see the AWS SDK docs:
// https://godoc.org/github.com/aws/aws-sdk-go#hdr-Configuring_Credentials
func New() *Transport {
func New(batchSize int, batchInterval time.Duration) *Transport {
if batchSize > 10 {
batchSize = 10
}

if batchInterval == 0 {
batchInterval = 200 * time.Millisecond
}

return &Transport{
sendChans: make(map[string]chan []byte),
receiveChans: make(map[string]chan []byte),
errChan: make(chan error, 10),
stopchan: make(chan struct{}),
stopPubChan: make(chan struct{}),
stopSubChan: make(chan struct{}),
wg: &sync.WaitGroup{},
batchSize: batchSize,
batchInterval: batchInterval,
sendChans: make(map[string]chan []byte),
receiveChans: make(map[string]chan []byte),
errChan: make(chan error, 10),
stopchan: make(chan struct{}),
stopPubChan: make(chan struct{}),
stopSubChan: make(chan struct{}),

NewService: func(region string) (sqsiface.SQSAPI, error) {
awsConfig := aws.NewConfig().WithRegion(region)
Expand Down Expand Up @@ -169,6 +184,16 @@ func (t *Transport) makePublisher(name string) (chan []byte, error) {
t.wg.Add(1)
go func() {
defer t.wg.Done()
var accum []*sqs.SendMessageBatchRequestEntry

defer func() {
if t.batchSize == 0 && len(accum) == 0 {
return
}

t.sendBatch(svc, name, accum)
accum = make([]*sqs.SendMessageBatchRequestEntry, 0, t.batchSize)
}()
for {
select {
case <-t.stopPubChan:
Expand All @@ -178,21 +203,61 @@ func (t *Transport) makePublisher(name string) (chan []byte, error) {
return

case msg := <-ch:
params := &sqs.SendMessageInput{
if t.batchSize == 0 {
params := &sqs.SendMessageInput{
MessageBody: aws.String(string(msg)),
QueueUrl: aws.String(name),
}
_, err := svc.SendMessage(params)
if err != nil {
t.errChan <- vice.Err{Message: msg, Name: name, Err: err}
}
continue
}

id := fmt.Sprintf("%d", time.Now().UnixNano())
accum = append(accum, &sqs.SendMessageBatchRequestEntry{
MessageBody: aws.String(string(msg)),
QueueUrl: aws.String(name),
Id: aws.String(id),
})
if len(accum) < t.batchSize {
continue
}
_, err := svc.SendMessage(params)
if err != nil {
t.errChan <- vice.Err{Message: msg, Name: name, Err: err}

t.sendBatch(svc, name, accum)
accum = make([]*sqs.SendMessageBatchRequestEntry, 0, t.batchSize)
case <-time.After(t.batchInterval):
if len(accum) == 0 {
continue
}

t.sendBatch(svc, name, accum)
accum = make([]*sqs.SendMessageBatchRequestEntry, 0, t.batchSize)
}
}
}()

return ch, nil
}

func (t *Transport) sendBatch(svc sqsiface.SQSAPI, name string, entries []*sqs.SendMessageBatchRequestEntry) {
batchParams := &sqs.SendMessageBatchInput{
Entries: entries,
QueueUrl: aws.String(name),
}

resp, err := svc.SendMessageBatch(batchParams)
if err != nil {
t.errChan <- vice.Err{Name: name, Err: err}
return
}

for _, v := range resp.Failed {
err := fmt.Errorf("%s", *v.Message)
t.errChan <- vice.Err{Name: name, Err: err}
}
}

// ErrChan gets the channel on which errors are sent.
func (t *Transport) ErrChan() <-chan error {
return t.errChan
Expand Down
71 changes: 67 additions & 4 deletions queues/sqs/sqs_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package sqs

import (
"strconv"
"sync"
"testing"
"time"

"github.com/matryer/vice"
"github.com/matryer/vice/queues/sqs/sqsfakes"
"github.com/matryer/vice/vicetest"

"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/matryer/is"
"github.com/matryer/vice"
"github.com/matryer/vice/vicetest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestTransport(t *testing.T) {
Expand All @@ -18,7 +24,7 @@ func TestTransport(t *testing.T) {
}

new := func() vice.Transport {
transport := New()
transport := New(0, 10*time.Second)
transport.NewService = func(region string) (sqsiface.SQSAPI, error) {
return svc, nil
}
Expand All @@ -30,12 +36,69 @@ func TestTransport(t *testing.T) {
close(svc.finish)
}

func Test_Transport_BatchesWrites(t *testing.T) {
newTransport := New(10, 200*time.Millisecond)
svc := new(sqsfakes.FakeSQSAPI)
svc.SendMessageBatchReturns(&sqs.SendMessageBatchOutput{}, nil)
newTransport.NewService = func(region string) (sqsiface.SQSAPI, error) { return svc, nil }

stream := newTransport.Send("svcWriter")
for i := 0; i < 30; i++ {
stream <- []byte(strconv.Itoa(i))
}
time.Sleep(100 * time.Millisecond)

require.Equal(t, 3, svc.SendMessageBatchCallCount())

batch0 := svc.SendMessageBatchArgsForCall(0)
require.Equal(t, 10, len(batch0.Entries))
for i := 0; i < 10; i++ {
assert.Equal(t, strconv.Itoa(i), *batch0.Entries[i].MessageBody)
}

batch1 := svc.SendMessageBatchArgsForCall(1)
require.Equal(t, 10, len(batch1.Entries))
for i := 0; i < 10; i++ {
assert.Equal(t, strconv.Itoa(i+10), *batch1.Entries[i].MessageBody)
}

batch2 := svc.SendMessageBatchArgsForCall(2)
require.Equal(t, 10, len(batch2.Entries))
for i := 0; i < 10; i++ {
assert.Equal(t, strconv.Itoa(i+20), *batch2.Entries[i].MessageBody)
}
}

func Test_Transport_BatchFlushesOnReturn(t *testing.T) {
newTransport := New(10, 1000*time.Millisecond)
svc := new(sqsfakes.FakeSQSAPI)
svc.SendMessageBatchReturns(&sqs.SendMessageBatchOutput{}, nil)
newTransport.NewService = func(region string) (sqsiface.SQSAPI, error) { return svc, nil }

stream := newTransport.Send("svcWriter")
for i := 0; i < 9; i++ {
stream <- []byte(strconv.Itoa(i))
}
time.Sleep(100 * time.Millisecond)
newTransport.Stop()

time.Sleep(200 * time.Millisecond)

require.Equal(t, 1, svc.SendMessageBatchCallCount())

batch0 := svc.SendMessageBatchArgsForCall(0)
require.Equal(t, 9, len(batch0.Entries))
for i := 0; i < 9; i++ {
assert.Equal(t, strconv.Itoa(i), *batch0.Entries[i].MessageBody)
}
}

func TestParseRegion(t *testing.T) {
is := is.New(t)
reg := RegionFromURL("http://sqs.us-east-2.amazonaws.com/123456789012/MyQueue")
is.Equal("us-east-2", reg)

reg = RegionFromURL("http://localhost/foo")
reg = RegionFromURL("http://localhost/svcWriter")
is.Equal("", reg)
}

Expand Down
Loading

0 comments on commit 3a016f2

Please sign in to comment.