diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a8b17d..1d064e7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,14 +7,6 @@ jobs: strategy: matrix: go: - - "1.10" - - "1.11" - - "1.12" - - "1.13" - - "1.14" - - "1.15" - - "1.16" - - "1.17" - "1.18" - "1.19" name: Go ${{ matrix.go }} test diff --git a/README.md b/README.md index 7a2e028..1d1daa4 100644 --- a/README.md +++ b/README.md @@ -16,24 +16,24 @@ element to be selected are not equal, but rather defined by relative "weights" ```go import ( /* ...snip... */ - wr "github.com/mroth/weightedrand" + "github.com/mroth/weightedrand/v2" ) func main() { rand.Seed(time.Now().UTC().UnixNano()) // always seed random! - chooser, _ := wr.NewChooser( - wr.Choice{Item: "🍒", Weight: 0}, - wr.Choice{Item: "🍋", Weight: 1}, - wr.Choice{Item: "🍊", Weight: 1}, - wr.Choice{Item: "🍉", Weight: 3}, - wr.Choice{Item: "🥑", Weight: 5}, + chooser, _ := weightedrand.NewChooser( + weightedrand.NewChoice('🍒', 0), + weightedrand.NewChoice('🍋', 1), + weightedrand.NewChoice('🍊', 1), + weightedrand.NewChoice('🍉', 3), + weightedrand.NewChoice('🥑', 5), ) - /* The following will print 🍋 and 🍊 with 0.1 probability, 🍉 with 0.3 - probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note - the weights don't have to add up to 10, that was just done here to make the - example easier to read.) */ - result := chooser.Pick().(string) + // The following will print 🍋 and 🍊 with 0.1 probability, 🍉 with 0.3 + // probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note + // the weights don't have to add up to 10, that was just done here to make + // the example easier to read.) + result := chooser.Pick() fmt.Println(result) } ``` @@ -73,6 +73,11 @@ right choice! If you are only picking from the same distribution once, `randutil` will be faster. `weightedrand` optimizes for repeated calls at the expense of some initialization time and memory storage. +## Requirements + +weightedrand >= v2 requires go1.18 or greater. For support on earlier versions +of go, use weightedrand [v1](https://github.com/mroth/weightedrand/tree/v1). + ## Credits To better understand the algorithm used in this library (as well as the one used diff --git a/examples/compbench/bench_test.go b/examples/compbench/bench_test.go index 6a6f7e4..c8e417b 100644 --- a/examples/compbench/bench_test.go +++ b/examples/compbench/bench_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/jmcvetta/randutil" - "github.com/mroth/weightedrand" + "github.com/mroth/weightedrand/v2" ) const BMMinChoices = 10 @@ -98,9 +98,9 @@ func BenchmarkSingle(b *testing.B) { }) } -func mockChoices(tb testing.TB, n int) []weightedrand.Choice { +func mockChoices(tb testing.TB, n int) []weightedrand.Choice[rune, uint] { tb.Helper() - choices := make([]weightedrand.Choice, 0, n) + choices := make([]weightedrand.Choice[rune, uint], 0, n) for i := 0; i < n; i++ { s := '🥑' w := rand.Intn(10) @@ -110,7 +110,7 @@ func mockChoices(tb testing.TB, n int) []weightedrand.Choice { return choices } -func convertChoices(tb testing.TB, cs []weightedrand.Choice) []randutil.Choice { +func convertChoices(tb testing.TB, cs []weightedrand.Choice[rune, uint]) []randutil.Choice { tb.Helper() res := make([]randutil.Choice, len(cs)) for i, c := range cs { diff --git a/examples/compbench/go.mod b/examples/compbench/go.mod index 205ae69..bc67fc8 100644 --- a/examples/compbench/go.mod +++ b/examples/compbench/go.mod @@ -1,10 +1,10 @@ module github.com/mroth/weightedrand/examples/compbench -go 1.15 +go 1.18 require ( github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff - github.com/mroth/weightedrand v0.0.0 + github.com/mroth/weightedrand/v2 v2.0.0 ) -replace github.com/mroth/weightedrand => ../.. +replace github.com/mroth/weightedrand/v2 => ../.. diff --git a/examples/frequency/main.go b/examples/frequency/main.go index 9c1d0b3..08db4af 100644 --- a/examples/frequency/main.go +++ b/examples/frequency/main.go @@ -6,18 +6,18 @@ import ( "math/rand" "time" - wr "github.com/mroth/weightedrand" + "github.com/mroth/weightedrand/v2" ) func main() { rand.Seed(time.Now().UTC().UnixNano()) // always seed random! - c, err := wr.NewChooser( - wr.Choice{Item: '🍒', Weight: 0}, // alternatively: wr.NewChoice('🍒', 0) - wr.Choice{Item: '🍋', Weight: 1}, - wr.Choice{Item: '🍊', Weight: 1}, - wr.Choice{Item: '🍉', Weight: 3}, - wr.Choice{Item: '🥑', Weight: 5}, + c, err := weightedrand.NewChooser( + weightedrand.NewChoice('🍒', 0), + weightedrand.NewChoice('🍋', 1), + weightedrand.NewChoice('🍊', 1), + weightedrand.NewChoice('🍉', 3), + weightedrand.NewChoice('🥑', 5), ) if err != nil { log.Fatal(err) @@ -26,7 +26,7 @@ func main() { /* Let's pick a bunch of fruits so we can see the distribution in action! */ fruits := make([]rune, 40*18) for i := 0; i < len(fruits); i++ { - fruits[i] = c.Pick().(rune) + fruits[i] = c.Pick() } fmt.Println(string(fruits)) diff --git a/go.mod b/go.mod index edcdc7e..1ed6cf1 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/mroth/weightedrand +module github.com/mroth/weightedrand/v2 -go 1.10 +go 1.18 diff --git a/weightedrand.go b/weightedrand.go index 9c6c943..3320121 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -17,26 +17,30 @@ import ( ) // Choice is a generic wrapper that can be used to add weights for any item. -type Choice struct { - Item interface{} - Weight uint +type Choice[T any, W integer] struct { + Item T + Weight W +} + +type integer interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr } // NewChoice creates a new Choice with specified item and weight. -func NewChoice(item interface{}, weight uint) Choice { - return Choice{Item: item, Weight: weight} +func NewChoice[T any, W integer](item T, weight W) Choice[T, W] { + return Choice[T, W]{Item: item, Weight: weight} } // A Chooser caches many possible Choices in a structure designed to improve // performance on repeated calls for weighted random selection. -type Chooser struct { - data []Choice +type Chooser[T any, W integer] struct { + data []Choice[T, W] totals []int max int } // NewChooser initializes a new Chooser for picking from the provided choices. -func NewChooser(choices ...Choice) (*Chooser, error) { +func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], error) { sort.Slice(choices, func(i, j int) bool { return choices[i].Weight < choices[j].Weight }) @@ -45,6 +49,10 @@ func NewChooser(choices ...Choice) (*Chooser, error) { runningTotal := 0 for i, c := range choices { weight := int(c.Weight) + if weight < 0 { + continue // ignore negative weights, can never be picked + } + if (maxInt - runningTotal) <= weight { return nil, errWeightOverflow } @@ -56,7 +64,7 @@ func NewChooser(choices ...Choice) (*Chooser, error) { return nil, errNoValidChoices } - return &Chooser{data: choices, totals: totals, max: runningTotal}, nil + return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal}, nil } const ( @@ -80,7 +88,7 @@ var ( // Pick returns a single weighted random Choice.Item from the Chooser. // // Utilizes global rand as the source of randomness. -func (c Chooser) Pick() interface{} { +func (c Chooser[T, W]) Pick() T { r := rand.Intn(c.max) + 1 i := searchInts(c.totals, r) return c.data[i].Item @@ -95,7 +103,7 @@ func (c Chooser) Pick() interface{} { // // It is the responsibility of the caller to ensure the provided rand.Source is // free from thread safety issues. -func (c Chooser) PickSource(rs *rand.Rand) interface{} { +func (c Chooser[T, W]) PickSource(rs *rand.Rand) T { r := rs.Intn(c.max) + 1 i := searchInts(c.totals, r) return c.data[i].Item diff --git a/weightedrand_test.go b/weightedrand_test.go index 29f8ba3..bcc8358 100644 --- a/weightedrand_test.go +++ b/weightedrand_test.go @@ -24,7 +24,7 @@ func Example() { NewChoice('🍉', 0), NewChoice('🥑', 42), ) - fruit := chooser.Pick().(rune) + fruit := chooser.Pick() fmt.Printf("%c", fruit) //Output: 🥑 } @@ -45,32 +45,37 @@ const ( func TestNewChooser(t *testing.T) { tests := []struct { name string - cs []Choice + cs []Choice[rune, int] wantErr error }{ { name: "zero choices", - cs: []Choice{}, + cs: []Choice[rune, int]{}, wantErr: errNoValidChoices, }, { name: "no choices with positive weight", - cs: []Choice{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}}, + cs: []Choice[rune, int]{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}}, wantErr: errNoValidChoices, }, { name: "choice with weight equals 1", - cs: []Choice{{Item: 'a', Weight: 1}}, + cs: []Choice[rune, int]{{Item: 'a', Weight: 1}}, wantErr: nil, }, { name: "weight overflow", - cs: []Choice{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}}, + cs: []Choice[rune, int]{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}}, wantErr: errWeightOverflow, }, { name: "nominal case", - cs: []Choice{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}}, + cs: []Choice[rune, int]{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}}, + wantErr: nil, + }, + { + name: "negative weight case", + cs: []Choice[rune, int]{{Item: 'a', Weight: 3}, {Item: 'b', Weight: -2}}, wantErr: nil, }, } @@ -100,7 +105,7 @@ func TestChooser_Pick(t *testing.T) { counts := make(map[int]int) for i := 0; i < testIterations; i++ { c := chooser.Pick() - counts[c.(int)]++ + counts[c]++ } verifyFrequencyCounts(t, counts, choices) @@ -127,7 +132,7 @@ func TestChooser_PickSource(t *testing.T) { rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) for i := 0; i < testIterations/2; i++ { c := chooser.PickSource(rs) - counts[c.(int)]++ + counts[c]++ } } go checker(counts1) @@ -140,19 +145,19 @@ func TestChooser_PickSource(t *testing.T) { // Similar to what is used in randutil test, but in randomized order to avoid // any issues with algorithms that are accidentally dependant on presorted data. -func mockFrequencyChoices(t *testing.T, n int) []Choice { +func mockFrequencyChoices(t *testing.T, n int) []Choice[int, int] { t.Helper() - choices := make([]Choice, 0, n) + choices := make([]Choice[int, int], 0, n) list := rand.Perm(n) for _, v := range list { - c := NewChoice(v, uint(v)) + c := NewChoice(v, v) choices = append(choices, c) } t.Log("mocked choices of", choices) return choices } -func verifyFrequencyCounts(t *testing.T, counts map[int]int, choices []Choice) { +func verifyFrequencyCounts(t *testing.T, counts map[int]int, choices []Choice[int, int]) { t.Helper() // Ensure weight 0 results in no results @@ -202,7 +207,7 @@ func BenchmarkPick(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = chooser.Pick().(rune) + _ = chooser.Pick() } }) } @@ -220,19 +225,19 @@ func BenchmarkPickParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) for pb.Next() { - _ = chooser.PickSource(rs).(rune) + _ = chooser.PickSource(rs) } }) }) } } -func mockChoices(n int) []Choice { - choices := make([]Choice, 0, n) +func mockChoices(n int) []Choice[rune, int] { + choices := make([]Choice[rune, int], 0, n) for i := 0; i < n; i++ { s := '🥑' w := rand.Intn(10) - c := NewChoice(s, uint(w)) + c := NewChoice(s, w) choices = append(choices, c) } return choices