Skip to content

Commit 3819f72

Browse files
committed
Fix potential data race in Walk
The construct in #3 was added to allow setting values while walking but also allow other goroutines to walk the tree concurrently. But that premiss was flawed, as any walk would read the leaf values, hence causing a potential data race. This changes makes the Walk take an interface with an option al pre Check method that takes only the key, and a Handle method that takes both key and value. This allows the caller to skip reading the value if not needed.
1 parent 26639f3 commit 3819f72

File tree

2 files changed

+233
-54
lines changed

2 files changed

+233
-54
lines changed

‎radix.go‎

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ type WalkFlag uint32
1010

1111
const (
1212
WalkContinue WalkFlag = 0
13-
WalkStop WalkFlag = 1 << 0
14-
WalkSet WalkFlag = 1 << 1
13+
WalkSkip WalkFlag = 1 << 0
14+
WalkStop WalkFlag = 1 << 1
15+
WalkSet WalkFlag = 1 << 2
1516
)
1617

1718
// ShouldStop returns true if the walk should terminate.
@@ -24,11 +25,32 @@ func (w WalkFlag) ShouldSet() bool {
2425
return w&WalkSet != 0
2526
}
2627

28+
// ShouldSkip returns true if the walk should skip this node.
29+
func (w WalkFlag) ShouldSkip() bool {
30+
return w&WalkSkip != 0
31+
}
32+
2733
// WalkFn is used when walking the tree. Takes a
2834
// key and value, returning a WalkFlag to indicate
2935
// how to proceed and possibly a new value to set.
3036
type WalkFn[T any] func(s string, v T) (WalkFlag, T)
3137

