Skip to content

Commit

Permalink
Allow missing and null pointer fields to be nil.
Browse files Browse the repository at this point in the history
This commit ensures missing fields for pointer field types remain nil:

  * This allows one to distinguish the field wasn't present from the
    field being NULL when the PtrNil option is false.
  * This also makes it more convenient to reuse structs with pointer
    field types with various encoders like the standard library JSON
    marshaler, where you may want to omit a field if it didn't exist in
    the queried results.

The expectations around this change are captured in the
TestMissingFields test.

In addition, it also fixes an edge case where even with PtrNil set to
true, pointer field types would always return with a pointer to a zero
value instead of being nil even when the query property's value is
VT_NULL.

This fix is checked by the TestNullPointerField test.
  • Loading branch information
jodoherty committed Feb 19, 2023
1 parent b0230a1 commit 3183000
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 9 deletions.
1 change: 1 addition & 0 deletions swbemservices.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build windows
// +build windows

package wmi
Expand Down
7 changes: 4 additions & 3 deletions swbemservices_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build windows
// +build windows

package wmi
Expand Down Expand Up @@ -103,8 +104,8 @@ func WbemGetMemoryUsageMB(s *SWbemServices) (float64, float64, float64) {
//Run all benchmarks (should run for at least 60s to get a stable number):
//go test -run=NONE -bench=Version -benchtime=120s

//Individual benchmarks:
//go test -run=NONE -bench=NewVersion -benchtime=120s
// Individual benchmarks:
// go test -run=NONE -bench=NewVersion -benchtime=120s
func BenchmarkNewVersion(b *testing.B) {
s, err := InitializeSWbemServices(DefaultClient)
if err != nil {
Expand All @@ -128,7 +129,7 @@ func BenchmarkNewVersion(b *testing.B) {
}
}

//go test -run=NONE -bench=OldVersion -benchtime=120s
// go test -run=NONE -bench=OldVersion -benchtime=120s
func BenchmarkOldVersion(b *testing.B) {
var dst []Win32_OperatingSystem
q := CreateQuery(&dst, "")
Expand Down
13 changes: 7 additions & 6 deletions wmi.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build windows
// +build windows

/*
Expand All @@ -20,7 +21,6 @@ Example code to print names of running processes:
println(i, v.Name)
}
}
*/
package wmi

Expand Down Expand Up @@ -338,11 +338,6 @@ func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismat
f := v.Field(i)
of := f
isPtr := f.Kind() == reflect.Ptr
if isPtr {
ptr := reflect.New(f.Type().Elem())
f.Set(ptr)
f = f.Elem()
}
n := v.Type().Field(i).Name
if n[0] < 'A' || n[0] > 'Z' {
continue
Expand All @@ -367,6 +362,12 @@ func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismat
}
defer prop.Clear()

if isPtr && !(c.PtrNil && prop.VT == 0x1) {
ptr := reflect.New(f.Type().Elem())
f.Set(ptr)
f = f.Elem()
}

if prop.VT == 0x1 { //VT_NULL
continue
}
Expand Down
82 changes: 82 additions & 0 deletions wmi_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build windows
// +build windows

package wmi
Expand Down Expand Up @@ -39,6 +40,87 @@ func TestFieldMismatch(t *testing.T) {
}
}

func TestMissingFields(t *testing.T) {
type s struct {
Name string
Missing uint32
MissingPointer *uint32
}

var dst []s

client := &Client{
AllowMissingFields: true,
}
err := client.Query("SELECT Name FROM Win32_Process", &dst)
if err != nil {
t.Fatal(err)
}
for i := range dst {
if dst[i].Missing != 0 {
t.Fatal("Expected Missing field to be 0")
}
if dst[i].MissingPointer != nil {
t.Fatal("Expected MissingPointer field to be nil")
}
}

// NonePtrZero and PtrNil should only affect the behavior of fields that
// exist as result properties, not missing fields.
client = &Client{
NonePtrZero: true,
PtrNil: true,
AllowMissingFields: true,
}
dst = []s{}
err = client.Query("SELECT Name FROM Win32_Process", &dst)
if err != nil {
t.Fatal(err)
}
for i := range dst {
if dst[i].Missing != 0 {
t.Fatal("Expected Missing field to be 0")
}
if dst[i].MissingPointer != nil {
t.Fatal("Expected MissingPointer field to be nil")
}
}
}

func TestNullPointerField(t *testing.T) {
type s struct {
Name string
Status *string
}

var dst []s

client := &Client{}
err := client.Query("SELECT Name, Status FROM Win32_Process WHERE Status IS NULL", &dst)
if err != nil {
t.Fatal(err)
}
for i := range dst {
if dst[i].Status == nil {
t.Fatal("Expected Status field to not be nil")
}
}

client = &Client{
PtrNil: true,
}
dst = []s{}
err = client.Query("SELECT Name, Status FROM Win32_Process WHERE Status IS NULL", &dst)
if err != nil {
t.Fatal(err)
}
for i := range dst {
if dst[i].Status != nil {
t.Fatal("Expected Status field to be nil")
}
}
}

func TestStrings(t *testing.T) {
printed := false
f := func() {
Expand Down

0 comments on commit 3183000

Please sign in to comment.