diff --git a/gen/nvml/lib.go b/gen/nvml/lib.go index 4d5eb8e..42d2666 100644 --- a/gen/nvml/lib.go +++ b/gen/nvml/lib.go @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded") // This includes a reference to the underlying DynamicLibrary type library struct { sync.Mutex - path string - flags int - dl dynamicLibrary + path string + flags int + refcount refcount + dl dynamicLibrary } // libnvml is a global instance of the nvml library. @@ -77,16 +78,17 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary { // load initializes the library and updates the versioned symbols. // Multiple calls to an already loaded library will return without error. -func (l *library) load() error { +func (l *library) load() (rerr error) { l.Lock() defer l.Unlock() - if l.dl != nil { + + defer func() { l.refcount.IncOnNoError(rerr) }() + if l.refcount > 0 { return nil } dl := newDynamicLibrary(l.path, l.flags) - err := dl.Open() - if err != nil { + if err := dl.Open(); err != nil { return fmt.Errorf("error opening %s: %w", l.path, err) } @@ -99,16 +101,18 @@ func (l *library) load() error { // close the underlying library and ensure that the global pointer to the // library is set to nil to ensure that subsequent calls to open will reinitialize it. // Multiple calls to an already closed nvml library will return without error. -func (l *library) close() error { +func (l *library) close() (rerr error) { l.Lock() defer l.Unlock() - if l.dl == nil { + fmt.Printf("refcount: %v\n", l.refcount) + + defer func() { l.refcount.DecOnNoError(rerr) }() + if l.refcount != 1 { return nil } - err := l.dl.Close() - if err != nil { + if err := l.dl.Close(); err != nil { return fmt.Errorf("error closing %s: %w", l.path, err) } diff --git a/gen/nvml/lib_test.go b/gen/nvml/lib_test.go index 9059605..a800b7e 100644 --- a/gen/nvml/lib_test.go +++ b/gen/nvml/lib_test.go @@ -134,6 +134,43 @@ func TestLookupFromDefault(t *testing.T) { } } +func TestLoadAndCloseNesting(t *testing.T) { + var numOpens, numCloses int + + dl := &dynamicLibraryMock{ + OpenFunc: func() error { + numOpens++ + return nil + }, + CloseFunc: func() error { + numCloses++ + return nil + }, + } + + defer setNewDynamicLibraryDuringTest(dl)() + defer resetLibrary() + GetLibrary() + + // When calling load twice, the library was only opened once + require.Equal(t, 0, numOpens) + require.Nil(t, libnvml.load()) + require.Equal(t, 1, numOpens) + require.Nil(t, libnvml.load()) + require.Equal(t, 1, numOpens) + + // Only after calling close twice, was the library closed + require.Equal(t, 0, numCloses) + require.Nil(t, libnvml.close()) + require.Equal(t, 0, numCloses) + require.Nil(t, libnvml.close()) + require.Equal(t, 1, numCloses) + + // Calling close again doesn't attempt to close the library again + require.Nil(t, libnvml.close()) + require.Equal(t, 1, numCloses) +} + func setNewDynamicLibraryDuringTest(dl dynamicLibrary) func() { original := newDynamicLibrary newDynamicLibrary = func(string, int) dynamicLibrary { diff --git a/gen/nvml/refcount.go b/gen/nvml/refcount.go new file mode 100644 index 0000000..4d1e212 --- /dev/null +++ b/gen/nvml/refcount.go @@ -0,0 +1,31 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvml + +type refcount int + +func (r *refcount) IncOnNoError(err error) { + if err == nil { + (*r)++ + } +} + +func (r *refcount) DecOnNoError(err error) { + if err == nil && (*r) > 0 { + (*r)-- + } +} diff --git a/pkg/nvml/lib.go b/pkg/nvml/lib.go index 4d5eb8e..42d2666 100644 --- a/pkg/nvml/lib.go +++ b/pkg/nvml/lib.go @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded") // This includes a reference to the underlying DynamicLibrary type library struct { sync.Mutex - path string - flags int - dl dynamicLibrary + path string + flags int + refcount refcount + dl dynamicLibrary } // libnvml is a global instance of the nvml library. @@ -77,16 +78,17 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary { // load initializes the library and updates the versioned symbols. // Multiple calls to an already loaded library will return without error. -func (l *library) load() error { +func (l *library) load() (rerr error) { l.Lock() defer l.Unlock() - if l.dl != nil { + + defer func() { l.refcount.IncOnNoError(rerr) }() + if l.refcount > 0 { return nil } dl := newDynamicLibrary(l.path, l.flags) - err := dl.Open() - if err != nil { + if err := dl.Open(); err != nil { return fmt.Errorf("error opening %s: %w", l.path, err) } @@ -99,16 +101,18 @@ func (l *library) load() error { // close the underlying library and ensure that the global pointer to the // library is set to nil to ensure that subsequent calls to open will reinitialize it. // Multiple calls to an already closed nvml library will return without error. -func (l *library) close() error { +func (l *library) close() (rerr error) { l.Lock() defer l.Unlock() - if l.dl == nil { + fmt.Printf("refcount: %v\n", l.refcount) + + defer func() { l.refcount.DecOnNoError(rerr) }() + if l.refcount != 1 { return nil } - err := l.dl.Close() - if err != nil { + if err := l.dl.Close(); err != nil { return fmt.Errorf("error closing %s: %w", l.path, err) } diff --git a/pkg/nvml/lib_test.go b/pkg/nvml/lib_test.go index 9059605..a800b7e 100644 --- a/pkg/nvml/lib_test.go +++ b/pkg/nvml/lib_test.go @@ -134,6 +134,43 @@ func TestLookupFromDefault(t *testing.T) { } } +func TestLoadAndCloseNesting(t *testing.T) { + var numOpens, numCloses int + + dl := &dynamicLibraryMock{ + OpenFunc: func() error { + numOpens++ + return nil + }, + CloseFunc: func() error { + numCloses++ + return nil + }, + } + + defer setNewDynamicLibraryDuringTest(dl)() + defer resetLibrary() + GetLibrary() + + // When calling load twice, the library was only opened once + require.Equal(t, 0, numOpens) + require.Nil(t, libnvml.load()) + require.Equal(t, 1, numOpens) + require.Nil(t, libnvml.load()) + require.Equal(t, 1, numOpens) + + // Only after calling close twice, was the library closed + require.Equal(t, 0, numCloses) + require.Nil(t, libnvml.close()) + require.Equal(t, 0, numCloses) + require.Nil(t, libnvml.close()) + require.Equal(t, 1, numCloses) + + // Calling close again doesn't attempt to close the library again + require.Nil(t, libnvml.close()) + require.Equal(t, 1, numCloses) +} + func setNewDynamicLibraryDuringTest(dl dynamicLibrary) func() { original := newDynamicLibrary newDynamicLibrary = func(string, int) dynamicLibrary { diff --git a/pkg/nvml/refcount.go b/pkg/nvml/refcount.go new file mode 100644 index 0000000..4d1e212 --- /dev/null +++ b/pkg/nvml/refcount.go @@ -0,0 +1,31 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvml + +type refcount int + +func (r *refcount) IncOnNoError(err error) { + if err == nil { + (*r)++ + } +} + +func (r *refcount) DecOnNoError(err error) { + if err == nil && (*r) > 0 { + (*r)-- + } +} diff --git a/pkg/nvml/refcount_test.go b/pkg/nvml/refcount_test.go new file mode 100644 index 0000000..73bf479 --- /dev/null +++ b/pkg/nvml/refcount_test.go @@ -0,0 +1,139 @@ +/** +# Copyright 2023 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvml + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRefcount(t *testing.T) { + testCases := []struct { + description string + workload func(r *refcount) + expectedRefcount refcount + }{ + { + description: "No inc or dec", + workload: func(r *refcount) {}, + expectedRefcount: refcount(0), + }, + { + description: "Single inc, no error", + workload: func(r *refcount) { + r.IncOnNoError(nil) + }, + expectedRefcount: refcount(1), + }, + { + description: "Single inc, with error", + workload: func(r *refcount) { + r.IncOnNoError(errors.ErrUnsupported) + }, + expectedRefcount: refcount(0), + }, + { + description: "Double inc, no error", + workload: func(r *refcount) { + r.IncOnNoError(nil) + r.IncOnNoError(nil) + }, + expectedRefcount: refcount(2), + }, + { + description: "Double inc, one with error", + workload: func(r *refcount) { + r.IncOnNoError(nil) + r.IncOnNoError(errors.ErrUnsupported) + }, + expectedRefcount: refcount(1), + }, + { + description: "Single dec, no error", + workload: func(r *refcount) { + r.DecOnNoError(nil) + }, + expectedRefcount: refcount(0), + }, + { + description: "Single dec, with error", + workload: func(r *refcount) { + r.DecOnNoError(errors.ErrUnsupported) + }, + expectedRefcount: refcount(0), + }, + { + description: "Single inc, single dec, no errors", + workload: func(r *refcount) { + r.IncOnNoError(nil) + r.DecOnNoError(nil) + }, + expectedRefcount: refcount(0), + }, + { + description: "Double inc, Double dec, no errors", + workload: func(r *refcount) { + r.IncOnNoError(nil) + r.IncOnNoError(nil) + r.DecOnNoError(nil) + r.DecOnNoError(nil) + }, + expectedRefcount: refcount(0), + }, + { + description: "Double inc, Double dec, one inc error", + workload: func(r *refcount) { + r.IncOnNoError(nil) + r.IncOnNoError(errors.ErrUnsupported) + r.DecOnNoError(nil) + r.DecOnNoError(nil) + }, + expectedRefcount: refcount(0), + }, + { + description: "Double inc, Double dec, one dec error", + workload: func(r *refcount) { + r.IncOnNoError(nil) + r.IncOnNoError(nil) + r.DecOnNoError(nil) + r.DecOnNoError(errors.ErrUnsupported) + }, + expectedRefcount: refcount(1), + }, + { + description: "Double inc, Tripple dec, one dec error early on", + workload: func(r *refcount) { + r.IncOnNoError(nil) + r.IncOnNoError(nil) + r.DecOnNoError(errors.ErrUnsupported) + r.DecOnNoError(nil) + r.DecOnNoError(nil) + }, + expectedRefcount: refcount(0), + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + var r refcount + tc.workload(&r) + require.Equal(t, tc.expectedRefcount, r) + }) + } +}