diff --git a/README.md b/README.md index 5eb043f..f29b502 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # DynamicTLS + [![License](https://img.shields.io/badge/license-mit-blue.svg?style=for-the-badge)](https://raw.githubusercontent.com/abursavich/dynamictls/master/LICENSE) [![GoDev Reference](https://img.shields.io/static/v1?logo=go&logoColor=white&color=00ADD8&label=dev&message=reference&style=for-the-badge)](https://pkg.go.dev/bursavich.dev/dynamictls) [![Go Report Card](https://goreportcard.com/badge/bursavich.dev/dynamictls?style=for-the-badge)](https://goreportcard.com/report/bursavich.dev/dynamictls) @@ -15,16 +16,16 @@ It provides simple integrations with HTTP/1.1, HTTP/2, gRPC, and Prometheus. ```go // create metrics -metrics, err := tlsprom.NewMetrics( +observer, err := tlsprom.NewObserver( tlsprom.WithHTTP(), tlsprom.WithServer(), ) check(err) -prometheus.MustRegister(metrics) +prometheus.MustRegister(observer) // create TLS config cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithCertificate(primaryCertFile, primaryKeyFile), dynamictls.WithCertificate(secondaryCertFile, secondaryKeyFile), dynamictls.WithRootCAs(caFile), @@ -43,16 +44,16 @@ check(http.Serve(lis, http.DefaultServeMux)) ```go // create metrics -metrics, err := tlsprom.NewMetrics( +observer, err := tlsprom.NewObserver( tlsprom.WithHTTP(), tlsprom.WithClient(), ) check(err) -prometheus.MustRegister(metrics) +prometheus.MustRegister(observer) // create TLS config cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithBase(&tls.Config{ MinVersion: tls.VersionTLS12, }), @@ -77,16 +78,16 @@ defer client.CloseIdleConnections() ```go // create metrics -metrics, err := tlsprom.NewMetrics( +observer, err := tlsprom.NewObserver( tlsprom.WithGRPC(), tlsprom.WithServer(), ) check(err) -prometheus.MustRegister(metrics) +prometheus.MustRegister(observer) // create TLS config cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithBase(&tls.Config{ ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS13, @@ -115,16 +116,16 @@ check(srv.Serve(lis)) ```go // create metrics -metrics, err := tlsprom.NewMetrics( +observer, err := tlsprom.NewObserver( tlsprom.WithGRPC(), tlsprom.WithClient(), ) check(err) -prometheus.MustRegister(metrics) +prometheus.MustRegister(observer) // create TLS config cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithBase(&tls.Config{ MinVersion: tls.VersionTLS13, }), @@ -146,4 +147,4 @@ conn, err := grpc.Dial( check(err) defer conn.Close() client := pb.NewTestServiceClient(conn) -``` \ No newline at end of file +``` diff --git a/dynamictls.go b/dynamictls.go index f6b12af..ef59695 100644 --- a/dynamictls.go +++ b/dynamictls.go @@ -27,9 +27,16 @@ import ( const hashSize = 16 // 128-bit -// NotifyFunc is a function that is called when new config data -// is loaded or an error occurs loading new config data. -type NotifyFunc func(cfg *tls.Config, err error) +// An Observer observes when new config data is loaded or an error occurs loading new config data. +type Observer interface { + ObserveConfig(cfg *tls.Config) + ObserveReadError(err error) +} + +type noopObserver struct{} + +func (noopObserver) ObserveConfig(cfg *tls.Config) {} +func (noopObserver) ObserveReadError(err error) {} // An Option applies optional configuration. type Option interface { @@ -110,10 +117,10 @@ func WithCertificate(certFile, keyFile string) Option { }) } -// WithNotifyFunc returns an Option that registers the notify function. -func WithNotifyFunc(notify NotifyFunc) Option { +// WithObserver returns an Option that registers the Observer. +func WithObserver(observer Observer) Option { return optionFunc(func(c *Config) error { - c.notifyFns = append(c.notifyFns, notify) + c.observer = observer return nil }) } @@ -183,7 +190,7 @@ type Config struct { rootCAs []string clientCAs []string certs []keyPair - notifyFns []NotifyFunc + observer Observer log logr.Logger watcher *fsnotify.Watcher @@ -207,6 +214,7 @@ func NewConfig(options ...Option) (cfg *Config, err error) { }() cfg = &Config{ base: &tls.Config{}, + observer: noopObserver{}, log: logr.Discard(), watcher: w, close: make(chan struct{}), @@ -322,9 +330,7 @@ func (cfg *Config) read() error { } cfg.latest.Store(config) - for _, fn := range cfg.notifyFns { - fn(config, nil) - } + cfg.observer.ObserveConfig(config) return nil } @@ -337,9 +343,7 @@ func (cfg *Config) watch() { // TODO: ignore unrelated events if err := cfg.read(); err != nil { cfg.log.Error(err, "Read failure") // errors already decorated - for _, fn := range cfg.notifyFns { - fn(nil, err) - } + cfg.observer.ObserveReadError(err) } case err := <-cfg.watcher.Errors: cfg.log.Error(err, "Watch failure") diff --git a/dynamictls_test.go b/dynamictls_test.go index c0aeaf8..394203a 100644 --- a/dynamictls_test.go +++ b/dynamictls_test.go @@ -221,6 +221,36 @@ func certPoolEqual(x, y *x509.CertPool) bool { return reflect.DeepEqual(xs, ys) } +type testObserver struct { + configCh chan *tls.Config + errCh chan error +} + +func newTestObserver() *testObserver { + return &testObserver{ + configCh: make(chan *tls.Config, 1), + errCh: make(chan error, 1), + } +} + +func (o *testObserver) ObserveConfig(cfg *tls.Config) { + timeout := time.NewTimer(10 * time.Second) + defer timeout.Stop() + select { + case <-timeout.C: + case o.configCh <- cfg: + } +} + +func (o *testObserver) ObserveReadError(err error) { + timeout := time.NewTimer(10 * time.Second) + defer timeout.Stop() + select { + case <-timeout.C: + case o.errCh <- err: + } +} + func TestNotifyError(t *testing.T) { // create temp dir dir, err := ioutil.TempDir("", "") @@ -234,10 +264,10 @@ func TestNotifyError(t *testing.T) { keyFile := createFile(t, dir, "key.pem", keyPEMBlock) // create config - errCh := make(chan error, 1) + obs := newTestObserver() cfg, err := NewConfig( WithCertificate(certFile, keyFile), - WithNotifyFunc(func(_ *tls.Config, err error) { errCh <- err }), + WithObserver(obs), ) check(t, "Failed to initialize config", err) defer cfg.Close() @@ -246,7 +276,11 @@ func TestNotifyError(t *testing.T) { defer timeout.Stop() select { - case err := <-errCh: + case cfg := <-obs.configCh: + if cfg == nil { + t.Fatalf("Unexpected nil config") + } + case err := <-obs.errCh: if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -257,7 +291,11 @@ func TestNotifyError(t *testing.T) { check(t, "Failed to remove cert file", os.Remove(certFile)) select { - case err := <-errCh: + case cfg := <-obs.configCh: + if cfg != nil { + t.Fatalf("Unexpected config") + } + case err := <-obs.errCh: if err == nil { t.Fatal("Expected an error after deleting certs") } @@ -307,14 +345,7 @@ func TestKubernetes(t *testing.T) { check(t, "Failed to create symlink", os.Symlink(data0, data)) // create config - ch := make(chan result, 1) - notifyFn := func(config *tls.Config, err error) { - select { - case <-ch: - default: - } - ch <- result{config: config, err: err} - } + obs := newTestObserver() wantCert := func(want *tls.Certificate) { t.Helper() timeout := time.NewTimer(5 * time.Second) @@ -322,23 +353,21 @@ func TestKubernetes(t *testing.T) { var err error for { select { - case res := <-ch: - if res.err != nil { - // An error can occur if a filesystem event triggers a reload and a - // symlink flip happens between reading the public and private keys. - // The keys won't match due to this race, but a subsequent reload - // will also be triggered and they will match the next time. - t.Logf("Unexpected error, may be transient: %v", res.err) - err = res.err - continue - } - if res.config == nil { + case err = <-obs.errCh: + // An error can occur if a filesystem event triggers a reload and a + // symlink flip happens between reading the public and private keys. + // The keys won't match due to this race, but a subsequent reload + // will also be triggered and they will match the next time. + t.Logf("Unexpected error, may be transient: %v", err) + continue + case cfg := <-obs.configCh: + if cfg == nil { t.Fatal("Config missing") } - if len(res.config.Certificates) == 0 { + if len(cfg.Certificates) == 0 { t.Fatal("Config missing certs") } - got := res.config.Certificates[0] + got := cfg.Certificates[0] if !reflect.DeepEqual(got.Certificate, want.Certificate) { t.Fatal("Unexpected cert") } @@ -358,7 +387,7 @@ func TestKubernetes(t *testing.T) { cfg, err := NewConfig( WithCertificate(certFile, keyFile), WithRootCAs(caFile), - WithNotifyFunc(notifyFn), + WithObserver(obs), ) check(t, "Failed to initialize config", err) defer cfg.Close() diff --git a/example_test.go b/example_test.go index 288a5d8..8b66197 100644 --- a/example_test.go +++ b/example_test.go @@ -4,6 +4,7 @@ // Use of this source code is governed by The MIT License // which can be found in the LICENSE file. +//go:build go1.14 // +build go1.14 package dynamictls_test @@ -19,15 +20,15 @@ import ( ) func ExampleConfig_Listen() { - metrics, err := tlsprom.NewMetrics( + observer, err := tlsprom.NewObserver( tlsprom.WithHTTP(), tlsprom.WithServer(), ) check(err) - prometheus.MustRegister(metrics) + prometheus.MustRegister(observer) cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithCertificate(primaryCertFile, primaryKeyFile), dynamictls.WithCertificate(secondaryCertFile, secondaryKeyFile), dynamictls.WithRootCAs(caFile), @@ -42,15 +43,15 @@ func ExampleConfig_Listen() { } func ExampleConfig_Dial() { - metrics, err := tlsprom.NewMetrics( + observer, err := tlsprom.NewObserver( tlsprom.WithHTTP(), tlsprom.WithClient(), ) check(err) - prometheus.MustRegister(metrics) + prometheus.MustRegister(observer) cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithBase(&tls.Config{ MinVersion: tls.VersionTLS12, }), diff --git a/grpctls/example_test.go b/grpctls/example_test.go index a91dd65..0d26078 100644 --- a/grpctls/example_test.go +++ b/grpctls/example_test.go @@ -21,13 +21,13 @@ import ( func Example() { // create shared metrics - metrics, err := tlsprom.NewMetrics(tlsprom.WithGRPC()) + observer, err := tlsprom.NewObserver(tlsprom.WithGRPC()) check(err) - prometheus.MustRegister(metrics) + prometheus.MustRegister(observer) // create shared TLS config cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithBase(&tls.Config{ ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS13, @@ -65,16 +65,16 @@ func Example() { func Example_client() { // create metrics - metrics, err := tlsprom.NewMetrics( + observer, err := tlsprom.NewObserver( tlsprom.WithGRPC(), tlsprom.WithClient(), ) check(err) - prometheus.MustRegister(metrics) + prometheus.MustRegister(observer) // create TLS config cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithBase(&tls.Config{ MinVersion: tls.VersionTLS13, }), @@ -103,16 +103,16 @@ func Example_client() { func Example_server() { // create metrics - metrics, err := tlsprom.NewMetrics( + observer, err := tlsprom.NewObserver( tlsprom.WithGRPC(), tlsprom.WithServer(), ) check(err) - prometheus.MustRegister(metrics) + prometheus.MustRegister(observer) // create TLS config cfg, err := dynamictls.NewConfig( - dynamictls.WithNotifyFunc(metrics.Update), + dynamictls.WithObserver(observer), dynamictls.WithBase(&tls.Config{ ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS13, diff --git a/tlsprom/tlsprom.go b/tlsprom/tlsprom.go index d58c835..b02b730 100644 --- a/tlsprom/tlsprom.go +++ b/tlsprom/tlsprom.go @@ -13,6 +13,7 @@ import ( "sort" "time" + "bursavich.dev/dynamictls" "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus" ) @@ -137,8 +138,13 @@ func WithKeyUsages(usages ...x509.ExtKeyUsage) Option { }) } -// Metrics is a collection of TLS config metrics. -type Metrics struct { +// Observer is a collection of TLS config metrics. +type Observer interface { + dynamictls.Observer + prometheus.Collector +} + +type observer struct { updateError prometheus.Gauge verifyError prometheus.Gauge expiration prometheus.Gauge @@ -147,8 +153,8 @@ type Metrics struct { log logr.Logger } -// NewMetrics returns new Metrics with the given options. -func NewMetrics(options ...Option) (*Metrics, error) { +// NewObserver returns a new Observer with the given options. +func NewObserver(options ...Option) (Observer, error) { cfg := &config{ usages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, log: logr.Discard(), @@ -158,7 +164,7 @@ func NewMetrics(options ...Option) (*Metrics, error) { return nil, err } } - m := &Metrics{ + o := &observer{ updateError: prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: cfg.namespace, Subsystem: cfg.subsystem, @@ -180,59 +186,59 @@ func NewMetrics(options ...Option) (*Metrics, error) { usages: cfg.usages, log: cfg.log, } - return m, nil + return o, nil } // Describe sends the super-set of all possible descriptors of metrics // to the provided channel and returns once the last descriptor has been sent. -func (m *Metrics) Describe(ch chan<- *prometheus.Desc) { - m.updateError.Describe(ch) - m.verifyError.Describe(ch) - m.expiration.Describe(ch) +func (o *observer) Describe(ch chan<- *prometheus.Desc) { + o.updateError.Describe(ch) + o.verifyError.Describe(ch) + o.expiration.Describe(ch) } // Collect sends each collected metric via the provided channel // and returns once the last metric has been sent. -func (m *Metrics) Collect(ch chan<- prometheus.Metric) { - m.updateError.Collect(ch) - m.verifyError.Collect(ch) - m.expiration.Collect(ch) +func (o *observer) Collect(ch chan<- prometheus.Metric) { + o.updateError.Collect(ch) + o.verifyError.Collect(ch) + o.expiration.Collect(ch) } -// Update updates the metrics with the new TLS config or error. -func (m *Metrics) Update(cfg *tls.Config, err error) { - if err != nil { - m.updateError.Set(1) - return - } - m.updateError.Set(0) +func (o *observer) ObserveConfig(cfg *tls.Config) { + o.updateError.Set(0) - t, err := m.earliestExpiration(cfg) + t, err := o.earliestExpiration(cfg) if err != nil || t.IsZero() { - m.verifyError.Set(1) + o.verifyError.Set(1) + o.expiration.Set(0) return } - m.verifyError.Set(0) - m.expiration.Set(float64(t.Unix())) + o.verifyError.Set(0) + o.expiration.Set(float64(t.Unix())) +} + +func (o *observer) ObserveReadError(err error) { + o.updateError.Set(1) } -func (m *Metrics) earliestExpiration(cfg *tls.Config) (time.Time, error) { +func (o *observer) earliestExpiration(cfg *tls.Config) (time.Time, error) { var t time.Time for _, cert := range cfg.Certificates { x509Cert := cert.Leaf if x509Cert == nil { var err error if x509Cert, err = x509.ParseCertificate(cert.Certificate[0]); err != nil { - m.log.Error(err, "Failed to parse TLS certificate") + o.log.Error(err, "Failed to parse TLS certificate") return time.Time{}, err } } chains, err := x509Cert.Verify(x509.VerifyOptions{ Roots: cfg.RootCAs, - KeyUsages: m.usages, + KeyUsages: o.usages, }) if err != nil { - m.log.Error(err, "Failed to validate TLS certificate") + o.log.Error(err, "Failed to validate TLS certificate") return time.Time{}, err } for _, chain := range chains { @@ -244,7 +250,7 @@ func (m *Metrics) earliestExpiration(cfg *tls.Config) (time.Time, error) { } } if t.IsZero() { - m.log.Error(nil, "Failed to find a certificate in the TLS config") + o.log.Error(nil, "Failed to find a certificate in the TLS config") } return t, nil } diff --git a/tlsprom/tlsprom_test.go b/tlsprom/tlsprom_test.go index d76b4fd..6d8e5ad 100644 --- a/tlsprom/tlsprom_test.go +++ b/tlsprom/tlsprom_test.go @@ -19,7 +19,7 @@ import ( ) func TestCollector(t *testing.T) { - m, err := NewMetrics() + m, err := NewObserver() check(t, "Failed to create metrics", err) reg := prometheus.NewRegistry() check(t, "Failed to register metrics", reg.Register(m)) @@ -90,12 +90,13 @@ func TestMetricNames(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - m, err := NewMetrics(tt.options...) + m, err := NewObserver(tt.options...) check(t, "Failed to create metrics", err) + o := m.(*observer) for baseName, metric := range map[string]*gauge{ - updateErrorName: readGauge(t, m.updateError), - verifyErrorName: readGauge(t, m.verifyError), - expirationName: readGauge(t, m.expiration), + updateErrorName: readGauge(t, o.updateError), + verifyErrorName: readGauge(t, o.verifyError), + expirationName: readGauge(t, o.expiration), } { got := metric.name want := tt.namespace + "_" + tt.subsystem + "_" + baseName @@ -108,22 +109,23 @@ func TestMetricNames(t *testing.T) { } func TestUpdateError(t *testing.T) { - m, err := NewMetrics() + m, err := NewObserver() check(t, "Failed to create metrics", err) + o := m.(*observer) - metric := readGauge(t, m.updateError) + metric := readGauge(t, o.updateError) if got, want := metric.value, float64(0); got != want { t.Fatalf("Unexpected %s value: got: %v; want: %v", metric.name, metric.value, want) } - m.Update(nil, fmt.Errorf("testing")) - metric = readGauge(t, m.updateError) + o.ObserveReadError(fmt.Errorf("testing")) + metric = readGauge(t, o.updateError) if got, want := metric.value, float64(1); got != want { t.Fatalf("Unexpected %s value: got: %v; want: %v", metric.name, metric.value, want) } - m.Update(&tls.Config{}, nil) - metric = readGauge(t, m.updateError) + o.ObserveConfig(&tls.Config{}) + metric = readGauge(t, o.updateError) if got, want := metric.value, float64(0); got != want { t.Fatalf("Unexpected %s value: got: %v; want: %v", metric.name, metric.value, want) } @@ -200,10 +202,11 @@ func TestValidation(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - m, err := NewMetrics(tt.options...) + m, err := NewObserver(tt.options...) check(t, "Failed to create metrics", err) - m.Update(tt.config, nil) - got := readGauge(t, m.verifyError) + o := m.(*observer) + o.ObserveConfig(tt.config) + got := readGauge(t, o.verifyError) want := float64(0) if tt.invalid { want = 1 @@ -286,10 +289,11 @@ func TestExpiration(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - m, err := NewMetrics(WithLogger(tlstest.Logr(t))) + m, err := NewObserver(WithLogger(tlstest.Logr(t))) check(t, "Failed to create metrics", err) - m.Update(tt.config, nil) - got := readGauge(t, m.expiration) + o := m.(*observer) + o.ObserveConfig(tt.config) + got := readGauge(t, o.expiration) want := float64(tt.expiry.Unix()) if got.value != want { t.Fatalf("Unexpected %s value: got: %v; want: %v", got.name, got.value, want)