Skip to content

Commit

Permalink
refactor(pkg,api,agent): use param as reverse connection path
Browse files Browse the repository at this point in the history
  • Loading branch information
henrybarreto committed Aug 1, 2024
1 parent 9c21a02 commit 73a83f6
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func (a *Agent) Listen(ctx context.Context) error {
"{sshEndpoint}", strings.Split(sshEndpoint, ":")[0],
).Replace("{namespace}.{tenantName}@{sshEndpoint}")

listener, err := a.cli.NewReverseListener(ctx, a.authData.Token)
listener, err := a.cli.NewReverseListener(ctx, a.authData.Token, "/ssh/connection")
if err != nil {
log.WithError(err).WithFields(log.Fields{
"version": AgentVersion,
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type publicAPI interface {
Endpoints() (*models.Endpoints, error)
AuthDevice(req *models.DeviceAuthRequest) (*models.DeviceAuthResponse, error)
AuthPublicKey(req *models.PublicKeyAuthRequest, token string) (*models.PublicKeyAuthResponse, error)
NewReverseListener(ctx context.Context, token string) (*revdial.Listener, error)
NewReverseListener(ctx context.Context, token string, connPath string) (*revdial.Listener, error)
}

//go:generate mockery --name=Client --filename=client.go
Expand Down
4 changes: 2 additions & 2 deletions pkg/api/client/client_public.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ func (c *client) AuthPublicKey(req *models.PublicKeyAuthRequest, token string) (

// NewReverseListener creates a new reverse listener connection to ShellHub's server. This listener receives the SSH
// requests coming from the ShellHub server. Only authenticated devices can obtain a listener connection.
func (c *client) NewReverseListener(ctx context.Context, token string) (*revdial.Listener, error) {
func (c *client) NewReverseListener(ctx context.Context, token string, connPath string) (*revdial.Listener, error) {
if token == "" {
return nil, errors.New("token is empty")
}

if err := c.reverser.Auth(ctx, token); err != nil {
if err := c.reverser.Auth(ctx, token, connPath); err != nil {
return nil, err
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/api/client/client_public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,15 @@ func TestReverseListener(t *testing.T) {
description: "fail when connot auth the agent on the SSH server",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
requiredMocks: func() {
mock.On("Auth", context.Background(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c").Return(errors.New("")).Once()
mock.On("Auth", context.Background(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", "").Return(errors.New("")).Once()
},
expected: errors.New(""),
},
{
description: "fail when connot create a new reverse listener",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
requiredMocks: func() {
mock.On("Auth", context.Background(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c").Return(nil).Once()
mock.On("Auth", context.Background(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", "").Return(nil).Once()

mock.On("NewListener").Return(nil, errors.New("")).Once()
},
Expand All @@ -451,7 +451,7 @@ func TestReverseListener(t *testing.T) {
description: "success to create a new reverse listener",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
requiredMocks: func() {
mock.On("Auth", context.Background(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c").Return(nil).Once()
mock.On("Auth", context.Background(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", "").Return(nil).Once()

mock.On("NewListener").Return(new(revdial.Listener), nil).Once()
},
Expand All @@ -468,7 +468,7 @@ func TestReverseListener(t *testing.T) {

test.requiredMocks()

_, err = cli.NewReverseListener(ctx, test.token)
_, err = cli.NewReverseListener(ctx, test.token, "")
assert.Equal(t, err, test.expected)
})
}
Expand Down
18 changes: 9 additions & 9 deletions pkg/api/client/mocks/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions pkg/api/client/mocks/reverser.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pkg/api/client/reverser.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

//go:generate mockery --name=IReverser --filename=reverser.go
type IReverser interface {
Auth(ctx context.Context, token string) error
Auth(ctx context.Context, token string, connPath string) error
NewListener() (*revdial.Listener, error)
}

Expand All @@ -35,8 +35,8 @@ func NewReverser(host string) *Reverser {
}

// Auth creates a initial connection to the ShellHub SSH's server and authenticate it with the token received.
func (r *Reverser) Auth(ctx context.Context, token string) error {
uri, err := url.JoinPath(r.host, "/ssh/connection")
func (r *Reverser) Auth(ctx context.Context, token string, connPath string) error {
uri, err := url.JoinPath(r.host, connPath)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/connman/connman.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func New() *ConnectionManager {
}
}

func (m *ConnectionManager) Set(key string, conn *wsconnadapter.Adapter) {
dialer := revdial.NewDialer(conn, "/ssh/revdial")
func (m *ConnectionManager) Set(key string, conn *wsconnadapter.Adapter, connPath string) {
dialer := revdial.NewDialer(conn, connPath)

m.dialers.Store(key, dialer)

Expand Down
2 changes: 1 addition & 1 deletion pkg/httptunnel/httptunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (t *Tunnel) Router() http.Handler {
return c.String(http.StatusBadRequest, err.Error())
}

t.connman.Set(id, wsconnadapter.New(conn))
t.connman.Set(id, wsconnadapter.New(conn), t.DialerPath)

return nil
})
Expand Down

0 comments on commit 73a83f6

Please sign in to comment.