Skip to content

Commit 787624d

Browse files
apelissedmitshur
authored andcommitted
Cache GET requests when body is fully read. (gregjones#71)
Lazy read GET response body and dump the full response in cache once EOF has been reached. This fixes a bug with infinite streams, that the cache tries to eager read and save, and hangs the request. Fixes #70.
1 parent efb97ba commit 787624d

File tree

2 files changed

+169
-3
lines changed

2 files changed

+169
-3
lines changed

‎httpcache.go‎

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"bytes"
1212
"errors"
1313
"fmt"
14+
"io"
15+
"io/ioutil"
1416
"net/http"
1517
"net/http/httputil"
1618
"strings"
@@ -227,9 +229,25 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
227229
resp.Header.Set(fakeHeader, reqValue)
228230
}
229231
}
230-
respBytes, err := httputil.DumpResponse(resp, true)
231-
if err == nil {
232-
t.Cache.Set(cacheKey, respBytes)
232+
switch req.Method {
233+
case "GET":
234+
// Delay caching until EOF is reached.
235+
resp.Body = &cachingReadCloser{
236+
R: resp.Body,
237+
OnEOF: func(r io.Reader) {
238+
resp := *resp
239+
resp.Body = ioutil.NopCloser(r)
240+
respBytes, err := httputil.DumpResponse(&resp, true)
241+
if err == nil {
242+
t.Cache.Set(cacheKey, respBytes)
243+
}
244+
},
245+
}
246+
default:
247+
respBytes, err := httputil.DumpResponse(resp, true)
248+
if err == nil {
249+
t.Cache.Set(cacheKey, respBytes)
250+
}
233251
}
234252
} else {
235253
t.Cache.Delete(cacheKey)
@@ -498,6 +516,35 @@ func headerAllCommaSepValues(headers http.Header, name string) []string {
498516
return vals
499517
}
500518

519+
// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF
520+
// handler with a full copy of the content read from R when EOF is
521+
// reached.
522+
type cachingReadCloser struct {
523+
// Underlying ReadCloser.
524+
R io.ReadCloser
525+
// OnEOF is called with a copy of the content of R when EOF is reached.
526+
OnEOF func(io.Reader)
527+
528+
buf bytes.Buffer // buf stores a copy of the content of R.
529+
}
530+
531+
// Read reads the next len(p) bytes from R or until R is drained. The
532+
// return value n is the number of bytes read. If R has no data to
533+
// return, err is io.EOF and OnEOF is called with a full copy of what
534+
// has been read so far.
535+
func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
536+
n, err = r.R.Read(p)
537+
r.buf.Write(p[:n])
538+
if err == io.EOF {
539+
r.OnEOF(bytes.NewReader(r.buf.Bytes()))
540+
}
541+
return n, err
542+
}
543+
544+
func (r *cachingReadCloser) Close() error {
545+
return r.R.Close()
546+
}
547+
501548
// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation
502549
func NewMemoryCacheTransport() *Transport {
503550
c := NewMemoryCache()

‎httpcache_test.go‎

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ var s struct {
1818
server *httptest.Server
1919
client http.Client
2020
transport *Transport
21+
done chan struct{} // Closed to unlock infinite handlers.
2122
}
2223

2324
type fakeClock struct {
@@ -41,6 +42,7 @@ func setup() {
4142
client := http.Client{Transport: tp}
4243
s.transport = tp
4344
s.client = client
45+
s.done = make(chan struct{})
4446

4547
mux := http.NewServeMux()
4648
s.server = httptest.NewServer(mux)
@@ -134,9 +136,21 @@ func setup() {
134136
mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135137
time.Sleep(3 * time.Second)
136138
}))
139+
140+
mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141+
for {
142+
select {
143+
case <-s.done:
144+
return
145+
default:
146+
w.Write([]byte{0})
147+
}
148+
}
149+
}))
137150
}
138151

139152
func teardown() {
153+
close(s.done)
140154
s.server.Close()
141155
}
142156

@@ -316,6 +330,58 @@ func TestDontStorePartialRangeInCache(t *testing.T) {
316330
}
317331
}
318332

