Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ const (
XETag1 = xEtags + "1"

// XETag2 is the key for the second eTag value.
// Note that in the cache, XETag1 and XETag2 will always be the same.
// In the Response returned from Response, XETag1 will be the cached value (old) and
// XETag2 will be the eTag value from the server (new).
XETag2 = xEtags + "2"
)

// A Cache interface is used by the Transport to store and retrieve responses.
type Cache interface {
// Get returns the []byte representation of a cached response and a bool
// set to true if the value isn't empty
// set to set to false if the key is not found or the value is stale.
Get(key string) (responseBytes []byte, ok bool)
// Set stores the []byte representation of a response against a key
Set(key string, responseBytes []byte)
Expand All @@ -65,16 +68,19 @@ func (t *Transport) cacheKey(req *http.Request) string {
}
}

// cachedResponse returns the cached http.Response for req if present, and nil
// otherwise.
func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) {
// cachedResponse returns the cached http.Response for req if present and
// a bool set to false if the value is stale.
func (t *Transport) cachedResponse(req *http.Request) (*http.Response, bool, error) {
cachedVal, ok := t.Cache.Get(t.cacheKey(req))
if !ok {
return
if !ok && len(cachedVal) == 0 {
return nil, false, nil
}

b := bytes.NewBuffer(cachedVal)
return http.ReadResponse(bufio.NewReader(b), req)
resp, err := http.ReadResponse(bufio.NewReader(b), req)
if err != nil {
return nil, false, err
}
return resp, ok, nil
}

// Transport is an implementation of http.RoundTripper that will return values from a cache
Expand Down Expand Up @@ -145,10 +151,13 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error

cacheable := cacheKey != ""

var cachedResp *http.Response
var (
cachedResp *http.Response
hasCachedResp bool
)
if cacheable {
cachedResp, err = t.cachedResponse(req)
if err == nil && cachedResp != nil && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) {
cachedResp, hasCachedResp, err = t.cachedResponse(req)
if err == nil && hasCachedResp && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) {
return cachedResp, nil
}
} else {
Expand All @@ -161,13 +170,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
transport = http.DefaultTransport
}

if cacheable && cachedResp != nil && err == nil {
if t.MarkCachedResponses {
cachedResp.Header.Set(XFromCache, "1")
}
if cachedResp != nil {
if t.EnableETagPair {
cachedXEtag, _ = getXETags(cachedResp.Header)
}
}

if cacheable && hasCachedResp && err == nil {
if t.MarkCachedResponses {
cachedResp.Header.Set(XFromCache, "1")
}

if varyMatches(cachedResp, req) {
// Can only use cached value if the new request doesn't Vary significantly
Expand Down Expand Up @@ -247,16 +259,19 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
t.Cache.Set(cacheKey, respBytes)
}
default:
var etagHash hash.Hash
var (
etagHash hash.Hash
etag1 = cachedXEtag
etag2 string
)

r := resp.Body
if t.EnableETagPair {
if etag := resp.Header.Get("etag"); etag != "" {
resp.Header.Set(XETag1, etag)
etag2 := cachedXEtag
etag1 = etag
if etag2 == "" {
etag2 = etag
}
resp.Header.Set(XETag2, etag2)
} else {
etagHash = md5.New()
r = struct {
Expand All @@ -274,17 +289,23 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
OnEOF: func(r io.Reader) {
if etagHash != nil {
md5Str := hex.EncodeToString(etagHash.Sum(nil))
etag2 = md5Str
resp.Header.Set(XETag1, md5Str)
etag2 := cachedXEtag
if etag2 == "" {
etag2 = md5Str
resp.Header.Set(XETag2, md5Str)
if etag1 == "" {
etag1 = md5Str
}
resp.Header.Set(XETag2, etag2)
} else {
resp.Header.Set(XETag1, etag1)
resp.Header.Set(XETag2, etag1)
}

resp := *resp
resp.Body = io.NopCloser(r)
respBytes, err := httputil.DumpResponse(&resp, true)
if err == nil {
// Signal any change back to the caller.
resp.Header.Set(XETag1, etag1)
t.Cache.Set(cacheKey, respBytes)
}
},
Expand Down
50 changes: 47 additions & 3 deletions httpcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ func TestEnableETagPair(t *testing.T) {
{
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"})
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "61b7d44bc024f189195b549bf094fbe8")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "61b7d44bc024f189195b549bf094fbe8")

}
}

Expand Down Expand Up @@ -277,7 +278,6 @@ func TestShouldCache(t *testing.T) {
s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool {
return true
}

s.transport.ShouldCache = func(req *http.Request, resp *http.Response, key string) bool {
return req.Header.Get("Hello") == "world2"
}
Expand All @@ -295,6 +295,28 @@ func TestShouldCache(t *testing.T) {
}
}

func TestStaleCachedResponse(t *testing.T) {
resetTest()
s.transport.Cache = &staleCache{}
s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool {
return true
}
s.transport.EnableETagPair = true
c := qt.New(t)
{
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"})
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
}
{
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"})
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "61b7d44bc024f189195b549bf094fbe8")
}
}

func TestAround(t *testing.T) {
resetTest()
c := qt.New(t)
Expand Down Expand Up @@ -1420,3 +1442,25 @@ func (c *memoryCache) Delete(key string) {
delete(c.items, key)
c.mu.Unlock()
}

var _ Cache = &staleCache{}

type staleCache struct {
val []byte
}

func (c *staleCache) Get(key string) ([]byte, bool) {
return c.val, false
}

func (c *staleCache) Set(key string, resp []byte) {
c.val = resp
}

func (c *staleCache) Delete(key string) {
c.val = nil
}

func (c *staleCache) Size() int {
return 1
}