Skip to content

Commit 4bed6ab

Browse files
committed
Rework the signature of WalkFn to allow setting new values
Before this commit, you could change the nodes during walk, assuming they were pointers. But it was not possible to replace a node with a new one (unless they were stored as *interface{}, but that doesn's sound practical). You could do a new `Insert`, but that be data racy if you e.g. partitioned the tree and walked and modified it concurrently.
1 parent eacf5ee commit 4bed6ab

File tree

3 files changed

+117
-45
lines changed

3 files changed

+117
-45
lines changed

‎go.mod‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ module github.com/gohugoio/go-radix
22

33
go 1.24.0
44

5+
require github.com/frankban/quicktest v1.14.6
6+
57
require (
6-
github.com/frankban/quicktest v1.14.6 // indirect
78
github.com/google/go-cmp v0.5.9 // indirect
89
github.com/kr/pretty v0.3.1 // indirect
910
github.com/kr/text v0.2.0 // indirect

‎radix.go‎

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,29 @@ import (
55
"strings"
66
)
77

8+
// WalkFlag is a bitmask return value for Walk functions.
9+
type WalkFlag uint32
10+
11+
const (
12+
WalkContinue WalkFlag = 0
13+
WalkStop WalkFlag = 1 << 0
14+
WalkSet WalkFlag = 1 << 1
15+
)
16+
17+
// shouldStop returns true if the walk should terminate.
18+
func (w WalkFlag) shouldStop() bool {
19+
return w&WalkStop != 0
20+
}
21+
22+
// shouldSet returns true if the walk function wants to set a new value.
23+
func (w WalkFlag) shouldSet() bool {
24+
return w&WalkSet != 0
25+
}
26+
827
// WalkFn is used when walking the tree. Takes a
9-
// key and value, returning if iteration should
10-
// be terminated.
11-
type WalkFn[T any] func(s string, v T) bool
28+
// key and value, returning a WalkFlag to indicate
29+
// how to proceed and possibly a new value to set.
30+
type WalkFn[T any] func(s string, v T) (WalkFlag, T)
1231

1332
// leafNode is used to represent a value
1433
type leafNode[T any] struct {
@@ -313,10 +332,10 @@ func (t *Tree[T]) deletePrefix(parent, n *node[T], prefix string) int {
313332
if len(prefix) == 0 {
314333
// Remove the leaf node
315334
subTreeSize := 0
316-
//recursively walk from all edges of the node to be deleted
317-
recursiveWalk(n, func(s string, v T) bool {
335+
// recursively walk from all edges of the node to be deleted
336+
recursiveWalk(n, func(s string, v T) (WalkFlag, T) {
318337
subTreeSize++
319-
return false
338+
return WalkContinue, v
320339
})
321340
if n.isLeaf() {
322341
n.leaf = nil
@@ -501,8 +520,14 @@ func (t *Tree[T]) WalkPath(path string, fn WalkFn[T]) {
501520
search := path
502521
for {
503522
// Visit the leaf values if any
504-
if n.leaf != nil && fn(n.leaf.key, n.leaf.val) {
505-
return
523+
if n.leaf != nil {
524+
f, n2 := fn(n.leaf.key, n.leaf.val)
525+
if f.shouldSet() {
526+
n.leaf.val = n2
527+
}
528+
if f.shouldStop() {
529+
return
530+
}
506531
}
507532

508533
// Check for key exhaution
@@ -529,8 +554,14 @@ func (t *Tree[T]) WalkPath(path string, fn WalkFn[T]) {
529554
// recursively. Returns true if the walk should be aborted
530555
func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool {
531556
// Visit the leaf values if any
532-
if n.leaf != nil && fn(n.leaf.key, n.leaf.val) {
533-
return true
557+
if n.leaf != nil {
558+
f, n2 := fn(n.leaf.key, n.leaf.val)
559+
if f.shouldSet() {
560+
n.leaf.val = n2
561+
}
562+
if f.shouldStop() {
563+
return true
564+
}
534565
}
535566

536567
// Recurse on the children
@@ -562,9 +593,10 @@ func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool {
562593
// ToMap is used to walk the tree and convert it into a map
563594
func (t *Tree[T]) ToMap() map[string]T {
564595
out := make(map[string]T, t.size)
565-
t.Walk(func(k string, v T) bool {
596+
var zero T
597+
t.Walk(func(k string, v T) (WalkFlag, T) {
566598
out[k] = v
567-
return false
599+
return WalkContinue, zero
568600
})
569601
return out
570602
}

‎radix_test.go‎

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ func TestRadix(t *testing.T) {
2828
r := NewFromMap(inp)
2929
c.Assert(r.Len(), qt.Equals, len(inp))
3030

31-
r.Walk(func(k string, v int) bool {
32-
return false
31+
r.Walk(func(k string, v int) (WalkFlag, int) {
32+
return WalkContinue, v
3333
})
3434

3535
for k, v := range inp {
@@ -67,6 +67,45 @@ func TestRoot(t *testing.T) {
6767
c.Assert(val, qt.IsTrue)
6868
}
6969

70+
func TestWalkSet(t *testing.T) {
71+
c := qt.New(t)
72+
r := New[int]()
73+
74+
for i := range 10 {
75+
r.Insert(fmt.Sprintf("key%d", i), i)
76+
}
77+
78+
r.Walk(func(s string, v int) (WalkFlag, int) {
79+
k := fmt.Sprintf("key%d", v)
80+
c.Assert(s, qt.Equals, k)
81+
v2 := v
82+
if v%2 == 0 {
83+
v2 = v * 10
84+
}
85+
return WalkSet, v2
86+
})
87+
88+
var ints []int
89+
r.Walk(func(s string, v int) (WalkFlag, int) {
90+
ints = append(ints, v)
91+
return WalkContinue, 0
92+
})
93+
94+
c.Assert(ints, qt.DeepEquals, []int{0, 1, 20, 3, 40, 5, 60, 7, 80, 9})
95+
}
96+
97+
func TestWalkFlag(t *testing.T) {
98+
c := qt.New(t)
99+
100+
f := WalkSet
101+
c.Assert(f.shouldSet(), qt.IsTrue)
102+
c.Assert(f.shouldStop(), qt.IsFalse)
103+
104+
f = WalkSet | WalkStop
105+
c.Assert(f.shouldSet(), qt.IsTrue)
106+
c.Assert(f.shouldStop(), qt.IsTrue)
107+
}
108+
70109
func TestDelete(t *testing.T) {
71110
c := qt.New(t)
72111
r := New[bool]()
@@ -110,9 +149,9 @@ func TestDeletePrefix(t *testing.T) {
110149
c.Assert(deleted, qt.Equals, test.numDeleted)
111150

112151
out := []string{}
113-
fn := func(s string, v bool) bool {
152+
fn := func(s string, v bool) (WalkFlag, bool) {
114153
out = append(out, s)
115-
return false
154+
return WalkContinue, false
116155
}
117156
r.Walk(fn)
118157

@@ -133,7 +172,7 @@ func TestLongestPrefix(t *testing.T) {
133172
"foozip",
134173
}
135174
for _, k := range keys {
136-
r.Insert(k, nil)
175+
r.Insert(k, "")
137176
}
138177
c.Assert(r.Len(), qt.Equals, len(keys))
139178

@@ -167,7 +206,7 @@ func TestLongestPrefix(t *testing.T) {
167206

168207
func TestWalkPrefix(t *testing.T) {
169208
c := qt.New(t)
170-
r := New[any]()
209+
r := New[string]()
171210

172211
keys := []string{
173212
"foobar",
@@ -177,7 +216,7 @@ func TestWalkPrefix(t *testing.T) {
177216
"zipzap",
178217
}
179218
for _, k := range keys {
180-
r.Insert(k, nil)
219+
r.Insert(k, "")
181220
}
182221
c.Assert(r.Len(), qt.Equals, len(keys))
183222

@@ -230,9 +269,9 @@ func TestWalkPrefix(t *testing.T) {
230269

231270
for _, test := range cases {
232271
out := []string{}
233-
fn := func(s string, v any) bool {
272+
fn := func(s string, v string) (WalkFlag, string) {
234273
out = append(out, s)
235-
return false
274+
return WalkContinue, ""
236275
}
237276
r.WalkPrefix(test.inp, fn)
238277
sort.Strings(out)
@@ -243,7 +282,7 @@ func TestWalkPrefix(t *testing.T) {
243282

244283
func TestWalkPath(t *testing.T) {
245284
c := qt.New(t)
246-
r := New[any]()
285+
r := New[string]()
247286

248287
keys := []string{
249288
"foo",
@@ -254,7 +293,7 @@ func TestWalkPath(t *testing.T) {
254293
"zipzap",
255294
}
256295
for _, k := range keys {
257-
r.Insert(k, nil)
296+
r.Insert(k, "")
258297
}
259298
c.Assert(r.Len(), qt.Equals, len(keys))
260299

@@ -299,9 +338,9 @@ func TestWalkPath(t *testing.T) {
299338

300339
for _, test := range cases {
301340
out := []string{}
302-
fn := func(s string, v any) bool {
341+
fn := func(s string, v string) (WalkFlag, string) {
303342
out = append(out, s)
304-
return false
343+
return WalkContinue, ""
305344
}
306345
r.WalkPath(test.inp, fn)
307346
sort.Strings(out)
@@ -312,20 +351,20 @@ func TestWalkPath(t *testing.T) {
312351

313352
func TestWalkDelete(t *testing.T) {
314353
c := qt.New(t)
315-
r := New[any]()
316-
r.Insert("init0/0", nil)
317-
r.Insert("init0/1", nil)
318-
r.Insert("init0/2", nil)
319-
r.Insert("init0/3", nil)
320-
r.Insert("init1/0", nil)
321-
r.Insert("init1/1", nil)
322-
r.Insert("init1/2", nil)
323-
r.Insert("init1/3", nil)
324-
r.Insert("init2", nil)
325-
326-
deleteFn := func(s string, v any) bool {
354+
r := New[string]()
355+
r.Insert("init0/0", "")
356+
r.Insert("init0/1", "")
357+
r.Insert("init0/2", "")
358+
r.Insert("init0/3", "")
359+
r.Insert("init1/0", "")
360+
r.Insert("init1/1", "")
361+
r.Insert("init1/2", "")
362+
r.Insert("init1/3", "")
363+
r.Insert("init2", "")
364+
365+
deleteFn := func(s string, v string) (WalkFlag, string) {
327366
r.Delete(s)
328-
return false
367+
return WalkContinue, ""
329368
}
330369

331370
r.WalkPrefix("init1", deleteFn)
@@ -387,24 +426,24 @@ func BenchmarkRadix(b *testing.B) {
387426

388427
b.Run("Walk", func(b *testing.B) {
389428
for b.Loop() {
390-
r.Walk(func(s string, v *v) bool {
391-
return false
429+
r.Walk(func(s string, v *v) (WalkFlag, *v) {
430+
return WalkContinue, nil
392431
})
393432
}
394433
})
395434

396435
b.Run("WalkPrefix", func(b *testing.B) {
397436
for b.Loop() {
398-
r.WalkPrefix("init50", func(s string, v *v) bool {
399-
return false
437+
r.WalkPrefix("init50", func(s string, v *v) (WalkFlag, *v) {
438+
return WalkContinue, nil
400439
})
401440
}
402441
})
403442

404443
b.Run("WalkPath", func(b *testing.B) {
405444
for b.Loop() {
406-
r.WalkPath("init50/50", func(s string, v *v) bool {
407-
return false
445+
r.WalkPath("init50/50", func(s string, v *v) (WalkFlag, *v) {
446+
return WalkContinue, nil
408447
})
409448
}
410449
})

0 commit comments

Comments
 (0)