Skip to content

Commit 1ba8087

Browse files
authored
tpl/collections: Add collections.D using Vitter's Method D for sequential random sampling
1 parent 84dd495 commit 1ba8087

File tree

8 files changed

+387
-7
lines changed

8 files changed

+387
-7
lines changed

���common/maps/cache.go‎

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,27 @@ import (
2020
// Cache is a simple thread safe cache backed by a map.
2121
type Cache[K comparable, T any] struct {
2222
m map[K]T
23+
opts CacheOptions
2324
hasBeenInitialized bool
2425
sync.RWMutex
2526
}
2627

27-
// NewCache creates a new Cache.
28+
// CacheOptions are the options for the Cache.
29+
type CacheOptions struct {
30+
// If set, the cache will not grow beyond this size.
31+
Size uint64
32+
}
33+
34+
var defaultCacheOptions = CacheOptions{}
35+
36+
// NewCache creates a new Cache with default options.
2837
func NewCache[K comparable, T any]() *Cache[K, T] {
29-
return &Cache[K, T]{m: make(map[K]T)}
38+
return &Cache[K, T]{m: make(map[K]T), opts: defaultCacheOptions}
39+
}
40+
41+
// NewCacheWithOptions creates a new Cache with the given options.
42+
func NewCacheWithOptions[K comparable, T any](opts CacheOptions) *Cache[K, T] {
43+
return &Cache[K, T]{m: make(map[K]T), opts: opts}
3044
}
3145

3246
// Delete deletes the given key from the cache.
@@ -65,6 +79,7 @@ func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) {
6579
if err != nil {
6680
return v, err
6781
}
82+
c.clearIfNeeded()
6883
c.m[key] = v
6984
return v, nil
7085
}
@@ -127,7 +142,15 @@ func (c *Cache[K, T]) SetIfAbsent(key K, value T) {
127142
}
128143
}
129144

145+
func (c *Cache[K, T]) clearIfNeeded() {
146+
if c.opts.Size > 0 && uint64(len(c.m)) >= c.opts.Size {
147+
// clear the map
148+
clear(c.m)
149+
}
150+
}
151+
130152
func (c *Cache[K, T]) set(key K, value T) {
153+
c.clearIfNeeded()
131154
c.m[key] = value
132155
}
133156

‎common/maps/cache_test.go‎

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright 2024 The Hugo Authors. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package maps
15+
16+
import (
17+
"testing"
18+
19+
qt "github.com/frankban/quicktest"
20+
)
21+
22+
func TestCacheSize(t *testing.T) {
23+
c := qt.New(t)
24+
25+
cache := NewCacheWithOptions[string, string](CacheOptions{Size: 10})
26+
27+
for i := 0; i < 30; i++ {
28+
cache.Set(string(rune('a'+i)), "value")
29+
}
30+
31+
c.Assert(len(cache.m), qt.Equals, 10)
32+
33+
for i := 20; i < 50; i++ {
34+
cache.GetOrCreate(string(rune('a'+i)), func() (string, error) {
35+
return "value", nil
36+
})
37+
}
38+
39+
c.Assert(len(cache.m), qt.Equals, 10)
40+
41+
for i := 100; i < 200; i++ {
42+
cache.SetIfAbsent(string(rune('a'+i)), "value")
43+
}
44+
45+
c.Assert(len(cache.m), qt.Equals, 10)
46+
47+
cache.InitAndGet("foo", func(
48+
get func(key string) (string, bool), set func(key string, value string),
49+
) error {
50+
for i := 50; i < 100; i++ {
51+
set(string(rune('a'+i)), "value")
52+
}
53+
return nil
54+
})
55+
56+
c.Assert(len(cache.m), qt.Equals, 10)
57+
}