333+
func TestCacheOnlyIfBodyRead(t *testing.T) {
334+
resetTest()
335+
{
336+
req, err := http.NewRequest("GET", s.server.URL, nil)
337+
if err != nil {
338+
t.Fatal(err)
339+
}
340+
resp, err := s.client.Do(req)
341+
if err != nil {
342+
t.Fatal(err)
343+
}
344+
if resp.Header.Get(XFromCache) != "" {
345+
t.Fatal("XFromCache header isn't blank")
346+
}
347+
// We do not read the body
348+
resp.Body.Close()
349+
}
350+
{
351+
req, err := http.NewRequest("GET", s.server.URL, nil)
352+
if err != nil {
353+
t.Fatal(err)
354+
}
355+
resp, err := s.client.Do(req)
356+
if err != nil {
357+
t.Fatal(err)
358+
}
359+
defer resp.Body.Close()
360+
if resp.Header.Get(XFromCache) != "" {
361+
t.Fatalf("XFromCache header isn't blank")
362+
}
363+
}
364+
}
365+
366+
func TestOnlyReadBodyOnDemand(t *testing.T) {
367+
resetTest()
368+
369+
req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil)
370+
if err != nil {
371+
t.Fatal(err)
372+
}
373+
resp, err := s.client.Do(req) // This shouldn't hang forever.
374+
if err != nil {
375+
t.Fatal(err)
376+
}
377+
buf := make([]byte, 10) // Only partially read the body.
378+
_, err = resp.Body.Read(buf)
379+
if err != nil {
380+
t.Fatal(err)
381+
}
382+
resp.Body.Close()
383+
}
384+
319385
func TestGetOnlyIfCachedHit(t *testing.T) {
320386
resetTest()
321387
{
@@ -331,6 +397,10 @@ func TestGetOnlyIfCachedHit(t *testing.T) {
331397
if resp.Header.Get(XFromCache) != "" {
332398
t.Fatal("XFromCache header isn't blank")
333399
}
400+
_, err = ioutil.ReadAll(resp.Body)
401+
if err != nil {
402+
t.Fatal(err)
403+
}
334404
}
335405
{
336406
req, err := http.NewRequest("GET", s.server.URL, nil)
@@ -444,6 +514,11 @@ func TestGetWithEtag(t *testing.T) {
444514
if resp.Header.Get(XFromCache) != "" {
445515
t.Fatal("XFromCache header isn't blank")
446516
}
517+
_, err = ioutil.ReadAll(resp.Body)
518+
if err != nil {
519+
t.Fatal(err)
520+
}
521+
447522
}
448523
{
449524
resp, err := s.client.Do(req)
@@ -479,6 +554,10 @@ func TestGetWithLastModified(t *testing.T) {
479554
if resp.Header.Get(XFromCache) != "" {
480555
t.Fatal("XFromCache header isn't blank")
481556
}
557+
_, err = ioutil.ReadAll(resp.Body)
558+
if err != nil {
559+
t.Fatal(err)
560+
}
482561
}
483562
{
484563
resp, err := s.client.Do(req)
@@ -508,6 +587,10 @@ func TestGetWithVary(t *testing.T) {
508587
if resp.Header.Get("Vary") != "Accept" {
509588
t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary"))
510589
}
590+
_, err = ioutil.ReadAll(resp.Body)
591+
if err != nil {
592+
t.Fatal(err)
593+
}
511594
}
512595
{
513596
resp, err := s.client.Do(req)
@@ -560,6 +643,10 @@ func TestGetWithDoubleVary(t *testing.T) {
560643
if resp.Header.Get("Vary") == "" {
561644
t.Fatalf(`Vary header is blank`)
562645
}
646+
_, err = ioutil.ReadAll(resp.Body)
647+
if err != nil {
648+
t.Fatal(err)
649+
}
563650
}
564651
{
565652
resp, err := s.client.Do(req)
@@ -618,6 +705,10 @@ func TestGetWith2VaryHeaders(t *testing.T) {
618705
if resp.Header.Get("Vary") == "" {
619706
t.Fatalf(`Vary header is blank`)
620707
}
708+
_, err = ioutil.ReadAll(resp.Body)
709+
if err != nil {
710+
t.Fatal(err)
711+
}
621712
}
622713
{
623714
resp, err := s.client.Do(req)
@@ -673,6 +764,10 @@ func TestGetWith2VaryHeaders(t *testing.T) {
673764
if resp.Header.Get(XFromCache) != "" {
674765
t.Fatal("XFromCache header isn't blank")
675766
}
767+
_, err = ioutil.ReadAll(resp.Body)
768+
if err != nil {
769+
t.Fatal(err)
770+
}
676771
}
677772
{
678773
resp, err := s.client.Do(req)
@@ -702,6 +797,10 @@ func TestGetVaryUnused(t *testing.T) {
702797
if resp.Header.Get("Vary") == "" {
703798
t.Fatalf(`Vary header is blank`)
704799
}
800+
_, err = ioutil.ReadAll(resp.Body)
801+
if err != nil {
802+
t.Fatal(err)
803+
}
705804
}
706805
{
707806
resp, err := s.client.Do(req)
@@ -729,6 +828,10 @@ func TestUpdateFields(t *testing.T) {
729828
}
730829
defer resp.Body.Close()
731830
counter = resp.Header.Get("x-counter")
831+
_, err = ioutil.ReadAll(resp.Body)
832+
if err != nil {
833+
t.Fatal(err)
834+
}
732835
}
733836
{
734837
resp, err := s.client.Do(req)
@@ -1053,6 +1156,10 @@ func TestStaleIfErrorRequest(t *testing.T) {
10531156
if resp == nil {
10541157
t.Fatal("resp is nil")
10551158
}
1159+
_, err = ioutil.ReadAll(resp.Body)
1160+
if err != nil {
1161+
t.Fatal(err)
1162+
}
10561163

10571164
// On failure, response is returned from the cache
10581165
tmock.response = nil
@@ -1094,6 +1201,10 @@ func TestStaleIfErrorRequestLifetime(t *testing.T) {
10941201
if resp == nil {
10951202
t.Fatal("resp is nil")
10961203
}
1204+
_, err = ioutil.ReadAll(resp.Body)
1205+
if err != nil {
1206+
t.Fatal(err)
1207+
}
10971208

10981209
// On failure, response is returned from the cache
10991210
tmock.response = nil
@@ -1152,6 +1263,10 @@ func TestStaleIfErrorResponse(t *testing.T) {
11521263
if resp == nil {
11531264
t.Fatal("resp is nil")
11541265
}
1266+
_, err = ioutil.ReadAll(resp.Body)
1267+
if err != nil {
1268+
t.Fatal(err)
1269+
}
11551270

11561271
// On failure, response is returned from the cache
11571272
tmock.response = nil
@@ -1192,6 +1307,10 @@ func TestStaleIfErrorResponseLifetime(t *testing.T) {
11921307
if resp == nil {
11931308
t.Fatal("resp is nil")
11941309
}
1310+
_, err = ioutil.ReadAll(resp.Body)
1311+
if err != nil {
1312+
t.Fatal(err)
1313+
}
11951314

11961315
// On failure, response is returned from the cache
11971316
tmock.response = nil

0 commit comments

Comments
 (0)