Skip to content

Commit 6d9d83c

Browse files
committed
replace getHopByHop with getEndToEnd that does the right thing with headers on not-modified responses
2 parents c4f0acf + 6e141ba commit 6d9d83c

File tree

2 files changed

+103
-11
lines changed

2 files changed

+103
-11
lines changed

‎httpcache.go‎

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
176176
resp, err = transport.RoundTrip(req)
177177
if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {
178178
// Replace the 304 response with the one from cache, but update with some new headers
179-
headersToMerge := getHopByHopHeaders(resp)
180-
for _, headerKey := range headersToMerge {
181-
key := http.CanonicalHeaderKey(headerKey)
182-
if v, ok := resp.Header[key]; ok {
183-
cachedResp.Header[key] = v
184-
}
179+
endToEndHeaders := getEndToEndHeaders(resp.Header)
180+
for _, header := range endToEndHeaders {
181+
cachedResp.Header[header] = resp.Header[header]
185182
}
186183
cachedResp.Status = fmt.Sprintf("%d %s", http.StatusOK, http.StatusText(http.StatusOK))
187184
cachedResp.StatusCode = http.StatusOK
@@ -345,17 +342,32 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) {
345342
return stale
346343
}
347344

348-
func getHopByHopHeaders(resp *http.Response) []string {
345+
func containsHeader(headers []string, header string) bool {
346+
for _, v := range headers {
347+
if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) {
348+
return true
349+
}
350+
}
351+
return false
352+
}
353+
354+
func getEndToEndHeaders(respHeaders http.Header) []string {
349355
// These headers are always hop-by-hop
350-
headers := []string{"connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade"}
356+
hopByHopHeaders := []string{"connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade"}
351357

352-
for _, extra := range strings.Split(resp.Header.Get("connection"), ",") {
358+
for _, extra := range strings.Split(respHeaders.Get("connection"), ",") {
353359
// any header listed in connection, if present, is also considered hop-by-hop
354360
if strings.Trim(extra, " ") != "" {
355-
headers = append(headers, extra)
361+
hopByHopHeaders = append(hopByHopHeaders, extra)
362+
}
363+
}
364+
endToEndHeaders := []string{}
365+
for respHeader, _ := range respHeaders {
366+
if !containsHeader(hopByHopHeaders, respHeader) {
367+
endToEndHeaders = append(endToEndHeaders, respHeader)
356368
}
357369
}
358-
return headers
370+
return endToEndHeaders
359371
}
360372

361373
func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {

‎httpcache_test.go‎

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"net/http"
66
"net/http/httptest"
7+
"strconv"
78
"testing"
89
"time"
910

@@ -89,6 +90,18 @@ func (s *S) SetUpSuite(c *C) {
8990
w.Header().Set("Vary", "X-Madeup-Header")
9091
w.Write([]byte("Some text content"))
9192
}))
93+
94+
updateFieldsCounter := 0
95+
mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
96+
w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter))
97+
w.Header().Set("Etag", `"e"`)
98+
updateFieldsCounter++
99+
if r.Header.Get("if-none-match") != "" {
100+
w.WriteHeader(http.StatusNotModified)
101+
} else {
102+
w.Write([]byte("Some text content"))
103+
}
104+
}))
92105
}
93106

94107
func (s *S) TearDownSuite(c *C) {
@@ -302,6 +315,22 @@ func (s *S) TestGetVaryUnused(c *C) {
302315
c.Assert(resp2.Header.Get(XFromCache), Equals, "1")
303316
}
304317

318+
func (s *S) TestUpdateFields(c *C) {
319+
req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil)
320+
resp, err := s.client.Do(req)
321+
defer resp.Body.Close()
322+
c.Assert(err, IsNil)
323+
counter := resp.Header.Get("x-counter")
324+
325+
resp2, err2 := s.client.Do(req)
326+
defer resp2.Body.Close()
327+
c.Assert(err2, IsNil)
328+
c.Assert(resp2.Header.Get(XFromCache), Equals, "1")
329+
counter2 := resp2.Header.Get("x-counter")
330+
331+
c.Assert(counter, Not(Equals), counter2)
332+
}
333+
305334
func (s *S) TestParseCacheControl(c *C) {
306335
h := http.Header{}
307336
for _ = range parseCacheControl(h) {
@@ -459,3 +488,54 @@ func (s *S) TestMaxStaleValue(c *C) {
459488

460489
c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale)
461490
}
491+
492+
type containsHeaderChecker struct {
493+
*CheckerInfo
494+
}
495+
496+
func (c *containsHeaderChecker) Check(params []interface{}, names []string) (bool, string) {
497+
items, ok := params[0].([]string)
498+
if !ok {
499+
return false, "Expected first param to be []string"
500+
}
501+
value, ok := params[1].(string)
502+
if !ok {
503+
return false, "Expected 2nd param to be string"
504+
}
505+
return containsHeader(items, value), ""
506+
}
507+
508+
var ContainsHeader Checker = &containsHeaderChecker{&CheckerInfo{Name: "Contains", Params: []string{"Container", "expected to contain"}}}
509+
510+
func (s *S) TestGetEndToEndHeaders(c *C) {
511+
var (
512+
headers http.Header
513+
end2end []string
514+
)
515+
516+
headers = http.Header{}
517+
headers.Set("content-type", "text/html")
518+
headers.Set("te", "deflate")
519+
520+
end2end = getEndToEndHeaders(headers)
521+
c.Check(end2end, ContainsHeader, "content-type")
522+
c.Check(end2end, Not(ContainsHeader), "te")
523+
524+
headers = http.Header{}
525+
headers.Set("connection", "content-type")
526+
headers.Set("content-type", "text/csv")
527+
headers.Set("te", "deflate")
528+
end2end = getEndToEndHeaders(headers)
529+
c.Check(end2end, Not(ContainsHeader), "connection")
530+
c.Check(end2end, Not(ContainsHeader), "content-type")
531+
c.Check(end2end, Not(ContainsHeader), "te")
532+
533+
headers = http.Header{}
534+
end2end = getEndToEndHeaders(headers)
535+
c.Check(end2end, HasLen, 0)
536+
537+
headers = http.Header{}
538+
headers.Set("connection", "content-type")
539+
end2end = getEndToEndHeaders(headers)
540+
c.Check(end2end, HasLen, 0)
541+
}

0 commit comments

Comments
 (0)