Skip to content

Commit ac54fac

Browse files
committed
Fix potential deadlock with nested GetOrCreate calls
The new revived lock logic is also faster: ``` name old time/op new time/op delta NewGetOrCreate-10 301µs ± 0% 101µs ± 0% -66.59% (p=0.004 n=6+5) name old alloc/op new alloc/op delta NewGetOrCreate-10 19.3B ± 3% 19.7B ± 3% ~ (p=0.567 n=6+6) name old allocs/op new allocs/op delta NewGetOrCreate-10 0.00 0.00 ~ (all equal) ```
1 parent e6d6418 commit ac54fac

File tree

2 files changed

+97
-6
lines changed

2 files changed

+97
-6
lines changed

‎lazycache.go‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,16 @@ func (c *Cache[K, V]) Get(key K) (V, bool) {
8686
// Note that create, the cache prime function, is called once and then not called again for a given key
8787
// unless the cache entry is evicted; it does not block other goroutines from calling GetOrCreate,
8888
// it is not called with the cache lock held.
89+
// Note that any error returned by create will be returned by GetOrCreate and repeated calls with the same key will
90+
// receive the same error.
8991
func (c *Cache[K, V]) GetOrCreate(key K, create func(key K) (V, error)) (V, bool, error) {
9092
c.mu.Lock()
9193
w := c.get(key)
9294
if w != nil {
95+
c.mu.Unlock()
9396
w.wait()
94-
// if w.ready is set then w comes from a concurrent GetOrCreate call.
95-
if w.found || w.ready != nil {
96-
c.mu.Unlock()
97-
return w.value, w.found, nil
98-
}
97+
// If w.ready is nil, we will repeat any error from the create function to concurrent callers.
98+
return w.value, w.found, w.err
9999
}
100100

101101
w = &valueWrapper[V]{

‎lazycache_test.go‎

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,59 @@ func TestGetOrCreateConcurrent(t *testing.T) {
200200
wg.Wait()
201201
}
202202

203-
func BenchmarkGetOrCreate(b *testing.B) {
203+
func TestGetOrCreateRecursive(t *testing.T) {
204+
c := qt.New(t)
205+
206+
var wg sync.WaitGroup
207+
208+
n := 200
209+
210+
for i := 0; i < 30; i++ {
211+
cache := New[int, any](Options{MaxEntries: 1000})
212+
213+
for j := 0; j < 10; j++ {
214+
wg.Add(1)
215+
go func() {
216+
defer wg.Done()
217+
for k := 0; k < 10; k++ {
218+
// This test was added to test a deadlock situation with nested GetOrCreate calls on the same cache.
219+
// Note that the keys below are carefully selected to not overlap, as this case may still deadlock:
220+
// goroutine 1: GetOrCreate(1) => GetOrCreate(2)
221+
// goroutine 2: GetOrCreate(2) => GetOrCreate(1)
222+
key1, key2 := rand.Intn(n), rand.Intn(n)+n
223+
if key2 == key1 {
224+
key2++
225+
}
226+
shouldFail := key1%10 == 0
227+
v, found, err := cache.GetOrCreate(key1, func(key int) (any, error) {
228+
if shouldFail {
229+
return nil, fmt.Errorf("failed")
230+
}
231+
v, _, err := cache.GetOrCreate(key2, func(key int) (any, error) {
232+
return "inner", nil
233+
})
234+
c.Assert(err, qt.IsNil)
235+
return v, nil
236+
})
237+
238+
if shouldFail {
239+
c.Assert(err, qt.ErrorMatches, "failed")
240+
c.Assert(v, qt.IsNil)
241+
c.Assert(found, qt.IsFalse)
242+
} else {
243+
c.Assert(err, qt.IsNil)
244+
c.Assert(found, qt.IsTrue)
245+
c.Assert(v, qt.Equals, "inner")
246+
}
247+
}
248+
}()
249+
250+
}
251+
wg.Wait()
252+
}
253+
}
254+
255+
func BenchmarkGetOrCreateAndGet(b *testing.B) {
204256
const maxSize = 1000
205257

206258
runBenchmark := func(b *testing.B, cache *Cache[int, any], getOrCreate func(key int, create func(key int) (any, error)) (any, bool, error)) {
@@ -249,6 +301,45 @@ func BenchmarkGetOrCreate(b *testing.B) {
249301

250302
}
251303

304+
func BenchmarkGetOrCreate(b *testing.B) {
305+
const maxSize = 1000
306+
307+
r := rand.New(rand.NewSource(99))
308+
var mu sync.Mutex
309+
310+
cache := New[int, any](Options{MaxEntries: maxSize})
311+
312+
// Partially fill the cache.
313+
for i := 0; i < maxSize/3; i++ {
314+
cache.Set(i, i)
315+
}
316+
b.ResetTimer()
317+
318+
b.RunParallel(func(pb *testing.PB) {
319+
for pb.Next() {
320+
mu.Lock()
321+
i2 := r.Intn(maxSize)
322+
mu.Unlock()
323+
324+
res2, found, err := cache.GetOrCreate(i2, func(key int) (any, error) {
325+
if i2%100 == 0 {
326+
// Simulate a slow create.
327+
time.Sleep(1 * time.Second)
328+
}
329+
return i2, nil
330+
})
331+
332+
if err != nil {
333+
b.Fatal(err)
334+
}
335+
336+
if v := res2; !found || v != i2 {
337+
b.Fatalf("got %v, want %v", v, i2)
338+
}
339+
}
340+
})
341+
}
342+
252343
func BenchmarkCacheSerial(b *testing.B) {
253344
const maxSize = 1000
254345

0 commit comments

Comments
 (0)