diff --git a/cmd/envd-ssh/main.go b/cmd/envd-ssh/main.go index c65297784..b7b85520a 100644 --- a/cmd/envd-ssh/main.go +++ b/cmd/envd-ssh/main.go @@ -35,6 +35,7 @@ const ( flagAuthKey = "authorized-keys" flagNoAuth = "no-auth" flagPort = "port" + flagShell = "shell" ) func main() { @@ -66,6 +67,11 @@ func main() { Name: flagPort, Usage: "port to listen on", }, + &cli.StringFlag{ + Name: flagShell, + Usage: "shell to use", + Value: "bash", + }, } // Deal with debug flag. @@ -86,10 +92,11 @@ func main() { } func sshServer(c *cli.Context) error { - shell, err := sshd.GetShell() + err := sshd.GetShell(c.String(flagShell)) if err != nil { logrus.Fatal(err.Error()) } + shell := c.String(flagShell) port := c.Int(flagPort) if port == 0 { diff --git a/pkg/docker/entrypoint.go b/pkg/docker/entrypoint.go index 15b9e8f5d..16c2668dc 100644 --- a/pkg/docker/entrypoint.go +++ b/pkg/docker/entrypoint.go @@ -25,7 +25,7 @@ import ( const ( template = `set -e -/var/envd/bin/envd-ssh --authorized-keys %s --port %d & +/var/envd/bin/envd-ssh --authorized-keys %s --port %d --shell %s & %s wait -n` ) @@ -34,8 +34,9 @@ func entrypointSH(g ir.Graph, workingDir string, sshPort int) string { if g.JupyterConfig != nil { cmds := jupyter.GenerateCommand(g, workingDir) return fmt.Sprintf(template, - config.ContainerauthorizedKeysPath, sshPort, strings.Join(cmds, " ")) + config.ContainerauthorizedKeysPath, sshPort, g.Shell, + strings.Join(cmds, " ")) } return fmt.Sprintf(template, - config.ContainerauthorizedKeysPath, sshPort, "") + config.ContainerauthorizedKeysPath, sshPort, g.Shell, "") } diff --git a/pkg/remote/sshd/os.go b/pkg/remote/sshd/os.go index 003733c02..32d56e3bd 100644 --- a/pkg/remote/sshd/os.go +++ b/pkg/remote/sshd/os.go @@ -24,23 +24,15 @@ import ( var ( errNoShell = fmt.Errorf("failed to find any shell in the PATH") - - shells = []string{ - "zsh", - "bash", - "sh", - } ) // GetShell returns the shell in $PATH. -func GetShell() (string, error) { - for _, shell := range shells { - if path, err := exec.LookPath(shell); err == nil { - logrus.Infof("%s exists at %s", shell, path) - return shell, nil - } - logrus.Debugf("%s does not exist", shell) +func GetShell(shell string) error { + if path, err := exec.LookPath(shell); err == nil { + logrus.Infof("%s exists at %s", shell, path) + return nil } + logrus.Debugf("%s does not exist", shell) - return "", errNoShell + return errNoShell }