Skip to content

Commit 076fa00

Browse files
committed
Misc adjustments
* Rework (some) tests * Add EnableETagPair option to add a pair of eTags, even if the server does not provide one * Add CacheKey option * Remove redundant nil check * Optionally allow caching for other HTTP methods than GET and HEAD
1 parent ef54744 commit 076fa00

File tree

4 files changed

+332
-382
lines changed

4 files changed

+332
-382
lines changed

‎go.mod‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11
module github.com/gohugoio/httpcache
22

33
go 1.22.2
4+
5+
require github.com/frankban/quicktest v1.14.6
6+
7+
require (
8+
github.com/google/go-cmp v0.5.9 // indirect
9+
github.com/kr/pretty v0.3.1 // indirect
10+
github.com/kr/text v0.2.0 // indirect
11+
github.com/rogpeppe/go-internal v1.9.0 // indirect
12+
)

‎go.sum‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
2+
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
3+
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
4+
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
5+
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
6+
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
7+
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
8+
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
9+
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
10+
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
11+
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
12+
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=

‎httpcache.go‎

Lines changed: 99 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ package httpcache
88
import (
99
"bufio"
1010
"bytes"
11+
"crypto/md5"
12+
"encoding/hex"
1113
"errors"
14+
"hash"
1215
"io"
1316
"net/http"
1417
"net/http/httputil"
@@ -23,6 +26,15 @@ const (
2326
transparent
2427
// XFromCache is the header added to responses that are returned from the cache
2528
XFromCache = "X-From-Cache"
29+
30+
// xEtags is the prefix for the header with the custom etag pair set in the cached response.
31+
xEtags = "X-Etags-"
32+
33+
// XETag1 is the key for the first eTag value.
34+
XETag1 = xEtags + "1"
35+
36+
// XETag2 is the key for the second eTag value.
37+
XETag2 = xEtags + "2"
2638
)
2739

2840
// A Cache interface is used by the Transport to store and retrieve responses.
@@ -37,7 +49,16 @@ type Cache interface {
3749
}
3850

3951
// cacheKey returns the cache key for req.
40-
func cacheKey(req *http.Request) string {
52+
func (t *Transport) cacheKey(req *http.Request) string {
53+
if t.CacheKey != nil {
54+
return t.CacheKey(req)
55+
}
56+
57+
cacheable := (req.Method != http.MethodHead || req.Method == "HEAD") && req.Header.Get("range") == ""
58+
if !cacheable {
59+
return ""
60+
}
61+
4162
if req.Method == http.MethodGet {
4263
return req.URL.String()
4364
} else {
@@ -47,8 +68,8 @@ func cacheKey(req *http.Request) string {
4768

4869
// cachedResponse returns the cached http.Response for req if present, and nil
4970
// otherwise.
50-
func cachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
51-
cachedVal, ok := c.Get(cacheKey(req))
71+
func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) {
72+
cachedVal, ok := t.Cache.Get(t.cacheKey(req))
5273
if !ok {
5374
return
5475
}
@@ -63,6 +84,12 @@ type memoryCache struct {
6384
items map[string][]byte
6485
}
6586

87+
func (c *memoryCache) Size() int {
88+
c.mu.RLock()
89+
defer c.mu.RUnlock()
90+
return len(c.items)
91+
}
92+
6693
// Get returns the []byte representation of the response and true if present, false if not
6794
func (c *memoryCache) Get(key string) (resp []byte, ok bool) {
6895
c.mu.RLock()
@@ -105,11 +132,21 @@ type Transport struct {
105132
// If true, responses returned from the cache will be given an extra header, X-From-Cache
106133
MarkCachedResponses bool
107134

135+
// if EnableETagPair is true, the Transport will store the pair of eTags in the response header.
136+
// These are stored in the X-Etags-1 and X-Etags-2 headers.
137+
// If these are different, the response has been modified.
138+
// If the server does not return an eTag, the MD5 hash of the response body is used.
139+
EnableETagPair bool
140+
141+
// CacheKey is an optional func that returns the key to use to store the response.
142+
// An empty string signals that this request should not be cached.
143+
CacheKey func(req *http.Request) string
144+
108145
// Around is an optional func.
109146
// If set, the Transport will call Around at the start of RoundTrip
110147
// and defer the returned func until the end of RoundTrip.
111148
// Typically used to implement a lock that is held for the duration of the RoundTrip.
112-
Around func(key string) func()
149+
Around func(req *http.Request, key string) func()
113150
}
114151

115152
// varyMatches will return false unless all of the cached values for the headers listed in Vary
@@ -133,14 +170,18 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool {
133170
// to give the server a chance to respond with NotModified. If this happens, then the cached Response
134171
// will be returned.
135172
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
136-
cacheKey := cacheKey(req)
173+
cacheKey := t.cacheKey(req)
137174
if f := t.Around; f != nil {
138-
defer f(cacheKey)()
175+
defer f(req, cacheKey)()
139176
}
140-
cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
177+
178+
var cachedXEtag string
179+
180+
cacheable := cacheKey != ""
181+
141182
var cachedResp *http.Response
142183
if cacheable {
143-
cachedResp, err = cachedResponse(t.Cache, req)
184+
cachedResp, err = t.cachedResponse(req)
144185
} else {
145186
// Need to invalidate an existing value
146187
t.Cache.Delete(cacheKey)
@@ -155,6 +196,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
155196
if t.MarkCachedResponses {
156197
cachedResp.Header.Set(XFromCache, "1")
157198
}
199+
if t.EnableETagPair {
200+
cachedXEtag, _ = getXETags(cachedResp.Header)
201+
}
158202

159203
if varyMatches(cachedResp, req) {
160204
// Can only use cached value if the new request doesn't Vary significantly
@@ -185,15 +229,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
185229
}
186230

187231
resp, err = transport.RoundTrip(req)
188-
if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {
232+
233+
if err == nil && req.Method != http.MethodHead && resp.StatusCode == http.StatusNotModified {
189234
// Replace the 304 response with the one from cache, but update with some new headers
190235
endToEndHeaders := getEndToEndHeaders(resp.Header)
191236
for _, header := range endToEndHeaders {
192237
cachedResp.Header[header] = resp.Header[header]
193238
}
194239
resp = cachedResp
195-
} else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
196-
req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
240+
} else if (err != nil || resp.StatusCode >= 500) &&
241+
req.Method != http.MethodHead && canStaleOnError(cachedResp.Header, req.Header) {
197242
// In case of transport failure and stale-if-error activated, returns cached content
198243
// when available
199244
return cachedResp, nil
@@ -227,24 +272,51 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
227272
}
228273
}
229274
switch req.Method {
230-
case "GET":
231-
// Delay caching until EOF is reached.
232-
resp.Body = &cachingReadCloser{
233-
R: resp.Body,
275+
case http.MethodHead:
276+
respBytes, err := httputil.DumpResponse(resp, true)
277+
if err == nil {
278+
t.Cache.Set(cacheKey, respBytes)
279+
}
280+
default:
281+
var etagHash hash.Hash
282+
r := resp.Body
283+
if t.EnableETagPair {
284+
if etag := resp.Header.Get("etag"); etag != "" {
285+
resp.Header.Set(XETag1, etag)
286+
resp.Header.Set(XETag2, cachedXEtag)
287+
} else {
288+
etagHash = md5.New()
289+
r = struct {
290+
io.Reader
291+
io.Closer
292+
}{
293+
io.TeeReader(r, etagHash),
294+
resp.Body,
295+
}
296+
}
297+
}
298+
299+
r = &cachingReadCloser{
300+
R: r,
234301
OnEOF: func(r io.Reader) {
302+
if etagHash != nil {
303+
md5Str := hex.EncodeToString(etagHash.Sum(nil))
304+
resp.Header.Set(XETag1, md5Str)
305+
resp.Header.Set(XETag2, cachedXEtag)
306+
307+
}
235308
resp := *resp
236309
resp.Body = io.NopCloser(r)
237310
respBytes, err := httputil.DumpResponse(&resp, true)
238311
if err == nil {
239312
t.Cache.Set(cacheKey, respBytes)
240313
}
241314
},
315+
buf: &bytes.Buffer{},
242316
}
243-
default:
244-
respBytes, err := httputil.DumpResponse(resp, true)
245-
if err == nil {
246-
t.Cache.Set(cacheKey, respBytes)
247-
}
317+
// Delay caching until EOF is reached.
318+
resp.Body = r
319+
248320
}
249321
} else {
250322
t.Cache.Delete(cacheKey)
@@ -278,6 +350,10 @@ type timer interface {
278350

279351
var clock timer = &realClock{}
280352

353+
func getXETags(h http.Header) (string, string) {
354+
return h.Get(XETag1), h.Get(XETag2)
355+
}
356+
281357
// getFreshness will return one of fresh/stale/transparent based on the cache-control
282358
// values of the request and the response
283359
//
@@ -522,7 +598,7 @@ type cachingReadCloser struct {
522598
// OnEOF is called with a copy of the content of R when EOF is reached.
523599
OnEOF func(io.Reader)
524600

525-
buf bytes.Buffer // buf stores a copy of the content of R.
601+
buf *bytes.Buffer // buf stores a copy of the content of R.
526602
}
527603

528604
// Read reads the next len(p) bytes from R or until R is drained. The
@@ -533,7 +609,7 @@ func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
533609
n, err = r.R.Read(p)
534610
r.buf.Write(p[:n])
535611
if err == io.EOF {
536-
r.OnEOF(bytes.NewReader(r.buf.Bytes()))
612+
r.OnEOF(r.buf)
537613
}
538614
return n, err
539615
}
@@ -545,6 +621,6 @@ func (r *cachingReadCloser) Close() error {
545621
// newMemoryCacheTransport returns a new Transport using the in-memory cache implementation
546622
func newMemoryCacheTransport() *Transport {
547623
c := newMemoryCache()
548-
t := &Transport{Cache: c, MarkCachedResponses: true}
624+
t := &Transport{Cache: c}
549625
return t
550626
}

0 commit comments

Comments
 (0)