‎go.mod‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ require (
7474
github.com/yuin/goldmark-emoji v1.0.6
7575
go.uber.org/automaxprocs v1.5.3
7676
gocloud.dev v0.43.0
77-
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326
77+
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
7878
golang.org/x/image v0.30.0
7979
golang.org/x/mod v0.27.0
8080
golang.org/x/net v0.43.0
@@ -190,4 +190,4 @@ require (
190190
software.sslmate.com/src/go-pkcs12 v0.2.0 // indirect
191191
)
192192

193-
go 1.24
193+
go 1.24.0

‎go.sum‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
580580
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
581581
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
582582
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
583-
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE=
584-
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
583+
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0=
584+
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4=
585585
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
586586
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
587587
golang.org/x/image v0.0.0-20210220032944-ac19c3e999fb/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=

‎tpl/collections/collections.go‎

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
"context"
2020
"errors"
2121
"fmt"
22-
"math/rand"
22+
"math/rand/v2"
2323
"reflect"
2424
"strings"
2525
"time"
@@ -41,9 +41,12 @@ func New(deps *deps.Deps) *Namespace {
4141
}
4242
loc := langs.GetLocation(language)
4343

44+
dCache := maps.NewCacheWithOptions[dKey, []int](maps.CacheOptions{Size: 100})
45+
4446
return &Namespace{
4547
loc: loc,
4648
sortComp: compare.New(loc, true),
49+
dCache: dCache,
4750
deps: deps,
4851
}
4952
}
@@ -52,6 +55,7 @@ func New(deps *deps.Deps) *Namespace {
5255
type Namespace struct {
5356
loc *time.Location
5457
sortComp *compare.Namespace
58+
dCache *maps.Cache[dKey, []int]
5559
deps *deps.Deps
5660
}
5761

@@ -520,6 +524,29 @@ func (ns *Namespace) Slice(args ...any) any {
520524
return collections.Slice(args...)
521525
}
522526

527+
type dKey struct {
528+
seed uint64
529+
n int
530+
hi int
531+
}
532+
533+
// D returns a slice of n unique random numbers in the range [0, hi) using the provded seed,
534+
// using J. S. Vitter's Method D for sequential random sampling, from Vitter, J.S.
535+
// - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57.
536+
// See https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
537+
func (ns *Namespace) D(seed, n, hi any) []int {
538+
key := dKey{seed: cast.ToUint64(seed), n: cast.ToInt(n), hi: cast.ToInt(hi)}
539+
v, _ := ns.dCache.GetOrCreate(key, func() ([]int, error) {
540+
prng := rand.New(rand.NewPCG(key.seed, 0))
541+
result := make([]int, 0, key.n)
542+
_d(prng, key.n, key.hi, func(i int) {
543+
result = append(result, i)
544+
})
545+
return result, nil
546+
})
547+
return v
548+
}
549+
523550
type intersector struct {
524551
r reflect.Value
525552
seen map[any]bool

‎tpl/collections/collections_test.go‎

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,35 @@ func TestUniq(t *testing.T) {
788788
}
789789
}
790790

791+
func TestD(t *testing.T) {
792+
t.Parallel()
793+
c := qt.New(t)
794+
ns := newNs()
795+
796+
c.Assert(ns.D(42, 5, 100), qt.DeepEquals, []int{24, 34, 66, 82, 96})
797+
c.Assert(ns.D(31, 5, 100), qt.DeepEquals, []int{12, 37, 38, 69, 98})
798+
}
799+
800+
func BenchmarkD2(b *testing.B) {
801+
ns := newNs()
802+
803+
runBenchmark := func(seed, n, max int) {
804+
name := fmt.Sprintf("n=%d,max=%d", n, max)
805+
b.Run(name, func(b *testing.B) {
806+
for i := 0; i < b.N; i++ {
807+
ns.D(seed, n, max)
808+
}
809+
})
810+
}
811+
812+
runBenchmark(32, 5, 100)
813+
runBenchmark(32, 50, 1000)
814+
runBenchmark(32, 10, 10000)
815+
runBenchmark(32, 500, 10000)
816+
runBenchmark(32, 10, 500000)
817+
runBenchmark(32, 5000, 500000)
818+
}
819+
791820
func (x *TstX) TstRp() string {
792821
return "r" + x.A
793822
}

‎tpl/collections/vitter.go‎

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// This is just a temporary fork of https://github.com/josharian/vitter (ISC License, https://github.com/josharian/vitter/blob/main/LICENSE)
2+
//
3+
// This file will be removed once https://github.com/josharian/vitter/issues/1 is resolved.
4+
5+
package collections
6+
7+
import (
8+
"math"
9+
"math/rand/v2"
10+
)
11+
12+
// https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
13+
14+
// Copyright Kevin Lawler, released under ISC License
15+
16+
// _d generates an in-order uniform random sample of size 'want' from the range [0, max) using the provided PRNG.
17+
//
18+
// Parameters:
19+
// - prng: random number generator
20+
// - want: number of samples to select
21+
// - max: upper bound of the range [0, max) from which to sample
22+
// - choose: callback function invoked with each selected index in ascending order
23+
//
24+
// If the parameters are invalid (want < 0 or want > max), no samples are selected.
25+
//
26+
// Vitter, J.S. - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57.
27+
func _d(prng *rand.Rand, want, max int, choose func(n int)) {
28+
if want <= 0 || want > max {
29+
return
30+
}
31+
// POTENTIAL_OPTIMIZATION_POINT: Christian Neukirchen points out we can replace exp(log(x)*y) by pow(x,y)
32+
// POTENTIAL_OPTIMIZATION_POINT: Vitter paper points out an exponentially distributed random var can provide speed ups
33+
// 'a' is space allocated for the hand
34+
// 'n' is the size of the hand
35+
// 'N' is the upper bound on the random card values
36+
j := -1
37+
qu1 := -want + 1 + max
38+
const negalphainv = -13 // threshold parameter from Vitter's paper for algorithm selection
39+
threshold := -negalphainv * want
40+
41+
wantf := float64(want)
42+
maxf := float64(max)
43+
ninv := 1.0 / wantf
44+
var nmin1inv float64
45+
Vprime := math.Exp(math.Log(prng.Float64()) * ninv)
46+
47+
qu1real := -wantf + 1.0 + maxf
48+
var U, X, y1, y2, top, bottom, negSreal float64
49+
50+
for want > 1 && threshold < max {
51+
var S int
52+
53+
nmin1inv = 1.0 / (-1.0 + wantf)
54+
55+
for {
56+
for {
57+
X = maxf * (-Vprime + 1.0)
58+
S = int(math.Floor(X))
59+
60+
if S < qu1 {
61+
break
62+
}
63+
64+
Vprime = math.Exp(math.Log(prng.Float64()) * ninv)
65+
}
66+
67+
U = prng.Float64()
68+
negSreal = float64(-S)
69+
y1 = math.Exp(math.Log(U*maxf/qu1real) * nmin1inv)
70+
Vprime = y1 * (-X/maxf + 1.0) * (qu1real / (negSreal + qu1real))
71+
72+
if Vprime <= 1.0 {
73+
break
74+
}
75+
76+
y2 = 1.0
77+
top = -1.0 + maxf
78+
var limit int
79+
80+
if -1+want > S {
81+
bottom = -wantf + maxf
82+
limit = -S + max
83+
} else {
84+
bottom = -1.0 + negSreal + maxf
85+
limit = qu1
86+
}
87+
88+
for t := max - 1; t >= limit; t-- {
89+
y2 = (y2 * top) / bottom
90+
top--
91+
bottom--
92+
}
93+
94+
if maxf/(-X+maxf) >= y1*math.Exp(math.Log(y2)*nmin1inv) {
95+
Vprime = math.Exp(math.Log(prng.Float64()) * nmin1inv)
96+
break
97+
}
98+
99+
Vprime = math.Exp(math.Log(prng.Float64()) * ninv)
100+
}
101+
102+
j += S + 1
103+
104+
choose(j)
105+
106+
max = -S + (-1 + max)
107+
maxf = negSreal + (-1.0 + maxf)
108+
want--
109+
wantf--
110+
ninv = nmin1inv
111+
112+
qu1 = -S + qu1
113+
qu1real = negSreal + qu1real
114+
115+
threshold += negalphainv
116+
}
117+
118+
if want > 1 {
119+
methodA(prng, want, max, j, choose) // if i>0 then n has been decremented
120+
} else {
121+
S := int(math.Floor(float64(max) * Vprime))
122+
123+
j += S + 1
124+
125+
choose(j)
126+
}
127+
}
128+
129+
// methodA is the simpler fallback algorithm used when Algorithm D's optimizations are not beneficial.
130+
func methodA(prng *rand.Rand, want, max int, j int, choose func(n int)) {
131+
for want >= 2 {
132+
j++
133+
V := prng.Float64()
134+
quot := float64(max-want) / float64(max)
135+
for quot > V {
136+
j++
137+
max--
138+
quot *= float64(max - want)
139+
quot /= float64(max)
140+
}
141+
choose(j)
142+
max--
143+
want--
144+
}
145+
146+
S := int(math.Floor(float64(max) * prng.Float64()))
147+
j += S + 1
148+
choose(j)
149+
}

0 commit comments

Comments
 (0)