diff --git a/README.md b/README.md index d91d693..189eb7a 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ import ( func main() { rand.Seed(time.Now().UTC().UnixNano()) // always seed random! - c := wr.NewChooser( + chooser, _ := wr.NewChooser( wr.Choice{Item: "🍒", Weight: 0}, wr.Choice{Item: "🍋", Weight: 1}, wr.Choice{Item: "🍊", Weight: 1}, @@ -33,7 +33,7 @@ func main() { probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note the weights don't have to add up to 10, that was just done here to make the example easier to read.) */ - result := c.Pick().(string) + result := chooser.Pick().(string) fmt.Println(result) } ``` diff --git a/examples/compbench/bench_test.go b/examples/compbench/bench_test.go index d756e46..6a6f7e4 100644 --- a/examples/compbench/bench_test.go +++ b/examples/compbench/bench_test.go @@ -32,7 +32,10 @@ func BenchmarkMultiple(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(b, n) - chs := weightedrand.NewChooser(choices...) + chs, err := weightedrand.NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() for i := 0; i < b.N; i++ { chs.Pick() @@ -45,7 +48,10 @@ func BenchmarkMultiple(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(b, n) - chs := weightedrand.NewChooser(choices...) + chs, err := weightedrand.NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) @@ -84,7 +90,7 @@ func BenchmarkSingle(b *testing.B) { choices := mockChoices(b, n) b.ResetTimer() for i := 0; i < b.N; i++ { - chs := weightedrand.NewChooser(choices...) + chs, _ := weightedrand.NewChooser(choices...) chs.Pick() } }) diff --git a/examples/frequency/main.go b/examples/frequency/main.go index fdc0be0..9c1d0b3 100644 --- a/examples/frequency/main.go +++ b/examples/frequency/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "log" "math/rand" "time" @@ -11,13 +12,16 @@ import ( func main() { rand.Seed(time.Now().UTC().UnixNano()) // always seed random! - c := wr.NewChooser( + c, err := wr.NewChooser( wr.Choice{Item: '🍒', Weight: 0}, // alternatively: wr.NewChoice('🍒', 0) wr.Choice{Item: '🍋', Weight: 1}, wr.Choice{Item: '🍊', Weight: 1}, wr.Choice{Item: '🍉', Weight: 3}, wr.Choice{Item: '🥑', Weight: 5}, ) + if err != nil { + log.Fatal(err) + } /* Let's pick a bunch of fruits so we can see the distribution in action! */ fruits := make([]rune, 40*18) diff --git a/weightedrand.go b/weightedrand.go index d3adce3..90ea2aa 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -11,6 +11,7 @@ package weightedrand import ( + "errors" "math/rand" "sort" ) @@ -34,28 +35,55 @@ type Chooser struct { max int } -// NewChooser initializes a new Chooser for picking from the provided Choices. -func NewChooser(cs ...Choice) Chooser { - sort.Slice(cs, func(i, j int) bool { - return cs[i].Weight < cs[j].Weight +// NewChooser initializes a new Chooser for picking from the provided choices. +func NewChooser(choices ...Choice) (*Chooser, error) { + sort.Slice(choices, func(i, j int) bool { + return choices[i].Weight < choices[j].Weight }) - totals := make([]int, len(cs)) + + totals := make([]int, len(choices)) runningTotal := 0 - for i, c := range cs { - runningTotal += int(c.Weight) + for i, c := range choices { + weight := int(c.Weight) + if (maxInt - runningTotal) <= weight { + return nil, errWeightOverflow + } + runningTotal += weight totals[i] = runningTotal } - return Chooser{data: cs, totals: totals, max: runningTotal} + + if runningTotal <= 1 { + return nil, errNoValidChoices + } + + return &Chooser{data: choices, totals: totals, max: runningTotal}, nil } +const ( + intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize + maxInt = 1<<(intSize-1) - 1 +) + +// Possible errors returned by NewChooser, preventing the creation of a Chooser +// with unsafe runtime states. +var ( + // If the sum of provided Choice weights exceed the maximum integer value + // for the current platform (e.g. math.MaxInt32 or math.MaxInt64), then + // the internal running total will overflow, resulting in an imbalanced + // distribution generating improper results. + errWeightOverflow = errors.New("sum of Choice Weights exceeds max int") + // If there are no Choices available to the Chooser with a weight >= 1, + // there are no valid choices and Pick would produce a runtime panic. + errNoValidChoices = errors.New("zero Choices with Weight >= 1") +) + // Pick returns a single weighted random Choice.Item from the Chooser. // -// Utilizes global rand as the source of randomness -- you will likely want to -// seed it. -func (chs Chooser) Pick() interface{} { - r := rand.Intn(chs.max) + 1 - i := searchInts(chs.totals, r) - return chs.data[i].Item +// Utilizes global rand as the source of randomness. +func (c Chooser) Pick() interface{} { + r := rand.Intn(c.max) + 1 + i := searchInts(c.totals, r) + return c.data[i].Item } // PickSource returns a single weighted random Choice.Item from the Chooser, @@ -67,10 +95,10 @@ func (chs Chooser) Pick() interface{} { // // It is the responsibility of the caller to ensure the provided rand.Source is // free from thread safety issues. -func (chs Chooser) PickSource(rs *rand.Rand) interface{} { - r := rs.Intn(chs.max) + 1 - i := searchInts(chs.totals, r) - return chs.data[i].Item +func (c Chooser) PickSource(rs *rand.Rand) interface{} { + r := rs.Intn(c.max) + 1 + i := searchInts(c.totals, r) + return c.data[i].Item } // The standard library sort.SearchInts() just wraps the generic sort.Search() diff --git a/weightedrand_test.go b/weightedrand_test.go index 890cf19..7c3656c 100644 --- a/weightedrand_test.go +++ b/weightedrand_test.go @@ -18,7 +18,7 @@ import ( // not on any absolute scoring system. In this trivial case, we will assign a // weight of 0 to all but one fruit, so that the output will be predictable. func Example() { - chooser := NewChooser( + chooser, _ := NewChooser( NewChoice('🍋', 0), NewChoice('🍊', 0), NewChoice('🍉', 0), @@ -42,12 +42,52 @@ const ( testIterations = 1000000 ) +func TestNewChooser(t *testing.T) { + tests := []struct { + name string + cs []Choice + wantErr error + }{ + { + name: "zero choices", + cs: []Choice{}, + wantErr: errNoValidChoices, + }, + { + name: "no choices with positive weight", + cs: []Choice{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}}, + wantErr: errNoValidChoices, + }, + { + name: "weight overflow", + cs: []Choice{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}}, + wantErr: errWeightOverflow, + }, + { + name: "nominal case", + cs: []Choice{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}}, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewChooser(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + // TestChooser_Pick assembles a list of Choices, weighted 0-9, and tests that // over the course of 1,000,000 calls to Pick() each choice is returned more // often than choices with a lower weight. func TestChooser_Pick(t *testing.T) { choices := mockFrequencyChoices(t, testChoices) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + t.Fatal(err) + } t.Log("totals in chooser", chooser.totals) // run Pick() a million times, and record how often it returns each of the @@ -67,7 +107,10 @@ func TestChooser_Pick(t *testing.T) { // randomness. func TestChooser_PickSource(t *testing.T) { choices := mockFrequencyChoices(t, testChoices) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + t.Fatal(err) + } t.Log("totals in chooser", chooser.totals) counts1 := make(map[int]int) @@ -137,7 +180,7 @@ func BenchmarkNewChooser(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = NewChooser(choices...) + _, _ = NewChooser(choices...) } }) } @@ -147,7 +190,10 @@ func BenchmarkPick(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(n) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() for i := 0; i < b.N; i++ { @@ -161,7 +207,10 @@ func BenchmarkPickParallel(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(n) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))