Skip to content

Commit

Permalink
Enforces passing slices of the exact size when unmarshaling KEM keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
armfazh committed Apr 26, 2024
1 parent 12ba47d commit df5ea67
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
19 changes: 11 additions & 8 deletions hpke/shortkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
bitmask = 0x01
}

Nsk := s.PrivateKeySize()
dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
var bytes []byte
ctr := 0
Expand All @@ -64,14 +65,12 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
dkpPrk,
[]byte("candidate"),
[]byte{byte(ctr)},
uint16(s.byteSize()),
uint16(Nsk),
)
bytes[0] &= bitmask
skBig.SetBytes(bytes)
}
l := s.PrivateKeySize()
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(bytes):], bytes)
sk := &shortKEMPrivKey{s, bytes, nil}
return sk.Public(), sk
}

Expand All @@ -83,11 +82,11 @@ func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {

func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := s.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
if len(data) != l {
return nil, kem.ErrPrivKeySize
}
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(data):l], data[:l])
copy(sk.priv, data[:l])
if !sk.validate() {
return nil, ErrInvalidKEMPrivateKey
}
Expand All @@ -96,7 +95,11 @@ func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error)
}

func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
x, y := elliptic.Unmarshal(s, data)
l := s.PublicKeySize()
if len(data) != l {
return nil, kem.ErrPubKeySize
}
x, y := elliptic.Unmarshal(s, data[:l])
if x == nil {
return nil, ErrInvalidKEMPublicKey
}
Expand Down
13 changes: 7 additions & 6 deletions hpke/xkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ func (x xKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
if len(seed) != x.SeedSize() {
panic(kem.ErrSeedSize)
}
sk := &xKEMPrivKey{scheme: x, priv: make([]byte, x.size)}
Nsk := x.PrivateKeySize()
sk := &xKEMPrivKey{scheme: x, priv: make([]byte, Nsk)}
dkpPrk := x.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
bytes := x.labeledExpand(
dkpPrk,
[]byte("sk"),
nil,
uint16(x.PrivateKeySize()),
uint16(Nsk),
)
copy(sk.priv, bytes)
return sk.Public(), sk
Expand All @@ -81,8 +82,8 @@ func (x xKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {

func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := x.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
if len(data) != l {
return nil, kem.ErrPrivKeySize
}
sk := &xKEMPrivKey{x, make([]byte, l), nil}
copy(sk.priv, data[:l])
Expand All @@ -94,8 +95,8 @@ func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {

func (x xKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
l := x.PublicKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPublicKey
if len(data) != l {
return nil, kem.ErrPubKeySize
}
pk := &xKEMPubKey{x, make([]byte, l)}
copy(pk.pub, data[:l])
Expand Down

0 comments on commit df5ea67

Please sign in to comment.