Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exploratory work towards a potential v3 release #36

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ jobs:
strategy:
matrix:
go:
- "1.18"
- "1.19"
- "1.20"
- "1.21"
- "1.22"
name: Go ${{ matrix.go }} test
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- run: go test -race
- run: go test -race .
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/mroth/weightedrand/v2
module github.com/mroth/weightedrand/v3

go 1.18
go 1.22
97 changes: 30 additions & 67 deletions weightedrand.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
package weightedrand

import (
"cmp"
"errors"
"math/rand"
"sort"
"math"
"math/rand/v2"
"slices"
)

// Choice is a generic wrapper that can be used to add weights for any item.
Expand All @@ -33,30 +35,27 @@ func NewChoice[T any, W integer](item T, weight W) Choice[T, W] {
// performance on repeated calls for weighted random selection.
type Chooser[T any, W integer] struct {
data []Choice[T, W]
totals []int
max int
totals []uint64
max uint64

customRand *rand.Rand
}

// NewChooser initializes a new Chooser for picking from the provided choices.
func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], error) {
sort.Slice(choices, func(i, j int) bool {
return choices[i].Weight < choices[j].Weight
slices.SortFunc(choices, func(a, b Choice[T, W]) int {
return cmp.Compare(a.Weight, b.Weight)
})

totals := make([]int, len(choices))
runningTotal := 0
totals := make([]uint64, len(choices))
var runningTotal uint64
for i, c := range choices {
if c.Weight < 0 {
continue // ignore negative weights, can never be picked
}

// case of single ~uint64 or similar value that exceeds maxInt on its own
if uint64(c.Weight) >= maxInt {
return nil, errWeightOverflow
}

weight := int(c.Weight) // convert weight to int for internal counter usage
if (maxInt - runningTotal) <= weight {
weight := uint64(c.Weight) // convert weight to uint64 for internal counter usage
if (math.MaxUint64 - runningTotal) <= weight {
return nil, errWeightOverflow
}
runningTotal += weight
Expand All @@ -67,14 +66,14 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro
return nil, errNoValidChoices
}

return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal}, nil
return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal, customRand: nil}, nil
}

const (
intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize
maxInt = 1<<(intSize-1) - 1
maxUint64 = 1<<64 - 1
)
// SetRand applies an optional custom randomness source r for the Chooser. If
// set to nil nil, global rand will be used.
func (c *Chooser[T, W]) SetRand(r *rand.Rand) {
c.customRand = r
}

// Possible errors returned by NewChooser, preventing the creation of a Chooser
// with unsafe runtime states.
Expand All @@ -91,53 +90,17 @@ var (

// Pick returns a single weighted random Choice.Item from the Chooser.
//
// Utilizes global rand as the source of randomness. Safe for concurrent usage.
// Utilizes global rand as the source of randomness by default, which is safe
// for concurrent usage. If a custom rand source was set with SetRand, that
// source will be used instead.
func (c Chooser[T, W]) Pick() T {
r := rand.Intn(c.max) + 1
i := searchInts(c.totals, r)
return c.data[i].Item
}
var r uint64
if c.customRand == nil {
r = rand.Uint64N(c.max) + 1
} else {
r = c.customRand.Uint64N(c.max) + 1
}

// PickSource returns a single weighted random Choice.Item from the Chooser,
// utilizing the provided *rand.Rand source rs for randomness.
//
// The primary use-case for this is avoid lock contention from the global random
// source if utilizing Chooser(s) from multiple goroutines in extremely
// high-throughput situations.
//
// It is the responsibility of the caller to ensure the provided rand.Source is
// free from thread safety issues.
//
// Deprecated: Since go1.21 global rand no longer suffers from lock contention
// when used in multiple high throughput goroutines, as long as you don't
// manually seed it. Use [Chooser.Pick] instead.
func (c Chooser[T, W]) PickSource(rs *rand.Rand) T {
r := rs.Intn(c.max) + 1
i := searchInts(c.totals, r)
i, _ := slices.BinarySearch(c.totals, r)
return c.data[i].Item
}

// The standard library sort.SearchInts() just wraps the generic sort.Search()
// function, which takes a function closure to determine truthfulness. However,
// since this function is utilized within a for loop, it cannot currently be
// properly inlined by the compiler, resulting in non-trivial performance
// overhead.
//
// Thus, this is essentially manually inlined version. In our use case here, it
// results in a significant throughput increase for Pick.
//
// See also github.com/mroth/xsort.
func searchInts(a []int, x int) int {
// Possible further future optimization for searchInts via SIMD if we want
// to write some Go assembly code: http://0x80.pl/articles/simd-search.html
i, j := 0, len(a)
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
if a[h] < x {
i = h + 1
} else {
j = h
}
}
return i
}
98 changes: 33 additions & 65 deletions weightedrand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ package weightedrand
import (
"fmt"
"math"
"math/rand"
"sync"
"math/rand/v2"
"testing"
"time"
)

/******************************************************************************
Expand Down Expand Up @@ -41,37 +39,42 @@ const (
func TestNewChooser(t *testing.T) {
tests := []struct {
name string
cs []Choice[rune, int]
cs []Choice[rune, int64]
wantErr error
}{
{
name: "zero choices",
cs: []Choice[rune, int]{},
cs: []Choice[rune, int64]{},
wantErr: errNoValidChoices,
},
{
name: "no choices with positive weight",
cs: []Choice[rune, int]{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}},
cs: []Choice[rune, int64]{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}},
wantErr: errNoValidChoices,
},
{
name: "choice with weight equals 1",
cs: []Choice[rune, int]{{Item: 'a', Weight: 1}},
cs: []Choice[rune, int64]{{Item: 'a', Weight: 1}},
wantErr: nil,
},
{
name: "weight overflow",
cs: []Choice[rune, int]{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}},
name: "weight overflow",
cs: []Choice[rune, int64]{
{Item: 'a', Weight: math.MaxInt64/2 + 1},
{Item: 'b', Weight: math.MaxInt64/2 + 1},
{Item: 'c', Weight: math.MaxInt64/2 + 1},
{Item: 'd', Weight: math.MaxInt64/2 + 1},
},
wantErr: errWeightOverflow,
},
{
name: "nominal case",
cs: []Choice[rune, int]{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}},
cs: []Choice[rune, int64]{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}},
wantErr: nil,
},
{
name: "negative weight case",
cs: []Choice[rune, int]{{Item: 'a', Weight: 3}, {Item: 'b', Weight: -2}},
cs: []Choice[rune, int64]{{Item: 'a', Weight: 3}, {Item: 'b', Weight: -2}},
wantErr: nil,
},
}
Expand All @@ -96,8 +99,24 @@ func TestNewChooser(t *testing.T) {
wantErr error
}{
{
name: "weight overflow from single uint64 exceeding system maxInt",
cs: []Choice[rune, uint64]{{Item: 'a', Weight: maxInt + 1}},
name: "single uint64 equalling MaxUint64",
cs: []Choice[rune, uint64]{{Item: 'a', Weight: math.MaxUint64}},
wantErr: errWeightOverflow,
},
{
name: "single uint64 equalling MaxUint64 and a zero weight",
cs: []Choice[rune, uint64]{
{Item: 'a', Weight: math.MaxUint64},
{Item: 'b', Weight: 0},
},
wantErr: errWeightOverflow,
},
{
name: "multiple uint64s with sum MaxUint64",
cs: []Choice[rune, uint64]{
{Item: 'a', Weight: math.MaxUint64/2 + 1},
{Item: 'b', Weight: math.MaxUint64/2 + 1},
},
wantErr: errWeightOverflow,
},
}
Expand Down Expand Up @@ -139,38 +158,6 @@ func TestChooser_Pick(t *testing.T) {
verifyFrequencyCounts(t, counts, choices)
}

// TestChooser_PickSource is the same test methodology as TestChooser_Pick, but
// here we use the PickSource method and access the same chooser concurrently
// from multiple different goroutines, each providing its own source of
// randomness.
func TestChooser_PickSource(t *testing.T) {
choices := mockFrequencyChoices(t, testChoices)
chooser, err := NewChooser(choices...)
if err != nil {
t.Fatal(err)
}
t.Log("totals in chooser", chooser.totals)

counts1 := make(map[int]int)
counts2 := make(map[int]int)
var wg sync.WaitGroup
wg.Add(2)
checker := func(counts map[int]int) {
defer wg.Done()
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
for i := 0; i < testIterations/2; i++ {
c := chooser.PickSource(rs)
counts[c]++
}
}
go checker(counts1)
go checker(counts2)
wg.Wait()

verifyFrequencyCounts(t, counts1, choices)
verifyFrequencyCounts(t, counts2, choices)
}

// Similar to what is used in randutil test, but in randomized order to avoid
// any issues with algorithms that are accidentally dependant on presorted data.
func mockFrequencyChoices(t *testing.T, n int) []Choice[int, int] {
Expand Down Expand Up @@ -259,30 +246,11 @@ func BenchmarkPickParallel(b *testing.B) {
}
}

func BenchmarkPickSourceParallel(b *testing.B) {
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) {
choices := mockChoices(n)
chooser, err := NewChooser(choices...)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
for pb.Next() {
_ = chooser.PickSource(rs)
}
})
})
}
}

func mockChoices(n int) []Choice[rune, int] {
choices := make([]Choice[rune, int], 0, n)
for i := 0; i < n; i++ {
s := '🥑'
w := rand.Intn(10)
w := rand.IntN(10)
c := NewChoice(s, w)
choices = append(choices, c)
}
Expand Down
Loading