Skip to content

Commit

Permalink
checked constructor for potential error conditions (#4)
Browse files Browse the repository at this point in the history
* checked constructor for potential error conditions

This commit introduces a variant of the NewChooser constructor that will error
on conditions that could later cause a runtime issue during Pick().

The conditions handled are a lack of valid choices and a potential integer
overflow in the running total.

This is a proof of concept, but the final API may likely be different to avoid
introducing extra complexity into the library. This commit merely serves as the
intial code to aide a discussion in the PR.

* checked constructor as new default

* privatize sentinel errors for NewChooser

I don't see a current use case scenario where being able to act upon
these as sentinel errors would be significant, so better to avoid the
API complexity and keep them private for now. Always easier to export
them later if needed than taking them away once out in the wild.

* docs: clean up variable names
  • Loading branch information
mroth authored Oct 29, 2020
1 parent 3b00289 commit 953da99
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 30 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
func main() {
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!

c := wr.NewChooser(
chooser, _ := wr.NewChooser(
wr.Choice{Item: "🍒", Weight: 0},
wr.Choice{Item: "🍋", Weight: 1},
wr.Choice{Item: "🍊", Weight: 1},
Expand All @@ -33,7 +33,7 @@ func main() {
probability, and 🥑 with 0.5 probability. 🍒 will never be printed. (Note
the weights don't have to add up to 10, that was just done here to make the
example easier to read.) */
result := c.Pick().(string)
result := chooser.Pick().(string)
fmt.Println(result)
}
```
Expand Down
12 changes: 9 additions & 3 deletions examples/compbench/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ func BenchmarkMultiple(b *testing.B) {
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
b.Run(strconv.Itoa(n), func(b *testing.B) {
choices := mockChoices(b, n)
chs := weightedrand.NewChooser(choices...)
chs, err := weightedrand.NewChooser(choices...)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
chs.Pick()
Expand All @@ -45,7 +48,10 @@ func BenchmarkMultiple(b *testing.B) {
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
b.Run(strconv.Itoa(n), func(b *testing.B) {
choices := mockChoices(b, n)
chs := weightedrand.NewChooser(choices...)
chs, err := weightedrand.NewChooser(choices...)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
Expand Down Expand Up @@ -84,7 +90,7 @@ func BenchmarkSingle(b *testing.B) {
choices := mockChoices(b, n)
b.ResetTimer()
for i := 0; i < b.N; i++ {
chs := weightedrand.NewChooser(choices...)
chs, _ := weightedrand.NewChooser(choices...)
chs.Pick()
}
})
Expand Down
6 changes: 5 additions & 1 deletion examples/frequency/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"log"
"math/rand"
"time"

Expand All @@ -11,13 +12,16 @@ import (
func main() {
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!

c := wr.NewChooser(
c, err := wr.NewChooser(
wr.Choice{Item: '🍒', Weight: 0}, // alternatively: wr.NewChoice('🍒', 0)
wr.Choice{Item: '🍋', Weight: 1},
wr.Choice{Item: '🍊', Weight: 1},
wr.Choice{Item: '🍉', Weight: 3},
wr.Choice{Item: '🥑', Weight: 5},
)
if err != nil {
log.Fatal(err)
}

/* Let's pick a bunch of fruits so we can see the distribution in action! */
fruits := make([]rune, 40*18)
Expand Down
64 changes: 46 additions & 18 deletions weightedrand.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package weightedrand

import (
"errors"
"math/rand"
"sort"
)
Expand All @@ -34,28 +35,55 @@ type Chooser struct {
max int
}

// NewChooser initializes a new Chooser for picking from the provided Choices.
func NewChooser(cs ...Choice) Chooser {
sort.Slice(cs, func(i, j int) bool {
return cs[i].Weight < cs[j].Weight
// NewChooser initializes a new Chooser for picking from the provided choices.
func NewChooser(choices ...Choice) (*Chooser, error) {
sort.Slice(choices, func(i, j int) bool {
return choices[i].Weight < choices[j].Weight
})
totals := make([]int, len(cs))

totals := make([]int, len(choices))
runningTotal := 0
for i, c := range cs {
runningTotal += int(c.Weight)
for i, c := range choices {
weight := int(c.Weight)
if (maxInt - runningTotal) <= weight {
return nil, errWeightOverflow
}
runningTotal += weight
totals[i] = runningTotal
}
return Chooser{data: cs, totals: totals, max: runningTotal}

if runningTotal <= 1 {
return nil, errNoValidChoices
}

return &Chooser{data: choices, totals: totals, max: runningTotal}, nil
}

const (
intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize
maxInt = 1<<(intSize-1) - 1
)

// Possible errors returned by NewChooser, preventing the creation of a Chooser
// with unsafe runtime states.
var (
// If the sum of provided Choice weights exceed the maximum integer value
// for the current platform (e.g. math.MaxInt32 or math.MaxInt64), then
// the internal running total will overflow, resulting in an imbalanced
// distribution generating improper results.
errWeightOverflow = errors.New("sum of Choice Weights exceeds max int")
// If there are no Choices available to the Chooser with a weight >= 1,
// there are no valid choices and Pick would produce a runtime panic.
errNoValidChoices = errors.New("zero Choices with Weight >= 1")
)

// Pick returns a single weighted random Choice.Item from the Chooser.
//
// Utilizes global rand as the source of randomness -- you will likely want to
// seed it.
func (chs Chooser) Pick() interface{} {
r := rand.Intn(chs.max) + 1
i := searchInts(chs.totals, r)
return chs.data[i].Item
// Utilizes global rand as the source of randomness.
func (c Chooser) Pick() interface{} {
r := rand.Intn(c.max) + 1
i := searchInts(c.totals, r)
return c.data[i].Item
}

// PickSource returns a single weighted random Choice.Item from the Chooser,
Expand All @@ -67,10 +95,10 @@ func (chs Chooser) Pick() interface{} {
//
// It is the responsibility of the caller to ensure the provided rand.Source is
// free from thread safety issues.
func (chs Chooser) PickSource(rs *rand.Rand) interface{} {
r := rs.Intn(chs.max) + 1
i := searchInts(chs.totals, r)
return chs.data[i].Item
func (c Chooser) PickSource(rs *rand.Rand) interface{} {
r := rs.Intn(c.max) + 1
i := searchInts(c.totals, r)
return c.data[i].Item
}

// The standard library sort.SearchInts() just wraps the generic sort.Search()
Expand Down
61 changes: 55 additions & 6 deletions weightedrand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
// not on any absolute scoring system. In this trivial case, we will assign a
// weight of 0 to all but one fruit, so that the output will be predictable.
func Example() {
chooser := NewChooser(
chooser, _ := NewChooser(
NewChoice('🍋', 0),
NewChoice('🍊', 0),
NewChoice('🍉', 0),
Expand All @@ -42,12 +42,52 @@ const (
testIterations = 1000000
)

func TestNewChooser(t *testing.T) {
tests := []struct {
name string
cs []Choice
wantErr error
}{
{
name: "zero choices",
cs: []Choice{},
wantErr: errNoValidChoices,
},
{
name: "no choices with positive weight",
cs: []Choice{{Item: 'a', Weight: 0}, {Item: 'b', Weight: 0}},
wantErr: errNoValidChoices,
},
{
name: "weight overflow",
cs: []Choice{{Item: 'a', Weight: maxInt/2 + 1}, {Item: 'b', Weight: maxInt/2 + 1}},
wantErr: errWeightOverflow,
},
{
name: "nominal case",
cs: []Choice{{Item: 'a', Weight: 1}, {Item: 'b', Weight: 2}},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewChooser(tt.cs...)
if err != tt.wantErr {
t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

// TestChooser_Pick assembles a list of Choices, weighted 0-9, and tests that
// over the course of 1,000,000 calls to Pick() each choice is returned more
// often than choices with a lower weight.
func TestChooser_Pick(t *testing.T) {
choices := mockFrequencyChoices(t, testChoices)
chooser := NewChooser(choices...)
chooser, err := NewChooser(choices...)
if err != nil {
t.Fatal(err)
}
t.Log("totals in chooser", chooser.totals)

// run Pick() a million times, and record how often it returns each of the
Expand All @@ -67,7 +107,10 @@ func TestChooser_Pick(t *testing.T) {
// randomness.
func TestChooser_PickSource(t *testing.T) {
choices := mockFrequencyChoices(t, testChoices)
chooser := NewChooser(choices...)
chooser, err := NewChooser(choices...)
if err != nil {
t.Fatal(err)
}
t.Log("totals in chooser", chooser.totals)

counts1 := make(map[int]int)
Expand Down Expand Up @@ -137,7 +180,7 @@ func BenchmarkNewChooser(b *testing.B) {
b.ResetTimer()

for i := 0; i < b.N; i++ {
_ = NewChooser(choices...)
_, _ = NewChooser(choices...)
}
})
}
Expand All @@ -147,7 +190,10 @@ func BenchmarkPick(b *testing.B) {
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
b.Run(strconv.Itoa(n), func(b *testing.B) {
choices := mockChoices(n)
chooser := NewChooser(choices...)
chooser, err := NewChooser(choices...)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()

for i := 0; i < b.N; i++ {
Expand All @@ -161,7 +207,10 @@ func BenchmarkPickParallel(b *testing.B) {
for n := BMMinChoices; n <= BMMaxChoices; n *= 10 {
b.Run(strconv.Itoa(n), func(b *testing.B) {
choices := mockChoices(n)
chooser := NewChooser(choices...)
chooser, err := NewChooser(choices...)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
rs := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
Expand Down

0 comments on commit 953da99

Please sign in to comment.