Skip to content

Commit

Permalink
Subscription connection parameters support (#360)
Browse files Browse the repository at this point in the history
Adds support for subscription (websocket) connection parameters support.

We are having an issue with authorization when using subscriptions in
genqlient. Our mobile app team is using Apollo for GraphQL requests, and
Apollo seems to sends its auth in a way that is different from how
genqlient does it:


https://www.apollographql.com/docs/react/data/subscriptions#5-authenticate-over-websocket-optional

This is causing a problems, because we use genqlient for testing the
backend. It seems that the way Apollo is implemented, authentication
(and other headers) are not actually sent as HTTP headers, but as
connection parameters instead.

I have:
- [x] Written a clear PR title and description (above)
- [x] Signed the [Khan Academy CLA](https://www.khanacademy.org/r/cla)
- [x] Added tests covering my changes, if applicable
- [x] Included a link to the issue fixed, if applicable
- [x] Included documentation, for new features
- [x] Added an entry to the changelog

---------

Co-authored-by: Ben Kraft <[email protected]>
  • Loading branch information
HaraldNordgren and benjaminjkraft authored Nov 30, 2024
1 parent adb9dd6 commit 800909d
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 9 deletions.
20 changes: 18 additions & 2 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,41 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
return newClient(endpoint, httpClient, http.MethodGet)
}

type WebSocketOption func(*webSocketClient)

// NewClientUsingWebSocket returns a [WebSocketClient] which makes subscription requests
// to the given endpoint using webSocket.
//
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header, opts ...WebSocketOption) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
if headers.Get("Sec-WebSocket-Protocol") == "" {
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}
return &webSocketClient{
client := &webSocketClient{
Dialer: wsDialer,
Header: headers,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
}

for _, opt := range opts {
opt(client)
}

return client
}

// WithConnectionParams sets up connection params to be sent to the server
// during the initial connection handshake.
func WithConnectionParams(connParams map[string]interface{}) WebSocketOption {
return func(ws *webSocketClient) {
ws.connParams = connParams
}
}

func newClient(endpoint string, httpClient Doer, method string) Client {
Expand Down
11 changes: 9 additions & 2 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,18 @@ type webSocketClient struct {
Header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
errChan chan error
subscriptions subscriptionMap
isClosing bool
sync.Mutex
}

type webSocketInitMessage struct {
Payload map[string]interface{} `json:"payload"`
Type string `json:"type"`
}

type webSocketSendMessage struct {
Payload *Request `json:"payload"`
Type string `json:"type"`
Expand All @@ -67,8 +73,9 @@ type webSocketReceiveMessage struct {
}

func (w *webSocketClient) sendInit() error {
connInitMsg := webSocketSendMessage{
Type: webSocketTypeConnInit,
connInitMsg := webSocketInitMessage{
Type: webSocketTypeConnInit,
Payload: w.connParams,
}
return w.sendStructAsJSON(connInitMsg)
}
Expand Down
56 changes: 56 additions & 0 deletions internal/integration/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 88 additions & 0 deletions internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func TestSubscription(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -146,6 +147,93 @@ func TestSubscription(t *testing.T) {
}
}

func TestSubscriptionConnectionParams(t *testing.T) {
_ = `# @genqlient
subscription countAuthorized { countAuthorized }`

authKey := server.AuthKey

ctx := context.Background()
server := server.RunServer()
defer server.Close()

cases := []struct {
connParams map[string]interface{}
name string
expectedError string
opts []graphql.WebSocketOption
}{
{
name: "authorized_user_gets_counter",
opts: []graphql.WebSocketOption{
graphql.WithConnectionParams(map[string]interface{}{
authKey: "authorized-user-token",
}),
},
},
{
name: "unauthorized_user_gets_error",
expectedError: "input: countAuthorized unauthorized\n",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(
t,
server.URL,
tc.opts...,
)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

dataChan, subscriptionID, err := countAuthorized(ctx, wsClient)
require.NoError(t, err)
defer wsClient.Close()

var (
counter = 0
start = time.Now()
)

for loop := true; loop; {
select {
case resp, more := <-dataChan:
if !more {
loop = false
break
}

if tc.expectedError != "" {
require.Error(t, resp.Errors)
assert.Equal(t, tc.expectedError, resp.Errors.Error())
continue
}

require.NotNil(t, resp.Data)
assert.Equal(t, counter, resp.Data.CountAuthorized)
require.Nil(t, resp.Errors)

if time.Since(start) > 5*time.Second {
err := wsClient.Unsubscribe(subscriptionID)
require.NoError(t, err)
loop = false
}

counter++

case err := <-errChan:
require.NoError(t, err)

case <-time.After(10 * time.Second):
require.NoError(t, fmt.Errorf("subscription timed out"))
}
}
})
}
}

func TestServerError(t *testing.T) {
_ = `# @genqlient
query failingQuery { fail me { id } }`
Expand Down
12 changes: 9 additions & 3 deletions internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,20 @@ func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeade
return graphql.WSConn(conn), err
}

func newRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient {
func newRoundtripWebSocketClient(t *testing.T, endpoint string, opts ...graphql.WebSocketOption) graphql.WebSocketClient {
dialer := websocket.DefaultDialer
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
endpoint = "ws://" + address
}

return &roundtripClient{
wsWrapped: graphql.NewClientUsingWebSocket(endpoint, &MyDialer{Dialer: dialer}, nil),
t: t,
wsWrapped: graphql.NewClientUsingWebSocket(
endpoint,
&MyDialer{Dialer: dialer},
nil,
opts...,
),
t: t,
}
}
1 change: 1 addition & 0 deletions internal/integration/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Mutation {

type Subscription {
count: Int!
countAuthorized: Int!
}

type User implements Being & Lucky {
Expand Down
72 changes: 71 additions & 1 deletion internal/integration/server/gqlgen_exec.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 800909d

Please sign in to comment.