From 953da99f62a7995b5b73a21940be7af2a0be59cc Mon Sep 17 00:00:00 2001 From: Matthew Rothenberg Date: Thu, 29 Oct 2020 19:44:42 -0400 Subject: [PATCH] checked constructor for potential error conditions (#4) * checked constructor for potential error conditions This commit introduces a variant of the NewChooser constructor that will error on conditions that could later cause a runtime issue during Pick(). The conditions handled are a lack of valid choices and a potential integer overflow in the running total. This is a proof of concept, but the final API may likely be different to avoid introducing extra complexity into the library. This commit merely serves as the intial code to aide a discussion in the PR. * checked constructor as new default * privatize sentinel errors for NewChooser I don't see a current use case scenario where being able to act upon these as sentinel errors would be significant, so better to avoid the API complexity and keep them private for now. Always easier to export them later if needed than taking them away once out in the wild. * docs: clean up variable names --- README.md | 4 +- examples/compbench/bench_test.go | 12 ++++-- examples/frequency/main.go | 6 ++- weightedrand.go | 64 +++++++++++++++++++++++--------- weightedrand_test.go | 61 +++++++++++++++++++++++++++--- 5 files changed, 117 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index d91d693..189eb7a 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ import ( func main() { rand.Seed(time.Now().UTC().UnixNano()) // always seed random! - c := wr.NewChooser( + chooser, _ := wr.NewChooser( wr.Choice{Item: "🍒", Weight: 0}, wr.Choice{Item: "🍋", Weight: 1}, wr.Choice{Item: "🍊", Weight: 1}, @@ -33,7 +33,7 @@ func main() { probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note the weights don't have to add up to 10, that was just done here to make the example easier to read.) */ - result := c.Pick().(string) + result := chooser.Pick().(string) fmt.Println(result) } ``` diff --git a/examples/compbench/bench_test.go b/examples/compbench/bench_test.go index d756e46..6a6f7e4 100644 --- a/examples/compbench/bench_test.go +++ b/examples/compbench/bench_test.go @@ -32,7 +32,10 @@ func BenchmarkMultiple(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(b, n) - chs := weightedrand.NewChooser(choices...) + chs, err := weightedrand.NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() for i := 0; i < b.N; i++ { chs.Pick() @@ -45,7 +48,10 @@ func BenchmarkMultiple(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(b, n) - chs := weightedrand.NewChooser(choices...) + chs, err := weightedrand.NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) @@ -84,7 +90,7 @@ func BenchmarkSingle(b *testing.B) { choices := mockChoices(b, n) b.ResetTimer() for i := 0; i < b.N; i++ { - chs := weightedrand.NewChooser(choices...) + chs, _ := weightedrand.NewChooser(choices...) chs.Pick() } }) diff --git a/examples/frequency/main.go b/examples/frequency/main.go index fdc0be0..9c1d0b3 100644 --- a/examples/frequency/main.go +++ b/examples/frequency/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "log" "math/rand" "time" @@ -11,13 +12,16 @@ import ( func main() { rand.Seed(time.Now().UTC().UnixNano()) // always seed random! - c := wr.NewChooser( + c, err := wr.NewChooser( wr.Choice{Item: '🍒', Weight: 0}, // alternatively: wr.NewChoice('🍒', 0) wr.Choice{Item: '🍋', Weight: 1}, wr.Choice{Item: '🍊', Weight: 1}, wr.Choice{Item: '🍉', Weight: 3}, wr.Choice{Item: '🥑', Weight: 5}, ) + if err != nil { + log.Fatal(err) + } /* Let's pick a bunch of fruits so we can see the distribution in action! */ fruits := make([]rune, 40*18) diff --git a/weightedrand.go b/weightedrand.go index d3adce3..90ea2aa 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -11,6 +11,7 @@ package weightedrand import ( + "errors" "math/rand" "sort" ) @@ -34,28 +35,55 @@ type Chooser struct { max int } -// NewChooser initializes a new Chooser for picking from the provided Choices. -func NewChooser(cs ...Choice) Chooser { - sort.Slice(cs, func(i, j int) bool { - return cs[i].Weight < cs[j].Weight +// NewChooser initializes a new Chooser for picking from the provided choices. +func NewChooser(choices ...Choice) (*Chooser, error) { + sort.Slice(choices, func(i, j int) bool { + return choices[i].Weight < choices[j].Weight }) - totals := make([]int, len(cs)) + + totals := make([]int, len(choices)) runningTotal := 0 - for i, c := range cs { - runningTotal += int(c.Weight) + for i, c := range choices { + weight := int(c.Weight) + if (maxInt - runningTotal) <= weight { + return nil, errWeightOverflow + } + runningTotal += weight totals[i] = runningTotal } - return Chooser{data: cs, totals: totals, max: runningTotal} + + if runningTotal <= 1 { + return nil, errNoValidChoices + } + + return &Chooser{data: choices, totals: totals, max: runningTotal}, nil } +const ( + intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize + maxInt = 1<<(intSize-1) - 1 +) + +// Possible errors returned by NewChooser, preventing the creation of a Chooser +// with unsafe runtime states. +var ( + // If the sum of provided Choice weights exceed the maximum integer value + // for the current platform (e.g. math.MaxInt32 or math.MaxInt64), then + // the internal running total will overflow, resulting in an imbalanced + // distribution generating improper results. + errWeightOverflow = errors.New("sum of Choice Weights exceeds max int") + // If there are no Choices available to the Chooser with a weight >= 1, + // there are no valid choices and Pick would produce a runtime panic. + errNoValidChoices = errors.New("zero Choices with Weight >= 1") +) + // Pick returns a single weighted random Choice.Item from the Chooser. // -// Utilizes global rand as the source of randomness -- you will likely want to -// seed it. -func (chs Chooser) Pick() interface{} { - r := rand.Intn(chs.max) + 1 - i := searchInts(chs.totals, r) - return chs.data[i].Item +// Utilizes global rand as the source of randomness. +func (c Chooser) Pick() interface{} { + r := rand.Intn(c.max) + 1 + i := searchInts(c.totals, r) + return c.data[i].Item } // PickSource returns a single weighted random Choice.Item from the Chooser, @@ -67,10 +95,10 @@ func (chs Chooser) Pick() interface{} { // // It is the responsibility of the caller to ensure the provided rand.Source is // free from thread safety issues. -func (chs Chooser) PickSource(rs *rand.Rand) interface{} { - r := rs.Intn(chs.max) + 1 - i := searchInts(chs.totals, r) - return chs.data[i].Item +func (c Chooser) PickSource(rs *rand.Rand) interface{} { + r := rs.Intn(c.max) + 1 + i := searchInts(c.totals, r) + return c.data[i].Item } // The standard library sort.SearchInts() just wraps the generic sort.Search() diff --git a/weightedrand_test.go b/weightedrand_test.go index 890cf19..7c3656c 100644 --- a/weightedrand_test.go +++ b/weightedrand_test.go @@ -18,7 +18,7 @@ import ( // not on any absolute scoring system. In this trivial case, we will assign a // weight of 0 to all but one fruit, so that the output will be predictable. func Example() { - chooser := NewChooser( + chooser, _ := NewChooser( NewChoice('🍋', 0), NewChoice('🍊', 0), NewChoice('🍉', 0), @@ -42,12 +42,52 @@ const ( testIterations = 1000000 ) +func TestNewChooser(t *testing.T) { + tests := []struct { + name string + cs []Choice + wantErr error + }{ + { + name: "zero choices", + cs: []Choice{}, + wantErr: errNoValidChoices, + }, + { + name: "no choices with positive weight", + cs: []Choice{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}}, + wantErr: errNoValidChoices, + }, + { + name: "weight overflow", + cs: []Choice{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}}, + wantErr: errWeightOverflow, + }, + { + name: "nominal case", + cs: []Choice{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}}, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewChooser(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + // TestChooser_Pick assembles a list of Choices, weighted 0-9, and tests that // over the course of 1,000,000 calls to Pick() each choice is returned more // often than choices with a lower weight. func TestChooser_Pick(t *testing.T) { choices := mockFrequencyChoices(t, testChoices) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + t.Fatal(err) + } t.Log("totals in chooser", chooser.totals) // run Pick() a million times, and record how often it returns each of the @@ -67,7 +107,10 @@ func TestChooser_Pick(t *testing.T) { // randomness. func TestChooser_PickSource(t *testing.T) { choices := mockFrequencyChoices(t, testChoices) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + t.Fatal(err) + } t.Log("totals in chooser", chooser.totals) counts1 := make(map[int]int) @@ -137,7 +180,7 @@ func BenchmarkNewChooser(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = NewChooser(choices...) + _, _ = NewChooser(choices...) } }) } @@ -147,7 +190,10 @@ func BenchmarkPick(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(n) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() for i := 0; i < b.N; i++ { @@ -161,7 +207,10 @@ func BenchmarkPickParallel(b *testing.B) { for n := BMMinChoices; n <= BMMaxChoices; n *= 10 { b.Run(strconv.Itoa(n), func(b *testing.B) { choices := mockChoices(n) - chooser := NewChooser(choices...) + chooser, err := NewChooser(choices...) + if err != nil { + b.Fatal(err) + } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))