From 950f3017e054e0be90d7b34058d2d7f33fb83ddf Mon Sep 17 00:00:00 2001 From: Matthew Rothenberg Date: Tue, 11 Dec 2018 16:28:00 -0500 Subject: [PATCH] prevent negative weights in Choice These should break things (by design), so let's just guard against it at the API interface via type system. Things are kept internally as ints because most golang stdlib functions expect that, so we can avoid casting everywhere. --- weightedrand.go | 4 ++-- weightedrand_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/weightedrand.go b/weightedrand.go index efff032..f113f3d 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -18,7 +18,7 @@ import ( // Choice is a generic wrapper that can be used to add weights for any object type Choice struct { Item interface{} - Weight int + Weight uint } // A Chooser caches many possible Choices in a structure designed to improve @@ -38,7 +38,7 @@ func NewChooser(cs ...Choice) Chooser { totals := make([]int, n, n) runningTotal := 0 for i, c := range cs { - runningTotal += c.Weight + runningTotal += int(c.Weight) totals[i] = runningTotal } return Chooser{data: cs, totals: totals, max: runningTotal} diff --git a/weightedrand_test.go b/weightedrand_test.go index 9d98543..1e8038e 100644 --- a/weightedrand_test.go +++ b/weightedrand_test.go @@ -16,7 +16,7 @@ func mockChoices(n int) []Choice { for i := 0; i < n; i++ { s := "⚽️" w := rand.Intn(10) - c := Choice{Item: s, Weight: w} + c := Choice{Item: s, Weight: uint(w)} choices = append(choices, c) } return choices @@ -35,7 +35,7 @@ func TestWeightedChoice(t *testing.T) { presorted data. */ list := rand.Perm(10) for _, v := range list { - c := Choice{Weight: v, Item: v} + c := Choice{Weight: uint(v), Item: v} choices = append(choices, c) } t.Log("FYI mocked choices of", choices) @@ -59,8 +59,8 @@ func TestWeightedChoice(t *testing.T) { for i, c := range choices[0 : len(choices)-1] { next := choices[i+1] cw, nw := c.Weight, next.Weight - if !(chosenCount[cw] < chosenCount[nw]) { - t.Error("Value not lesser", cw, nw, chosenCount[cw], chosenCount[nw]) + if !(chosenCount[int(cw)] < chosenCount[int(nw)]) { + t.Error("Value not lesser", cw, nw, chosenCount[int(cw)], chosenCount[int(nw)]) } }