diff --git a/testdata/fuzz/FuzzNewChooser/a547669aeb7ca0ca b/testdata/fuzz/FuzzNewChooser/a547669aeb7ca0ca new file mode 100644 index 0000000..e4a5f9c --- /dev/null +++ b/testdata/fuzz/FuzzNewChooser/a547669aeb7ca0ca @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("0000000\x8f00000000") diff --git a/weightedrand.go b/weightedrand.go index 3320121..52fd89b 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -48,11 +48,16 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro totals := make([]int, len(choices)) runningTotal := 0 for i, c := range choices { - weight := int(c.Weight) - if weight < 0 { + if c.Weight < 0 { continue // ignore negative weights, can never be picked } + // case of single ~uint64 or similar value that exceeds maxInt on its own + if uint64(c.Weight) >= maxInt { + return nil, errWeightOverflow + } + + weight := int(c.Weight) // convert weight to int for internal counter usage if (maxInt - runningTotal) <= weight { return nil, errWeightOverflow } @@ -68,8 +73,9 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro } const ( - intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize - maxInt = 1<<(intSize-1) - 1 + intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize + maxInt = 1<<(intSize-1) - 1 + maxUint64 = 1<<64 - 1 ) // Possible errors returned by NewChooser, preventing the creation of a Chooser diff --git a/weightedrand_test.go b/weightedrand_test.go index 824eeb0..0eb4206 100644 --- a/weightedrand_test.go +++ b/weightedrand_test.go @@ -81,10 +81,42 @@ func TestNewChooser(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewChooser(tt.cs...) + c, err := NewChooser(tt.cs...) if err != tt.wantErr { t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr) } + + if err == nil { // run a few Picks to make sure there are no panics + for i := 0; i < 10; i++ { + _ = c.Pick() + } + } + }) + } + + u64tests := []struct { + name string + cs []Choice[rune, uint64] + wantErr error + }{ + { + name: "weight overflow from single uint64 exceeding system maxInt", + cs: []Choice[rune, uint64]{{Item: 'a', Weight: maxInt + 1}}, + wantErr: errWeightOverflow, + }, + } + for _, tt := range u64tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewChooser(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewChooser() error = %v, wantErr %v", err, tt.wantErr) + } + + if err == nil { // run a few Picks to make sure there are no panics + for i := 0; i < 10; i++ { + _ = c.Pick() + } + } }) } }