Skip to content

Commit

Permalink
feat: add GetWaitingPod support
Browse files Browse the repository at this point in the history
  • Loading branch information
chansuke committed Sep 14, 2024
1 parent 705f97d commit 4e1ee21
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 4 deletions.
13 changes: 13 additions & 0 deletions guest/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,16 @@ type NodeScore interface {
// which is keyed by the node name and valued by the score.
Map() map[string]int
}

// WaitingPod represents a pod currently waiting in the permit phase.
type WaitingPod interface {
GetPod() proto.Pod
// GetPendingPlugins returns a list of pending Permit plugin's name.
GetPendingPlugins() []string
// Allow declares the waiting pod is allowed to be scheduled by the plugin named as "pluginName".
// If this is the last remaining plugin to allow, then a success signal is delivered
// to unblock the pod.
Allow(pluginName string)
// Reject declares the waiting pod unschedulable.
Reject(pluginName, msg string)
}
14 changes: 14 additions & 0 deletions guest/handle/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package handle
import (
"runtime"

"sigs.k8s.io/kube-scheduler-wasm-extension/guest/api"
"sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/mem"
)

Expand All @@ -34,3 +35,16 @@ func RejectWaitingPod(uid string) bool {
runtime.KeepAlive(uid)
return wasmBool == 1
}

func GetWaitingPod(uid string) api.WaitingPod {
ptr, size := mem.StringToPtr(uid)

// Wrap to avoid TinyGo 0.28: cannot use an exported function as value
mem.SendAndGetString(ptr, size, func(input_ptr, input_size, ptr uint32, limit mem.BufLimit) {
getWaitingPod(input_ptr, input_size, ptr, limit)
})
runtime.KeepAlive(uid)

waitingPod := make([]api.WaitingPod, size)
return waitingPod[0]
}
3 changes: 3 additions & 0 deletions guest/handle/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ import "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/mem"

//go:wasmimport k8s.io/scheduler handle.reject_waiting_pod
func rejectWaitingPod(input_ptr, input_size, ptr uint32, limit mem.BufLimit)

//go:wasmimport k8s.io/scheduler handle.get_waiting_pod
func getWaitingPod(input_ptr, input_size, ptr uint32, limit mem.BufLimit)
3 changes: 3 additions & 0 deletions guest/handle/imports_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ import "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/mem"

// rejectWaitingPod is stubbed for compilation outside TinyGo.
func rejectWaitingPod(uint32, uint32, uint32, mem.BufLimit) {}

// getWaitingPod is stubbed for compilation outside TinyGo.
func getWaitingPod(uint32, uint32, uint32, mem.BufLimit) { return }
6 changes: 6 additions & 0 deletions guest/internal/mem/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,9 @@ func SendAndGetUint64(input_ptr uint32, input_size uint32, fn func(input_ptr, in
fn(input_ptr, input_size, uint32(readBufPtr), readBufLimit)
return binary.LittleEndian.Uint64(readBuf)
}

func SendAndGetString(input_ptr uint32, input_size uint32, fn func(input_ptr, input_size, ptr uint32, limit BufLimit)) string {
fn(input_ptr, input_size, uint32(readBufPtr), readBufLimit)
size := binary.LittleEndian.Uint32(readBuf)
return string(readBuf[size : size+binary.LittleEndian.Uint32(readBuf[size:])])
}
16 changes: 16 additions & 0 deletions guest/testdata/handle/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func main() {
switch os.Args[1] {
case "rejectWaitingPod":
plugin = pluginForReject{}
case "getWaitingPod":
plugin = pluginForGet{}
}
}
prefilter.SetPlugin(plugin)
Expand Down Expand Up @@ -82,3 +84,17 @@ func (pluginForReject) Filter(_ api.CycleState, pod proto.Pod, nodeInfo api.Node
Code: api.StatusCodeSuccess,
}
}

// pluginForGet checks the function of GetWaitingPod
type pluginForGet struct{ noopPlugin }

func (pluginForGet) Filter(_ api.CycleState, pod proto.Pod, nodeInfo api.NodeInfo) *api.Status {
// Call GetWaitingPod first
waitingPod := handle.GetWaitingPod(pod.GetUid())

// This is being skipped, note the reason.
return &api.Status{
Code: api.StatusCodeSkip,
Reason: "UID is " + pod.GetUid() + " and waitingPod is " + waitingPod.Name,
}
}
29 changes: 29 additions & 0 deletions scheduler/plugin/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ const (
k8sSchedulerResultNormalizedScoreList = "result.normalized_score_list"
k8sSchedulerHandleEventRecorderEventf = "handle.eventrecorder.eventf"
k8sSchedulerHandleRejectWaitingPod = "handle.reject_waiting_pod"
k8sSchedulerHandleGetWaitingPod = "handle.get_waiting_pod"
)

func instantiateHostApi(ctx context.Context, runtime wazero.Runtime) (wazeroapi.Module, error) {
Expand Down Expand Up @@ -122,6 +123,9 @@ func instantiateHostScheduler(ctx context.Context, runtime wazero.Runtime, guest
NewFunctionBuilder().
WithGoModuleFunction(wazeroapi.GoModuleFunc(host.k8sHandleRejectWaitingPodFn), []wazeroapi.ValueType{i32, i32, i32, i32}, []wazeroapi.ValueType{}).
WithParameterNames("buf", "buf_len").Export(k8sSchedulerHandleRejectWaitingPod).
NewFunctionBuilder().
WithGoModuleFunction(wazeroapi.GoModuleFunc(host.k8sHandleGetWaitingPodFn), []wazeroapi.ValueType{i32, i32, i32, i32}, []wazeroapi.ValueType{}).
WithParameterNames("buf", "buf_len").Export(k8sSchedulerHandleGetWaitingPod).
Instantiate(ctx)
}

Expand Down Expand Up @@ -534,3 +538,28 @@ func (h host) k8sHandleRejectWaitingPodFn(ctx context.Context, mod wazeroapi.Mod
}
writeUint64(mod.Memory(), wasmBool, oBuf, oBufLimit)
}

