diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..5b48abc Binary files /dev/null and b/.DS_Store differ diff --git a/proxy/README.md b/proxy/README.md deleted file mode 120000 index 71bfc07..0000000 --- a/proxy/README.md +++ /dev/null @@ -1 +0,0 @@ -DOC.md \ No newline at end of file diff --git a/proxy/README.md b/proxy/README.md new file mode 100644 index 0000000..85c411a --- /dev/null +++ b/proxy/README.md @@ -0,0 +1,83 @@ +# proxy +-- + import "github.com/mwitkow/grpc-proxy/proxy" + +Package proxy provides a reverse proxy handler for gRPC. + +The implementation allows a `grpc.Server` to pass a received ServerStream to a +ClientStream without understanding the semantics of the messages exchanged. It +basically provides a transparent reverse-proxy. + +This package is intentionally generic, exposing a `StreamDirector` function that +allows users of this package to implement whatever logic of backend-picking, +dialing and service verification to perform. + +See examples on documented functions. + +## Usage + +#### func Codec + +```go +func Codec() grpc.Codec +``` +Codec returns a proxying grpc.Codec with the default protobuf codec as parent. + +See CodecWithParent. + +#### func CodecWithParent + +```go +func CodecWithParent(fallback grpc.Codec) grpc.Codec +``` +CodecWithParent returns a proxying grpc.Codec with a user provided codec as +parent. + +This codec is *crucial* to the functioning of the proxy. It allows the proxy +server to be oblivious to the schema of the forwarded messages. It basically +treats a gRPC message frame as raw bytes. However, if the server handler, or the +client caller are not proxy-internal functions it will fall back to trying to +decode the message using a fallback codec. + +#### func RegisterService + +```go +func RegisterService(server *grpc.Server, director StreamDirector, serviceName string, methodNames ...string) +``` +RegisterService sets up a proxy handler for a particular gRPC service and +method. The behaviour is the same as if you were registering a handler method, +e.g. from a codegenerated pb.go file. + +This can *only* be used if the `server` also uses grpcproxy.CodecForServer() +ServerOption. + +#### func TransparentHandler + +```go +func TransparentHandler(director StreamDirector) grpc.StreamHandler +``` +TransparentHandler returns a handler that attempts to proxy all requests that +are not registered in the server. The indented use here is as a transparent +proxy, where the server doesn't know about the services implemented by the +backends. It should be used as a `grpc.UnknownServiceHandler`. + +This can *only* be used if the `server` also uses grpcproxy.CodecForServer() +ServerOption. + +#### type StreamDirector + +```go +type StreamDirector func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) +``` + +StreamDirector returns a gRPC ClientConn to be used to forward the call to. + +The presence of the `Context` allows for rich filtering, e.g. based on Metadata +(headers). If no handling is meant to be done, a `codes.NotImplemented` gRPC +error should be returned. + +It is worth noting that the StreamDirector will be fired *after* all server-side +stream interceptors are invoked. So decisions around authorization, monitoring +etc. are better to be handled there. + +See the rather rich example. diff --git a/proxy/director.go b/proxy/director.go index 2e1c916..371ca60 100644 --- a/proxy/director.go +++ b/proxy/director.go @@ -13,8 +13,12 @@ import ( // The presence of the `Context` allows for rich filtering, e.g. based on Metadata (headers). // If no handling is meant to be done, a `codes.NotImplemented` gRPC error should be returned. // +// The context returned from this function should be the context for the *outgoing* (to backend) call. In case you want +// to forward any Metadata between the inbound request and outbound requests, you should do it manually. However, you +// *must* propagate the cancel function (`context.WithCancel`) of the inbound context to the one returned. +// // It is worth noting that the StreamDirector will be fired *after* all server-side stream interceptors // are invoked. So decisions around authorization, monitoring etc. are better to be handled there. // // See the rather rich example. -type StreamDirector func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) +type StreamDirector func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) diff --git a/proxy/examples_test.go b/proxy/examples_test.go index ad3dbb4..bef4ce3 100644 --- a/proxy/examples_test.go +++ b/proxy/examples_test.go @@ -35,21 +35,26 @@ func ExampleTransparentHandler() { // Provide sa simple example of a director that shields internal services and dials a staging or production backend. // This is a *very naive* implementation that creates a new connection on every request. Consider using pooling. func ExampleStreamDirector() { - director = func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) { + director = func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) { // Make sure we never forward internal services. if strings.HasPrefix(fullMethodName, "/com.example.internal.") { - return nil, grpc.Errorf(codes.Unimplemented, "Unknown method") + return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method") } - md, ok := metadata.FromContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) + // Copy the inbound metadata explicitly. + outCtx, _ := context.WithCancel(ctx) + outCtx = metadata.NewOutgoingContext(outCtx, md.Copy()) if ok { // Decide on which backend to dial if val, exists := md[":authority"]; exists && val[0] == "staging.api.example.com" { // Make sure we use DialContext so the dialing can be cancelled/time out together with the context. - return grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec())) + conn, err := grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec())) + return outCtx, conn, err } else if val, exists := md[":authority"]; exists && val[0] == "api.example.com" { - return grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec())) + conn, err := grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec())) + return outCtx, conn, err } } - return nil, grpc.Errorf(codes.Unimplemented, "Unknown method") + return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method") } } diff --git a/proxy/handler.go b/proxy/handler.go index f5868d9..f43b05a 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -9,7 +9,8 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/transport" + // "google.golang.org/grpc/transport" + // "google.golang.org/grpc/transport" ) var ( @@ -17,6 +18,7 @@ var ( ServerStreams: true, ClientStreams: true, } + HandleEndCallback func(context.Context, *grpc.ClientConn) ) // RegisterService sets up a proxy handler for a particular gRPC service and method. @@ -60,25 +62,32 @@ type handler struct { // forwarding it to a ClientStream established against the relevant ClientConn. func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error { // little bit of gRPC internals never hurt anyone - lowLevelServerStream, ok := transport.StreamFromContext(serverStream.Context()) - if !ok { - return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context") - } + lowLevelServerStream := grpc.ServerTransportStreamFromContext(serverStream.Context()) + // if !ok { + // return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context") + // } fullMethodName := lowLevelServerStream.Method() - clientCtx, clientCancel := context.WithCancel(serverStream.Context()) - backendConn, err := s.director(serverStream.Context(), fullMethodName) + // We require that the director's returned context inherits from the serverStream.Context(). + outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName) + clientCtx, clientCancel := context.WithCancel(outgoingCtx) if err != nil { return err } + if HandleEndCallback == nil { + defer backendConn.Close() + } else { + HandleEndCallback(clientCtx, backendConn) + } // TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For. clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName) if err != nil { return err } + // Explicitly *do not close* s2cErrChan and c2sErrChan, otherwise the select below will not terminate. + // Channels do not have to be closed, it is just a control flow mechanism, see + // https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ s2cErrChan := s.forwardServerToClient(serverStream, clientStream) - defer close(s2cErrChan) c2sErrChan := s.forwardClientToServer(clientStream, serverStream) - defer close(c2sErrChan) // We don't know which side is going to stop sending first, so we need a select between the two. for i := 0; i < 2; i++ { select { diff --git a/proxy/handler_test.go b/proxy/handler_test.go index ea67ab2..7cb55e7 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -46,7 +46,7 @@ type assertingService struct { func (s *assertingService) PingEmpty(ctx context.Context, _ *pb.Empty) (*pb.PingResponse, error) { // Check that this call has client's metadata. - md, ok := metadata.FromContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) assert.True(s.t, ok, "PingEmpty call must have metadata in context") _, ok = md[clientMdKey] assert.True(s.t, ok, "PingEmpty call must have clients's custom headers in metadata") @@ -99,10 +99,11 @@ func (s *assertingService) PingStream(stream pb.TestService_PingStreamServer) er type ProxyHappySuite struct { suite.Suite - serverListener net.Listener - server *grpc.Server - proxyListener net.Listener - proxy *grpc.Server + serverListener net.Listener + server *grpc.Server + proxyListener net.Listener + proxy *grpc.Server + serverClientConn *grpc.ClientConn client *grpc.ClientConn testClient pb.TestServiceClient @@ -115,12 +116,18 @@ func (s *ProxyHappySuite) ctx() context.Context { } func (s *ProxyHappySuite) TestPingEmptyCarriesClientMetadata() { - ctx := metadata.NewContext(s.ctx(), metadata.Pairs(clientMdKey, "true")) + ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(clientMdKey, "true")) out, err := s.testClient.PingEmpty(ctx, &pb.Empty{}) require.NoError(s.T(), err, "PingEmpty should succeed without errors") require.Equal(s.T(), &pb.PingResponse{Value: pingDefaultValue, Counter: 42}, out) } +func (s *ProxyHappySuite) TestPingEmpty_StressTest() { + for i := 0; i < 50; i++ { + s.TestPingEmptyCarriesClientMetadata() + } +} + func (s *ProxyHappySuite) TestPingCarriesServerHeadersAndTrailers() { headerMd := make(metadata.MD) trailerMd := make(metadata.MD) @@ -141,7 +148,7 @@ func (s *ProxyHappySuite) TestPingErrorPropagatesAppError() { func (s *ProxyHappySuite) TestDirectorErrorIsPropagated() { // See SetupSuite where the StreamDirector has a special case. - ctx := metadata.NewContext(s.ctx(), metadata.Pairs(rejectingMdKey, "true")) + ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(rejectingMdKey, "true")) _, err := s.testClient.Ping(ctx, &pb.PingRequest{Value: "foo"}) require.Error(s.T(), err, "Director should reject this RPC") assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err)) @@ -175,6 +182,12 @@ func (s *ProxyHappySuite) TestPingStream_FullDuplexWorks() { assert.Len(s.T(), trailerMd, 1, "PingList trailer headers user contain metadata") } +func (s *ProxyHappySuite) TestPingStream_StressTest() { + for i := 0; i < 50; i++ { + s.TestPingStream_FullDuplexWorks() + } +} + func (s *ProxyHappySuite) SetupSuite() { var err error @@ -189,16 +202,19 @@ func (s *ProxyHappySuite) SetupSuite() { pb.RegisterTestServiceServer(s.server, &assertingService{t: s.T()}) // Setup of the proxy's Director. - proxyClientConn, err := grpc.Dial(s.serverListener.Addr().String(), grpc.WithInsecure(), grpc.WithCodec(proxy.Codec())) + s.serverClientConn, err = grpc.Dial(s.serverListener.Addr().String(), grpc.WithInsecure(), grpc.WithCodec(proxy.Codec())) require.NoError(s.T(), err, "must not error on deferred client Dial") - director := func(ctx context.Context, fullName string) (*grpc.ClientConn, error) { - md, ok := metadata.FromContext(ctx) + director := func(ctx context.Context, fullName string) (context.Context, *grpc.ClientConn, error) { + md, ok := metadata.FromIncomingContext(ctx) if ok { if _, exists := md[rejectingMdKey]; exists { - return nil, grpc.Errorf(codes.PermissionDenied, "testing rejection") + return ctx, nil, grpc.Errorf(codes.PermissionDenied, "testing rejection") } } - return proxyClientConn, nil + // Explicitly copy the metadata, otherwise the tests will fail. + outCtx, _ := context.WithCancel(ctx) + outCtx = metadata.NewOutgoingContext(outCtx, md.Copy()) + return outCtx, s.serverClientConn, nil } s.proxy = grpc.NewServer( grpc.CustomCodec(proxy.Codec()), @@ -225,6 +241,14 @@ func (s *ProxyHappySuite) SetupSuite() { } func (s *ProxyHappySuite) TearDownSuite() { + if s.client != nil { + s.client.Close() + } + if s.serverClientConn != nil { + s.serverClientConn.Close() + } + // Close all transports so the logs don't get spammy. + time.Sleep(10 * time.Millisecond) if s.proxy != nil { s.proxy.Stop() s.proxyListener.Close() @@ -233,9 +257,6 @@ func (s *ProxyHappySuite) TearDownSuite() { s.server.Stop() s.serverListener.Close() } - if s.client != nil { - s.client.Close() - } } func TestProxyHappySuite(t *testing.T) {