From bb5e38c62e4c6b1d8c12506b10901350afee0860 Mon Sep 17 00:00:00 2001 From: Kevin Klues Date: Wed, 3 Apr 2024 19:45:51 +0000 Subject: [PATCH] Remove global newDynamicLibrary function and replace with embedded one Signed-off-by: Kevin Klues --- pkg/nvml/lib.go | 22 ++++++------ pkg/nvml/lib_test.go | 79 ++++++++++++++++++-------------------------- 2 files changed, 44 insertions(+), 57 deletions(-) diff --git a/pkg/nvml/lib.go b/pkg/nvml/lib.go index be4b3fb..79bfaaa 100644 --- a/pkg/nvml/lib.go +++ b/pkg/nvml/lib.go @@ -34,20 +34,27 @@ const ( var errLibraryNotLoaded = errors.New("library not loaded") var errLibraryAlreadyLoaded = errors.New("library already loaded") +// newDynamicLibraryFunc defines a function to create a new dynamicLibrary +type newDynamicLibraryFunc func(path string, flags int) dynamicLibrary + // library represents an nvml library. // This includes a reference to the underlying DynamicLibrary type library struct { sync.Mutex - path string - flags int - refcount refcount - dl dynamicLibrary + path string + flags int + refcount refcount + dl dynamicLibrary + newDynamicLibrary newDynamicLibraryFunc } // libnvml is a global instance of the nvml library. var libnvml = library{ path: defaultNvmlLibraryName, flags: defaultNvmlLibraryLoadFlags, + newDynamicLibrary: func(path string, flags int) dynamicLibrary { + return dl.New(path, flags) + }, } var _ Interface = (*library)(nil) @@ -71,11 +78,6 @@ func (l *library) Lookup(name string) error { return l.dl.Lookup(name) } -// newDynamicLibrary is a function variable that can be overridden for testing. -var newDynamicLibrary = func(path string, flags int) dynamicLibrary { - return dl.New(path, flags) -} - // load initializes the library and updates the versioned symbols. // Multiple calls to an already loaded library will return without error. func (l *library) load() (rerr error) { @@ -87,7 +89,7 @@ func (l *library) load() (rerr error) { return nil } - dl := newDynamicLibrary(l.path, l.flags) + dl := l.newDynamicLibrary(l.path, l.flags) if err := dl.Open(); err != nil { return fmt.Errorf("error opening %s: %w", l.path, err) } diff --git a/pkg/nvml/lib_test.go b/pkg/nvml/lib_test.go index cfe8477..817b431 100644 --- a/pkg/nvml/lib_test.go +++ b/pkg/nvml/lib_test.go @@ -24,6 +24,14 @@ import ( "github.com/stretchr/testify/require" ) +func newTestLibrary(dl dynamicLibrary) *library { + return &library{ + newDynamicLibrary: func(string, int) dynamicLibrary { + return dl + }, + } +} + func TestLookupFromDefault(t *testing.T) { errClose := errors.New("close error") errOpen := errors.New("open error") @@ -31,7 +39,7 @@ func TestLookupFromDefault(t *testing.T) { testCases := []struct { description string - library dynamicLibrary + dl dynamicLibrary skipLoadLibrary bool expectedLoadError error expectedLookupErrror error @@ -39,13 +47,13 @@ func TestLookupFromDefault(t *testing.T) { }{ { description: "library not loaded yields error", - library: &dynamicLibraryMock{}, + dl: &dynamicLibraryMock{}, skipLoadLibrary: true, expectedLookupErrror: errLibraryNotLoaded, }, { description: "open error is returned", - library: &dynamicLibraryMock{ + dl: &dynamicLibraryMock{ OpenFunc: func() error { return errOpen }, @@ -56,7 +64,7 @@ func TestLookupFromDefault(t *testing.T) { }, { description: "lookup error is returned", - library: &dynamicLibraryMock{ + dl: &dynamicLibraryMock{ OpenFunc: func() error { return nil }, @@ -72,7 +80,7 @@ func TestLookupFromDefault(t *testing.T) { }, { description: "lookup succeeds", - library: &dynamicLibraryMock{ + dl: &dynamicLibraryMock{ OpenFunc: func() error { return nil }, @@ -86,7 +94,7 @@ func TestLookupFromDefault(t *testing.T) { }, { description: "lookup succeeds", - library: &dynamicLibraryMock{ + dl: &dynamicLibraryMock{ OpenFunc: func() error { return nil }, @@ -100,7 +108,7 @@ func TestLookupFromDefault(t *testing.T) { }, { description: "close error is returned", - library: &dynamicLibraryMock{ + dl: &dynamicLibraryMock{ OpenFunc: func() error { return nil }, @@ -117,18 +125,16 @@ func TestLookupFromDefault(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - defer setNewDynamicLibraryDuringTest(tc.library)() - defer resetLibrary() - l := GetLibrary() + l := newTestLibrary(tc.dl) if !tc.skipLoadLibrary { - require.ErrorIs(t, libnvml.load(), tc.expectedLoadError) + require.ErrorIs(t, l.load(), tc.expectedLoadError) } require.ErrorIs(t, l.Lookup("symbol"), tc.expectedLookupErrror) - require.ErrorIs(t, libnvml.close(), tc.expectedCloseError) + require.ErrorIs(t, l.close(), tc.expectedCloseError) if tc.expectedCloseError == nil { - require.Nil(t, libnvml.dl) + require.Nil(t, l.dl) } else { - require.Equal(t, tc.library, libnvml.dl) + require.Equal(t, tc.dl, l.dl) } }) } @@ -144,30 +150,29 @@ func TestLoadAndCloseNesting(t *testing.T) { }, } - defer setNewDynamicLibraryDuringTest(dl)() - defer resetLibrary() + l := newTestLibrary(dl) // When calling close before opening the library nothing happens. require.Equal(t, 0, len(dl.calls.Close)) - require.Nil(t, libnvml.close()) + require.Nil(t, l.close()) require.Equal(t, 0, len(dl.calls.Close)) // When calling load twice, the library was only opened once require.Equal(t, 0, len(dl.calls.Open)) - require.Nil(t, libnvml.load()) + require.Nil(t, l.load()) require.Equal(t, 1, len(dl.calls.Open)) - require.Nil(t, libnvml.load()) + require.Nil(t, l.load()) require.Equal(t, 1, len(dl.calls.Open)) // Only after calling close twice, was the library closed require.Equal(t, 0, len(dl.calls.Close)) - require.Nil(t, libnvml.close()) + require.Nil(t, l.close()) require.Equal(t, 0, len(dl.calls.Close)) - require.Nil(t, libnvml.close()) + require.Nil(t, l.close()) require.Equal(t, 1, len(dl.calls.Close)) // Calling close again doesn't attempt to close the library again - require.Nil(t, libnvml.close()) + require.Nil(t, l.close()) require.Equal(t, 1, len(dl.calls.Close)) } @@ -234,31 +239,11 @@ func TestLoadAndCloseWithErrors(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - defer setNewDynamicLibraryDuringTest(tc.dl)() - defer resetLibrary() - - _ = libnvml.load() - require.Equal(t, tc.expectedLoadRefcount, libnvml.refcount) - _ = libnvml.close() - require.Equal(t, tc.expectedCloseRefcount, libnvml.refcount) + l := newTestLibrary(tc.dl) + _ = l.load() + require.Equal(t, tc.expectedLoadRefcount, l.refcount) + _ = l.close() + require.Equal(t, tc.expectedCloseRefcount, l.refcount) }) } } - -func setNewDynamicLibraryDuringTest(dl dynamicLibrary) func() { - original := newDynamicLibrary - newDynamicLibrary = func(string, int) dynamicLibrary { - return dl - } - - return func() { - newDynamicLibrary = original - } -} - -func resetLibrary() { - libnvml = library{ - path: defaultNvmlLibraryName, - flags: defaultNvmlLibraryLoadFlags, - } -}