Skip to content

Commit

Permalink
refactor heapsort to support generic (#553)
Browse files Browse the repository at this point in the history
* refactor: generic heapsort

* revert: remove generic test

* revert: max heap

* refactor: make max heap more generic

* fix: zero index

* revert: generic MaxHeap

* revert: heapifyDown
  • Loading branch information
phantomnat authored Oct 15, 2022
1 parent 03c8ce8 commit 6e6d4d7
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 38 deletions.
75 changes: 39 additions & 36 deletions sort/heapsort.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
package sort

import "github.com/TheAlgorithms/Go/constraints"

type MaxHeap struct {
slice []Comparable
heapSize int
indices map[int]int
}

func buildMaxHeap(slice0 []int) MaxHeap {
var slice []Comparable
for _, i := range slice0 {
slice = append(slice, Int(i))
}
h := MaxHeap{}
h.Init(slice)
return h
}

func (h *MaxHeap) Init(slice []Comparable) {
if slice == nil {
slice = make([]Comparable, 0)
Expand Down Expand Up @@ -73,62 +65,73 @@ func (h MaxHeap) updateidx(i int) {
h.indices[h.slice[i].Idx()] = i
}

func (h *MaxHeap) swap(i, j int) {
h.slice[i], h.slice[j] = h.slice[j], h.slice[i]
h.updateidx(i)
h.updateidx(j)
}

func (h MaxHeap) more(i, j int) bool {
return h.slice[i].More(h.slice[j])
}

func (h MaxHeap) heapifyUp(i int) {
if i == 0 {
return
}
p := i / 2

if h.slice[i].More(h.slice[p]) {
h.slice[i], h.slice[p] = h.slice[p], h.slice[i]
h.updateidx(i)
h.updateidx(p)
h.swap(i, p)
h.heapifyUp(p)
}
}

func (h MaxHeap) heapifyDown(i int) {
heapifyDown(h.slice, h.heapSize, i, h.more, h.swap)
}

func heapifyDown[T any](slice []T, N, i int, moreFunc func(i, j int) bool, swapFunc func(i, j int)) {
l, r := 2*i+1, 2*i+2
max := i

if l < h.heapSize && h.slice[l].More(h.slice[max]) {
if l < N && moreFunc(l, max) {
max = l
}
if r < h.heapSize && h.slice[r].More(h.slice[max]) {
if r < N && moreFunc(r, max) {
max = r
}
if max != i {
h.slice[i], h.slice[max] = h.slice[max], h.slice[i]
h.updateidx(i)
h.updateidx(max)
h.heapifyDown(max)
swapFunc(i, max)

heapifyDown(slice, N, max, moreFunc, swapFunc)
}
}

type Comparable interface {
Idx() int
More(any) bool
}
type Int int

func (a Int) More(b any) bool {
return a > b.(Int)
}
func (a Int) Idx() int {
return int(a)
}
func HeapSort[T constraints.Ordered](slice []T) []T {
N := len(slice)

func HeapSort(slice []int) []int {
h := buildMaxHeap(slice)
for i := len(h.slice) - 1; i >= 1; i-- {
h.slice[0], h.slice[i] = h.slice[i], h.slice[0]
h.heapSize--
h.heapifyDown(0)
moreFunc := func(i, j int) bool {
return slice[i] > slice[j]
}
swapFunc := func(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
}

// build a maxheap
for i := N/2 - 1; i >= 0; i-- {
heapifyDown(slice, N, i, moreFunc, swapFunc)
}

res := []int{}
for _, i := range h.slice {
res = append(res, int(i.(Int)))
for i := N - 1; i > 0; i-- {
slice[i], slice[0] = slice[0], slice[i]
heapifyDown(slice, i, 0, moreFunc, swapFunc)
}
return res

return slice
}
4 changes: 2 additions & 2 deletions sort/sorts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestMergeParallel(t *testing.T) {
}

func TestHeap(t *testing.T) {
testFramework(t, sort.HeapSort)
testFramework(t, sort.HeapSort[int])
}

func TestCount(t *testing.T) {
Expand Down Expand Up @@ -227,7 +227,7 @@ func BenchmarkMergeParallel(b *testing.B) {
}

func BenchmarkHeap(b *testing.B) {
benchmarkFramework(b, sort.HeapSort)
benchmarkFramework(b, sort.HeapSort[int])
}

func BenchmarkCount(b *testing.B) {
Expand Down

0 comments on commit 6e6d4d7

Please sign in to comment.