Skip to content

Commit

Permalink
feat: automatic TLS certificate reloading (ory#2744)
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr authored Oct 4, 2022
1 parent d612612 commit 09751e6
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 149 deletions.
8 changes: 4 additions & 4 deletions cmd/daemon/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
handler = cors.New(options).Handler(handler)
}

certs := c.GetTSLCertificatesForPublic(ctx)
certs := c.GetTLSCertificatesForPublic(ctx)

if tracer := r.Tracer(ctx); tracer.IsLoaded() {
handler = x.TraceHandler(handler)
Expand All @@ -130,7 +130,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
// #nosec G112 - the correct settings are set by graceful.WithDefaults
server := graceful.WithDefaults(&http.Server{
Handler: handler,
TLSConfig: &tls.Config{Certificates: certs, MinVersion: tls.VersionTLS12},
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
})
addr := c.PublicListenOn(ctx)

Expand Down Expand Up @@ -186,7 +186,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
r.PrometheusManager().RegisterRouter(router.Router)

n.UseHandler(router)
certs := c.GetTSLCertificatesForAdmin(ctx)
certs := c.GetTLSCertificatesForAdmin(ctx)

var handler http.Handler = n
if tracer := r.Tracer(ctx); tracer.IsLoaded() {
Expand All @@ -196,7 +196,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
// #nosec G112 - the correct settings are set by graceful.WithDefaults
server := graceful.WithDefaults(&http.Server{
Handler: handler,
TLSConfig: &tls.Config{Certificates: certs, MinVersion: tls.VersionTLS12},
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
})

addr := c.AdminListenOn(ctx)
Expand Down
26 changes: 2 additions & 24 deletions cmd/serve/root_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
package serve_test

import (
"encoding/base64"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"

"github.com/ory/kratos/x"

"github.com/ory/kratos/internal/testhelpers"
)

Expand All @@ -18,19 +11,7 @@ func TestServe(t *testing.T) {
}

func TestServeTLSBase64(t *testing.T) {
certPath := filepath.Join(os.TempDir(), "e2e_test_cert_"+x.NewUUID().String()+".pem")
keyPath := filepath.Join(os.TempDir(), "e2e_test_key_"+x.NewUUID().String()+".pem")

testhelpers.GenerateTLSCertificateFilesForTests(t, certPath, keyPath)

certRaw, err := os.ReadFile(certPath)
require.NoError(t, err)

keyRaw, err := os.ReadFile(keyPath)
require.NoError(t, err)

certBase64 := base64.StdEncoding.EncodeToString(certRaw)
keyBase64 := base64.StdEncoding.EncodeToString(keyRaw)
_, _, certBase64, keyBase64 := testhelpers.GenerateTLSCertificateFilesForTests(t)
publicPort, adminPort := testhelpers.StartE2EServerOnly(t,
"./stub/kratos.yml",
true,
Expand All @@ -45,10 +26,7 @@ func TestServeTLSBase64(t *testing.T) {
}

func TestServeTLSPaths(t *testing.T) {
certPath := filepath.Join(os.TempDir(), "e2e_test_cert_"+x.NewUUID().String()+".pem")
keyPath := filepath.Join(os.TempDir(), "e2e_test_key_"+x.NewUUID().String()+".pem")

testhelpers.GenerateTLSCertificateFilesForTests(t, certPath, keyPath)
certPath, keyPath, _, _ := testhelpers.GenerateTLSCertificateFilesForTests(t)

publicPort, adminPort := testhelpers.StartE2EServerOnly(t,
"./stub/kratos.yml",
Expand Down
43 changes: 31 additions & 12 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1284,8 +1284,11 @@ func (p *Config) CipherAlgorithm(ctx context.Context) string {
}
}

func (p *Config) GetTSLCertificatesForPublic(ctx context.Context) []tls.Certificate {
return p.getTSLCertificates(
type CertFunc = func(*tls.ClientHelloInfo) (*tls.Certificate, error)

func (p *Config) GetTLSCertificatesForPublic(ctx context.Context) CertFunc {
return p.getTLSCertificates(
ctx,
"public",
p.GetProvider(ctx).String(ViperKeyPublicTLSCertBase64),
p.GetProvider(ctx).String(ViperKeyPublicTLSKeyBase64),
Expand All @@ -1294,8 +1297,9 @@ func (p *Config) GetTSLCertificatesForPublic(ctx context.Context) []tls.Certific
)
}

func (p *Config) GetTSLCertificatesForAdmin(ctx context.Context) []tls.Certificate {
return p.getTSLCertificates(
func (p *Config) GetTLSCertificatesForAdmin(ctx context.Context) CertFunc {
return p.getTLSCertificates(
ctx,
"admin",
p.GetProvider(ctx).String(ViperKeyAdminTLSCertBase64),
p.GetProvider(ctx).String(ViperKeyAdminTLSKeyBase64),
Expand All @@ -1304,16 +1308,31 @@ func (p *Config) GetTSLCertificatesForAdmin(ctx context.Context) []tls.Certifica
)
}

func (p *Config) getTSLCertificates(daemon, certBase64, keyBase64, certPath, keyPath string) []tls.Certificate {
cert, err := tlsx.Certificate(certBase64, keyBase64, certPath, keyPath)

if err == nil {
func (p *Config) getTLSCertificates(ctx context.Context, daemon, certBase64, keyBase64, certPath, keyPath string) CertFunc {
if certBase64 != "" && keyBase64 != "" {
cert, err := tlsx.CertificateFromBase64(certBase64, keyBase64)
if err != nil {
p.l.WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
return nil // reachable in unit tests when Fatalf is hooked
}
p.l.Infof("Setting up HTTPS for %s", daemon)
return cert
} else if !errors.Is(err, tlsx.ErrNoCertificatesConfigured) {
p.l.WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return &cert, nil }
}
if certPath != "" && keyPath != "" {
errs := make(chan error, 1)
getCert, err := tlsx.GetCertificate(ctx, certPath, keyPath, errs)
if err != nil {
p.l.WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
return nil // reachable in unit tests when Fatalf is hooked
}
go func() {
for err := range errs {
p.l.WithError(err).Error("Failed to reload TLS certificates, using previous certificates")
}
}()
p.l.Infof("Setting up HTTPS for %s (automatic certificate reloading active)", daemon)
return getCert
}

p.l.Infof("TLS has not been configured for %s, skipping", daemon)
return nil
}
Expand Down
128 changes: 53 additions & 75 deletions driver/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import (

"github.com/ory/x/watcherx"

"github.com/ory/kratos/x"

"github.com/ory/kratos/internal/testhelpers"

"github.com/ory/x/configx"
Expand Down Expand Up @@ -775,126 +773,106 @@ func TestViperProvider_HaveIBeenPwned(t *testing.T) {
})
}

func newTestConfig(t *testing.T) (_ *config.Config, _ *test.Hook, exited *bool) {
l := logrusx.New("", "")
h := new(test.Hook)
exited = new(bool)
l.Logger.Hooks.Add(h)
l.Logger.ExitFunc = func(code int) { *exited = true }
config := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
return config, h, exited
}

func TestLoadingTLSConfig(t *testing.T) {
ctx := context.Background()
t.Parallel()

certPath := filepath.Join(os.TempDir(), "e2e_test_cert_"+x.NewUUID().String()+".pem")
keyPath := filepath.Join(os.TempDir(), "e2e_test_key_"+x.NewUUID().String()+".pem")

testhelpers.GenerateTLSCertificateFilesForTests(t, certPath, keyPath)

certRaw, err := os.ReadFile(certPath)
assert.Nil(t, err)
certPath, keyPath, certBase64, keyBase64 := testhelpers.GenerateTLSCertificateFilesForTests(t)

keyRaw, err := os.ReadFile(keyPath)
assert.Nil(t, err)
t.Run("case=public: no TLS config", func(t *testing.T) {
p, hook, exited := newTestConfig(t)
assert.Nil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "TLS has not been configured for public, skipping", hook.LastEntry().Message)
assert.False(t, *exited)
})

certBase64 := base64.StdEncoding.EncodeToString(certRaw)
keyBase64 := base64.StdEncoding.EncodeToString(keyRaw)
t.Run("case=admin: no TLS config", func(t *testing.T) {
p, hook, exited := newTestConfig(t)
assert.Nil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "TLS has not been configured for admin, skipping", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=public: loading inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyBase64, keyBase64)
p.MustSet(ctx, config.ViperKeyPublicTLSCertBase64, certBase64)
assert.NotNil(t, p.GetTSLCertificatesForPublic(ctx))
assert.NotNil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Setting up HTTPS for public", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=public: loading certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyPath, keyPath)
p.MustSet(ctx, config.ViperKeyPublicTLSCertPath, certPath)
assert.NotNil(t, p.GetTSLCertificatesForPublic(ctx))
assert.Equal(t, "Setting up HTTPS for public", hook.LastEntry().Message)
assert.NotNil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Setting up HTTPS for public (automatic certificate reloading active)", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=public: failing to load inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyBase64, "empty")
p.MustSet(ctx, config.ViperKeyPublicTLSCertBase64, certBase64)
assert.Nil(t, p.GetTSLCertificatesForPublic(ctx))
assert.Equal(t, "TLS has not been configured for public, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

t.Run("case=public: failing to load certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyPath, "/dev/null")
p.MustSet(ctx, config.ViperKeyPublicTLSCertPath, certPath)
assert.Nil(t, p.GetTSLCertificatesForPublic(ctx))
assert.Equal(t, "TLS has not been configured for public, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

t.Run("case=admin: loading inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyBase64, keyBase64)
p.MustSet(ctx, config.ViperKeyAdminTLSCertBase64, certBase64)
assert.NotNil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.NotNil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Setting up HTTPS for admin", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=admin: loading certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyPath, keyPath)
p.MustSet(ctx, config.ViperKeyAdminTLSCertPath, certPath)
assert.NotNil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.Equal(t, "Setting up HTTPS for admin", hook.LastEntry().Message)
assert.NotNil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Setting up HTTPS for admin (automatic certificate reloading active)", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=admin: failing to load inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyBase64, "empty")
p.MustSet(ctx, config.ViperKeyAdminTLSCertBase64, certBase64)
assert.Nil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.Equal(t, "TLS has not been configured for admin, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

t.Run("case=admin: failing to load certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyPath, "/dev/null")
p.MustSet(ctx, config.ViperKeyAdminTLSCertPath, certPath)
assert.Nil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.Equal(t, "TLS has not been configured for admin, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

}
Expand Down
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ require (
github.com/ory/kratos-client-go v0.6.3-alpha.1
github.com/ory/mail/v3 v3.0.0
github.com/ory/nosurf v1.2.7
github.com/ory/x v0.0.470
github.com/ory/x v0.0.474
github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.3.0
Expand Down Expand Up @@ -199,7 +199,7 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/serf v0.9.7 // indirect
github.com/huandu/xstrings v1.3.2 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.12.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
Expand Down Expand Up @@ -267,11 +267,11 @@ require (
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d // indirect
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e // indirect
github.com/spf13/afero v1.8.2 // indirect
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/viper v1.12.0 // indirect
github.com/subosito/gotenv v1.3.0 // indirect
github.com/subosito/gotenv v1.4.1 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/timtadh/data-structures v0.5.3 // indirect
Expand Down Expand Up @@ -315,7 +315,7 @@ require (
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.17.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2 // indirect
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect
Expand All @@ -327,7 +327,7 @@ require (
gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect
gopkg.in/ini.v1 v1.66.4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/op/go-logging.v1 v1.0.0-20160211212156-b2cb9fa56473 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
Expand Down
Loading

0 comments on commit 09751e6

Please sign in to comment.