Skip to content

Commit

Permalink
Merge pull request #119 from elezar/fix-topology-common-ancestor
Browse files Browse the repository at this point in the history
Add internal method to get device handle
  • Loading branch information
elezar authored May 14, 2024
2 parents ca3eca7 + 339dc1a commit a94486b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 6 deletions.
48 changes: 45 additions & 3 deletions pkg/nvml/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,48 @@
package nvml

import (
"fmt"
"reflect"
"unsafe"
)

// nvmlDeviceHandle attempts to convert a device d to an nvmlDevice.
// This is required for functions such as GetTopologyCommonAncestor which
// accept Device arguments that need to be passed to internal nvml* functions
// as nvmlDevice parameters.
func nvmlDeviceHandle(d Device) nvmlDevice {
var helper func(val reflect.Value) nvmlDevice
helper = func(val reflect.Value) nvmlDevice {
if val.Kind() == reflect.Interface {
val = val.Elem()
}

if val.Kind() == reflect.Ptr {
val = val.Elem()
}

if val.Type() == reflect.TypeOf(nvmlDevice{}) {
return val.Interface().(nvmlDevice)
}

if val.Kind() != reflect.Struct {
panic(fmt.Errorf("unable to convert non-struct type %v to nvmlDevice", val.Kind()))
}

for i := 0; i < val.Type().NumField(); i++ {
if !val.Type().Field(i).Anonymous {
continue
}
if !val.Field(i).Type().Implements(reflect.TypeOf((*Device)(nil)).Elem()) {
continue
}
return helper(val.Field(i))
}
panic(fmt.Errorf("unable to convert %T to nvmlDevice", d))
}
return helper(reflect.ValueOf(d))
}

// EccBitType
type EccBitType = MemoryErrorType

Expand Down Expand Up @@ -220,10 +259,13 @@ func (l *library) DeviceGetTopologyCommonAncestor(device1 Device, device2 Device

func (device1 nvmlDevice) GetTopologyCommonAncestor(device2 Device) (GpuTopologyLevel, Return) {
var pathInfo GpuTopologyLevel
ret := nvmlDeviceGetTopologyCommonAncestor(device1, device2.(nvmlDevice), &pathInfo)
ret := nvmlDeviceGetTopologyCommonAncestorStub(device1, nvmlDeviceHandle(device2), &pathInfo)
return pathInfo, ret
}

// nvmlDeviceGetTopologyCommonAncestorStub allows us to override this for testing.
var nvmlDeviceGetTopologyCommonAncestorStub = nvmlDeviceGetTopologyCommonAncestor

// nvml.DeviceGetTopologyNearestGpus()
func (l *library) DeviceGetTopologyNearestGpus(device Device, level GpuTopologyLevel) ([]Device, Return) {
return device.GetTopologyNearestGpus(level)
Expand All @@ -250,7 +292,7 @@ func (l *library) DeviceGetP2PStatus(device1 Device, device2 Device, p2pIndex Gp

func (device1 nvmlDevice) GetP2PStatus(device2 Device, p2pIndex GpuP2PCapsIndex) (GpuP2PStatus, Return) {
var p2pStatus GpuP2PStatus
ret := nvmlDeviceGetP2PStatus(device1, device2.(nvmlDevice), p2pIndex, &p2pStatus)
ret := nvmlDeviceGetP2PStatus(device1, nvmlDeviceHandle(device2), p2pIndex, &p2pStatus)
return p2pStatus, ret
}

Expand Down Expand Up @@ -1182,7 +1224,7 @@ func (l *library) DeviceOnSameBoard(device1 Device, device2 Device) (int, Return

func (device1 nvmlDevice) OnSameBoard(device2 Device) (int, Return) {
var onSameBoard int32
ret := nvmlDeviceOnSameBoard(device1, device2.(nvmlDevice), &onSameBoard)
ret := nvmlDeviceOnSameBoard(device1, nvmlDeviceHandle(device2), &onSameBoard)
return int(onSameBoard), ret
}

Expand Down
99 changes: 99 additions & 0 deletions pkg/nvml/device_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvml

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGetTopologyCommonAncestor(t *testing.T) {
type wrappedDevice struct {
Device
}

type wrappedWrappedDevice struct {
wrappedDevice
}

testCases := []struct {
description string
device Device
}{
{
description: "nvmlDevice",
device: nvmlDevice{},
},
{
description: "pointer to nvmlDevice",
device: &nvmlDevice{},
},
{
description: "wrapped device",
device: wrappedDevice{
Device: nvmlDevice{},
},
},
{
description: "pointer to wrapped device",
device: &wrappedDevice{
Device: nvmlDevice{},
},
},
{
description: "nested wrapped device",
device: wrappedWrappedDevice{
wrappedDevice: wrappedDevice{
Device: nvmlDevice{},
},
},
},
{
description: "non-device fields included",
device: struct {
name string
Name string
Device
}{
Device: wrappedDevice{
Device: nvmlDevice{},
},
},
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
defer setNvmlDeviceGetTopologyCommonAncestorStubForTest(SUCCESS)()

_, ret := nvmlDevice{}.GetTopologyCommonAncestor(tc.device)
require.Equal(t, SUCCESS, ret)
})
}
}

func setNvmlDeviceGetTopologyCommonAncestorStubForTest(ret Return) func() {
original := nvmlDeviceGetTopologyCommonAncestorStub

nvmlDeviceGetTopologyCommonAncestorStub = func(Device1, Device2 nvmlDevice, PathInfo *GpuTopologyLevel) Return {
return ret
}
return func() {
nvmlDeviceGetTopologyCommonAncestorStub = original
}
}
4 changes: 2 additions & 2 deletions pkg/nvml/gpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (device nvmlDevice) GpmSampleGet(gpmSample GpmSample) Return {
}

func (gpmSample nvmlGpmSample) Get(device Device) Return {
return nvmlGpmSampleGet(device.(nvmlDevice), gpmSample)
return nvmlGpmSampleGet(nvmlDeviceHandle(device), gpmSample)
}

// nvml.GpmQueryDeviceSupport()
Expand Down Expand Up @@ -137,5 +137,5 @@ func (device nvmlDevice) GpmMigSampleGet(gpuInstanceId int, gpmSample GpmSample)
}

func (gpmSample nvmlGpmSample) MigGet(device Device, gpuInstanceId int) Return {
return nvmlGpmMigSampleGet(device.(nvmlDevice), uint32(gpuInstanceId), gpmSample)
return nvmlGpmMigSampleGet(nvmlDeviceHandle(device), uint32(gpuInstanceId), gpmSample)
}
2 changes: 1 addition & 1 deletion pkg/nvml/vgpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (device nvmlDevice) VgpuTypeGetMaxInstances(vgpuTypeId VgpuTypeId) (int, Re

func (vgpuTypeId nvmlVgpuTypeId) GetMaxInstances(device Device) (int, Return) {
var vgpuInstanceCount uint32
ret := nvmlVgpuTypeGetMaxInstances(device.(nvmlDevice), vgpuTypeId, &vgpuInstanceCount)
ret := nvmlVgpuTypeGetMaxInstances(nvmlDeviceHandle(device), vgpuTypeId, &vgpuInstanceCount)
return int(vgpuInstanceCount), ret
}

Expand Down

0 comments on commit a94486b

Please sign in to comment.