Skip to content

Commit

Permalink
Merge pull request #199 from martinthomson/quic_record_layer
Browse files Browse the repository at this point in the history
QUIC record layer changes
  • Loading branch information
bifurcation authored Jan 29, 2019
2 parents a14404e + e78e097 commit 83ba9bc
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 129 deletions.
2 changes: 1 addition & 1 deletion client-state-machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
var offeredPSK PreSharedKey
var earlyHash crypto.Hash
var earlySecret []byte
var clientEarlyTrafficKeys keySet
var clientEarlyTrafficKeys KeySet
var clientHello *HandshakeMessage
if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok {
offeredPSK = key
Expand Down
8 changes: 5 additions & 3 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ func assertNotByteEquals(t *testing.T, a, b []byte) {
func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
t.Helper()
assertEquals(t, a.Suite, b.Suite)
// Can't compare aeadFactory values
// Can't compare AEADFactory values
assertEquals(t, a.Hash, b.Hash)
assertEquals(t, a.KeyLen, b.KeyLen)
assertEquals(t, a.IvLen, b.IvLen)
assertEquals(t, len(a.KeyLengths), len(b.KeyLengths))
for k, v := range a.KeyLengths {
assertEquals(t, v, b.KeyLengths[k])
}
}

func assertDeepEquals(t *testing.T, a, b interface{}) {
Expand Down
27 changes: 17 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ type Config struct {
NonBlocking bool
UseDTLS bool

RecordLayer RecordLayerFactory

// The same config object can be shared among different connections, so it
// needs its own mutex
mutex sync.RWMutex
Expand Down Expand Up @@ -270,28 +272,33 @@ type Conn struct {
handshakeComplete bool

readBuffer []byte
in, out *RecordLayer
in, out RecordLayer
hsCtx *HandshakeContext
}

func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
if !config.UseDTLS {
c.in = NewRecordLayerTLS(c.conn, directionRead)
c.out = NewRecordLayerTLS(c.conn, directionWrite)
if config.RecordLayer == nil {
c.in = NewRecordLayerTLS(c.conn, DirectionRead)
c.out = NewRecordLayerTLS(c.conn, DirectionWrite)
} else {
c.in = config.RecordLayer.NewLayer(c.conn, DirectionRead)
c.out = config.RecordLayer.NewLayer(c.conn, DirectionWrite)
}
c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
} else {
c.in = NewRecordLayerDTLS(c.conn, directionRead)
c.out = NewRecordLayerDTLS(c.conn, directionWrite)
c.in = NewRecordLayerDTLS(c.conn, DirectionRead)
c.out = NewRecordLayerDTLS(c.conn, DirectionWrite)
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
c.hsCtx.timeoutMS = initialTimeout
c.hsCtx.timers = newTimerSet()
c.hsCtx.waitingNextFlight = true
}
c.in.label = c.label()
c.out.label = c.label()
c.in.SetLabel(c.label())
c.out.SetLabel(c.label())
c.hsCtx.hIn.nonblocking = c.config.NonBlocking
return c
}
Expand Down Expand Up @@ -598,15 +605,15 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
logf(logTypeHandshake, "%s Rekey with data still in handshake buffers", label)
return AlertDecodeError
}
err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
err := c.in.Rekey(action.epoch, action.KeySet.Cipher, &action.KeySet)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
return AlertInternalError
}

case RekeyOut:
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet)
err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
err := c.out.Rekey(action.epoch, action.KeySet.Cipher, &action.KeySet)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
return AlertInternalError
Expand Down Expand Up @@ -906,7 +913,7 @@ func (c *Conn) Writable() bool {
}

// If we're a client in 0-RTT, then we're writable.
if c.isClient && c.out.cipher.epoch == EpochEarlyData {
if c.isClient && c.out.Epoch() == EpochEarlyData {
return true
}

Expand Down
9 changes: 6 additions & 3 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,14 @@ func init() {
}
}

