Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support raw public keys (RFC 7250) #208

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 126 additions & 41 deletions client-state-machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [

state.Params.ServerName = state.Opts.ServerName

// Application Layer Protocol Negotiation
var alpn *ALPNExtension
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
}

// Construct base ClientHello
ch := &ClientHelloBody{
LegacyVersion: wireVersion(state.hsCtx.hIn),
Expand All @@ -119,15 +113,42 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
return nil, nil, AlertInternalError
}
}
// XXX: These optional extensions can't be folded into the above because Go
// interface-typed values are never reported as nil
if alpn != nil {

// Application Layer Protocol Negotiation
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
alpn := &ALPNExtension{Protocols: state.Opts.NextProtos}
err := ch.Extensions.Add(alpn)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
return nil, nil, AlertInternalError
}
}

// Certificate type negotiation (to enable raw public keys)
if state.Config.AllowRawPublicKeys {
cct := &ClientCertTypeExtension{HandshakeType: HandshakeTypeClientHello}
sct := &ServerCertTypeExtension{HandshakeType: HandshakeTypeClientHello}

cct.CertificateTypes = append(cct.CertificateTypes, CertificateTypeRawPublicKey)
sct.CertificateTypes = append(sct.CertificateTypes, CertificateTypeRawPublicKey)

if !state.Config.ForbidX509 {
cct.CertificateTypes = append(cct.CertificateTypes, CertificateTypeX509)
sct.CertificateTypes = append(sct.CertificateTypes, CertificateTypeX509)
}

err = ch.Extensions.Add(cct)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ClientCertType extension [%v]", err)
return nil, nil, AlertInternalError
}
err = ch.Extensions.Add(sct)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ServerCertType extension [%v]", err)
return nil, nil, AlertInternalError
}
}

if state.cookie != nil {
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
if err != nil {
Expand Down Expand Up @@ -590,11 +611,15 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,

serverALPN := &ALPNExtension{}
serverEarlyData := &EarlyDataExtension{}
serverClientCertType := &ClientCertTypeExtension{HandshakeType: HandshakeTypeEncryptedExtensions}
serverServerCertType := &ServerCertTypeExtension{HandshakeType: HandshakeTypeEncryptedExtensions}
bifurcation marked this conversation as resolved.
Show resolved Hide resolved

foundExts, err := ee.Extensions.Parse(
[]ExtensionBody{
serverALPN,
serverEarlyData,
serverClientCertType,
serverServerCertType,
})
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err)
Expand All @@ -607,6 +632,26 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
state.Params.NextProto = serverALPN.Protocols[0]
}

if foundExts[ExtensionTypeServerCertType] {
certType := serverServerCertType.CertificateTypes[0]
if !CertificateTypeValid(certType, state.Config.AllowRawPublicKeys, state.Config.ForbidX509) {
logf(logTypeHandshake, "[ClientStateWaitEE] Server sent illegal certificate type: %v", certType)
return nil, nil, AlertHandshakeFailure
}

state.Params.ServerCertType = certType
}

if foundExts[ExtensionTypeClientCertType] {
certType := serverClientCertType.CertificateTypes[0]
if !CertificateTypeValid(certType, state.Config.AllowRawPublicKeys, state.Config.ForbidX509) {
logf(logTypeHandshake, "[ClientStateWaitEE] Client sent illegal certificate type: %v", certType)
return nil, nil, AlertHandshakeFailure
}

state.Params.ClientCertType = certType
}

state.handshakeHash.Write(hm.Marshal())

toSend := []HandshakeAction{}
Expand Down Expand Up @@ -818,44 +863,71 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)

serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
certs := make([][]byte, len(state.serverCertificate.CertificateList))
for i, certEntry := range state.serverCertificate.CertificateList {
certs[i] = certEntry.CertData
}

var err error
var serverPublicKey crypto.PublicKey
var eeCert *x509.Certificate
switch state.Params.ServerCertType {
case CertificateTypeX509:
eeCert, err = x509.ParseCertificate(certs[0])
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Unable to parse client cert: %v", err)
return nil, nil, AlertDecodeError
}

serverPublicKey = eeCert.PublicKey

case CertificateTypeRawPublicKey:
serverPublicKey, err = unmarshalSigningKey(certs[0])
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Unable to parse raw public key: %v", err)
return nil, nil, AlertDecodeError
}
}
bifurcation marked this conversation as resolved.
Show resolved Hide resolved

