diff --git a/sort/heapsort.go b/sort/heapsort.go index 07a5d9547..741b4e4e8 100644 --- a/sort/heapsort.go +++ b/sort/heapsort.go @@ -1,21 +1,13 @@ package sort +import "github.com/TheAlgorithms/Go/constraints" + type MaxHeap struct { slice []Comparable heapSize int indices map[int]int } -func buildMaxHeap(slice0 []int) MaxHeap { - var slice []Comparable - for _, i := range slice0 { - slice = append(slice, Int(i)) - } - h := MaxHeap{} - h.Init(slice) - return h -} - func (h *MaxHeap) Init(slice []Comparable) { if slice == nil { slice = make([]Comparable, 0) @@ -73,6 +65,16 @@ func (h MaxHeap) updateidx(i int) { h.indices[h.slice[i].Idx()] = i } +func (h *MaxHeap) swap(i, j int) { + h.slice[i], h.slice[j] = h.slice[j], h.slice[i] + h.updateidx(i) + h.updateidx(j) +} + +func (h MaxHeap) more(i, j int) bool { + return h.slice[i].More(h.slice[j]) +} + func (h MaxHeap) heapifyUp(i int) { if i == 0 { return @@ -80,28 +82,29 @@ func (h MaxHeap) heapifyUp(i int) { p := i / 2 if h.slice[i].More(h.slice[p]) { - h.slice[i], h.slice[p] = h.slice[p], h.slice[i] - h.updateidx(i) - h.updateidx(p) + h.swap(i, p) h.heapifyUp(p) } } func (h MaxHeap) heapifyDown(i int) { + heapifyDown(h.slice, h.heapSize, i, h.more, h.swap) +} + +func heapifyDown[T any](slice []T, N, i int, moreFunc func(i, j int) bool, swapFunc func(i, j int)) { l, r := 2*i+1, 2*i+2 max := i - if l < h.heapSize && h.slice[l].More(h.slice[max]) { + if l < N && moreFunc(l, max) { max = l } - if r < h.heapSize && h.slice[r].More(h.slice[max]) { + if r < N && moreFunc(r, max) { max = r } if max != i { - h.slice[i], h.slice[max] = h.slice[max], h.slice[i] - h.updateidx(i) - h.updateidx(max) - h.heapifyDown(max) + swapFunc(i, max) + + heapifyDown(slice, N, max, moreFunc, swapFunc) } } @@ -109,26 +112,26 @@ type Comparable interface { Idx() int More(any) bool } -type Int int -func (a Int) More(b any) bool { - return a > b.(Int) -} -func (a Int) Idx() int { - return int(a) -} +func HeapSort[T constraints.Ordered](slice []T) []T { + N := len(slice) -func HeapSort(slice []int) []int { - h := buildMaxHeap(slice) - for i := len(h.slice) - 1; i >= 1; i-- { - h.slice[0], h.slice[i] = h.slice[i], h.slice[0] - h.heapSize-- - h.heapifyDown(0) + moreFunc := func(i, j int) bool { + return slice[i] > slice[j] + } + swapFunc := func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + } + + // build a maxheap + for i := N/2 - 1; i >= 0; i-- { + heapifyDown(slice, N, i, moreFunc, swapFunc) } - res := []int{} - for _, i := range h.slice { - res = append(res, int(i.(Int))) + for i := N - 1; i > 0; i-- { + slice[i], slice[0] = slice[0], slice[i] + heapifyDown(slice, i, 0, moreFunc, swapFunc) } - return res + + return slice } diff --git a/sort/sorts_test.go b/sort/sorts_test.go index 3c7e89caa..ecee3d33c 100644 --- a/sort/sorts_test.go +++ b/sort/sorts_test.go @@ -118,7 +118,7 @@ func TestMergeParallel(t *testing.T) { } func TestHeap(t *testing.T) { - testFramework(t, sort.HeapSort) + testFramework(t, sort.HeapSort[int]) } func TestCount(t *testing.T) { @@ -227,7 +227,7 @@ func BenchmarkMergeParallel(b *testing.B) { } func BenchmarkHeap(b *testing.B) { - benchmarkFramework(b, sort.HeapSort) + benchmarkFramework(b, sort.HeapSort[int]) } func BenchmarkCount(b *testing.B) {