38+
func (wf WalkFn[T]) Check(s string) WalkFlag {
39+
return WalkContinue
40+
}
41+
42+
func (wf WalkFn[T]) Handle(s string, v T) (WalkFlag, T) {
43+
return wf(s, v)
44+
}
45+
46+
// WalkHandler is the handler passed to Walk functions.
47+
// When split into two steps, this allows us to walk the
48+
// keys withut reading the values.
49+
type WalkHandler[T any] interface {
50+
Check(s string) WalkFlag
51+
Handle(s string, v T) (WalkFlag, T)
52+
}
53+
3254
// leafNode is used to represent a value
3355
type leafNode[T any] struct {
3456
key string
@@ -332,11 +354,12 @@ func (t *Tree[T]) deletePrefix(parent, n *node[T], prefix string) int {
332354
if len(prefix) == 0 {
333355
// Remove the leaf node
334356
subTreeSize := 0
335-
// recursively walk from all edges of the node to be deleted
336-
recursiveWalk(n, func(s string, v T) (WalkFlag, T) {
357+
var fn WalkFn[T] = func(s string, v T) (WalkFlag, T) {
337358
subTreeSize++
338359
return WalkContinue, v
339-
})
360+
}
361+
// recursively walk from all edges of the node to be deleted
362+
recursiveWalk(n, fn)
340363
if n.isLeaf() {
341364
n.leaf = nil
342365
}
@@ -477,18 +500,18 @@ func (t *Tree[T]) Maximum() (string, T, bool) {
477500
}
478501

479502
// Walk is used to walk the tree
480-
func (t *Tree[T]) Walk(fn WalkFn[T]) {
481-
recursiveWalk(t.root, fn)
503+
func (t *Tree[T]) Walk(h WalkHandler[T]) {
504+
recursiveWalk(t.root, h)
482505
}
483506

484507
// WalkPrefix is used to walk the tree under a prefix
485-
func (t *Tree[T]) WalkPrefix(prefix string, fn WalkFn[T]) {
508+
func (t *Tree[T]) WalkPrefix(prefix string, h WalkHandler[T]) {
486509
n := t.root
487510
search := prefix
488511
for {
489512
// Check for key exhaustion
490513
if len(search) == 0 {
491-
recursiveWalk(n, fn)
514+
recursiveWalk(n, h)
492515
return
493516
}
494517

@@ -505,7 +528,7 @@ func (t *Tree[T]) WalkPrefix(prefix string, fn WalkFn[T]) {
505528
}
506529
if strings.HasPrefix(n.prefix, search) {
507530
// Child may be under our search prefix
508-
recursiveWalk(n, fn)
531+
recursiveWalk(n, h)
509532
}
510533
return
511534
}
@@ -515,22 +538,28 @@ func (t *Tree[T]) WalkPrefix(prefix string, fn WalkFn[T]) {
515538
// from the root down to a given leaf. Where WalkPrefix walks
516539
// all the entries *under* the given prefix, this walks the
517540
// entries *above* the given prefix.
518-
func (t *Tree[T]) WalkPath(path string, fn WalkFn[T]) {
541+
func (t *Tree[T]) WalkPath(path string, h WalkHandler[T]) {
519542
n := t.root
520543
search := path
521544
for {
522545
// Visit the leaf values if any
523546
if n.leaf != nil {
524-
f, n2 := fn(n.leaf.key, n.leaf.val)
525-
if f.ShouldSet() {
526-
n.leaf.val = n2
527-
}
547+
f := h.Check(n.leaf.key)
528548
if f.ShouldStop() {
529549
return
530550
}
551+
if !f.ShouldSkip() {
552+
f, n2 := h.Handle(n.leaf.key, n.leaf.val)
553+
if f.ShouldSet() {
554+
n.leaf.val = n2
555+
}
556+
if f.ShouldStop() {
557+
return
558+
}
559+
}
531560
}
532561

533-
// Check for key exhaution
562+
// Check for key exhaustion
534563
if len(search) == 0 {
535564
return
536565
}
@@ -552,32 +581,38 @@ func (t *Tree[T]) WalkPath(path string, fn WalkFn[T]) {
552581

553582
// recursiveWalk is used to do a pre-order walk of a node
554583
// recursively. Returns true if the walk should be aborted
555-
func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool {
584+
func recursiveWalk[T any](n *node[T], h WalkHandler[T]) bool {
556585
// Visit the leaf values if any
557586
if n.leaf != nil {
558-
f, n2 := fn(n.leaf.key, n.leaf.val)
559-
if f.ShouldSet() {
560-
n.leaf.val = n2
561-
}
587+
f := h.Check(n.leaf.key)
562588
if f.ShouldStop() {
563589
return true
564590
}
591+
if !f.ShouldSkip() {
592+
f, n2 := h.Handle(n.leaf.key, n.leaf.val)
593+
if f.ShouldSet() {
594+
n.leaf.val = n2
595+
}
596+
if f.ShouldStop() {
597+
return true
598+
}
599+
}
565600
}
566601

567602
// Recurse on the children
568603
i := 0
569604
k := len(n.edges) // keeps track of number of edges in previous iteration
570605
for i < k {
571606
e := n.edges[i]
572-
if recursiveWalk(e.node, fn) {
607+
if recursiveWalk(e.node, h) {
573608
return true
574609
}
575610
// It is a possibility that the WalkFn modified the node we are
576611
// iterating on. If there are no more edges, mergeChild happened,
577612
// so the last edge became the current node n, on which we'll
578613
// iterate one last time.
579614
if len(n.edges) == 0 {
580-
return recursiveWalk(n, fn)
615+
return recursiveWalk(n, h)
581616
}
582617
// If there are now less edges than in the previous iteration,
583618
// then do not increment the loop index, since the current index
@@ -594,9 +629,10 @@ func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool {
594629
func (t *Tree[T]) ToMap() map[string]T {
595630
out := make(map[string]T, t.size)
596631
var zero T
597-
t.Walk(func(k string, v T) (WalkFlag, T) {
632+
var fn WalkFn[T] = func(k string, v T) (WalkFlag, T) {
598633
out[k] = v
599634
return WalkContinue, zero
600-
})
635+
}
636+
t.Walk(fn)
601637
return out
602638
}

0 commit comments

Comments
 (0)