diff --git a/v3/circuit.go b/v3/circuit.go index 79607e5..187c1f0 100644 --- a/v3/circuit.go +++ b/v3/circuit.go @@ -334,11 +334,30 @@ func (c *Circuit) checkSuccess(runFuncDoneTime time.Time, totalCmdTime time.Dura } } +// checkErrInterrupt returns true if this is considered an interrupt error: interrupt errors do not open the circuit. +// Normally if the parent context is canceled before a timeout is reached, we don't consider the circuit +// unhealthy. But when ExecutionConfig.IgnoreInterrupts set to true we try to classify originalContext.Err() +// with help of ExecutionConfig.IsErrInterrupt function. When this function returns true we do not open the circuit func (c *Circuit) checkErrInterrupt(originalContext context.Context, ret error, runFuncDoneTime time.Time, totalCmdTime time.Duration) bool { - if !c.threadSafeConfig.GoSpecific.IgnoreInterrupts.Get() && ret != nil && originalContext.Err() != nil { + // We need to see an error in both the original context and the return value to consider this an "interrupt" caused + // error. + if ret == nil || originalContext.Err() == nil { + return false + } + + isErrInterrupt := c.notThreadSafeConfig.Execution.IsErrInterrupt + if isErrInterrupt == nil { + isErrInterrupt = func(_ error) bool { + // By default, we consider any error from the original context an interrupt causing error + return true + } + } + + if !c.threadSafeConfig.GoSpecific.IgnoreInterrupts.Get() && isErrInterrupt(originalContext.Err()) { c.CmdMetricCollector.ErrInterrupt(runFuncDoneTime, totalCmdTime) return true } + return false } diff --git a/v3/circuit_test.go b/v3/circuit_test.go index f976ffc..c6113ff 100644 --- a/v3/circuit_test.go +++ b/v3/circuit_test.go @@ -283,22 +283,130 @@ func TestFallbackCircuit(t *testing.T) { } func TestCircuitIgnoreContextFailures(t *testing.T) { - c := NewCircuitFromConfig("TestFailingCircuit", Config{ - Execution: ExecutionConfig{ - Timeout: time.Hour, - }, + + t.Run("ignore context.DeadlineExceeded by default", func(t *testing.T) { + c := circuitFactory(t) + + for i := 0; i < 100; i++ { + rootCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*3) + err := c.Execute(rootCtx, testhelp.SleepsForX(time.Second), nil) + if err != context.DeadlineExceeded { + t.Errorf("saw no error from circuit that should end in an error(%d):%v", i, err) + cancel() + break + } + cancel() + } + if c.IsOpen() { + t.Error("Parent context cancellations should not close the circuit by default") + } }) - for i := 0; i < 100; i++ { - rootCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*3) - err := c.Execute(rootCtx, testhelp.SleepsForX(time.Second), nil) - if err == nil { - t.Error("saw no error from circuit that should end in an error") + + t.Run("ignore context.Canceled by default", func(t *testing.T) { + c := circuitFactory(t) + + for i := 0; i < 100; i++ { + rootCtx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(time.Millisecond*3, func() { cancel() }) + err := c.Execute(rootCtx, testhelp.SleepsForX(time.Second), nil) + if err != context.Canceled { + t.Errorf("saw no error from circuit that should end in an error(%d):%v", i, err) + cancel() + break + } + cancel() } - cancel() - } - if c.IsOpen() { - t.Error("Parent context cacelations should not close the circuit by default") - } + if c.IsOpen() { + t.Error("Parent context cancellations should not close the circuit by default") + } + }) + + t.Run("open circuit on context.DeadlineExceeded with IgnoreInterrupts", func(t *testing.T) { + c := circuitFactory(t, withIgnoreInterrupts(true)) + + for i := 0; i < 100; i++ { + rootCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*3) + err := c.Execute(rootCtx, testhelp.SleepsForX(time.Second), nil) + + if err != context.DeadlineExceeded && err != errCircuitOpen { + t.Errorf("saw no error from circuit that should end in an error(%d):%v", i, err) + cancel() + break + } + cancel() + } + if !c.IsOpen() { + t.Error("Parent context cancellations should open the circuit when IgnoreInterrupts sets to true") + } + }) + + t.Run("open circuit on context.Canceled with IgnoreInterrupts", func(t *testing.T) { + c := circuitFactory(t, withIgnoreInterrupts(true)) + + for i := 0; i < 100; i++ { + rootCtx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(time.Millisecond*3, func() { cancel() }) + err := c.Execute(rootCtx, testhelp.SleepsForX(time.Second), nil) + + if err != context.Canceled && err != errCircuitOpen { + t.Errorf("saw no error from circuit that should end in an error(%d):%v", i, err) + cancel() + break + } + cancel() + } + if !c.IsOpen() { + t.Error("Parent context cancellations should open the circuit when IgnoreInterrupts sets to true") + } + }) + + t.Run("open circuit on context.DeadlineExceeded with IgnoreInterrupts and IsErrInterrupt", func(t *testing.T) { + c := circuitFactory( + t, + withIgnoreInterrupts(true), + withIsErrInterrupt(func(err error) bool { return err == context.Canceled }), + ) + + for i := 0; i < 100; i++ { + rootCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*3) + err := c.Execute(rootCtx, testhelp.SleepsForX(time.Second), nil) + + if err != context.DeadlineExceeded && err != errCircuitOpen { + t.Errorf("saw no error from circuit that should end in an error(%d):%v", i, err) + cancel() + break + } + cancel() + } + if !c.IsOpen() { + t.Error("Parent context cancellations should open the circuit when IgnoreInterrupts sets to true") + } + }) + + t.Run("ignore context.Canceled with IgnoreInterrupts and IsErrInterrupt", func(t *testing.T) { + c := circuitFactory( + t, + withIgnoreInterrupts(true), + withIsErrInterrupt(func(err error) bool { return err == context.Canceled }), + ) + + for i := 0; i < 100; i++ { + rootCtx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(time.Millisecond*3, func() { cancel() }) + err := c.Execute(rootCtx, testhelp.SleepsForX(time.Second), nil) + + if err != context.Canceled && err != errCircuitOpen { + t.Errorf("saw no error from circuit that should end in an error(%d):%v", i, err) + cancel() + break + } + cancel() + } + if c.IsOpen() { + t.Error("Parent context cancellations should not open the circuit when IgnoreInterrupts sets to true") + } + }) + } func TestFallbackCircuitConcurrency(t *testing.T) { @@ -453,3 +561,56 @@ func TestVariousRaceConditions(t *testing.T) { } wg.Wait() } + +func openOnFirstErrorFactory() ClosedToOpen { + return &closeOnFirstErrorOpener{ + ClosedToOpen: neverOpensFactory(), + } +} + +type closeOnFirstErrorOpener struct { + ClosedToOpen + isOpened bool +} + +func (o *closeOnFirstErrorOpener) ShouldOpen(_ time.Time) bool { + o.isOpened = true + return true +} +func (o *closeOnFirstErrorOpener) Prevent(_ time.Time) bool { + return o.isOpened +} + +type configOverride func(*Config) *Config + +func withIgnoreInterrupts(b bool) configOverride { + return func(c *Config) *Config { + c.Execution.IgnoreInterrupts = b + return c + } +} + +func withIsErrInterrupt(fn func(error) bool) configOverride { + return func(c *Config) *Config { + c.Execution.IsErrInterrupt = fn + return c + } +} + +func circuitFactory(t *testing.T, cfgOpts ...configOverride) *Circuit { + t.Helper() + + cfg := Config{ + General: GeneralConfig{ + ClosedToOpenFactory: openOnFirstErrorFactory, + }, + Execution: ExecutionConfig{ + Timeout: time.Hour, + }, + } + for _, co := range cfgOpts { + co(&cfg) + } + + return NewCircuitFromConfig(t.Name(), cfg) +} diff --git a/v3/config.go b/v3/config.go index faff8ba..420a976 100644 --- a/v3/config.go +++ b/v3/config.go @@ -46,10 +46,14 @@ type ExecutionConfig struct { // MaxConcurrentRequests is https://github.com/Netflix/Hystrix/wiki/Configuration#executionisolationsemaphoremaxconcurrentrequests MaxConcurrentRequests int64 // Normally if the parent context is canceled before a timeout is reached, we don't consider the circuit - // unhealth. Set this to true to consider those circuits unhealthy. - // Note: This is a typo: Should be renamed as IgnoreInterrupts. Tracking this in - // https://github.com/cep21/circuit/issues/39 + // unhealthy. Set this to true to consider those circuits unhealthy. IgnoreInterrupts bool `json:",omitempty"` + // IsErrInterrupt should return true if the error from the original context should be considered an interrupt error. + // The error passed in will be a non-nil error returned by calling `Err()` on the context passed into Run. + // The default behavior is to consider all errors from the original context interrupt caused errors. + // Default behaviour: + // IsErrInterrupt: function(e err) bool { return true } + IsErrInterrupt func(originalContextError error) bool `json:"-"` } // FallbackConfig is https://github.com/Netflix/Hystrix/wiki/Configuration#fallback @@ -100,6 +104,9 @@ func (c *ExecutionConfig) merge(other ExecutionConfig) { if !c.IgnoreInterrupts { c.IgnoreInterrupts = other.IgnoreInterrupts } + if c.IsErrInterrupt == nil { + c.IsErrInterrupt = other.IsErrInterrupt + } if c.MaxConcurrentRequests == 0 { c.MaxConcurrentRequests = other.MaxConcurrentRequests } diff --git a/v3/config_test.go b/v3/config_test.go index b9b86a1..d71a56c 100644 --- a/v3/config_test.go +++ b/v3/config_test.go @@ -41,3 +41,28 @@ func TestGeneralConfig_Merge(t *testing.T) { }) } + +func TestExecutionConfig_Merge(t *testing.T) { + + t.Run("isErrInterrupt check function", func(t *testing.T) { + cfg := ExecutionConfig{} + + cfg.merge(ExecutionConfig{IsErrInterrupt: func(e error) bool { return e != nil }}) + + assert.NotNil(t, cfg.IsErrInterrupt) + }) + + t.Run("ignore isErrInterrupt if previously set", func(t *testing.T) { + fn1 := func(err error) bool { return true } + fn2 := func(err error) bool { return false } + + cfg := ExecutionConfig{ + IsErrInterrupt: fn1, + } + + cfg.merge(ExecutionConfig{IsErrInterrupt: fn2}) + + assert.NotNil(t, fn1, cfg.IsErrInterrupt) + assert.True(t, cfg.IsErrInterrupt(nil)) + }) +}