Skip to content

Commit

Permalink
fix(router): Handle repeated headers in response
Browse files Browse the repository at this point in the history
When multiple headers with the same name were sent in the response from
a subgraph, only one would be included in the response to the client
when forwarding headers. Fix it so that they're all forwarded.
  • Loading branch information
cmtm committed Jan 31, 2025
1 parent 8e78a37 commit 3807826
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 19 deletions.
88 changes: 83 additions & 5 deletions router-tests/header_propagation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ func TestHeaderPropagation(t *testing.T) {
const (
customHeader = "X-Custom-Header"
employeeVal = "employee-value"
employeeVal2 = "employee-value-2"
hobbyVal = "hobby-value"
hobbyVal2 = "hobby-value-2"
)

const queryEmployeeWithHobby = `{
Expand Down Expand Up @@ -97,20 +99,20 @@ func TestHeaderPropagation(t *testing.T) {
}
}

setSubgraphPropagateHeader := func(header, valA, valB string) testenv.SubgraphsConfig {
setSubgraphPropagateHeader := func(header string, valA, valB []string) testenv.SubgraphsConfig {
return testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header, valA)
w.Header()[header] = valA
handler.ServeHTTP(w, r)
})
},
},
Hobbies: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header, valB)
w.Header()[header] = valB
handler.ServeHTTP(w, r)
})
},
Expand Down Expand Up @@ -140,11 +142,12 @@ func TestHeaderPropagation(t *testing.T) {
}

cacheOptions := func(cacheControlEmployees, cacheControlHobbies string) testenv.SubgraphsConfig {
return setSubgraphPropagateHeader("Cache-Control", cacheControlEmployees, cacheControlHobbies)
return setSubgraphPropagateHeader("Cache-Control", []string{cacheControlEmployees}, []string{cacheControlHobbies})
}

var (
subgraphsPropagateCustomHeader = setSubgraphPropagateHeader(customHeader, employeeVal, hobbyVal)
subgraphsPropagateCustomHeader = setSubgraphPropagateHeader(customHeader, []string{employeeVal}, []string{hobbyVal})
subgraphsPropagateRepeatedCustomHeader = setSubgraphPropagateHeader(customHeader, []string{employeeVal, employeeVal2}, []string{hobbyVal, hobbyVal2})
)

t.Run(" no propagate", func(t *testing.T) {
Expand Down Expand Up @@ -234,6 +237,21 @@ func TestHeaderPropagation(t *testing.T) {
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})

t.Run("repeated header names last write wins", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: global(config.ResponseHeaderRuleAlgorithmLastWrite, customHeader, ""),
Subgraphs: subgraphsPropagateRepeatedCustomHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
ch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, "hobby-value,hobby-value-2", ch)
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
})

// Test for the First Write Wins Algorithm
Expand Down Expand Up @@ -283,6 +301,21 @@ func TestHeaderPropagation(t *testing.T) {
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})

t.Run("repeated header names first write wins", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: partial(config.ResponseHeaderRuleAlgorithmFirstWrite, customHeader, ""),
Subgraphs: subgraphsPropagateRepeatedCustomHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
ch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, "employee-value,employee-value-2", ch)
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
})

// Test for the Append Algorithm
Expand Down Expand Up @@ -332,6 +365,21 @@ func TestHeaderPropagation(t *testing.T) {
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})

t.Run("repeated header names append headers", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: global(config.ResponseHeaderRuleAlgorithmAppend, customHeader, ""),
Subgraphs: subgraphsPropagateRepeatedCustomHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
ch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, "employee-value,employee-value-2,hobby-value,hobby-value-2", ch)
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
})

t.Run("Cache Control Propagation", func(t *testing.T) {
Expand Down Expand Up @@ -718,4 +766,34 @@ func TestHeaderPropagation(t *testing.T) {
})
})
})

t.Run("header name canonicalization", func(t *testing.T) {
t.Parallel()
nonCanonicalCustomHeader := "x-Custom-header"
subgraphsNonCanonicalHeader := testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header()[nonCanonicalCustomHeader] = []string{employeeVal}
handler.ServeHTTP(w, r)
})
},
},
}

testenv.Run(t, &testenv.Config{
RouterOptions: global(config.ResponseHeaderRuleAlgorithmAppend, nonCanonicalCustomHeader, ""),
Subgraphs: subgraphsNonCanonicalHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
cch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, employeeVal, cch)
ncch := strings.Join(res.Response.Header[nonCanonicalCustomHeader], ",")
require.Equal(t, "", ncch)

require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
}
29 changes: 16 additions & 13 deletions router/core/header_rule_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ func (h *HeaderPropagation) applyResponseRule(propagation *responseHeaderPropaga
return
}

value := res.Header.Get(rule.Named)
if value != "" {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, value)
values := res.Header.Values(rule.Named)
if len(values) > 0 {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, values)
} else if rule.Default != "" {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, rule.Default)
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, []string{rule.Default})
}

return
Expand All @@ -324,31 +324,34 @@ func (h *HeaderPropagation) applyResponseRule(propagation *responseHeaderPropaga
if slices.Contains(ignoredHeaders, name) {
continue
}
h.applyResponseRuleKeyValue(res, propagation, rule, name, res.Header.Get(name))
values := res.Header.Values(name)
h.applyResponseRuleKeyValue(res, propagation, rule, name, values)
}
}
}
} else if rule.Algorithm == config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl {
// Explicitly apply the CacheControl algorithm on the headers
h.applyResponseRuleKeyValue(res, propagation, rule, "", "")
h.applyResponseRuleKeyValue(res, propagation, rule, "", []string{""})
}
}

func (h *HeaderPropagation) applyResponseRuleKeyValue(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule, key, value string) {
func (h *HeaderPropagation) applyResponseRuleKeyValue(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule, key string, values []string) {
// Since we'll be setting the header map directly, we need to canonicalize the key
key = http.CanonicalHeaderKey(key)
switch rule.Algorithm {
case config.ResponseHeaderRuleAlgorithmFirstWrite:
propagation.m.Lock()
if val := propagation.header.Get(key); val == "" {
propagation.header.Set(key, value)
propagation.header[key] = values
}
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmLastWrite:
propagation.m.Lock()
propagation.header.Set(key, value)
propagation.header[key] = values
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmAppend:
propagation.m.Lock()
propagation.header.Add(key, value)
propagation.header[key] = append(propagation.header[key], values...)
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl:
h.applyResponseRuleMostRestrictiveCacheControl(res, propagation, rule)
Expand Down Expand Up @@ -407,9 +410,9 @@ func (h *HeaderPropagation) applyRequestRule(ctx RequestContext, request *http.R
return
}

value := ctx.Request().Header.Get(rule.Named)
if value != "" {
request.Header.Set(rule.Named, ctx.Request().Header.Get(rule.Named))
values := ctx.Request().Header.Values(rule.Named)
if len(values) > 0 {
request.Header[http.CanonicalHeaderKey(rule.Named)] = values
} else if rule.Default != "" {
request.Header.Set(rule.Named, rule.Default)
}
Expand Down
38 changes: 37 additions & 1 deletion router/core/header_rule_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,42 @@ func TestPropagateHeaderRule(t *testing.T) {

})

t.Run("Should propagate repeated header names", func(t *testing.T) {

ht, err := NewHeaderPropagation(&config.HeaderRules{
All: &config.GlobalHeaderRule{
Request: []*config.RequestHeaderRule{
{
Operation: "propagate",
Named: "X-Test-1",
},
},
},
})
assert.Nil(t, err)

rr := httptest.NewRecorder()

clientReq, err := http.NewRequest("POST", "http://localhost", nil)
require.NoError(t, err)
clientReq.Header.Add("X-Test-1", "test1")
clientReq.Header.Add("X-Test-1", "test2")

originReq, err := http.NewRequest("POST", "http://localhost", nil)
assert.Nil(t, err)

updatedClientReq, _ := ht.OnOriginRequest(originReq, &requestContext{
logger: zap.NewNop(),
responseWriter: rr,
request: clientReq,
operation: &operationContext{},
subgraphResolver: NewSubgraphResolver(nil),
})

assert.Len(t, updatedClientReq.Header, 1)
assert.Equal(t, []string{"test1", "test2"}, updatedClientReq.Header.Values("X-Test-1"))
})

t.Run("Should propagate based on matching regex / matching", func(t *testing.T) {
ht, err := NewHeaderPropagation(&config.HeaderRules{
All: &config.GlobalHeaderRule{
Expand Down Expand Up @@ -175,7 +211,7 @@ func TestPropagateHeaderRule(t *testing.T) {

})

t.Run("Should handle nil resonses", func(t *testing.T) {
t.Run("Should handle nil responses", func(t *testing.T) {
ht, err := NewHeaderPropagation(&config.HeaderRules{
All: &config.GlobalHeaderRule{},
})
Expand Down

0 comments on commit 3807826

Please sign in to comment.