func assertKeySetEquals(t *testing.T, k1, k2 keySet) {
func assertKeySetEquals(t *testing.T, k1, k2 KeySet) {
t.Helper()
// Assume cipher is the same
assertByteEquals(t, k1.iv, k2.iv)
assertByteEquals(t, k1.key, k2.key)
assertTrue(t, len(k1.Keys) > 0, "assert that there are some keys")
assertEquals(t, len(k1.Keys), len(k2.Keys))
for k, v := range k1.Keys {
assertByteEquals(t, v, k2.Keys[k])
}
}

func computeExporter(t *testing.T, c *Conn, label string, context []byte, length int) []byte {
Expand Down
46 changes: 21 additions & 25 deletions crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ import (

var prng = rand.Reader

type aeadFactory func(key []byte) (cipher.AEAD, error)
type AEADFactory func(key []byte) (cipher.AEAD, error)

type CipherSuiteParams struct {
Suite CipherSuite
Cipher aeadFactory // Cipher factory
Hash crypto.Hash // Hash function
KeyLen int // Key length in octets
IvLen int // IV length in octets
Suite CipherSuite
Cipher AEADFactory // Cipher factory
Hash crypto.Hash // Hash function
KeyLengths map[string]int // This maps keys (the label used for HKDF-Expand-Label) to the length of the key needed.
}

type signatureAlgorithm uint8
Expand Down Expand Up @@ -91,18 +90,16 @@ var (

cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
TLS_AES_128_GCM_SHA256: {
Suite: TLS_AES_128_GCM_SHA256,
Cipher: newAESGCM,
Hash: crypto.SHA256,
KeyLen: 16,
IvLen: 12,
Suite: TLS_AES_128_GCM_SHA256,
Cipher: newAESGCM,
Hash: crypto.SHA256,
KeyLengths: map[string]int{labelForKey: 16, labelForIV: 12},
},
TLS_AES_256_GCM_SHA384: {
Suite: TLS_AES_256_GCM_SHA384,
Cipher: newAESGCM,
Hash: crypto.SHA384,
KeyLen: 32,
IvLen: 12,
Suite: TLS_AES_256_GCM_SHA384,
Cipher: newAESGCM,
Hash: crypto.SHA384,
KeyLengths: map[string]int{labelForKey: 32, labelForIV: 12},
},
}

Expand Down Expand Up @@ -604,19 +601,18 @@ func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte)
return mac.Sum(nil)
}

type keySet struct {
cipher aeadFactory
key []byte
iv []byte
type KeySet struct {
Cipher AEADFactory
Keys map[string][]byte
}

func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
func makeTrafficKeys(params CipherSuiteParams, secret []byte) KeySet {
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
return keySet{
cipher: params.Cipher,
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
ks := KeySet{Cipher: params.Cipher, Keys: make(map[string][]byte, len(params.KeyLengths))}
for label, length := range params.KeyLengths {
ks.Keys[label] = HkdfExpandLabel(params.Hash, secret, label, []byte{}, length)
}
return ks
}

func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) {
Expand Down
51 changes: 35 additions & 16 deletions handshake-layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*H
type HandshakeLayer struct {
ctx *HandshakeContext // The handshake we are attached to
nonblocking bool // Should we operate in nonblocking mode
conn *RecordLayer // Used for reading/writing records
conn RecordLayer // Used for reading/writing records
frame *frameReader // The buffered frame reader
datagram bool // Is this DTLS?
msgSeq uint32 // The DTLS message sequence number
Expand Down Expand Up @@ -153,7 +153,7 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return int(val), nil
}

func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
func NewHandshakeLayerTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.ctx = c
h.conn = r
Expand All @@ -163,7 +163,7 @@ func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
return &h
}

func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
func NewHandshakeLayerDTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.ctx = c
h.conn = r
Expand All @@ -174,8 +174,15 @@ func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer
}

func (h *HandshakeLayer) readRecord() error {
logf(logTypeVerbose, "Trying to read record")
pt, err := h.conn.readRecordAnyEpoch()
var pt *TLSPlaintext
var err error

if h.datagram {
logf(logTypeVerbose, "Trying to read record")
pt, err = h.conn.(*DefaultRecordLayer).ReadRecordAnyEpoch()
} else {
pt, err = h.conn.ReadRecord()
}
if err != nil {
return err
}
Expand Down Expand Up @@ -204,7 +211,7 @@ func (h *HandshakeLayer) readRecord() error {
}

assert(h.ctx.hIn.conn != nil)
if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
if pt.epoch != h.ctx.hIn.conn.Epoch() {
// This is out of order but we're dropping it.
// TODO([email protected]): If server, need to retransmit Finished.
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
Expand Down Expand Up @@ -394,9 +401,13 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
}

func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
hm.cipher = h.conn.cipher
h.queued = append(h.queued, hm)
return nil
if h.datagram {
hm.cipher = h.conn.(*DefaultRecordLayer).cipher
h.queued = append(h.queued, hm)
return nil
}
_, err := h.WriteMessages([]*HandshakeMessage{hm})
return err
}

func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
Expand Down Expand Up @@ -456,22 +467,30 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int
buf = body
}

