Skip to content

Commit

Permalink
Call library.dl.Lookup() directly in updateVersionedSymbols
Browse files Browse the repository at this point in the history
The refcount is still 0 within the body of load(), making the check in
LookupSymbol think that the library isn't loaded yet (when in fact it is) while
calling updateVersionedSymbols(). This change replaces calls with
library.LookupSymbol() within updateVersionedSymbols() to calls of
library.dl.Lookup() directly to avoid the check against refcount.

Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Apr 22, 2024
1 parent 8a7ac4d commit 31312e5
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,93 +198,93 @@ func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo {
// When new versioned symbols are added, these would have to be initialized above and have
// corresponding checks and subsequent assignments added below.
func (l *library) updateVersionedSymbols() {
err := l.LookupSymbol("nvmlInit_v2")
err := l.dl.Lookup("nvmlInit_v2")
if err == nil {
nvmlInit = nvmlInit_v2
}
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v2")
err = l.dl.Lookup("nvmlDeviceGetPciInfo_v2")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2
}
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v3")
err = l.dl.Lookup("nvmlDeviceGetPciInfo_v3")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3
}
err = l.LookupSymbol("nvmlDeviceGetCount_v2")
err = l.dl.Lookup("nvmlDeviceGetCount_v2")
if err == nil {
nvmlDeviceGetCount = nvmlDeviceGetCount_v2
}
err = l.LookupSymbol("nvmlDeviceGetHandleByIndex_v2")
err = l.dl.Lookup("nvmlDeviceGetHandleByIndex_v2")
if err == nil {
nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2
}
err = l.LookupSymbol("nvmlDeviceGetHandleByPciBusId_v2")
err = l.dl.Lookup("nvmlDeviceGetHandleByPciBusId_v2")
if err == nil {
nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2
}
err = l.LookupSymbol("nvmlDeviceGetNvLinkRemotePciInfo_v2")
err = l.dl.Lookup("nvmlDeviceGetNvLinkRemotePciInfo_v2")
if err == nil {
nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2
}
// Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes
// a different set of parameters than the v1 function.
//err = l.LookupSymbol("nvmlDeviceRemoveGpu_v2")
//err = l.dl.Lookup("nvmlDeviceRemoveGpu_v2")
//if err == nil {
// nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2
//}
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v2")
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v2")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2
}
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v3")
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v3")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3
}
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v4")
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v4")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4
}
err = l.LookupSymbol("nvmlEventSetWait_v2")
err = l.dl.Lookup("nvmlEventSetWait_v2")
if err == nil {
nvmlEventSetWait = nvmlEventSetWait_v2
}
err = l.LookupSymbol("nvmlDeviceGetAttributes_v2")
err = l.dl.Lookup("nvmlDeviceGetAttributes_v2")
if err == nil {
nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2
}
err = l.LookupSymbol("nvmlComputeInstanceGetInfo_v2")
err = l.dl.Lookup("nvmlComputeInstanceGetInfo_v2")
if err == nil {
nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2
}
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v2")
err = l.dl.Lookup("nvmlDeviceGetComputeRunningProcesses_v2")
if err == nil {
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2
}
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v3")
err = l.dl.Lookup("nvmlDeviceGetComputeRunningProcesses_v3")
if err == nil {
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3
}
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v2")
err = l.dl.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v2")
if err == nil {
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2
}
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v3")
err = l.dl.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v3")
if err == nil {
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3
}
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v2")
err = l.dl.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v2")
if err == nil {
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2
}
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v3")
err = l.dl.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v3")
if err == nil {
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3
}
err = l.LookupSymbol("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
err = l.dl.Lookup("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
if err == nil {
nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2
}
err = l.LookupSymbol("nvmlVgpuInstanceGetLicenseInfo_v2")
err = l.dl.Lookup("nvmlVgpuInstanceGetLicenseInfo_v2")
if err == nil {
nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2
}
Expand Down

0 comments on commit 31312e5

Please sign in to comment.