diff --git a/gen/nvml/lib.go b/gen/nvml/lib.go index 4d5eb8e..8d2307b 100644 --- a/gen/nvml/lib.go +++ b/gen/nvml/lib.go @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded") // This includes a reference to the underlying DynamicLibrary type library struct { sync.Mutex - path string - flags int - dl dynamicLibrary + path string + flags int + refcount refcount + dl dynamicLibrary } // libnvml is a global instance of the nvml library. @@ -77,10 +78,14 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary { // load initializes the library and updates the versioned symbols. // Multiple calls to an already loaded library will return without error. -func (l *library) load() error { +func (l *library) load() (rerr error) { l.Lock() defer l.Unlock() - if l.dl != nil { + + l.refcount.Inc() + defer l.refcount.DecOnError(rerr) + + if l.refcount > 1 { return nil } @@ -99,10 +104,17 @@ func (l *library) load() error { // close the underlying library and ensure that the global pointer to the // library is set to nil to ensure that subsequent calls to open will reinitialize it. // Multiple calls to an already closed nvml library will return without error. -func (l *library) close() error { +func (l *library) close() (rerr error) { l.Lock() defer l.Unlock() + l.refcount.Dec() + defer l.refcount.IncOnError(rerr) + + if l.refcount > 0 { + return nil + } + if l.dl == nil { return nil } @@ -111,7 +123,6 @@ func (l *library) close() error { if err != nil { return fmt.Errorf("error closing %s: %w", l.path, err) } - l.dl = nil return nil diff --git a/gen/nvml/refcount.go b/gen/nvml/refcount.go new file mode 100644 index 0000000..3e9435c --- /dev/null +++ b/gen/nvml/refcount.go @@ -0,0 +1,43 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvml + +import () + +type refcount int + +func (r *refcount) Inc() { + (*r)++ +} + +func (r *refcount) Dec() { + if *r > 0 { + (*r)-- + } +} + +func (r *refcount) IncOnError(err error) { + if err != nil { + r.Inc() + } +} + +func (r *refcount) DecOnError(err error) { + if err != nil { + r.Dec() + } +} diff --git a/pkg/nvml/lib.go b/pkg/nvml/lib.go index 4d5eb8e..8d2307b 100644 --- a/pkg/nvml/lib.go +++ b/pkg/nvml/lib.go @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded") // This includes a reference to the underlying DynamicLibrary type library struct { sync.Mutex - path string - flags int - dl dynamicLibrary + path string + flags int + refcount refcount + dl dynamicLibrary } // libnvml is a global instance of the nvml library. @@ -77,10 +78,14 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary { // load initializes the library and updates the versioned symbols. // Multiple calls to an already loaded library will return without error. -func (l *library) load() error { +func (l *library) load() (rerr error) { l.Lock() defer l.Unlock() - if l.dl != nil { + + l.refcount.Inc() + defer l.refcount.DecOnError(rerr) + + if l.refcount > 1 { return nil } @@ -99,10 +104,17 @@ func (l *library) load() error { // close the underlying library and ensure that the global pointer to the // library is set to nil to ensure that subsequent calls to open will reinitialize it. // Multiple calls to an already closed nvml library will return without error. -func (l *library) close() error { +func (l *library) close() (rerr error) { l.Lock() defer l.Unlock() + l.refcount.Dec() + defer l.refcount.IncOnError(rerr) + + if l.refcount > 0 { + return nil + } + if l.dl == nil { return nil } @@ -111,7 +123,6 @@ func (l *library) close() error { if err != nil { return fmt.Errorf("error closing %s: %w", l.path, err) } - l.dl = nil return nil diff --git a/pkg/nvml/refcount.go b/pkg/nvml/refcount.go new file mode 100644 index 0000000..3e9435c --- /dev/null +++ b/pkg/nvml/refcount.go @@ -0,0 +1,43 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvml + +import () + +type refcount int + +func (r *refcount) Inc() { + (*r)++ +} + +func (r *refcount) Dec() { + if *r > 0 { + (*r)-- + } +} + +func (r *refcount) IncOnError(err error) { + if err != nil { + r.Inc() + } +} + +func (r *refcount) DecOnError(err error) { + if err != nil { + r.Dec() + } +}