Skip to content

Commit

Permalink
Remove need for newDynamicLibraryFunc
Browse files Browse the repository at this point in the history
With the introduction of the refcount we no longer need to set library.dl to
nil to track when it is loaded or unloaded. Instead we can use the refcount
directly.

Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Apr 8, 2024
1 parent 5362631 commit 4c1d45c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 35 deletions.
54 changes: 26 additions & 28 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,19 @@ const (
var errLibraryNotLoaded = errors.New("library not loaded")
var errLibraryAlreadyLoaded = errors.New("library already loaded")

// newDynamicLibraryFunc defines a function to create a new dynamicLibrary
type newDynamicLibraryFunc func(path string, flags int) dynamicLibrary

// library represents an nvml library.
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
refcount refcount
dl dynamicLibrary
newDynamicLibrary newDynamicLibraryFunc
path string
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
var libnvml = library{
path: defaultNvmlLibraryName,
flags: defaultNvmlLibraryLoadFlags,
newDynamicLibrary: func(path string, flags int) dynamicLibrary {
return dl.New(path, flags)
},
path: defaultNvmlLibraryName,
dl: dl.New(defaultNvmlLibraryName, defaultNvmlLibraryLoadFlags),
}

var _ Interface = (*library)(nil)
Expand All @@ -72,7 +64,7 @@ func GetLibrary() Library {
// Lookup checks whether the specified library symbol exists in the library.
// Note that this requires that the library be loaded.
func (l *library) Lookup(name string) error {
if l == nil || l.dl == nil {
if l == nil || l.refcount == 0 {
return fmt.Errorf("error looking up %s: %w", name, errLibraryNotLoaded)
}
return l.dl.Lookup(name)
Expand All @@ -89,12 +81,10 @@ func (l *library) load() (rerr error) {
return nil
}

dl := l.newDynamicLibrary(l.path, l.flags)
if err := dl.Open(); err != nil {
if err := l.dl.Open(); err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}

l.dl = dl
l.updateVersionedSymbols()

return nil
Expand All @@ -116,8 +106,6 @@ func (l *library) close() (rerr error) {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

l.dl = nil

return nil
}

Expand Down Expand Up @@ -267,13 +255,19 @@ func (l *library) updateVersionedSymbols() {
}
}

// libraryOptions hold the paramaters than can be set by a LibraryOption
type libraryOptions struct {
path string
flags int
}

// LibraryOption represents a functional option to configure the underlying NVML library
type LibraryOption func(*library)
type LibraryOption func(*libraryOptions)

// WithLibraryPath provides an option to set the library name to be used by the NVML library.
func WithLibraryPath(path string) LibraryOption {
return func(l *library) {
l.path = path
return func(o *libraryOptions) {
o.path = path
}
}

Expand All @@ -282,20 +276,24 @@ func WithLibraryPath(path string) LibraryOption {
func SetLibraryOptions(opts ...LibraryOption) error {
libnvml.Lock()
defer libnvml.Unlock()
if libnvml.dl != nil {
if libnvml.refcount != 0 {
return errLibraryAlreadyLoaded
}

o := libraryOptions{}
for _, opt := range opts {
opt(&libnvml)
opt(&o)
}

if libnvml.path == "" {
libnvml.path = defaultNvmlLibraryName
if o.path == "" {
o.path = defaultNvmlLibraryName
}
if libnvml.flags == 0 {
libnvml.flags = defaultNvmlLibraryLoadFlags
if o.flags == 0 {
o.flags = defaultNvmlLibraryLoadFlags
}

libnvml.path = o.path
libnvml.dl = dl.New(o.path, o.flags)

return nil
}
10 changes: 3 additions & 7 deletions pkg/nvml/lib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ import (
)

func newTestLibrary(dl dynamicLibrary) *library {
return &library{
newDynamicLibrary: func(string, int) dynamicLibrary {
return dl
},
}
return &library{dl: dl}
}

func TestLookupFromDefault(t *testing.T) {
Expand Down Expand Up @@ -132,9 +128,9 @@ func TestLookupFromDefault(t *testing.T) {
require.ErrorIs(t, l.Lookup("symbol"), tc.expectedLookupErrror)
require.ErrorIs(t, l.close(), tc.expectedCloseError)
if tc.expectedCloseError == nil {
require.Nil(t, l.dl)
require.Equal(t, 0, int(l.refcount))
} else {
require.Equal(t, tc.dl, l.dl)
require.Equal(t, 1, int(l.refcount))
}
})
}
Expand Down

0 comments on commit 4c1d45c

Please sign in to comment.