diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 0000000..326b3d1 --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,100 @@ +//go:build go1.18 +// +build go1.18 + +package weightedrand + +import ( + "encoding/binary" + "errors" + "fmt" + "reflect" + "testing" +) + +// Fuzz testing does not support slices as a corpus type in go1.18, thus we +// write a bunch of boilerplate here to allow us to encode []uint64 as []byte +// for kicks. + +func bEncodeSlice(xs []uint64) []byte { + bs := make([]byte, len(xs)*8) + for i, x := range xs { + n := i * 8 + binary.LittleEndian.PutUint64(bs[n:], x) + } + return bs +} + +func bDecodeSlice(bs []byte) []uint64 { + n := len(bs) / 8 + xs := make([]uint64, 0, n) + for i := 0; i < n; i++ { + x := binary.LittleEndian.Uint64(bs[8*i:]) + xs = append(xs, x) + } + return xs +} + +// test our own encoder to make sure we didn't introduce errors. +func Test_bEncodeSlice(t *testing.T) { + var testcases = [][]uint64{ + {}, + {1}, + {42}, + {912346}, + {1, 2}, + {1, 1, 1}, + {1, 2, 3}, + {1, 1000000}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + for _, tc := range testcases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + before := tc + encoded := bEncodeSlice(before) + if want, got := len(before)*8, len(encoded); want != got { + t.Errorf("encoded length not as expected: want %d got %d", want, got) + } + decoded := bDecodeSlice(encoded) + if !reflect.DeepEqual(before, decoded) { + t.Errorf("want %v got %v", before, decoded) + } + }) + } +} + +func FuzzNewChooser(f *testing.F) { + var fuzzcases = [][]uint64{ + {}, + {0}, + {1}, + {1, 1}, + {1, 2, 3}, + {0, 1, 2}, + } + for _, tc := range fuzzcases { + f.Add(bEncodeSlice(tc)) + } + + f.Fuzz(func(t *testing.T, encodedWeights []byte) { + weights := bDecodeSlice(encodedWeights) + const sentinel = 1 + + cs := make([]Choice[int, uint64], 0, len(weights)) + for _, w := range weights { + cs = append(cs, Choice[int, uint64]{Item: sentinel, Weight: w}) + } + + // fuzz for error or panic on NewChooser + c, err := NewChooser(cs...) + if err != nil && !errors.Is(err, errNoValidChoices) && !errors.Is(err, errWeightOverflow) { + t.Fatal(err) + } + + if err == nil { + result := c.Pick() // fuzz for panic on Panic + if result != sentinel { // fuzz for returned value unexpected (just use same non-zero sentinel value for all choices) + t.Fatalf("expected %v got %v", sentinel, result) + } + } + }) +} diff --git a/testdata/fuzz/FuzzNewChooser/a547669aeb7ca0ca b/testdata/fuzz/FuzzNewChooser/a547669aeb7ca0ca new file mode 100644 index 0000000..e4a5f9c --- /dev/null +++ b/testdata/fuzz/FuzzNewChooser/a547669aeb7ca0ca @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("0000000\x8f00000000") diff --git a/weightedrand.go b/weightedrand.go index 3320121..52fd89b 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -48,11 +48,16 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro totals := make([]int, len(choices)) runningTotal := 0 for i, c := range choices { - weight := int(c.Weight) - if weight < 0 { + if c.Weight < 0 { continue // ignore negative weights, can never be picked } + // case of single ~uint64 or similar value that exceeds maxInt on its own + if uint64(c.Weight) >= maxInt { + return nil, errWeightOverflow + } + + weight := int(c.Weight) // convert weight to int for internal counter usage if (maxInt - runningTotal) <= weight { return nil, errWeightOverflow } @@ -68,8 +73,9 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro } const ( - intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize - maxInt = 1<<(intSize-1) - 1 + intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize + maxInt = 1<<(intSize-1) - 1 + maxUint64 = 1<<64 - 1 ) // Possible errors returned by NewChooser, preventing the creation of a Chooser diff --git a/weightedrand_test.go b/weightedrand_test.go index 824eeb0..0eb4206 100644 --- a/weightedrand_test.go +++ b/weightedrand_test.go @@ -81,10 +81,42 @@ func TestNewChooser(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewChooser(tt.cs...) + c, err := NewChooser(tt.cs...) if err != tt.wantErr { t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr) } + + if err == nil { // run a few Picks to make sure there are no panics + for i := 0; i < 10; i++ { + _ = c.Pick() + } + } + }) + } + + u64tests := []struct { + name string + cs []Choice[rune, uint64] + wantErr error + }{ + { + name: "weight overflow from single uint64 exceeding system maxInt", + cs: []Choice[rune, uint64]{{Item: 'a', Weight: maxInt + 1}}, + wantErr: errWeightOverflow, + }, + } + for _, tt := range u64tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewChooser(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr) + } + + if err == nil { // run a few Picks to make sure there are no panics + for i := 0; i < 10; i++ { + _ = c.Pick() + } + } }) } }