Skip to content

Commit

Permalink
Merge pull request #4 from treydock/hostkey
Browse files Browse the repository at this point in the history
Add known_hosts and host_key_algorithms options to configuration to support verifying SSH host keys
  • Loading branch information
treydock authored Nov 15, 2020
2 parents 2185f70 + 37228a1 commit 2684a5d
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 40 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## 0.2.0 / TBD

* Update to 1.15 and update Go module dependencies
* Add `known_hosts` configuration option to allow verifying SSH hosts against known hosts
* Add `host_key_algorithms` configuration option to specify host key algorithms to use when verifying SSH hosts

## 0.1.1 / 2020-04-01

Expand Down
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ modules:
password:
id: prometheus
password: secret
verify:
user: prometheus
private_key: /home/prometheus/.ssh/id_rsa
known_hosts: /etc/ssh/ssh_known_hosts
host_key_algorithms:
- ssh-rsa
command: uptime
command_expect: "load average"
timeout: 5
```
Example with curl would query host1 with the password module and host2 with the default module.
Expand All @@ -47,6 +56,9 @@ Configuration options for each module:
* `user` - The username for the SSH connection
* `password` - The password for the SSH connection, required if `private_key` is not specified
* `private_key` - The SSH private key for the SSH connection, required if `password` is not specified
* `known_hosts` - Optional SSH known hosts file to use to verify hosts
* `host_key_algorithms` - Optional list of SSH host key algorithms to use
* See constants beginning with `KeyAlgo*` in [crypto/ssh](https://godoc.org/golang.org/x/crypto/ssh#pkg-constants)
* `timeout` - Optional timeout of the SSH connection, session and optional command.
* The default comes from the `--collector.ssh.default-timeout` flag.
* `command` - Optional command to run.
Expand Down Expand Up @@ -110,15 +122,27 @@ The following example assumes this exporter is running on the Prometheus server
metrics_path: /ssh
static_configs:
- targets:
- ssh1.example.com
- ssh2.example.com
- host1.example.com:22
- host2.example.com:22
labels:
module: default
- targets:
- host3.example.com:22
- host4.example.com:22
labels:
module: verify
relabel_configs:
- source_labels: [__address__]
target_label: __param_target
- source_labels: [__param_target]
target_label: instance
- target_label: __address__
replacement: 127.0.0.1:9312
- source_labels: [module]
target_label: __param_module
metric_relabel_configs:
- regex: "^(module)$"
action: labeldrop
- job_name: ssh-metrics
metrics_path: /metrics
static_configs:
Expand Down
40 changes: 28 additions & 12 deletions collector/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ package collector

import (
"bytes"
"encoding/base64"
"io/ioutil"
"net"
"regexp"
"strings"
"time"
Expand All @@ -25,14 +27,15 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/treydock/ssh_exporter/config"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

const (
namespace = "ssh"
)

type Metric struct {
Success bool
Success float64
FailureReason string
}

Expand Down Expand Up @@ -70,7 +73,7 @@ func (c *Collector) Collect(ch chan<- prometheus.Metric) {

metric := c.collect()

ch <- prometheus.MustNewConstMetric(c.Success, prometheus.GaugeValue, boolToFloat64(metric.Success))
ch <- prometheus.MustNewConstMetric(c.Success, prometheus.GaugeValue, metric.Success)
for _, reason := range failureReasons {
var value float64
if reason == metric.FailureReason {
Expand Down Expand Up @@ -103,10 +106,11 @@ func (c *Collector) collect() Metric {
}

sshConfig := &ssh.ClientConfig{
User: c.target.User,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: time.Duration(c.target.Timeout) * time.Second,
User: c.target.User,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: hostKeyCallback(&metric, c.target, c.logger),
HostKeyAlgorithms: c.target.HostKeyAlgorithms,
Timeout: time.Duration(c.target.Timeout) * time.Second,
}
connection, err := ssh.Dial("tcp", c.target.Host, sshConfig)
if err != nil {
Expand Down Expand Up @@ -170,7 +174,7 @@ func (c *Collector) collect() Metric {
return metric
}
}
metric.Success = true
metric.Success = 1
return metric
}

Expand All @@ -186,10 +190,22 @@ func getPrivateKeyAuth(privatekey string) (ssh.AuthMethod, error) {
return ssh.PublicKeys(key), nil
}

func boolToFloat64(data bool) float64 {
if data {
return float64(1)
} else {
return float64(0)
func hostKeyCallback(metric *Metric, target *config.Target, logger log.Logger) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
var hostKeyCallback ssh.HostKeyCallback
var err error
if target.KnownHosts != "" {
publicKey := base64.StdEncoding.EncodeToString(key.Marshal())
level.Debug(logger).Log("msg", "Verify SSH known hosts", "hostname", hostname, "remote", remote.String(), "key", publicKey)
hostKeyCallback, err = knownhosts.New(target.KnownHosts)
if err != nil {
metric.FailureReason = "error"
level.Error(logger).Log("msg", "Error creating hostkeycallback function", "err", err)
return err
}
} else {
hostKeyCallback = ssh.InsecureIgnoreHostKey()
}
return hostKeyCallback(hostname, remote, key)
}
}
129 changes: 129 additions & 0 deletions collector/collector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package collector

import (
"crypto/rand"
"crypto/rsa"
"fmt"
"io"
"io/ioutil"
Expand All @@ -26,12 +28,16 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/treydock/ssh_exporter/config"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

const (
listen = 60022
)

var knownHosts *os.File

func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
buffer, err := ioutil.ReadFile("testdata/id_rsa_test1.pub")
if err != nil {
Expand Down Expand Up @@ -69,6 +75,27 @@ func TestMain(m *testing.M) {
PublicKeyHandler: publicKeyHandler,
PasswordHandler: passwordHandler,
}
hostKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
fmt.Printf("ERROR generating RSA host key: %s", err)
os.Exit(1)
}
signer, err := gossh.NewSignerFromKey(hostKey)
if err != nil {
fmt.Printf("ERROR generating host key signer: %s", err)
os.Exit(1)
}
s.AddHostKey(signer)
knownHosts, err = ioutil.TempFile("", "knowm_hosts")
if err != nil {
fmt.Printf("ERROR creating known hosts: %s", err)
os.Exit(1)
}
defer os.Remove(knownHosts.Name())
knownHostsLine := knownhosts.Line([]string{fmt.Sprintf("localhost:%d", listen)}, s.HostSigners[0].PublicKey())
if _, err = knownHosts.Write([]byte(knownHostsLine)); err != nil {
fmt.Printf("ERROR writing known hosts: %s", err)
}
go func() {
if err := s.ListenAndServe(); err != nil {
fmt.Printf("ERROR starting SSH server: %s", err)
Expand Down Expand Up @@ -279,6 +306,108 @@ func TestCollectorPrivateKey(t *testing.T) {
}
}

func TestCollectorKnownHosts(t *testing.T) {
expected := `
# HELP ssh_failure Indicates a failure
# TYPE ssh_failure gauge
ssh_failure{reason="command-error"} 0
ssh_failure{reason="command-output"} 0
ssh_failure{reason="error"} 0
ssh_failure{reason="timeout"} 0
# HELP ssh_success SSH connection was successful
# TYPE ssh_success gauge
ssh_success 1
`
target := &config.Target{
Host: fmt.Sprintf("localhost:%d", listen),
User: "test",
PrivateKey: "testdata/id_rsa_test1",
KnownHosts: knownHosts.Name(),
Timeout: 2,
}
w := log.NewSyncWriter(os.Stderr)
logger := log.NewLogfmtLogger(w)
collector := NewCollector(target, logger)
gatherers := setupGatherer(collector)
if val, err := testutil.GatherAndCount(gatherers); err != nil {
t.Errorf("Unexpected error: %v", err)
} else if val != 6 {
t.Errorf("Unexpected collection count %d, expected 6", val)
}
if err := testutil.GatherAndCompare(gatherers, strings.NewReader(expected),
"ssh_success", "ssh_failure"); err != nil {
t.Errorf("unexpected collecting result:\n%s", err)
}
}

func TestCollectorKnownHostsError(t *testing.T) {
expected := `
# HELP ssh_failure Indicates a failure
# TYPE ssh_failure gauge
ssh_failure{reason="command-error"} 0
ssh_failure{reason="command-output"} 0
ssh_failure{reason="error"} 1
ssh_failure{reason="timeout"} 0
# HELP ssh_success SSH connection was successful
# TYPE ssh_success gauge
ssh_success 0
`
target := &config.Target{
Host: fmt.Sprintf("127.0.0.1:%d", listen),
User: "test",
PrivateKey: "testdata/id_rsa_test1",
KnownHosts: knownHosts.Name(),
Timeout: 2,
}
w := log.NewSyncWriter(os.Stderr)
logger := log.NewLogfmtLogger(w)
collector := NewCollector(target, logger)
gatherers := setupGatherer(collector)
if val, err := testutil.GatherAndCount(gatherers); err != nil {
t.Errorf("Unexpected error: %v", err)
} else if val != 6 {
t.Errorf("Unexpected collection count %d, expected 6", val)
}
if err := testutil.GatherAndCompare(gatherers, strings.NewReader(expected),
"ssh_success", "ssh_failure"); err != nil {
t.Errorf("unexpected collecting result:\n%s", err)
}
}

func TestCollectorKnownHostsDNE(t *testing.T) {
expected := `
# HELP ssh_failure Indicates a failure
# TYPE ssh_failure gauge
ssh_failure{reason="command-error"} 0
ssh_failure{reason="command-output"} 0
ssh_failure{reason="error"} 1
ssh_failure{reason="timeout"} 0
# HELP ssh_success SSH connection was successful
# TYPE ssh_success gauge
ssh_success 0
`
target := &config.Target{
Host: fmt.Sprintf("localhost:%d", listen),
User: "test",
PrivateKey: "testdata/id_rsa_test1",
KnownHosts: "/dne",
Timeout: 2,
}
w := log.NewSyncWriter(os.Stderr)
logger := log.NewLogfmtLogger(w)
collector := NewCollector(target, logger)
gatherers := setupGatherer(collector)
if val, err := testutil.GatherAndCount(gatherers); err != nil {
t.Errorf("Unexpected error: %v", err)
} else if val != 6 {
t.Errorf("Unexpected collection count %d, expected 6", val)
}
if err := testutil.GatherAndCompare(gatherers, strings.NewReader(expected),
"ssh_success", "ssh_failure"); err != nil {
t.Errorf("unexpected collecting result:\n%s", err)
}
}

func TestCollectDNEKey(t *testing.T) {
target := &config.Target{
Host: fmt.Sprintf("localhost:%d", listen),
Expand Down
32 changes: 18 additions & 14 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,27 @@ type SafeConfig struct {
}

type Module struct {
ModuleName string
User string `yaml:"user"`
Password string `yaml:"password"`
PrivateKey string `yaml:"private_key"`
Timeout int `yaml:"timeout"`
Command string `yaml:"command"`
CommandExpect string `yaml:"command_expect"`
ModuleName string
User string `yaml:"user"`
Password string `yaml:"password"`
PrivateKey string `yaml:"private_key"`
KnownHosts string `yaml:"known_hosts"`
HostKeyAlgorithms []string `yaml:"host_key_algorithms"`
Timeout int `yaml:"timeout"`
Command string `yaml:"command"`
CommandExpect string `yaml:"command_expect"`
}

type Target struct {
Host string
User string
Password string
PrivateKey string
Timeout int
Command string
CommandExpect string
Host string
User string
Password string
PrivateKey string
KnownHosts string
HostKeyAlgorithms []string
Timeout int
Command string
CommandExpect string
}

func (sc *SafeConfig) ReloadConfig(configFile string) error {
Expand Down
Loading

0 comments on commit 2684a5d

Please sign in to comment.