Skip to content

Commit

Permalink
Add a refcount to dl load() and close() calls
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Mar 14, 2024
1 parent 20c256c commit c61a8c8
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 14 deletions.
25 changes: 18 additions & 7 deletions gen/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded")
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
dl dynamicLibrary
path string
flags int
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
Expand Down Expand Up @@ -77,10 +78,14 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary {

// load initializes the library and updates the versioned symbols.
// Multiple calls to an already loaded library will return without error.
func (l *library) load() error {
func (l *library) load() (rerr error) {
l.Lock()
defer l.Unlock()
if l.dl != nil {

l.refcount.Inc()
defer l.refcount.DecOnError(rerr)

if l.refcount > 1 {
return nil
}

Expand All @@ -99,10 +104,17 @@ func (l *library) load() error {
// close the underlying library and ensure that the global pointer to the
// library is set to nil to ensure that subsequent calls to open will reinitialize it.
// Multiple calls to an already closed nvml library will return without error.
func (l *library) close() error {
func (l *library) close() (rerr error) {
l.Lock()
defer l.Unlock()

l.refcount.Dec()
defer l.refcount.IncOnError(rerr)

if l.refcount > 0 {
return nil
}

if l.dl == nil {
return nil
}
Expand All @@ -111,7 +123,6 @@ func (l *library) close() error {
if err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

l.dl = nil

return nil
Expand Down
43 changes: 43 additions & 0 deletions gen/nvml/refcount.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvml

import ()

type refcount int

func (r *refcount) Inc() {
(*r)++
}

func (r *refcount) Dec() {
if *r > 0 {
(*r)--
}
}

func (r *refcount) IncOnError(err error) {
if err != nil {
r.Inc()
}
}

func (r *refcount) DecOnError(err error) {
if err != nil {
r.Dec()
}
}
25 changes: 18 additions & 7 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded")
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
dl dynamicLibrary
path string
flags int
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
Expand Down Expand Up @@ -77,10 +78,14 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary {

// load initializes the library and updates the versioned symbols.
// Multiple calls to an already loaded library will return without error.
func (l *library) load() error {
func (l *library) load() (rerr error) {
l.Lock()
defer l.Unlock()
if l.dl != nil {

l.refcount.Inc()
defer l.refcount.DecOnError(rerr)

if l.refcount > 1 {
return nil
}

Expand All @@ -99,10 +104,17 @@ func (l *library) load() error {
// close the underlying library and ensure that the global pointer to the
// library is set to nil to ensure that subsequent calls to open will reinitialize it.
// Multiple calls to an already closed nvml library will return without error.
func (l *library) close() error {
func (l *library) close() (rerr error) {
l.Lock()
defer l.Unlock()

l.refcount.Dec()
defer l.refcount.IncOnError(rerr)

if l.refcount > 0 {
return nil
}

if l.dl == nil {
return nil
}
Expand All @@ -111,7 +123,6 @@ func (l *library) close() error {
if err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

l.dl = nil

return nil
Expand Down
43 changes: 43 additions & 0 deletions pkg/nvml/refcount.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvml

import ()

type refcount int

func (r *refcount) Inc() {
(*r)++
}

func (r *refcount) Dec() {
if *r > 0 {
(*r)--
}
}

func (r *refcount) IncOnError(err error) {
if err != nil {
r.Inc()
}
}

func (r *refcount) DecOnError(err error) {
if err != nil {
r.Dec()
}
}

0 comments on commit c61a8c8

Please sign in to comment.