Skip to content

Commit

Permalink
Add a refcount to dl load() and close() calls
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Mar 14, 2024
1 parent 20c256c commit 9aa6a4f
Show file tree
Hide file tree
Showing 8 changed files with 444 additions and 22 deletions.
26 changes: 15 additions & 11 deletions gen/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded")
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
dl dynamicLibrary
path string
flags int
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
Expand Down Expand Up @@ -77,16 +78,17 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary {

// load initializes the library and updates the versioned symbols.
// Multiple calls to an already loaded library will return without error.
func (l *library) load() error {
func (l *library) load() (rerr error) {
l.Lock()
defer l.Unlock()
if l.dl != nil {

defer func() { l.refcount.IncOnNoError(rerr) }()
if l.refcount > 0 {
return nil
}

dl := newDynamicLibrary(l.path, l.flags)
err := dl.Open()
if err != nil {
if err := dl.Open(); err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}

Expand All @@ -99,16 +101,18 @@ func (l *library) load() error {
// close the underlying library and ensure that the global pointer to the
// library is set to nil to ensure that subsequent calls to open will reinitialize it.
// Multiple calls to an already closed nvml library will return without error.
func (l *library) close() error {
func (l *library) close() (rerr error) {
l.Lock()
defer l.Unlock()

if l.dl == nil {
fmt.Printf("refcount: %v\n", l.refcount)

defer func() { l.refcount.DecOnNoError(rerr) }()
if l.refcount != 1 {
return nil
}

err := l.dl.Close()
if err != nil {
if err := l.dl.Close(); err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

Expand Down
37 changes: 37 additions & 0 deletions gen/nvml/lib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ func TestLookupFromDefault(t *testing.T) {
}
}

func TestLoadAndCloseNesting(t *testing.T) {
var numOpens, numCloses int

dl := &dynamicLibraryMock{
OpenFunc: func() error {
numOpens++
return nil
},
CloseFunc: func() error {
numCloses++
return nil
},
}

defer setNewDynamicLibraryDuringTest(dl)()
defer resetLibrary()
GetLibrary()

// When calling load twice, the library was only opened once
require.Equal(t, 0, numOpens)
require.Nil(t, libnvml.load())
require.Equal(t, 1, numOpens)
require.Nil(t, libnvml.load())
require.Equal(t, 1, numOpens)

// Only after calling close twice, was the library closed
require.Equal(t, 0, numCloses)
require.Nil(t, libnvml.close())
require.Equal(t, 0, numCloses)
require.Nil(t, libnvml.close())
require.Equal(t, 1, numCloses)

// Calling close again doesn't attempt to close the library again
require.Nil(t, libnvml.close())
require.Equal(t, 1, numCloses)
}

func setNewDynamicLibraryDuringTest(dl dynamicLibrary) func() {
original := newDynamicLibrary
newDynamicLibrary = func(string, int) dynamicLibrary {
Expand Down
31 changes: 31 additions & 0 deletions gen/nvml/refcount.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvml

type refcount int

func (r *refcount) IncOnNoError(err error) {
if err == nil {
(*r)++
}
}

func (r *refcount) DecOnNoError(err error) {
if err == nil && (*r) > 0 {
(*r)--
}
}
139 changes: 139 additions & 0 deletions gen/nvml/refcount_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/**
# Copyright 2023 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvml

import (
"errors"
"testing"

"github.com/stretchr/testify/require"
)

func TestRefcount(t *testing.T) {
testCases := []struct {
description string
workload func(r *refcount)
expectedRefcount refcount
}{
{
description: "No inc or dec",
workload: func(r *refcount) {},
expectedRefcount: refcount(0),
},
{
description: "Single inc, no error",
workload: func(r *refcount) {
r.IncOnNoError(nil)
},
expectedRefcount: refcount(1),
},
{
description: "Single inc, with error",
workload: func(r *refcount) {
r.IncOnNoError(errors.ErrUnsupported)
},
expectedRefcount: refcount(0),
},
{
description: "Double inc, no error",
workload: func(r *refcount) {
r.IncOnNoError(nil)
r.IncOnNoError(nil)
},
expectedRefcount: refcount(2),
},
{
description: "Double inc, one with error",
workload: func(r *refcount) {
r.IncOnNoError(nil)
r.IncOnNoError(errors.ErrUnsupported)
},
expectedRefcount: refcount(1),
},
{
description: "Single dec, no error",
workload: func(r *refcount) {
r.DecOnNoError(nil)
},
expectedRefcount: refcount(0),
},
{
description: "Single dec, with error",
workload: func(r *refcount) {
r.DecOnNoError(errors.ErrUnsupported)
},
expectedRefcount: refcount(0),
},
{
description: "Single inc, single dec, no errors",
workload: func(r *refcount) {
r.IncOnNoError(nil)
r.DecOnNoError(nil)
},
expectedRefcount: refcount(0),
},
{
description: "Double inc, Double dec, no errors",
workload: func(r *refcount) {
r.IncOnNoError(nil)
r.IncOnNoError(nil)
r.DecOnNoError(nil)
r.DecOnNoError(nil)
},
expectedRefcount: refcount(0),
},
{
description: "Double inc, Double dec, one inc error",
workload: func(r *refcount) {
r.IncOnNoError(nil)
r.IncOnNoError(errors.ErrUnsupported)
r.DecOnNoError(nil)
r.DecOnNoError(nil)
},
expectedRefcount: refcount(0),
},
{
description: "Double inc, Double dec, one dec error",
workload: func(r *refcount) {
r.IncOnNoError(nil)
r.IncOnNoError(nil)
r.DecOnNoError(nil)
r.DecOnNoError(errors.ErrUnsupported)
},
expectedRefcount: refcount(1),
},
{
description: "Double inc, Tripple dec, one dec error early on",
workload: func(r *refcount) {
r.IncOnNoError(nil)
r.IncOnNoError(nil)
r.DecOnNoError(errors.ErrUnsupported)
r.DecOnNoError(nil)
r.DecOnNoError(nil)
},
expectedRefcount: refcount(0),
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
var r refcount
tc.workload(&r)
require.Equal(t, tc.expectedRefcount, r)
})
}
}
26 changes: 15 additions & 11 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded")
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
dl dynamicLibrary
path string
flags int
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
Expand Down Expand Up @@ -77,16 +78,17 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary {

// load initializes the library and updates the versioned symbols.
// Multiple calls to an already loaded library will return without error.
func (l *library) load() error {
func (l *library) load() (rerr error) {
l.Lock()
defer l.Unlock()
if l.dl != nil {

defer func() { l.refcount.IncOnNoError(rerr) }()
if l.refcount > 0 {
return nil
}

dl := newDynamicLibrary(l.path, l.flags)
err := dl.Open()
if err != nil {
if err := dl.Open(); err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}

Expand All @@ -99,16 +101,18 @@ func (l *library) load() error {
// close the underlying library and ensure that the global pointer to the
// library is set to nil to ensure that subsequent calls to open will reinitialize it.
// Multiple calls to an already closed nvml library will return without error.
func (l *library) close() error {
func (l *library) close() (rerr error) {
l.Lock()
defer l.Unlock()

if l.dl == nil {
fmt.Printf("refcount: %v\n", l.refcount)

defer func() { l.refcount.DecOnNoError(rerr) }()
if l.refcount != 1 {
return nil
}

err := l.dl.Close()
if err != nil {
if err := l.dl.Close(); err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

Expand Down
37 changes: 37 additions & 0 deletions pkg/nvml/lib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ func TestLookupFromDefault(t *testing.T) {
}
}

func TestLoadAndCloseNesting(t *testing.T) {
var numOpens, numCloses int

dl := &dynamicLibraryMock{
OpenFunc: func() error {
numOpens++
return nil
},
CloseFunc: func() error {
numCloses++
return nil
},
}

defer setNewDynamicLibraryDuringTest(dl)()
defer resetLibrary()
GetLibrary()

// When calling load twice, the library was only opened once
require.Equal(t, 0, numOpens)
require.Nil(t, libnvml.load())
require.Equal(t, 1, numOpens)
require.Nil(t, libnvml.load())
require.Equal(t, 1, numOpens)

// Only after calling close twice, was the library closed
require.Equal(t, 0, numCloses)
require.Nil(t, libnvml.close())
require.Equal(t, 0, numCloses)
require.Nil(t, libnvml.close())
require.Equal(t, 1, numCloses)

// Calling close again doesn't attempt to close the library again
require.Nil(t, libnvml.close())
require.Equal(t, 1, numCloses)
}

func setNewDynamicLibraryDuringTest(dl dynamicLibrary) func() {
original := newDynamicLibrary
newDynamicLibrary = func(string, int) dynamicLibrary {
Expand Down
Loading

0 comments on commit 9aa6a4f

Please sign in to comment.