diff --git a/pkg/vf/virtionet.go b/pkg/vf/virtionet.go index acc7bcf..25a077a 100644 --- a/pkg/vf/virtionet.go +++ b/pkg/vf/virtionet.go @@ -19,20 +19,12 @@ type VirtioNet struct { localAddr *net.UnixAddr } -func localUnixSocketPath() (string, error) { - homeDir, err := os.UserHomeDir() +func localUnixSocketPath(dir string) (string, error) { + tmpFile, err := os.CreateTemp(dir, fmt.Sprintf("vfkit-%d-*.sock", os.Getpid())) if err != nil { return "", err } - dir := filepath.Join(homeDir, "Library", "Application Support", "vfkit") - if err := os.MkdirAll(dir, 0755); err != nil { - return "", err - } - tmpFile, err := os.CreateTemp(dir, fmt.Sprintf("net-%d-*.sock", os.Getpid())) - if err != nil { - return "", err - } - // slightly racy, but this is in a directory only user-writable + // slightly racy, but hopefully this is in a directory only user-writable defer tmpFile.Close() defer os.Remove(tmpFile.Name()) @@ -44,7 +36,7 @@ func (dev *VirtioNet) connectUnixPath() error { Name: dev.UnixSocketPath, Net: "unixgram", } - localSocketPath, err := localUnixSocketPath() + localSocketPath, err := localUnixSocketPath(filepath.Dir(dev.UnixSocketPath)) if err != nil { return err } diff --git a/pkg/vf/virtionet_test.go b/pkg/vf/virtionet_test.go index 4ad8a9e..bc390cc 100644 --- a/pkg/vf/virtionet_test.go +++ b/pkg/vf/virtionet_test.go @@ -1,9 +1,6 @@ package vf import ( - "bytes" - "fmt" - "math/rand" "net" "os" "path/filepath" @@ -14,8 +11,26 @@ import ( "github.com/stretchr/testify/require" ) -func testConnectUnixgram(t *testing.T) error { - unixSocketPath := filepath.Join("/tmp", fmt.Sprintf("vnet-test-%x.sock", rand.Int31n(0xffff))) //#nosec G404 -- no need for crypto/rand here +func sourceSocketPath(t *testing.T, sourcePathLen int) (string, func()) { + // the 't.sock' name is chosen to be shorter than what + // localUnixSocketPath will generate so that the source socket path + // will not exceed the 104 byte limit while the destination socket path + // will, and will trigger an error + const sourceSocketName = "t.sock" + tmpDir := "/tmp" + subDirLen := sourcePathLen - len(tmpDir) - 2*len("/") - len(sourceSocketName) - 1 + subDir := filepath.Join(tmpDir, strings.Repeat("a", subDirLen)) + err := os.Mkdir(subDir, 0700) + require.NoError(t, err) + unixSocketPath := filepath.Join(subDir, sourceSocketName) + require.Equal(t, len(unixSocketPath), sourcePathLen-1) + return unixSocketPath, func() { os.RemoveAll(subDir) } + +} +func testConnectUnixgram(t *testing.T, sourcePathLen int) error { + unixSocketPath, closer := sourceSocketPath(t, sourcePathLen) + defer closer() + addr, err := net.ResolveUnixAddr("unixgram", unixSocketPath) require.NoError(t, err) @@ -23,7 +38,6 @@ func testConnectUnixgram(t *testing.T) error { require.NoError(t, err) defer l.Close() - defer os.Remove(unixSocketPath) dev := &VirtioNet{ &config.VirtioNet{ @@ -37,46 +51,31 @@ func testConnectUnixgram(t *testing.T) error { func TestConnectUnixPath(t *testing.T) { t.Run("Successful connection - no error", func(t *testing.T) { - err := testConnectUnixgram(t) + // 50 is an arbitrary number, small enough for the 104 bytes limit not to be exceeded + err := testConnectUnixgram(t, 50) require.NoError(t, err) }) t.Run("Failed connection - End socket longer than 104 bytes", func(t *testing.T) { - // Retrieve HOME env variable (used by the os.UserHomeDir) - origUserHome := os.Getenv("HOME") - defer func() { - os.Setenv("HOME", origUserHome) - }() - - // Create a string of 100 bytes to update the user home to be sure to create a socket path > 104 bytes - b := bytes.Repeat([]byte("a"), 100) - subDir := string(b) - - // Update HOME env so os.UserHomeDir returns the update path with subfolder - updatedUserHome := filepath.Join(origUserHome, subDir) - os.Setenv("HOME", updatedUserHome) - defer os.RemoveAll(updatedUserHome) - - err := testConnectUnixgram(t) + err := testConnectUnixgram(t, 104) // It should return an error require.Error(t, err) - require.ErrorContains(t, err, "invalid argument") + require.ErrorContains(t, err, "is too long") }) } func TestLocalUnixSocketPath(t *testing.T) { t.Run("Success case - Creates temporary socket path", func(t *testing.T) { // Retrieve HOME env variable (used by the os.UserHomeDir) - userHome := os.Getenv("HOME") + socketDir := t.TempDir() - path, err := localUnixSocketPath() + path, err := localUnixSocketPath(socketDir) // Assert successful execution require.NoError(t, err) // Check if path starts with the expected prefix - expectedPrefix := filepath.Join(userHome, "Library", "Application Support", "vfkit") - require.Truef(t, strings.HasPrefix(path, expectedPrefix), "Path doesn't start with expected prefix: %v", path) + require.Truef(t, strings.HasPrefix(path, socketDir), "Path doesn't start with expected prefix: %v", path) // Check if path ends with a socket extension require.Equalf(t, ".sock", filepath.Ext(path), "Path doesn't end with .sock extension: %v, ext is %v", path, filepath.Ext(path))