diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 7ee5e55..7604d39 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -15,9 +15,48 @@ package nvml import ( + "fmt" + "reflect" "unsafe" ) +// nvmlDeviceHandle attempts to convert a device d to an nvmlDevice. +// This is required for functions such as GetTopologyCommonAncestor which +// accept Device arguments that need to be passed to internal nvml* functions +// as nvmlDevice parameters. +func nvmlDeviceHandle(d Device) nvmlDevice { + var helper func(val reflect.Value) nvmlDevice + helper = func(val reflect.Value) nvmlDevice { + if val.Kind() == reflect.Interface { + val = val.Elem() + } + + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Type() == reflect.TypeOf(nvmlDevice{}) { + return val.Interface().(nvmlDevice) + } + + if val.Kind() != reflect.Struct { + panic(fmt.Errorf("unable to convert non-struct type %v to nvmlDevice", val.Kind())) + } + + for i := 0; i < val.Type().NumField(); i++ { + if !val.Type().Field(i).Anonymous { + continue + } + if !val.Field(i).Type().Implements(reflect.TypeOf((*Device)(nil)).Elem()) { + continue + } + return helper(val.Field(i)) + } + panic(fmt.Errorf("unable to convert %T to nvmlDevice", d)) + } + return helper(reflect.ValueOf(d)) +} + // EccBitType type EccBitType = MemoryErrorType @@ -220,10 +259,13 @@ func (l *library) DeviceGetTopologyCommonAncestor(device1 Device, device2 Device func (device1 nvmlDevice) GetTopologyCommonAncestor(device2 Device) (GpuTopologyLevel, Return) { var pathInfo GpuTopologyLevel - ret := nvmlDeviceGetTopologyCommonAncestor(device1, device2.(nvmlDevice), &pathInfo) + ret := nvmlDeviceGetTopologyCommonAncestorStub(device1, nvmlDeviceHandle(device2), &pathInfo) return pathInfo, ret } +// nvmlDeviceGetTopologyCommonAncestorStub allows us to override this for testing. +var nvmlDeviceGetTopologyCommonAncestorStub = nvmlDeviceGetTopologyCommonAncestor + // nvml.DeviceGetTopologyNearestGpus() func (l *library) DeviceGetTopologyNearestGpus(device Device, level GpuTopologyLevel) ([]Device, Return) { return device.GetTopologyNearestGpus(level) @@ -250,7 +292,7 @@ func (l *library) DeviceGetP2PStatus(device1 Device, device2 Device, p2pIndex Gp func (device1 nvmlDevice) GetP2PStatus(device2 Device, p2pIndex GpuP2PCapsIndex) (GpuP2PStatus, Return) { var p2pStatus GpuP2PStatus - ret := nvmlDeviceGetP2PStatus(device1, device2.(nvmlDevice), p2pIndex, &p2pStatus) + ret := nvmlDeviceGetP2PStatus(device1, nvmlDeviceHandle(device2), p2pIndex, &p2pStatus) return p2pStatus, ret } @@ -1182,7 +1224,7 @@ func (l *library) DeviceOnSameBoard(device1 Device, device2 Device) (int, Return func (device1 nvmlDevice) OnSameBoard(device2 Device) (int, Return) { var onSameBoard int32 - ret := nvmlDeviceOnSameBoard(device1, device2.(nvmlDevice), &onSameBoard) + ret := nvmlDeviceOnSameBoard(device1, nvmlDeviceHandle(device2), &onSameBoard) return int(onSameBoard), ret } diff --git a/pkg/nvml/device_test.go b/pkg/nvml/device_test.go new file mode 100644 index 0000000..e92d73c --- /dev/null +++ b/pkg/nvml/device_test.go @@ -0,0 +1,99 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvml + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetTopologyCommonAncestor(t *testing.T) { + type wrappedDevice struct { + Device + } + + type wrappedWrappedDevice struct { + wrappedDevice + } + + testCases := []struct { + description string + device Device + }{ + { + description: "nvmlDevice", + device: nvmlDevice{}, + }, + { + description: "pointer to nvmlDevice", + device: &nvmlDevice{}, + }, + { + description: "wrapped device", + device: wrappedDevice{ + Device: nvmlDevice{}, + }, + }, + { + description: "pointer to wrapped device", + device: &wrappedDevice{ + Device: nvmlDevice{}, + }, + }, + { + description: "nested wrapped device", + device: wrappedWrappedDevice{ + wrappedDevice: wrappedDevice{ + Device: nvmlDevice{}, + }, + }, + }, + { + description: "non-device fields included", + device: struct { + name string + Name string + Device + }{ + Device: wrappedDevice{ + Device: nvmlDevice{}, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + defer setNvmlDeviceGetTopologyCommonAncestorStubForTest(SUCCESS)() + + _, ret := nvmlDevice{}.GetTopologyCommonAncestor(tc.device) + require.Equal(t, SUCCESS, ret) + }) + } +} + +func setNvmlDeviceGetTopologyCommonAncestorStubForTest(ret Return) func() { + original := nvmlDeviceGetTopologyCommonAncestorStub + + nvmlDeviceGetTopologyCommonAncestorStub = func(Device1, Device2 nvmlDevice, PathInfo *GpuTopologyLevel) Return { + return ret + } + return func() { + nvmlDeviceGetTopologyCommonAncestorStub = original + } +} diff --git a/pkg/nvml/gpm.go b/pkg/nvml/gpm.go index 783514b..acdb2e0 100644 --- a/pkg/nvml/gpm.go +++ b/pkg/nvml/gpm.go @@ -93,7 +93,7 @@ func (device nvmlDevice) GpmSampleGet(gpmSample GpmSample) Return { } func (gpmSample nvmlGpmSample) Get(device Device) Return { - return nvmlGpmSampleGet(device.(nvmlDevice), gpmSample) + return nvmlGpmSampleGet(nvmlDeviceHandle(device), gpmSample) } // nvml.GpmQueryDeviceSupport() @@ -137,5 +137,5 @@ func (device nvmlDevice) GpmMigSampleGet(gpuInstanceId int, gpmSample GpmSample) } func (gpmSample nvmlGpmSample) MigGet(device Device, gpuInstanceId int) Return { - return nvmlGpmMigSampleGet(device.(nvmlDevice), uint32(gpuInstanceId), gpmSample) + return nvmlGpmMigSampleGet(nvmlDeviceHandle(device), uint32(gpuInstanceId), gpmSample) } diff --git a/pkg/nvml/vgpu.go b/pkg/nvml/vgpu.go index bd80077..da49524 100644 --- a/pkg/nvml/vgpu.go +++ b/pkg/nvml/vgpu.go @@ -142,7 +142,7 @@ func (device nvmlDevice) VgpuTypeGetMaxInstances(vgpuTypeId VgpuTypeId) (int, Re func (vgpuTypeId nvmlVgpuTypeId) GetMaxInstances(device Device) (int, Return) { var vgpuInstanceCount uint32 - ret := nvmlVgpuTypeGetMaxInstances(device.(nvmlDevice), vgpuTypeId, &vgpuInstanceCount) + ret := nvmlVgpuTypeGetMaxInstances(nvmlDeviceHandle(device), vgpuTypeId, &vgpuInstanceCount) return int(vgpuInstanceCount), ret }