if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
return nil, nil, AlertHandshakeFailure
}

certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList))
rawCerts := make([][]byte, len(state.serverCertificate.CertificateList))
for i, certEntry := range state.serverCertificate.CertificateList {
certs[i] = certEntry.CertData
rawCerts[i] = certEntry.CertData.Raw
}

var verifiedChains [][]*x509.Certificate
if !state.Config.InsecureSkipVerify {
opts := x509.VerifyOptions{
Roots: state.Config.RootCAs,
CurrentTime: state.Config.time(),
DNSName: state.Config.ServerName,
Intermediates: x509.NewCertPool(),
}
if state.Params.ServerCertType == CertificateTypeX509 {
if !state.Config.InsecureSkipVerify {
opts := x509.VerifyOptions{
Roots: state.Config.RootCAs,
CurrentTime: state.Config.time(),
DNSName: state.Config.ServerName,
Intermediates: x509.NewCertPool(),
}

for i, cert := range certs {
if i == 0 {
continue
}

caCert, err := x509.ParseCertificate(cert)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Error parsing server chain: %v", err)
return nil, nil, AlertDecodeError
}

for i, cert := range certs {
if i == 0 {
continue
opts.Intermediates.AddCert(caCert)
}
var err error
verifiedChains, err = eeCert.Verify(opts)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err)
return nil, nil, AlertBadCertificate
}
opts.Intermediates.AddCert(cert)
}
var err error
verifiedChains, err = certs[0].Verify(opts)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err)
return nil, nil, AlertBadCertificate
}
}

if state.Config.VerifyPeerCertificate != nil {
if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil {
if err := state.Config.VerifyPeerCertificate(certs, verifiedChains); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err)
return nil, nil, AlertBadCertificate
}
Expand Down Expand Up @@ -888,7 +960,7 @@ type clientStateWaitFinished struct {

certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
peerCertificates []*x509.Certificate
peerCertificates [][]byte
verifiedChains [][]*x509.Certificate

masterSecret []byte
Expand Down Expand Up @@ -1000,12 +1072,24 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
state.handshakeHash.Write(certm.Marshal())
} else {
// Create and send Certificate, CertificateVerify
certificate := &CertificateBody{
CertificateList: make([]CertificateEntry, len(cert.Chain)),
}
for i, entry := range cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
var certList []CertificateEntry
switch state.Params.ClientCertType {
case CertificateTypeX509:
certList = make([]CertificateEntry, len(cert.Chain))
for i, entry := range cert.Chain {
certList[i] = CertificateEntry{CertData: entry.Raw}
}
case CertificateTypeRawPublicKey:
certData, err := marshalSigningKey(cert.PrivateKey.Public())
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Unable to marshal raw public key [%v]", err)
return nil, nil, AlertInternalError
}

certList = []CertificateEntry{{CertData: certData}}
}

certificate := &CertificateBody{CertificateList: certList}
certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
Expand All @@ -1015,6 +1099,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
toSend = append(toSend, QueueHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal())

// Create and send CertificateVerify
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)

Expand Down
15 changes: 15 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ const (
ExtensionTypeCookie ExtensionType = 44
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
ExtensionTypeClientCertType ExtensionType = 19
ExtensionTypeServerCertType ExtensionType = 20
)

// enum {...} NamedGroup
Expand Down Expand Up @@ -164,6 +166,19 @@ const (
KeyUpdateRequested KeyUpdateRequest = 1
)

/*
0 X.509 Y [RFC6091]
1 OpenPGP_RESERVED N [RFC6091][RFC8446] Used in TLS versions prior to 1.3.
2 Raw Public Key Y [RFC7250]
3 1609Dot2 N
*/
type CertificateType uint8

const (
CertificateTypeX509 CertificateType = 0
CertificateTypeRawPublicKey CertificateType = 2
)

type State uint8

