Skip to content

Commit

Permalink
Merge pull request #113 from klueska/fix-bug-load-close
Browse files Browse the repository at this point in the history
Fix bug with wrong instance of lib being called for load/close
  • Loading branch information
klueska authored Apr 23, 2024
2 parents a97d07c + 73ee14c commit c7513c3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
6 changes: 3 additions & 3 deletions pkg/nvml/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import "C"

// nvml.Init()
func (l *library) Init() Return {
if err := libnvml.load(); err != nil {
if err := l.load(); err != nil {
return ERROR_LIBRARY_NOT_FOUND
}
return nvmlInit()
}

// nvml.InitWithFlags()
func (l *library) InitWithFlags(flags uint32) Return {
if err := libnvml.load(); err != nil {
if err := l.load(); err != nil {
return ERROR_LIBRARY_NOT_FOUND
}
return nvmlInitWithFlags(flags)
Expand All @@ -39,7 +39,7 @@ func (l *library) Shutdown() Return {
return ret
}

err := libnvml.close()
err := l.close()
if err != nil {
return ERROR_UNKNOWN
}
Expand Down
44 changes: 22 additions & 22 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,93 +198,93 @@ func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo {
// When new versioned symbols are added, these would have to be initialized above and have
// corresponding checks and subsequent assignments added below.
func (l *library) updateVersionedSymbols() {
err := l.LookupSymbol("nvmlInit_v2")
err := l.dl.Lookup("nvmlInit_v2")
if err == nil {
nvmlInit = nvmlInit_v2
}
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v2")
err = l.dl.Lookup("nvmlDeviceGetPciInfo_v2")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2
}
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v3")
err = l.dl.Lookup("nvmlDeviceGetPciInfo_v3")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3
}
err = l.LookupSymbol("nvmlDeviceGetCount_v2")
err = l.dl.Lookup("nvmlDeviceGetCount_v2")
if err == nil {
nvmlDeviceGetCount = nvmlDeviceGetCount_v2
}
err = l.LookupSymbol("nvmlDeviceGetHandleByIndex_v2")
err = l.dl.Lookup("nvmlDeviceGetHandleByIndex_v2")
if err == nil {
nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2
}
err = l.LookupSymbol("nvmlDeviceGetHandleByPciBusId_v2")
err = l.dl.Lookup("nvmlDeviceGetHandleByPciBusId_v2")
if err == nil {
nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2
}
err = l.LookupSymbol("nvmlDeviceGetNvLinkRemotePciInfo_v2")
err = l.dl.Lookup("nvmlDeviceGetNvLinkRemotePciInfo_v2")
if err == nil {
nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2
}
// Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes
// a different set of parameters than the v1 function.
//err = l.LookupSymbol("nvmlDeviceRemoveGpu_v2")
//err = l.dl.Lookup("nvmlDeviceRemoveGpu_v2")
//if err == nil {
// nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2
//}
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v2")
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v2")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2
}
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v3")
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v3")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3
}
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v4")
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v4")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4
}
err = l.LookupSymbol("nvmlEventSetWait_v2")
err = l.dl.Lookup("nvmlEventSetWait_v2")
if err == nil {
nvmlEventSetWait = nvmlEventSetWait_v2
}
err = l.LookupSymbol("nvmlDeviceGetAttributes_v2")
err = l.dl.Lookup("nvmlDeviceGetAttributes_v2")
if err == nil {
nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2
}
err = l.LookupSymbol("nvmlComputeInstanceGetInfo_v2")
err = l.dl.Lookup("nvmlComputeInstanceGetInfo_v2")
if err == nil {
nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2
}
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v2")
err = l.dl.Lookup("nvmlDeviceGetComputeRunningProcesses_v2")
if err == nil {
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2
}
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v3")
err = l.dl.Lookup("nvmlDeviceGetComputeRunningProcesses_v3")
if err == nil {
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3
}
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v2")
err = l.dl.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v2")
if err == nil {
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2
}
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v3")
err = l.dl.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v3")
if err == nil {
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3
}
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v2")
err = l.dl.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v2")
if err == nil {
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2
}
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v3")
err = l.dl.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v3")
if err == nil {
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3
}
err = l.LookupSymbol("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
err = l.dl.Lookup("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
if err == nil {
nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2
}
err = l.LookupSymbol("nvmlVgpuInstanceGetLicenseInfo_v2")
err = l.dl.Lookup("nvmlVgpuInstanceGetLicenseInfo_v2")
if err == nil {
nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2
}
Expand Down

0 comments on commit c7513c3

Please sign in to comment.