diff --git a/interceptors/retry/options.go b/interceptors/retry/options.go index 649db2028..e14665986 100644 --- a/interceptors/retry/options.go +++ b/interceptors/retry/options.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -22,11 +23,11 @@ var ( max: 0, // disabled perCallTimeout: 0, // disabled includeHeader: true, - codes: DefaultRetriableCodes, backoffFunc: BackoffLinearWithJitter(50*time.Millisecond /*jitter*/, 0.10), onRetryCallback: OnRetryCallback(func(ctx context.Context, attempt uint, err error) { logTrace(ctx, "grpc_retry attempt: %d, backoff for %v", attempt, err) }), + retriableFunc: newRetriableFuncForCodes(DefaultRetriableCodes), } ) @@ -41,6 +42,9 @@ type BackoffFunc func(ctx context.Context, attempt uint) time.Duration // OnRetryCallback is the type of function called when a retry occurs. type OnRetryCallback func(ctx context.Context, attempt uint, err error) +// RetriableFunc denotes a family of functions that control which error should be retried. +type RetriableFunc func(err error) bool + // Disable disables the retry behaviour on this call, or this interceptor. // // Its semantically the same to `WithMax` @@ -78,7 +82,7 @@ func WithOnRetryCallback(fn OnRetryCallback) CallOption { // You cannot automatically retry on Cancelled and Deadline, please use `WithPerRetryTimeout` for these. func WithCodes(retryCodes ...codes.Code) CallOption { return CallOption{applyFunc: func(o *options) { - o.codes = retryCodes + o.retriableFunc = newRetriableFuncForCodes(retryCodes) }} } @@ -100,13 +104,20 @@ func WithPerRetryTimeout(timeout time.Duration) CallOption { }} } +// WithRetriable sets which error should be retried. +func WithRetriable(retriableFunc RetriableFunc) CallOption { + return CallOption{applyFunc: func(o *options) { + o.retriableFunc = retriableFunc + }} +} + type options struct { max uint perCallTimeout time.Duration includeHeader bool - codes []codes.Code backoffFunc BackoffFunc onRetryCallback OnRetryCallback + retriableFunc RetriableFunc } // CallOption is a grpc.CallOption that is local to grpc_retry. @@ -137,3 +148,20 @@ func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOp } return grpcOptions, retryOptions } + +// newRetriableFuncForCodes returns retriable function for specific Codes. +func newRetriableFuncForCodes(codes []codes.Code) func(err error) bool { + return func(err error) bool { + errCode := status.Code(err) + if isContextError(err) { + // context errors are not retriable based on user settings. + return false + } + for _, code := range codes { + if code == errCode { + return true + } + } + return false + } +} diff --git a/interceptors/retry/retry.go b/interceptors/retry/retry.go index 9ea2e80a9..2b7708415 100644 --- a/interceptors/retry/retry.go +++ b/interceptors/retry/retry.go @@ -267,15 +267,8 @@ func waitRetryBackoff(attempt uint, parentCtx context.Context, callOpts *options } func isRetriable(err error, callOpts *options) bool { - errCode := status.Code(err) - if isContextError(err) { - // context errors are not retriable based on user settings. - return false - } - for _, code := range callOpts.codes { - if code == errCode { - return true - } + if callOpts.retriableFunc != nil { + return callOpts.retriableFunc(err) } return false } diff --git a/interceptors/retry/retry_test.go b/interceptors/retry/retry_test.go index 92645b21e..459c33270 100644 --- a/interceptors/retry/retry_test.go +++ b/interceptors/retry/retry_test.go @@ -6,6 +6,7 @@ package retry import ( "context" "io" + "strings" "sync" "testing" "time" @@ -178,6 +179,16 @@ func (s *RetrySuite) TestUnary_OverrideFromDialOpts() { require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made") } +func (s *RetrySuite) TestUnary_OverrideFromDialOpts2() { + s.srv.resetFailingConfiguration(5, codes.ResourceExhausted, noSleep) // default is 3 and retriable_errors + out, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing, WithRetriable(func(err error) bool { + return strings.Contains(err.Error(), "maybeFailRequest") + }), WithMax(5)) + require.NoError(s.T(), err, "the fifth invocation should succeed") + require.NotNil(s.T(), out, "Pong must be not nil") + require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made") +} + func (s *RetrySuite) TestUnary_OnRetryCallbackCalled() { retryCallbackCount := 0 @@ -209,6 +220,16 @@ func (s *RetrySuite) TestServerStream_OverrideFromContext() { require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made") } +func (s *RetrySuite) TestServerStream_OverrideFromContext2() { + s.srv.resetFailingConfiguration(5, codes.ResourceExhausted, noSleep) // default is 3 and retriable_errors + stream, err := s.Client.PingList(s.SimpleCtx(), testpb.GoodPingList, WithRetriable(func(err error) bool { + return strings.Contains(err.Error(), "maybeFailRequest") + }), WithMax(5)) + require.NoError(s.T(), err, "establishing the connection must always succeed") + s.assertPingListWasCorrect(stream) + require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made") +} + func (s *RetrySuite) TestServerStream_OnRetryCallbackCalled() { retryCallbackCount := 0