Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ type Transport struct {
// An empty string signals that this request should not be cached.
CacheKey func(req *http.Request) string

// AlwaysUseCachedResponse is an optional func that when it returns true
// a successful response from the cache will be returned without connecting to the server.
AlwaysUseCachedResponse func(req *http.Request, key string) bool

// Around is an optional func.
// If set, the Transport will call Around at the start of RoundTrip
// and defer the returned func until the end of RoundTrip.
Expand Down Expand Up @@ -141,6 +145,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
var cachedResp *http.Response
if cacheable {
cachedResp, err = t.cachedResponse(req)
if err == nil && cachedResp != nil && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) {
return cachedResp, nil
}
} else {
// Need to invalidate an existing value
t.Cache.Delete(cacheKey)
Expand Down Expand Up @@ -242,7 +249,11 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
if t.EnableETagPair {
if etag := resp.Header.Get("etag"); etag != "" {
resp.Header.Set(XETag1, etag)
resp.Header.Set(XETag2, cachedXEtag)
etag2 := cachedXEtag
if etag2 == "" {
etag2 = etag
}
resp.Header.Set(XETag2, etag2)
} else {
etagHash = md5.New()
r = struct {
Expand All @@ -261,8 +272,11 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
if etagHash != nil {
md5Str := hex.EncodeToString(etagHash.Sum(nil))
resp.Header.Set(XETag1, md5Str)
resp.Header.Set(XETag2, cachedXEtag)

etag2 := cachedXEtag
if etag2 == "" {
etag2 = md5Str
}
resp.Header.Set(XETag2, etag2)
}
resp := *resp
resp.Body = io.NopCloser(r)
Expand Down
25 changes: 23 additions & 2 deletions httpcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func TestEnableETagPair(t *testing.T) {
_, resp := doMethod(t, "GET", "/etag", nil)
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "124567")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "124567")
}
{
_, resp := doMethod(t, "GET", "/etag", nil)
Expand All @@ -238,7 +238,7 @@ func TestEnableETagPair(t *testing.T) {
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"})
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
}
{
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"})
Expand All @@ -248,6 +248,27 @@ func TestEnableETagPair(t *testing.T) {
}
}

func TestAlwaysUseCachedResponse(t *testing.T) {
resetTest()
c := qt.New(t)
s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool {
return req.Header.Get("Hello") == "world2"
}

{
s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"})
c.Assert(s, qt.Equals, "world1")
}
{
s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"})
c.Assert(s, qt.Equals, "world1")
}
{
s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world3"})
c.Assert(s, qt.Equals, "world3")
}
}

func TestAround(t *testing.T) {
resetTest()
c := qt.New(t)
Expand Down