Skip to content

Commit

Permalink
Add connection pooling for controller-runtime controllers
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelattwood committed Feb 18, 2025
1 parent e3bf14d commit 7ad2712
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ $(ENVTEST): $(LOCALBIN)
test: envtest
go vet ./controllers/... ./pkg/natsreloader/... ./internal/controller/...
$(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path ## Get k8s binaries
go test -race -cover -count=1 -timeout 10s ./controllers/... ./pkg/natsreloader/... ./internal/controller/...
go test -race -cover -count=1 -timeout 30s ./controllers/... ./pkg/natsreloader/... ./internal/controller/...

.PHONY: clean
clean:
Expand Down
2 changes: 1 addition & 1 deletion cicd/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#syntax=docker/dockerfile:1.13
ARG GO_APP

FROM alpine:3.21.3 as deps
FROM alpine:3.21.3 AS deps

ARG GO_APP
ARG GORELEASER_DIST_DIR=/go/src/dist
Expand Down
2 changes: 1 addition & 1 deletion cicd/Dockerfile_goreleaser
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#syntax=docker/dockerfile:1.13
FROM --platform=$BUILDPLATFORM golang:1.24.0-bullseye as build
FROM --platform=$BUILDPLATFORM golang:1.24.0-bullseye AS build


RUN <<EOT
Expand Down
117 changes: 86 additions & 31 deletions internal/controller/client.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,90 @@
package controller

import (
"crypto/sha256"
"encoding/json"
"fmt"
"os"

"github.com/nats-io/jsm.go"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
)

type NatsConfig struct {
ClientName string
ServerURL string
Certificate string
Key string
TLSFirst bool
CAs []string
Credentials string
NKey string
Token string
User string
Password string
ClientName string `json:"name,omitempty"`
ServerURL string `json:"url,omitempty"`
Certificate string `json:"tls_cert,omitempty"`
Key string `json:"tls_key,omitempty"`
TLSFirst bool `json:"tls_first,omitempty"`
CAs []string `json:"tls_ca,omitempty"`
Credentials string `json:"credential,omitempty"`
NKey string `json:"nkey,omitempty"`
Token string `json:"token,omitempty"`
User string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
}

func (o *NatsConfig) Copy() *NatsConfig {
if o == nil {
return nil
}

cp := *o
return &cp
}

func (o *NatsConfig) Hash() (string, error) {
b, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("error marshaling config to json: %v", err)
}

if o.NKey != "" {
fb, err := os.ReadFile(o.NKey)
if err != nil {
return "", fmt.Errorf("error opening nkey file %s: %v", o.NKey, err)
}
b = append(b, fb...)
}

if o.Credentials != "" {
fb, err := os.ReadFile(o.Credentials)
if err != nil {
return "", fmt.Errorf("error opening creds file %s: %v", o.Credentials, err)
}
b = append(b, fb...)
}

if len(o.CAs) > 0 {
for _, cert := range o.CAs {
fb, err := os.ReadFile(cert)
if err != nil {
return "", fmt.Errorf("error opening ca file %s: %v", cert, err)
}
b = append(b, fb...)
}
}

if o.Certificate != "" {
fb, err := os.ReadFile(o.Certificate)
if err != nil {
return "", fmt.Errorf("error opening cert file %s: %v", o.Certificate, err)
}
b = append(b, fb...)
}

if o.Key != "" {
fb, err := os.ReadFile(o.Key)
if err != nil {
return "", fmt.Errorf("error opening key file %s: %v", o.Key, err)
}
b = append(b, fb...)
}

hash := sha256.New()
hash.Write(b)
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

func (o *NatsConfig) Overlay(overlay *NatsConfig) {
Expand Down Expand Up @@ -125,15 +190,10 @@ type Closable interface {
Close()
}

func CreateJSMClient(cfg *NatsConfig, pedantic bool) (*jsm.Manager, Closable, error) {
nc, err := createNatsConn(cfg, pedantic)
func CreateJSMClient(conn *pooledConnection, pedantic bool) (*jsm.Manager, error) {
major, minor, _, err := versionComponents(conn.nc.ConnectedServerVersion())
if err != nil {
return nil, nil, fmt.Errorf("create nats connection: %w", err)
}

major, minor, _, err := versionComponents(nc.ConnectedServerVersion())
if err != nil {
return nil, nil, fmt.Errorf("parse server version: %w", err)
return nil, fmt.Errorf("parse server version: %w", err)
}

// JetStream pedantic mode unsupported prior to NATS Server 2.11
Expand All @@ -146,28 +206,23 @@ func CreateJSMClient(cfg *NatsConfig, pedantic bool) (*jsm.Manager, Closable, er
jsmOpts = append(jsmOpts, jsm.WithPedanticRequests())
}

jsmClient, err := jsm.New(nc, jsmOpts...)
jsmClient, err := jsm.New(conn.nc, jsmOpts...)
if err != nil {
return nil, nil, fmt.Errorf("new jsm client: %w", err)
return nil, fmt.Errorf("new jsm client: %w", err)
}

return jsmClient, nc, nil
return jsmClient, nil
}

// CreateJetStreamClient creates new Jetstream client with a connection based on the given NatsConfig.
// Returns a jetstream.Jetstream client and the Closable of the underlying connection.
// Close should be called when the client is no longer used.
func CreateJetStreamClient(cfg *NatsConfig, pedantic bool) (jetstream.JetStream, Closable, error) {
nc, err := createNatsConn(cfg, pedantic)
if err != nil {
return nil, nil, fmt.Errorf("create nats connection: %w", err)
}

js, err := jetstream.New(nc)
func CreateJetStreamClient(conn *pooledConnection, pedantic bool) (jetstream.JetStream, error) {
js, err := jetstream.New(conn.nc)
if err != nil {
return nil, nil, fmt.Errorf("new jetstream: %w", err)
return nil, fmt.Errorf("new jetstream: %w", err)
}
return js, nc, nil
return js, nil
}

func createNatsConn(cfg *NatsConfig, pedantic bool) (*nats.Conn, error) {
Expand Down
98 changes: 98 additions & 0 deletions internal/controller/connection_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package controller

import (
"sync"
"time"

"github.com/nats-io/nats.go"
)

type pooledConnection struct {
nc *nats.Conn
pool *connectionPool
hash string
refCount int
}

func (pc *pooledConnection) Close() {
if pc.pool != nil {
pc.pool.release(pc.hash)
} else if pc.nc != nil {
pc.nc.Close() // Close directly if not pool-managed
}
}

type connectionPool struct {
connections map[string]*pooledConnection
gracePeriod time.Duration
mu sync.Mutex
}

func newConnPool(gracePeriod time.Duration) *connectionPool {
return &connectionPool{
connections: make(map[string]*pooledConnection),
gracePeriod: gracePeriod,
}
}

func (p *connectionPool) Get(c *NatsConfig, pedantic bool) (*pooledConnection, error) {
p.mu.Lock()
defer p.mu.Unlock()

hash, err := c.Hash()
if err != nil {
// If hash fails, create a new non-pooled connection
nc, err := createNatsConn(c, pedantic)
if err != nil {
return nil, err
}
return &pooledConnection{nc: nc}, nil
}

if pc, ok := p.connections[hash]; ok && !pc.nc.IsClosed() {
pc.refCount++
return pc, nil
}

nc, err := createNatsConn(c, pedantic)
if err != nil {
return nil, err
}

pc := &pooledConnection{
nc: nc,
pool: p,
hash: hash,
refCount: 1,
}
p.connections[hash] = pc

return pc, nil
}

func (p *connectionPool) release(hash string) {
p.mu.Lock()
defer p.mu.Unlock()

pc, ok := p.connections[hash]
if !ok {
return
}

pc.refCount--
if pc.refCount < 1 {
go func() {
if p.gracePeriod > 0 {
time.Sleep(p.gracePeriod)
}

p.mu.Lock()
defer p.mu.Unlock()

if pc, ok := p.connections[hash]; ok && pc.refCount < 1 {
pc.nc.Close()
delete(p.connections, hash)
}
}()
}
}
98 changes: 98 additions & 0 deletions internal/controller/connection_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package controller

import (
"sync"
"testing"
"time"

natsservertest "github.com/nats-io/nats-server/v2/test"
"github.com/stretchr/testify/require"
)

func TestConnPool(t *testing.T) {
t.Parallel()

s := natsservertest.RunRandClientPortServer()
defer s.Shutdown()

c1 := &NatsConfig{
ClientName: "Client 1",
ServerURL: s.ClientURL(),
}

c2 := &NatsConfig{
ClientName: "Client 1",
ServerURL: s.ClientURL(),
}

c3 := &NatsConfig{
ClientName: "Client 2",
ServerURL: s.ClientURL(),
}

pool := newConnPool(0)

var conn1, conn2, conn3 *pooledConnection
var err1, err2, err3 error

wg := &sync.WaitGroup{}
wg.Add(3)

go func() {
conn1, err1 = pool.Get(c1, true)
wg.Done()
}()
go func() {
conn2, err2 = pool.Get(c2, true)
wg.Done()
}()
go func() {
conn3, err3 = pool.Get(c3, true)
wg.Done()
}()
wg.Wait()

require := require.New(t)

require.NoError(err1)
require.NoError(err2)
require.NoError(err3)

require.Same(conn1, conn2)
require.NotSame(conn1, conn3)
require.NotSame(conn2, conn3)

conn1.Close()
conn3.Close()

time.Sleep(time.Second)

require.False(conn1.nc.IsClosed())
require.False(conn2.nc.IsClosed())
require.True(conn3.nc.IsClosed())

conn4, err4 := pool.Get(c1, true)
require.NoError(err4)
require.Same(conn1, conn4)
require.Same(conn2, conn4)

conn2.Close()
conn4.Close()

time.Sleep(time.Second)

require.True(conn1.nc.IsClosed())
require.True(conn2.nc.IsClosed())
require.True(conn3.nc.IsClosed())
require.True(conn4.nc.IsClosed())

conn5, err5 := pool.Get(c1, true)
require.NoError(err5)
require.NotSame(conn1, conn5)

conn5.Close()

time.Sleep(time.Second)

require.True(conn5.nc.IsClosed())
}
9 changes: 7 additions & 2 deletions internal/controller/consumer_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,14 @@ var _ = Describe("Consumer Controller", func() {
By("setting up the alternative server")
altServer := CreateTestServer()
defer altServer.Shutdown()

connPool := newConnPool(0)
conn, err := connPool.Get(&NatsConfig{ServerURL: altServer.ClientURL()}, true)
Expect(err).NotTo(HaveOccurred())

// Setup altClient for alternate server
altClient, closer, err := CreateJetStreamClient(&NatsConfig{ServerURL: altServer.ClientURL()}, true)
defer closer.Close()
altClient, err := CreateJetStreamClient(conn, true)
defer conn.Close()
Expect(err).NotTo(HaveOccurred())

By("setting up the stream on the alternative server")
Expand Down
Loading

0 comments on commit 7ad2712

Please sign in to comment.