Skip to content

Commit

Permalink
Merge pull request #111 from klueska/simplify-lib
Browse files Browse the repository at this point in the history
Remove need for newDynamicLibraryFunc
  • Loading branch information
klueska authored Apr 8, 2024
2 parents f3a57ee + 4c1d45c commit 65a28b0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 35 deletions.
54 changes: 26 additions & 28 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,19 @@ const (
var errLibraryNotLoaded = errors.New("library not loaded")
var errLibraryAlreadyLoaded = errors.New("library already loaded")

// newDynamicLibraryFunc defines a function to create a new dynamicLibrary
type newDynamicLibraryFunc func(path string, flags int) dynamicLibrary

// library represents an nvml library.
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
refcount refcount
dl dynamicLibrary
newDynamicLibrary newDynamicLibraryFunc
path string
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
var libnvml = library{
path: defaultNvmlLibraryName,
flags: defaultNvmlLibraryLoadFlags,
newDynamicLibrary: func(path string, flags int) dynamicLibrary {
return dl.New(path, flags)
},
path: defaultNvmlLibraryName,
dl: dl.New(defaultNvmlLibraryName, defaultNvmlLibraryLoadFlags),
}

var _ Interface = (*library)(nil)
Expand All @@ -72,7 +64,7 @@ func GetLibrary() Library {
// Lookup checks whether the specified library symbol exists in the library.
// Note that this requires that the library be loaded.
func (l *library) Lookup(name string) error {
if l == nil || l.dl == nil {
if l == nil || l.refcount == 0 {
return fmt.Errorf("error looking up %s: %w", name, errLibraryNotLoaded)
}
return l.dl.Lookup(name)
Expand All @@ -89,12 +81,10 @@ func (l *library) load() (rerr error) {
return nil
}

dl := l.newDynamicLibrary(l.path, l.flags)
if err := dl.Open(); err != nil {
if err := l.dl.Open(); err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}

l.dl = dl
l.updateVersionedSymbols()

return nil
Expand All @@ -116,8 +106,6 @@ func (l *library) close() (rerr error) {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

l.dl = nil

return nil
}

Expand Down Expand Up @@ -267,13 +255,19 @@ func (l *library) updateVersionedSymbols() {
}
}

// libraryOptions hold the paramaters than can be set by a LibraryOption
type libraryOptions struct {
path string
flags int
}

// LibraryOption represents a functional option to configure the underlying NVML library
type LibraryOption func(*library)
type LibraryOption func(*libraryOptions)

// WithLibraryPath provides an option to set the library name to be used by the NVML library.
func WithLibraryPath(path string) LibraryOption {
return func(l *library) {
l.path = path
return func(o *libraryOptions) {
o.path = path
}
}

Expand All @@ -282,20 +276,24 @@ func WithLibraryPath(path string) LibraryOption {
func SetLibraryOptions(opts ...LibraryOption) error {
libnvml.Lock()
defer libnvml.Unlock()
if libnvml.dl != nil {
if libnvml.refcount != 0 {
return errLibraryAlreadyLoaded
}

o := libraryOptions{}
for _, opt := range opts {
opt(&libnvml)
opt(&o)
}

if libnvml.path == "" {
libnvml.path = defaultNvmlLibraryName
if o.path == "" {
o.path = defaultNvmlLibraryName
}
if libnvml.flags == 0 {
libnvml.flags = defaultNvmlLibraryLoadFlags
if o.flags == 0 {
o.flags = defaultNvmlLibraryLoadFlags
}

libnvml.path = o.path
libnvml.dl = dl.New(o.path, o.flags)

return nil
}
10 changes: 3 additions & 7 deletions pkg/nvml/lib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ import (
)

func newTestLibrary(dl dynamicLibrary) *library {
return &library{
newDynamicLibrary: func(string, int) dynamicLibrary {
return dl
},
}
return &library{dl: dl}
}

func TestLookupFromDefault(t *testing.T) {
Expand Down Expand Up @@ -132,9 +128,9 @@ func TestLookupFromDefault(t *testing.T) {
require.ErrorIs(t, l.Lookup("symbol"), tc.expectedLookupErrror)
require.ErrorIs(t, l.close(), tc.expectedCloseError)
if tc.expectedCloseError == nil {
require.Nil(t, l.dl)
require.Equal(t, 0, int(l.refcount))
} else {
require.Equal(t, tc.dl, l.dl)
require.Equal(t, 1, int(l.refcount))
}
})
}
Expand Down

0 comments on commit 65a28b0

Please sign in to comment.