diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e50beb9..907a6c7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,14 +7,11 @@ jobs: strategy: matrix: go: - - "1.18" - - "1.19" - - "1.20" - - "1.21" + - "1.22" name: Go ${{ matrix.go }} test steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - - run: go test -race + - run: go test -race . diff --git a/go.mod b/go.mod index 1ed6cf1..dc2b8d2 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/mroth/weightedrand/v2 +module github.com/mroth/weightedrand/v3 -go 1.18 +go 1.22 diff --git a/weightedrand.go b/weightedrand.go index ec9a73b..b8fd40d 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -9,9 +9,11 @@ package weightedrand import ( + "cmp" "errors" - "math/rand" - "sort" + "math" + "math/rand/v2" + "slices" ) // Choice is a generic wrapper that can be used to add weights for any item. @@ -33,30 +35,27 @@ func NewChoice[T any, W integer](item T, weight W) Choice[T, W] { // performance on repeated calls for weighted random selection. type Chooser[T any, W integer] struct { data []Choice[T, W] - totals []int - max int + totals []uint64 + max uint64 + + customRand *rand.Rand } // NewChooser initializes a new Chooser for picking from the provided choices. func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], error) { - sort.Slice(choices, func(i, j int) bool { - return choices[i].Weight < choices[j].Weight + slices.SortFunc(choices, func(a, b Choice[T, W]) int { + return cmp.Compare(a.Weight, b.Weight) }) - totals := make([]int, len(choices)) - runningTotal := 0 + totals := make([]uint64, len(choices)) + var runningTotal uint64 for i, c := range choices { if c.Weight < 0 { continue // ignore negative weights, can never be picked } - // case of single ~uint64 or similar value that exceeds maxInt on its own - if uint64(c.Weight) >= maxInt { - return nil, errWeightOverflow - } - - weight := int(c.Weight) // convert weight to int for internal counter usage - if (maxInt - runningTotal) <= weight { + weight := uint64(c.Weight) // convert weight to uint64 for internal counter usage + if (math.MaxUint64 - runningTotal) <= weight { return nil, errWeightOverflow } runningTotal += weight @@ -67,14 +66,14 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro return nil, errNoValidChoices } - return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal}, nil + return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal, customRand: nil}, nil } -const ( - intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize - maxInt = 1<<(intSize-1) - 1 - maxUint64 = 1<<64 - 1 -) +// SetRand applies an optional custom randomness source r for the Chooser. If +// set to nil nil, global rand will be used. +func (c *Chooser[T, W]) SetRand(r *rand.Rand) { + c.customRand = r +} // Possible errors returned by NewChooser, preventing the creation of a Chooser // with unsafe runtime states. @@ -91,53 +90,17 @@ var ( // Pick returns a single weighted random Choice.Item from the Chooser. // -// Utilizes global rand as the source of randomness. Safe for concurrent usage. +// Utilizes global rand as the source of randomness by default, which is safe +// for concurrent usage. If a custom rand source was set with SetRand, that +// source will be used instead. func (c Chooser[T, W]) Pick() T { - r := rand.Intn(c.max) + 1 - i := searchInts(c.totals, r) - return c.data[i].Item -} + var r uint64 + if c.customRand == nil { + r = rand.Uint64N(c.max) + 1 + } else { + r = c.customRand.Uint64N(c.max) + 1 + } -// PickSource returns a single weighted random Choice.Item from the Chooser, -// utilizing the provided *rand.Rand source rs for randomness. -// -// The primary use-case for this is avoid lock contention from the global random -// source if utilizing Chooser(s) from multiple goroutines in extremely -// high-throughput situations. -// -// It is the responsibility of the caller to ensure the provided rand.Source is -// free from thread safety issues. -// -// Deprecated: Since go1.21 global rand no longer suffers from lock contention -// when used in multiple high throughput goroutines, as long as you don't -// manually seed it. Use [Chooser.Pick] instead. -func (c Chooser[T, W]) PickSource(rs *rand.Rand) T { - r := rs.Intn(c.max) + 1 - i := searchInts(c.totals, r) + i, _ := slices.BinarySearch(c.totals, r) return c.data[i].Item } - -// The standard library sort.SearchInts() just wraps the generic sort.Search() -// function, which takes a function closure to determine truthfulness. However, -// since this function is utilized within a for loop, it cannot currently be -// properly inlined by the compiler, resulting in non-trivial performance -// overhead. -// -// Thus, this is essentially manually inlined version. In our use case here, it -// results in a significant throughput increase for Pick. -// -// See also github.com/mroth/xsort. -func searchInts(a []int, x int) int { - // Possible further future optimization for searchInts via SIMD if we want - // to write some Go assembly code: http://0x80.pl/articles/simd-search.html - i, j := 0, len(a) - for i < j { - h := int(uint(i+j) >> 1) // avoid overflow when computing h - if a[h] < x { - i = h + 1 - } else { - j = h - } - } - return i -} diff --git a/weightedrand_test.go b/weightedrand_test.go index 6dc642b..cda03ff 100644 --- a/weightedrand_test.go +++ b/weightedrand_test.go @@ -3,10 +3,8 @@ package weightedrand import ( "fmt" "math" - "math/rand" - "sync" + "math/rand/v2" "testing" - "time" ) /****************************************************************************** @@ -41,37 +39,42 @@ const ( func TestNewChooser(t *testing.T) { tests := []struct { name string - cs []Choice[rune, int] + cs []Choice[rune, int64] wantErr error }{ { name: "zero choices", - cs: []Choice[rune, int]{}, + cs: []Choice[rune, int64]{}, wantErr: errNoValidChoices, }, { name: "no choices with positive weight", - cs: []Choice[rune, int]{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}}, + cs: []Choice[rune, int64]{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}}, wantErr: errNoValidChoices, }, { name: "choice with weight equals 1", - cs: []Choice[rune, int]{{Item: 'a', Weight: 1}}, + cs: []Choice[rune, int64]{{Item: 'a', Weight: 1}}, wantErr: nil, }, { - name: "weight overflow", - cs: []Choice[rune, int]{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}}, + name: "weight overflow", + cs: []Choice[rune, int64]{ + {Item: 'a', Weight: math.MaxInt64/2 + 1}, + {Item: 'b', Weight: math.MaxInt64/2 + 1}, + {Item: 'c', Weight: math.MaxInt64/2 + 1}, + {Item: 'd', Weight: math.MaxInt64/2 + 1}, + }, wantErr: errWeightOverflow, }, { name: "nominal case", - cs: []Choice[rune, int]{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}}, + cs: []Choice[rune, int64]{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}}, wantErr: nil, }, { name: "negative weight case", - cs: []Choice[rune, int]{{Item: 'a', Weight: 3}, {Item: 'b', Weight: -2}}, + cs: []Choice[rune, int64]{{Item: 'a', Weight: 3}, {Item: 'b', Weight: -2}}, wantErr: nil, }, } @@ -96,8 +99,24 @@ func TestNewChooser(t *testing.T) { wantErr error }{ { - name: "weight overflow from single uint64 exceeding system maxInt", - cs: []Choice[rune, uint64]{{Item: 'a', Weight: maxInt + 1}}, + name: "single uint64 equalling MaxUint64", + cs: []Choice[rune, uint64]{{Item: 'a', Weight: math.MaxUint64}}, + wantErr: errWeightOverflow, + }, + { + name: "single uint64 equalling MaxUint64 and a zero weight", + cs: []Choice[rune, uint64]{ + {Item: 'a', Weight: math.MaxUint64}, + {Item: 'b', Weight: 0}, + }, + wantErr: errWeightOverflow, + }, + { + name: "multiple uint64s with sum MaxUint64", + cs: []Choice[rune, uint64]{ + {Item: 'a', Weight: math.MaxUint64/2 + 1}, + {Item: 'b', Weight: math.MaxUint64/2 + 1}, + }, wantErr: errWeightOverflow, }, } @@ -139,38 +158,6 @@ func TestChooser_Pick(t *testing.T) { verifyFrequencyCounts(t, counts, choices) } -// TestChooser_PickSource is the same test methodology as TestChooser_Pick, but -// here we use the PickSource method and access the same chooser concurrently -// from multiple different goroutines, each providing its own source of -// randomness. -func TestChooser_PickSource(t *testing.T) { - choices := mockFrequencyChoices(t, testChoices) - chooser, err := NewChooser(choices...) - if err != nil { - t.Fatal(err) - } - t.Log("totals in chooser", chooser.totals) - - counts1 := make(map[int]int) - counts2 := make(map[int]int) - var wg sync.WaitGroup - wg.Add(2) - checker := func(counts map[int]int) { - defer wg.Done() - rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) - for i := 0; i < testIterations/2; i++ { - c := chooser.PickSource(rs) - counts[c]++ - } - } - go checker(counts1) - go checker(counts2) - wg.Wait() - - verifyFrequencyCounts(t, counts1, choices) - verifyFrequencyCounts(t, counts2, choices) -} - // Similar to what is used in randutil test, but in randomized order to avoid // any issues with algorithms that are accidentally dependant on presorted data. func mockFrequencyChoices(t *testing.T, n int) []Choice[int, int] { @@ -259,30 +246,11 @@ func BenchmarkPickParallel(b *testing.B) { } } -func BenchmarkPickSourceParallel(b *testing.B) { - for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { - b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) { - choices := mockChoices(n) - chooser, err := NewChooser(choices...) - if err != nil { - b.Fatal(err) - } - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) - for pb.Next() { - _ = chooser.PickSource(rs) - } - }) - }) - } -} - func mockChoices(n int) []Choice[rune, int] { choices := make([]Choice[rune, int], 0, n) for i := 0; i < n; i++ { s := '🥑' - w := rand.Intn(10) + w := rand.IntN(10) c := NewChoice(s, w) choices = append(choices, c) }