func (h host) k8sHandleGetWaitingPodFn(ctx context.Context, mod wazeroapi.Module, stack []uint64) {
iBuf := uint32(stack[0])
iBufLen := uint32(stack[1])
oBuf := uint32(stack[2])
oBufLimit := uint32(stack[3])

b, ok := mod.Memory().Read(iBuf, iBufLen)
if !ok {
panic("out of memory reading getWaitingPod")
}
uid := types.UID(b)
waitingPod := h.handle.GetWaitingPod(uid)
println("waitingPod: ", waitingPod)
if waitingPod == nil {
print("waitingPod is nil")
stack[0] = 0 // Return 0 to indicate no pod found or an error
return
}

print("waitingPod!!!!!!: ", waitingPod)
print("waitingPod.GetPod() : ", waitingPod.GetPod())

stack[0] = uint64(marshalIfUnderLimit(mod.Memory(), waitingPod.GetPod(), oBuf, oBufLimit))
}
28 changes: 28 additions & 0 deletions scheduler/plugin/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,31 @@ func Test_k8sHandleRejectWaitingPodFn(t *testing.T) {
t.Fatalf("unexpected uid: %v != %v", want, have)
}
}

func Test_k8sHandleGetWaitingPodFn(t *testing.T) {
recorder := &test.FakeRecorder{EventMsg: ""}
handle := &test.FakeHandle{Recorder: recorder}
h := host{handle: handle}

// Create a fake wasm module, which has data the guest should write.
mem := wazerotest.NewMemory(wazerotest.PageSize)
mod := wazerotest.NewModule(mem)
uid := types.UID("c6feae3a-7082-42a5-a5ec-6ae2e1603727")
copy(mem.Bytes, uid)

// Invoke the host function in the same way the guest would have.
h.k8sHandleGetWaitingPodFn(context.Background(), mod, []uint64{
0,
uint64(len(uid)),
0,
0,
})

// Checking the value stored on handle
have := handle.GetWaitingPodValue
want := uid

if want != have {
t.Fatalf("unexpected uid: %v != %v", want, have)
}
}
53 changes: 53 additions & 0 deletions scheduler/plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,59 @@ func TestRejectWaitingPod(t *testing.T) {
}
}

// This test checks whether framework.handle.GetWaitingPod can be called within wasm file.
func TestGetWaitingPod(t *testing.T) {
tests := []struct {
name string
guestURL string
pod *v1.Pod
args []string
expectedUID types.UID
expectedStatusCode framework.Code
expectedStatusMsg string
}{
{
name: "Pod is not returned",
guestURL: test.URLTestHandle,
pod: test.PodSmall,
args: []string{"test", "getWaitingPod"},
expectedUID: "",
expectedStatusCode: framework.Success,
expectedStatusMsg: "",
},
{
name: "Pod is returned",
guestURL: test.URLTestHandle,
pod: test.PodForHandleTest,
args: []string{"test", "getWaitingPod"},
expectedUID: test.PodForHandleTest.GetUID(),
expectedStatusCode: framework.Success,
expectedStatusMsg: "",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
guestURL := tc.guestURL
recorder := &test.FakeRecorder{EventMsg: ""}
handle := &test.FakeHandle{Recorder: recorder}
p, err := wasm.NewFromConfig(ctx, "wasm", wasm.PluginConfig{GuestURL: guestURL, Args: tc.args}, handle)
if err != nil {
t.Fatal(err)
}
defer p.(io.Closer).Close()

status := p.(framework.FilterPlugin).Filter(ctx, nil, tc.pod, nil)
if want, have := tc.expectedStatusCode, status.Code(); want != have {
t.Fatalf("unexpected status code: want %v, have %v", want, have)
}
if want, have := tc.expectedStatusMsg, status.Message(); want != have {
t.Fatalf("unexpected status message: want %v, have %v", want, have)
}
})
}
}

// Extracts and trims the actual log message from a formatted klog string
// (klog includes timestamp before actual log message)
func extractMessage(log string) string {
Expand Down
10 changes: 6 additions & 4 deletions scheduler/test/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func (f *FakeRecorder) Eventf(regarding runtime.Object, related runtime.Object,
type FakeHandle struct {
Recorder events.EventRecorder
RejectWaitingPodValue types.UID
GetWaitingPodValue types.UID
}

func (h *FakeHandle) EventRecorder() events.EventRecorder {
Expand Down Expand Up @@ -67,17 +68,18 @@ func (h *FakeHandle) Parallelizer() (p parallelize.Parallelizer) {
return
}

func (h *FakeHandle) GetWaitingPod(uid types.UID) (w framework.WaitingPod) {
return
}

func (h *FakeHandle) IterateOverWaitingPods(callback func(framework.WaitingPod)) {
}

func (h *FakeHandle) NominatedPodsForNode(nodeName string) (f []*framework.PodInfo) {
return
}

func (h *FakeHandle) GetWaitingPod(uid types.UID) (w framework.WaitingPod) {
h.GetWaitingPodValue = uid
return
}

func (h *FakeHandle) RejectWaitingPod(uid types.UID) (b bool) {
h.RejectWaitingPodValue = uid
return uid == types.UID("handle-test")
Expand Down

0 comments on commit 4e1ee21

Please sign in to comment.