Skip to content

Commit c6b6910

Browse files
jmooringbep
authored andcommitted
tpl/collections: Improve collections.D
Changes: - If n > hi, return the full, sorted range [0, hi) of size hi. - Throw errors when seed, n, or hi are < 0 - Improve error messages Closes #14143
1 parent ca40254 commit c6b6910

File tree

2 files changed

+107
-29
lines changed

2 files changed

+107
-29
lines changed

‎tpl/collections/collections.go‎

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -537,27 +537,62 @@ type dKey struct {
537537
hi int
538538
}
539539

540-
// D returns a slice of n unique random numbers in the range [0, hi) using the provded seed,
541-
// using J. S. Vitter's Method D for sequential random sampling, from Vitter, J.S.
542-
// - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57.
543-
// See https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
544-
func (ns *Namespace) D(seed, n, hi int) []int {
545-
key := dKey{seed: cast.ToUint64(seed), n: n, hi: hi}
546-
if key.n <= 0 || key.hi <= 0 || key.n > key.hi {
547-
return nil
548-
}
549-
if key.n > maxSeqSize {
550-
panic(errSeqSizeExceedsLimit)
551-
}
552-
v, _ := ns.dCache.GetOrCreate(key, func() ([]int, error) {
540+
// D returns a sorted slice of unique random integers in the half-open interval
541+
// [0, hi) using the provided seed value. The number of elements in the
542+
// resulting slice is n or hi, whichever is less.
543+
//
544+
// If n <= hi, it returns a sorted random sample of size n using J. S. Vitter’s
545+
// Method D for sequential random sampling.
546+
//
547+
// If n > hi, it returns the full, sorted range [0, hi) of size hi.
548+
//
549+
// If n == 0 or hi == 0, it returns an empty slice.
550+
//
551+
// Reference:
552+
//
553+
// J. S. Vitter, "An efficient algorithm for sequential random sampling," ACM Trans. Math. Softw., vol. 11, no. 1, pp. 37–57, 1985.
554+
// See also: https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
555+
func (ns *Namespace) D(seed, n, hi any) ([]int, error) {
556+
seedInt, err := cast.ToInt64E(seed)
557+
if err != nil || seedInt < 0 {
558+
return nil, fmt.Errorf("the seed value (%v) must be a non-negative integer", seed)
559+
}
560+
561+
nInt, err := cast.ToIntE(n)
562+
if err != nil || nInt < 0 || nInt > maxSeqSize {
563+
return nil, fmt.Errorf("the number of requested values (%v) must be a non-negative integer <= %d", n, maxSeqSize)
564+
}
565+
566+
hiInt, err := cast.ToIntE(hi)
567+
if err != nil || hiInt < 0 || hiInt > maxSeqSize {
568+
return nil, fmt.Errorf("the maximum requested value (%v) must be a non-negative integer <= %d", hi, maxSeqSize)
569+
}
570+
571+
if nInt == 0 || hiInt == 0 {
572+
return []int{}, nil
573+
}
574+
575+
key := dKey{seed: uint64(seedInt), n: nInt, hi: hiInt}
576+
577+
v, err := ns.dCache.GetOrCreate(key, func() ([]int, error) {
578+
if key.n > key.hi {
579+
result := make([]int, key.hi)
580+
for i := 0; i < key.hi; i++ {
581+
result[i] = i
582+
}
583+
return result, nil
584+
}
585+
553586
prng := rand.New(rand.NewPCG(key.seed, 0))
554587
result := make([]int, 0, key.n)
555588
_d(prng, key.n, key.hi, func(i int) {
556589
result = append(result, i)
557590
})
591+
558592
return result, nil
559593
})
560-
return v
594+
595+
return v, err
561596
}
562597

563598
type intersector struct {

‎tpl/collections/collections_test.go‎

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -788,30 +788,73 @@ func TestUniq(t *testing.T) {
788788

789789
func TestD(t *testing.T) {
790790
t.Parallel()
791-
c := qt.New(t)
792791
ns := newNs()
793792

794-
c.Assert(ns.D(42, 5, 100), qt.DeepEquals, []int{24, 34, 66, 82, 96})
795-
c.Assert(ns.D(31, 5, 100), qt.DeepEquals, []int{12, 37, 38, 69, 98})
796-
c.Assert(ns.D(42, 9, 10), qt.DeepEquals, []int{0, 1, 2, 3, 4, 6, 7, 8, 9})
797-
c.Assert(ns.D(42, 10, 10), qt.DeepEquals, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
798-
c.Assert(ns.D(42, 11, 10), qt.IsNil) // n > hi
799-
c.Assert(ns.D(42, -5, 100), qt.IsNil)
800-
c.Assert(ns.D(42, 0, 100), qt.IsNil)
801-
c.Assert(ns.D(42, 5, 0), qt.IsNil)
802-
c.Assert(ns.D(42, 5, -10), qt.IsNil)
803-
c.Assert(ns.D(42, 5, 3000000), qt.DeepEquals, []int{720363, 1041693, 2009179, 2489106, 2873969})
804-
c.Assert(func() { ns.D(31, 2000000, 3000000) }, qt.PanicMatches, "size of result exceeds limit")
793+
const (
794+
errNumberOfRequestedValuesRegexPattern = `.*the number of requested values.*`
795+
errMaximumRequestedValueRegexPattern = `.*the maximum requested value.*`
796+
errSeedValueRegexPattern = `.*the seed value.*`
797+
)
798+
799+
tests := []struct {
800+
name string
801+
seed any
802+
n any
803+
hi any
804+
wantResult []int
805+
wantErrText string
806+
}{
807+
// n <= hi
808+
{"seed_eq_42", 42, 5, 100, []int{24, 34, 66, 82, 96}, ""},
809+
{"seed_eq_31", 31, 5, 100, []int{12, 37, 38, 69, 98}, ""},
810+
{"n_lt_hi", 42, 9, 10, []int{0, 1, 2, 3, 4, 6, 7, 8, 9}, ""},
811+
{"n_eq_hi", 42, 10, 10, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, ""},
812+
{"hi_eq_max_size", 42, 5, maxSeqSize, []int{240121, 347230, 669726, 829701, 957989}, ""},
813+
// n > hi
814+
{"n_gt_hi", 42, 11, 10, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, ""},
815+
// zero values
816+
{"seed_eq_0", 0, 5, 100, []int{0, 2, 29, 50, 72}, ""},
817+
{"n_eq_0", 42, 0, 100, []int{}, ""},
818+
{"hi_eq_0", 42, 5, 0, []int{}, ""},
819+
// errors: values < 0
820+
{"seed_lt_0", -42, 5, 100, nil, errSeedValueRegexPattern},
821+
{"n_lt_0", 42, -1, 100, nil, errNumberOfRequestedValuesRegexPattern},
822+
{"hi_lt_0", 42, 5, -100, nil, errMaximumRequestedValueRegexPattern},
823+
// errors: values that can't be cast to int
824+
{"seed_invalid_type", "foo", 5, 100, nil, errSeedValueRegexPattern},
825+
{"n_invalid_type", 42, "bar", 100, nil, errNumberOfRequestedValuesRegexPattern},
826+
{"hi_invalid_type", 42, 5, "baz", nil, errMaximumRequestedValueRegexPattern},
827+
// errors: values that exceed maxSeqSize
828+
{"n_gt_max_size", 42, maxSeqSize + 1, 10, nil, errNumberOfRequestedValuesRegexPattern},
829+
{"hi_gt_max_size", 42, 5, maxSeqSize + 1, nil, errMaximumRequestedValueRegexPattern},
830+
}
831+
832+
for _, tt := range tests {
833+
t.Run(tt.name, func(t *testing.T) {
834+
t.Parallel()
835+
c := qt.New(t)
836+
837+
got, err := ns.D(tt.seed, tt.n, tt.hi)
838+
839+
if tt.wantErrText != "" {
840+
c.Assert(err, qt.ErrorMatches, tt.wantErrText, qt.Commentf("n=%d, hi=%d", tt.n, tt.hi))
841+
c.Assert(got, qt.IsNil, qt.Commentf("Expected nil result on error"))
842+
return
843+
}
844+
845+
c.Assert(err, qt.IsNil, qt.Commentf("Did not expect an error, but got: %v", err))
846+
c.Assert(got, qt.DeepEquals, tt.wantResult)
847+
})
848+
}
805849
}
806850

807851
func BenchmarkD2(b *testing.B) {
808852
ns := newNs()
809-
810-
runBenchmark := func(seed, n, max int) {
853+
runBenchmark := func(seed, n, max any) {
811854
name := fmt.Sprintf("n=%d,max=%d", n, max)
812855
b.Run(name, func(b *testing.B) {
813856
for b.Loop() {
814-
ns.D(seed, n, max)
857+
_, _ = ns.D(seed, n, max)
815858
}
816859
})
817860
}

0 commit comments

Comments
 (0)