const (
Expand Down
29 changes: 23 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
type Certificate struct {
Chain []*x509.Certificate
PrivateKey crypto.Signer
PublicKey crypto.PublicKey
}

type PreSharedKey struct {
Expand Down Expand Up @@ -129,6 +130,12 @@ type Config struct {
NonBlocking bool
UseDTLS bool

// These bools are arranged in opposite directions so that their default
// values reflect the correct default semantics (certs yes, raw keys no).
// ForbidX509 is only meaningful if AllowRawPublicKeys is true.
AllowRawPublicKeys bool
ForbidX509 bool

RecordLayer RecordLayerFactory

// The same config object can be shared among different connections, so it
Expand Down Expand Up @@ -198,15 +205,25 @@ func (c *Config) Init(isClient bool) error {
return nil
}

func (c *Config) certTypeValid() bool {
// ForbidX509 can only be set when AllowRawPublicKeys is also set
// Note that this is equivalent to:
// ForbidX509 => AllowRawPublicKeys
return !c.ForbidX509 || c.AllowRawPublicKeys
}

func (c *Config) ValidForServer() bool {
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
(len(c.Certificates) > 0 &&
len(c.Certificates[0].Chain) > 0 &&
c.Certificates[0].PrivateKey != nil)
// The server must have either PSKs or certificates
havePSK := reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0
haveCert := len(c.Certificates) > 0 &&
len(c.Certificates[0].Chain) > 0 &&
c.Certificates[0].PrivateKey != nil

return (havePSK || haveCert) && c.certTypeValid()
}

func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0
return len(c.ServerName) > 0 && c.certTypeValid()
}

func (c *Config) time() time.Time {
Expand Down Expand Up @@ -250,7 +267,7 @@ var (
type ConnectionState struct {
HandshakeState State
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
PeerCertificates [][]byte // certificate chain presented by remote peer
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
NextProto string // Selected ALPN proto
UsingPSK bool // Are we using PSK.
Expand Down
26 changes: 18 additions & 8 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ var (
psk PreSharedKey
psks *PSKMapCache

basicConfig, dtlsConfig, nbConfig, nbDTLSConfig, hrrConfig, alpnConfig, pskConfig, pskDTLSConfig, pskECDHEConfig, pskDHEConfig, resumptionConfig, ffdhConfig, x25519Config *Config
basicConfig, dtlsConfig, nbConfig, nbDTLSConfig, hrrConfig, alpnConfig, rawConfig, pskConfig, pskDTLSConfig, pskECDHEConfig, pskDHEConfig, resumptionConfig, ffdhConfig, x25519Config *Config
)

func init() {
Expand Down Expand Up @@ -262,6 +262,13 @@ func init() {
InsecureSkipVerify: true,
}

rawConfig = &Config{
ServerName: serverName,
Certificates: certificates,
AllowRawPublicKeys: true,
InsecureSkipVerify: true,
}

pskConfig = &Config{
ServerName: serverName,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
Expand Down Expand Up @@ -357,11 +364,13 @@ func checkConsistency(t *testing.T, client *Conn, server *Conn) {

func testConnInner(t *testing.T, name string, p testInstanceState) {
// Configs array:
configs := map[string]*Config{"basic config": basicConfig,
"HRR": hrrConfig,
"ALPN": alpnConfig,
"FFDH": ffdhConfig,
"x25519": x25519Config,
configs := map[string]*Config{
"basic config": basicConfig,
"HRR": hrrConfig,
"ALPN": alpnConfig,
"RawPK": rawConfig,
"FFDH": ffdhConfig,
"x25519": x25519Config,
}

c := configs[p["config"]]
Expand Down Expand Up @@ -400,6 +409,7 @@ func TestBasicFlows(t *testing.T) {
"basic config",
"HRR",
"ALPN",
"RawPK",
"FFDH",
"x25519",
},
Expand Down Expand Up @@ -1232,9 +1242,9 @@ func TestConnectionState(t *testing.T) {
serverCS := server.ConnectionState()
assertEquals(t, clientCS.CipherSuite.Suite, configClient.CipherSuites[0])
assertDeepEquals(t, clientCS.VerifiedChains, [][]*x509.Certificate{{serverCert}})
assertDeepEquals(t, clientCS.PeerCertificates, []*x509.Certificate{serverCert})
assertDeepEquals(t, clientCS.PeerCertificates, [][]byte{serverCert.Raw})
assertEquals(t, serverCS.CipherSuite.Suite, serverConfig.CipherSuites[0])
assertDeepEquals(t, serverCS.PeerCertificates, []*x509.Certificate{clientCert})
assertDeepEquals(t, serverCS.PeerCertificates, [][]byte{clientCert.Raw})
}

func TestDTLS(t *testing.T) {
Expand Down
Loading