var err error
if h.datagram {
// Remember that we sent this.
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
hm.seq,
start,
len(body),
h.conn.cipher.combineSeq(true),
h.conn.(*DefaultRecordLayer).cipher.combineSeq(true),
false,
})
err = h.conn.(*DefaultRecordLayer).writeRecordWithPadding(
&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buf,
},
hm.cipher, 0)
} else {
err = h.conn.WriteRecord(
&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buf,
})
}
return true, start + bodylen, h.conn.writeRecordWithPadding(
&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buf,
},
hm.cipher, 0)
return true, start + bodylen, err
}

func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
Expand Down
18 changes: 9 additions & 9 deletions handshake-layer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestMessageFromBody(t *testing.T) {
chValid := unhex(chValidHex)

b := bytes.NewBuffer(nil)
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionRead))
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionRead))

// Test successful conversion
hm, err := h.HandshakeMessageFromBody(&chValidIn)
Expand All @@ -172,7 +172,7 @@ func TestMessageFromBody(t *testing.T) {
func newHandshakeLayerFromBytes(d []byte) *HandshakeLayer {
hc := &HandshakeContext{}
b := bytes.NewBuffer(d)
hc.hIn = NewHandshakeLayerTLS(hc, NewRecordLayerTLS(b, directionRead))
hc.hIn = NewHandshakeLayerTLS(hc, NewRecordLayerTLS(b, DirectionRead))
return hc.hIn
}

Expand Down Expand Up @@ -224,7 +224,7 @@ func TestReadHandshakeMessage(t *testing.T) {
}

func testWriteHandshakeMessage(h *HandshakeLayer, hm *HandshakeMessage) error {
hm.cipher = h.conn.cipher
hm.cipher = h.conn.(*DefaultRecordLayer).cipher
_, err := h.WriteMessage(hm)
return err
}
Expand All @@ -235,26 +235,26 @@ func TestWriteHandshakeMessage(t *testing.T) {

// Test successful write of single message
b := bytes.NewBuffer(nil)
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite))
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionWrite))
err := testWriteHandshakeMessage(h, shortMessageIn)
assertNotError(t, err, "Failed to write valid short message")
assertByteEquals(t, b.Bytes(), short)

// Test successful write of single long message
b = bytes.NewBuffer(nil)
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite))
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionWrite))
err = testWriteHandshakeMessage(h, longMessageIn)
assertNotError(t, err, "Failed to write valid long message")
assertByteEquals(t, b.Bytes(), long)

// Test write failure on message too large
b = bytes.NewBuffer(nil)
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite))
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionWrite))
err = testWriteHandshakeMessage(h, tooLongMessageIn)
assertError(t, err, "Wrote a message exceeding the length bound")

// Test write failure on underlying write failure
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(ErrorReadWriter{}, directionWrite))
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(ErrorReadWriter{}, DirectionWrite))
err = testWriteHandshakeMessage(h, longMessageIn)
assertError(t, err, "Write succeeded despite error in full fragment send")
err = testWriteHandshakeMessage(h, shortMessageIn)
Expand All @@ -265,7 +265,7 @@ type testReassembleFixture struct {
t *testing.T
c HandshakeContext
h *HandshakeLayer
r *RecordLayer
r *DefaultRecordLayer
rd *pipeConn
wr *pipeConn
m0 *HandshakeMessage
Expand Down Expand Up @@ -298,7 +298,7 @@ func newTestReassembleFixture(t *testing.T) *testReassembleFixture {
f.m1 = newHsFragment(m1, 1, 0, 2048)
f.rd, f.wr = pipe()

f.r = NewRecordLayerDTLS(f.rd, directionRead)
f.r = NewRecordLayerDTLS(f.rd, DirectionRead)
f.h = NewHandshakeLayerDTLS(&f.c, f.r)
f.c.hIn = f.h
f.c.timers = newTimerSet()
Expand Down
Loading

0 comments on commit 83ba9bc

Please sign in to comment.