diff --git a/fxtest/lifecycle.go b/fxtest/lifecycle.go index 111840727..37dc08775 100644 --- a/fxtest/lifecycle.go +++ b/fxtest/lifecycle.go @@ -66,18 +66,45 @@ func (t *panicT) FailNow() { panic("test lifecycle failed") } +// LifecycleOption modifies the behavior of the [Lifecycle] +// when passed to [NewLifecycle]. +type LifecycleOption interface { + apply(*Lifecycle) +} + +// EnforceTimeout will cause the [Lifecycle]'s Start and Stop methods +// to return an error as soon as context expires, +// regardless of whether specific hooks respect the timeout. +func EnforceTimeout(enforce bool) LifecycleOption { + return &enforceTimeout{ + enforce: enforce, + } +} + +type enforceTimeout struct { + enforce bool +} + +func (e *enforceTimeout) apply(lc *Lifecycle) { + lc.enforceTimeout = e.enforce +} + +var _ LifecycleOption = (*enforceTimeout)(nil) + // Lifecycle is a testing spy for fx.Lifecycle. It exposes Start and Stop // methods (and some test-specific helpers) so that unit tests can exercise // hooks. type Lifecycle struct { t TB lc *lifecycle.Lifecycle + + enforceTimeout bool } var _ fx.Lifecycle = (*Lifecycle)(nil) // NewLifecycle creates a new test lifecycle. -func NewLifecycle(t TB) *Lifecycle { +func NewLifecycle(t TB, opts ...LifecycleOption) *Lifecycle { var w io.Writer if t != nil { w = testutil.WriteSyncer{T: t} @@ -85,15 +112,45 @@ func NewLifecycle(t TB) *Lifecycle { w = os.Stderr t = &panicT{W: os.Stderr} } - return &Lifecycle{ + lc := &Lifecycle{ lc: lifecycle.New(fxlog.DefaultLogger(w), fxclock.System), t: t, } + for _, opt := range opts { + opt.apply(lc) + } + return lc +} + +func (l *Lifecycle) withTimeout(ctx context.Context, fn func(context.Context) error) error { + if !l.enforceTimeout { + return fn(ctx) + } + + // Cancel on timeout in case function only respects + // cancellation and not deadline exceeded. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + c := make(chan error, 1) // buffered to avoid goroutine leak + go func() { + c <- fn(ctx) + }() + + var err error + select { + case err = <-c: + case <-ctx.Done(): + err = ctx.Err() + } + return err } // Start executes all registered OnStart hooks in order, halting at the first // hook that doesn't succeed. -func (l *Lifecycle) Start(ctx context.Context) error { return l.lc.Start(ctx) } +func (l *Lifecycle) Start(ctx context.Context) error { + return l.withTimeout(ctx, l.lc.Start) +} // RequireStart calls Start with context.Background(), failing the test if an // error is encountered. @@ -114,7 +171,9 @@ func (l *Lifecycle) RequireStart() *Lifecycle { // If any hook returns an error, execution continues for a best-effort // cleanup. Any errors encountered are collected into a single error and // returned. -func (l *Lifecycle) Stop(ctx context.Context) error { return l.lc.Stop(ctx) } +func (l *Lifecycle) Stop(ctx context.Context) error { + return l.withTimeout(ctx, l.lc.Stop) +} // RequireStop calls Stop with context.Background(), failing the test if an error // is encountered. diff --git a/fxtest/lifecycle_test.go b/fxtest/lifecycle_test.go index a5746b291..97df4960b 100644 --- a/fxtest/lifecycle_test.go +++ b/fxtest/lifecycle_test.go @@ -26,6 +26,7 @@ import ( "errors" "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -117,6 +118,102 @@ func TestLifecycle(t *testing.T) { }) } +func TestEnforceTimeout(t *testing.T) { + // These tests directly call Start and Stop + // rather than RequireStart and RequireStop + // because EnforceTimeout does not apply to those. + + t.Run("StartHookTimeout", func(t *testing.T) { + t.Parallel() + + wait := make(chan struct{}) + defer close(wait) // force timeout by blocking OnStart until end of test + + spy := newTB() + lc := NewLifecycle(spy, EnforceTimeout(true)) + lc.Append(fx.Hook{ + OnStart: func(context.Context) error { + <-wait + return nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + assert.ErrorIs(t, lc.Start(ctx), context.DeadlineExceeded) + assert.Zero(t, spy.failures) + }) + + t.Run("StopHookTimeout", func(t *testing.T) { + t.Parallel() + + wait := make(chan struct{}) + defer close(wait) // force timeout by blocking OnStop until end of test + + spy := newTB() + lc := NewLifecycle(spy, EnforceTimeout(true)) + lc.Append(fx.Hook{ + OnStop: func(context.Context) error { + <-wait + return nil + }, + }) + + require.NoError(t, lc.Start(context.Background())) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + assert.ErrorIs(t, lc.Stop(ctx), context.DeadlineExceeded) + assert.Zero(t, spy.failures) + }) + + t.Run("NoTimeout", func(t *testing.T) { + t.Parallel() + + var ( + started bool + stopped bool + ) + + spy := newTB() + lc := NewLifecycle(spy, EnforceTimeout(true)) + lc.Append(fx.Hook{ + OnStart: func(context.Context) error { + started = true + return nil + }, + OnStop: func(context.Context) error { + stopped = true + return nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + require.NoError(t, lc.Start(ctx)) + require.NoError(t, lc.Stop(ctx)) + assert.True(t, started) + assert.True(t, stopped) + assert.Zero(t, spy.failures) + }) + + t.Run("OtherError", func(t *testing.T) { + t.Parallel() + + spy := newTB() + lc := NewLifecycle(spy, EnforceTimeout(true)) + lc.Append(fx.Hook{ + OnStart: func(context.Context) error { + return errors.New("NOT a context-related error") + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + assert.ErrorContains(t, lc.Start(ctx), "NOT a context-related error") + assert.Zero(t, spy.failures) + }) +} + func TestLifecycle_OptionalT(t *testing.T) { t.Parallel()