Skip to content

Commit

Permalink
Add a refcount to dl load() and close() calls
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Mar 14, 2024
1 parent 20c256c commit 928b2bf
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 48 deletions.
48 changes: 24 additions & 24 deletions gen/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded")
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
dl dynamicLibrary
path string
flags int
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
Expand Down Expand Up @@ -80,20 +81,20 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary {
func (l *library) load() error {
l.Lock()
defer l.Unlock()
if l.dl != nil {
return nil
}

dl := newDynamicLibrary(l.path, l.flags)
err := dl.Open()
if err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}
err := l.refcount.FirstIn(func() error {
dl := newDynamicLibrary(l.path, l.flags)
err := dl.Open()
if err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}

l.dl = dl
l.updateVersionedSymbols()
l.dl = dl
l.updateVersionedSymbols()
return nil
})

return nil
return err
}

// close the underlying library and ensure that the global pointer to the
Expand All @@ -103,18 +104,17 @@ func (l *library) close() error {
l.Lock()
defer l.Unlock()

if l.dl == nil {
return nil
}

err := l.dl.Close()
if err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}
err := l.refcount.LastOut(func() error {
err := l.dl.Close()
if err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

l.dl = nil
l.dl = nil
return nil
})

return nil
return err
}

// Default all versioned APIs to v1 (to infer the types)
Expand Down
48 changes: 48 additions & 0 deletions gen/nvml/refcount.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvml

import ()

type refcount int

func (r *refcount) FirstIn(f func() error) error {
(*r)++
if (*r) > 1 {
return nil
}

err := f()
if err != nil {
(*r)--
}

return err
}

func (r *refcount) LastOut(f func() error) error {
if (*r) != 1 {
return nil
}

err := f()
if err == nil {
(*r)--
}

return err
}
48 changes: 24 additions & 24 deletions pkg/nvml/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ var errLibraryAlreadyLoaded = errors.New("library already loaded")
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
dl dynamicLibrary
path string
flags int
refcount refcount
dl dynamicLibrary
}

// libnvml is a global instance of the nvml library.
Expand Down Expand Up @@ -80,20 +81,20 @@ var newDynamicLibrary = func(path string, flags int) dynamicLibrary {
func (l *library) load() error {
l.Lock()
defer l.Unlock()
if l.dl != nil {
return nil
}

dl := newDynamicLibrary(l.path, l.flags)
err := dl.Open()
if err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}
err := l.refcount.FirstIn(func() error {
dl := newDynamicLibrary(l.path, l.flags)
err := dl.Open()
if err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}

l.dl = dl
l.updateVersionedSymbols()
l.dl = dl
l.updateVersionedSymbols()
return nil
})

return nil
return err
}

// close the underlying library and ensure that the global pointer to the
Expand All @@ -103,18 +104,17 @@ func (l *library) close() error {
l.Lock()
defer l.Unlock()

if l.dl == nil {
return nil
}

err := l.dl.Close()
if err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}
err := l.refcount.LastOut(func() error {
err := l.dl.Close()
if err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}

l.dl = nil
l.dl = nil
return nil
})

return nil
return err
}

// Default all versioned APIs to v1 (to infer the types)
Expand Down
48 changes: 48 additions & 0 deletions pkg/nvml/refcount.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nvml

import ()

type refcount int

func (r *refcount) FirstIn(f func() error) error {
(*r)++
if (*r) > 1 {
return nil
}

err := f()
if err != nil {
(*r)--
}

return err
}

func (r *refcount) LastOut(f func() error) error {
if (*r) != 1 {
return nil
}

err := f()
if err == nil {
(*r)--
}

return err
}

0 comments on commit 928b2bf

Please sign in to comment.