Skip to content

Commit

Permalink
Move key, nonce & input validation into wrapCipher. Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmillr committed Nov 3, 2024
1 parent c9b06a5 commit 47116a6
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 95 deletions.
19 changes: 4 additions & 15 deletions src/_micro.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ const _1 = BigInt(1);
// Can be speed-up using BigUint64Array, but would be more complicated
export function poly1305(msg: Uint8Array, key: Uint8Array): Uint8Array {
abytes(msg);
abytes(key);
abytes(key, 32);
let acc = _0;
const r = bytesToNumberLE(key.subarray(0, 16)) & CLAMP_R;
const s = bytesToNumberLE(key.subarray(16));
Expand Down Expand Up @@ -255,11 +255,8 @@ function computeTag(
export const xsalsa20poly1305 = /* @__PURE__ */ wrapCipher(
{ blockSize: 64, nonceLength: 24, tagLength: 16 },
function xsalsa20poly1305(key: Uint8Array, nonce: Uint8Array) {
abytes(key);
abytes(nonce);
return {
encrypt(plaintext: Uint8Array) {
abytes(plaintext);
const m = concatBytes(new Uint8Array(32), plaintext);
const c = xsalsa20(key, nonce, m);
const authKey = c.subarray(0, 32);
Expand All @@ -268,12 +265,11 @@ export const xsalsa20poly1305 = /* @__PURE__ */ wrapCipher(
return concatBytes(tag, data);
},
decrypt(ciphertext: Uint8Array) {
abytes(ciphertext);
if (ciphertext.length < 16) throw new Error('encrypted data must be at least 16 bytes');
const c = concatBytes(new Uint8Array(16), ciphertext);
const passedTag = c.subarray(16, 32);
const authKey = xsalsa20(key, nonce, new Uint8Array(32));
const tag = poly1305(c.subarray(32), authKey);
if (!equalBytes(c.subarray(16, 32), tag)) throw new Error('invalid poly1305 tag');
if (!equalBytes(tag, passedTag)) throw new Error('invalid poly1305 tag');
return xsalsa20(key, nonce, c).subarray(32);
},
};
Expand All @@ -292,24 +288,17 @@ export const _poly1305_aead =
(fn: XorStream) =>
(key: Uint8Array, nonce: Uint8Array, AAD?: Uint8Array): Cipher => {
const tagLength = 16;
const keyLength = 32;
abytes(key, keyLength);
abytes(nonce);
return {
encrypt(plaintext: Uint8Array) {
abytes(plaintext);
const res = fn(key, nonce, plaintext, undefined, 1);
const tag = computeTag(fn, key, nonce, res, AAD);
return concatBytes(res, tag);
},
decrypt(ciphertext: Uint8Array) {
abytes(ciphertext);
if (ciphertext.length < tagLength)
throw new Error(`encrypted data must be at least ${tagLength} bytes`);
const passedTag = ciphertext.subarray(-tagLength);
const data = ciphertext.subarray(0, -tagLength);
const tag = computeTag(fn, key, nonce, data, AAD);
if (!equalBytes(passedTag, tag)) throw new Error('invalid poly1305 tag');
if (!equalBytes(tag, passedTag)) throw new Error('invalid poly1305 tag');
return fn(key, nonce, data, undefined, 1);
},
};
Expand Down
48 changes: 13 additions & 35 deletions src/aes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,15 @@ function decrypt(xk: Uint32Array, s0: number, s1: number, s2: number, s3: number
return { s0: t0, s1: t1, s2: t2, s3: t3 };
}

function getDst(len: number, dst?: Uint8Array) {
if (dst === undefined) return new Uint8Array(len);
abytes(dst);
if (dst.length < len)
throw new Error(`aes: wrong destination length, expected at least ${len}, got: ${dst.length}`);
if (!isAligned32(dst)) throw new Error('unaligned dst');
return dst;
function getDst(len: number, output?: Uint8Array): Uint8Array {
if (output === undefined) return new Uint8Array(len);
abytes(output);
if (output.length < len)
throw new Error(
`aes: wrong destination length, expected at least ${len}, got: ${output.length}`
);
if (!isAligned32(output)) throw new Error('unaligned output');
return output;
}

// TODO: investigate merging with ctr32
Expand Down Expand Up @@ -324,8 +326,6 @@ function ctr32(
export const ctr = wrapCipher(
{ blockSize: 16, nonceLength: 16 },
function ctr(key: Uint8Array, nonce: Uint8Array): CipherWithOutput {
abytes(key);
abytes(nonce, BLOCK_SIZE);
function processCtr(buf: Uint8Array, dst?: Uint8Array) {
abytes(buf);
if (dst !== undefined) {
Expand All @@ -351,7 +351,7 @@ function validateBlockDecrypt(data: Uint8Array) {
abytes(data);
if (data.length % BLOCK_SIZE !== 0) {
throw new Error(
`aes/(cbc-ecb).decrypt ciphertext should consist of blocks with size ${BLOCK_SIZE}`
`aes-(cbc/ecb).decrypt ciphertext should consist of blocks with size ${BLOCK_SIZE}`
);
}
}
Expand Down Expand Up @@ -404,7 +404,6 @@ export type BlockOpts = { disablePadding?: boolean };
export const ecb = wrapCipher(
{ blockSize: 16 },
function ecb(key: Uint8Array, opts: BlockOpts = {}): CipherWithOutput {
abytes(key);
const pcks5 = !opts.disablePadding;
return {
encrypt(plaintext: Uint8Array, dst?: Uint8Array) {
Expand Down Expand Up @@ -449,8 +448,6 @@ export const ecb = wrapCipher(
export const cbc = wrapCipher(
{ blockSize: 16, nonceLength: 16 },
function cbc(key: Uint8Array, iv: Uint8Array, opts: BlockOpts = {}): CipherWithOutput {
abytes(key);
abytes(iv, 16);
const pcks5 = !opts.disablePadding;
return {
encrypt(plaintext: Uint8Array, dst?: Uint8Array) {
Expand Down Expand Up @@ -511,8 +508,6 @@ export const cbc = wrapCipher(
export const cfb = wrapCipher(
{ blockSize: 16, nonceLength: 16 },
function cfb(key: Uint8Array, iv: Uint8Array): CipherWithOutput {
abytes(key);
abytes(iv, 16);
function processCfb(src: Uint8Array, isEncrypt: boolean, dst?: Uint8Array) {
abytes(src);
const srcLen = src.length;
Expand Down Expand Up @@ -584,11 +579,8 @@ function computeTag(
* As for nonce size, prefer 12-byte, instead of 8-byte.
*/
export const gcm = wrapCipher(
{ blockSize: 16, nonceLength: 12, tagLength: 16 },
{ blockSize: 16, nonceLength: 12, tagLength: 16, varSizeNonce: true },
function gcm(key: Uint8Array, nonce: Uint8Array, AAD?: Uint8Array): Cipher {
abytes(key);
abytes(nonce);
if (AAD !== undefined) abytes(AAD);
// NIST 800-38d doesn't enforce minimum nonce length.
// We enforce 8 bytes for compat with openssl.
// 12 bytes are recommended. More than 12 bytes would be converted into 12.
Expand Down Expand Up @@ -621,7 +613,6 @@ export const gcm = wrapCipher(
}
return {
encrypt(plaintext: Uint8Array) {
abytes(plaintext);
const { xk, authKey, counter, tagMask } = deriveKeys();
const out = new Uint8Array(plaintext.length + tagLength);
const toClean: (Uint8Array | Uint32Array)[] = [xk, authKey, counter, tagMask];
Expand All @@ -634,9 +625,6 @@ export const gcm = wrapCipher(
return out;
},
decrypt(ciphertext: Uint8Array) {
abytes(ciphertext);
if (ciphertext.length < tagLength)
throw new Error(`aes/gcm: ciphertext less than tagLen (${tagLength})`);
const { xk, authKey, counter, tagMask } = deriveKeys();
const toClean: (Uint8Array | Uint32Array)[] = [xk, authKey, tagMask, counter];
if (!isAligned32(ciphertext)) toClean.push((ciphertext = copyBytes(ciphertext)));
Expand Down Expand Up @@ -665,7 +653,7 @@ const limit = (name: string, min: number, max: number) => (value: number) => {
* RFC 8452, https://datatracker.ietf.org/doc/html/rfc8452
*/
export const siv = wrapCipher(
{ blockSize: 16, nonceLength: 12, tagLength: 16 },
{ blockSize: 16, nonceLength: 12, tagLength: 16, varSizeNonce: true },
function siv(key: Uint8Array, nonce: Uint8Array, AAD?: Uint8Array): Cipher {
const tagLength = 16;
// From RFC 8452: Section 6
Expand All @@ -674,12 +662,8 @@ export const siv = wrapCipher(
const NONCE_LIMIT = limit('nonce', 12, 12);
const CIPHER_LIMIT = limit('ciphertext', 16, 2 ** 36 + 16);
abytes(key, 16, 24, 32);
abytes(nonce);
NONCE_LIMIT(nonce.length);
if (AAD !== undefined) {
abytes(AAD);
AAD_LIMIT(AAD.length);
}
if (AAD !== undefined) AAD_LIMIT(AAD.length);
function deriveKeys() {
const xk = expandKeyLE(key);
const encKey = new Uint8Array(key.length);
Expand Down Expand Up @@ -732,7 +716,6 @@ export const siv = wrapCipher(
}
return {
encrypt(plaintext: Uint8Array) {
abytes(plaintext);
PLAIN_LIMIT(plaintext.length);
const { encKey, authKey } = deriveKeys();
const tag = _computeTag(encKey, authKey, plaintext);
Expand All @@ -746,7 +729,6 @@ export const siv = wrapCipher(
return out;
},
decrypt(ciphertext: Uint8Array) {
abytes(ciphertext);
CIPHER_LIMIT(ciphertext.length);
const tag = ciphertext.subarray(-tagLength);
const { encKey, authKey } = deriveKeys();
Expand Down Expand Up @@ -872,7 +854,6 @@ export const aeskw = wrapCipher(
{ blockSize: 8 },
(kek: Uint8Array): Cipher => ({
encrypt(plaintext: Uint8Array) {
abytes(plaintext);
if (!plaintext.length || plaintext.length % 8 !== 0)
throw new Error('invalid plaintext length');
if (plaintext.length === 8)
Expand All @@ -882,7 +863,6 @@ export const aeskw = wrapCipher(
return out;
},
decrypt(ciphertext: Uint8Array) {
abytes(ciphertext);
// ciphertext must be at least 24 bytes and a multiple of 8 bytes
// 24 because should have at least two block (1 iv + 2).
// Replace with 16 to enable '8-byte keys'
Expand Down Expand Up @@ -946,7 +926,6 @@ export const aeskwp = wrapCipher(
{ blockSize: 8 },
(kek: Uint8Array): Cipher => ({
encrypt(plaintext: Uint8Array) {
abytes(plaintext);
if (!plaintext.length) throw new Error('invalid plaintext length');
const padded = Math.ceil(plaintext.length / 8) * 8;
const out = new Uint8Array(8 + padded);
Expand All @@ -958,7 +937,6 @@ export const aeskwp = wrapCipher(
return out;
},
decrypt(ciphertext: Uint8Array) {
abytes(ciphertext);
// 16 because should have at least one block
if (ciphertext.length < 16) throw new Error('invalid ciphertext length');
const out = copyBytes(ciphertext);
Expand Down
23 changes: 3 additions & 20 deletions src/chacha.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// prettier-ignore
import { createCipher, rotl } from './_arx.js';
import { bytes as abytes } from './_assert.js';
import { poly1305 } from './_poly1305.js';
import {
CipherWithOutput,
XorStream,
clean,
createView,
equalBytes,
getDst,
setBigUint64,
wrapCipher,
} from './utils.js';
Expand Down Expand Up @@ -236,35 +236,18 @@ export const _poly1305_aead =
(xorStream: XorStream) =>
(key: Uint8Array, nonce: Uint8Array, AAD?: Uint8Array): CipherWithOutput => {
const tagLength = 16;
abytes(key, 32);
abytes(nonce);
return {
encrypt(plaintext: Uint8Array, output?: Uint8Array) {
abytes(plaintext);
const plength = plaintext.length;
const clength = plength + tagLength;
if (output) {
abytes(output, clength);
} else {
output = new Uint8Array(clength);
}
output = getDst(plength + tagLength, output);
xorStream(key, nonce, plaintext, output, 1);
const tag = computeTag(xorStream, key, nonce, output.subarray(0, -tagLength), AAD);
output.set(tag, plength); // append tag
clean(tag);
return output;
},
decrypt(ciphertext: Uint8Array, output?: Uint8Array) {
abytes(ciphertext);
const clength = ciphertext.length;
const plength = clength - tagLength;
if (clength < tagLength)
throw new Error(`encrypted data must be at least ${tagLength} bytes`);
if (output) {
abytes(output, plength);
} else {
output = new Uint8Array(plength);
}
output = getDst(ciphertext.length - tagLength, output);
const data = ciphertext.subarray(0, -tagLength);
const passedTag = ciphertext.subarray(-tagLength);
const tag = computeTag(xorStream, key, nonce, data, AAD);
Expand Down
17 changes: 3 additions & 14 deletions src/salsa.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { createCipher, rotl } from './_arx.js';
import { bytes as abytes } from './_assert.js';
import { poly1305 } from './_poly1305.js';
import { Cipher, clean, equalBytes, wrapCipher } from './utils.js';
import { Cipher, clean, equalBytes, getDst, wrapCipher } from './utils.js';

// Salsa20 stream cipher was released in 2005.
// Salsa's goal was to implement AES replacement that does not rely on S-Boxes,
Expand Down Expand Up @@ -122,11 +122,8 @@ export const xsalsa20poly1305 = /* @__PURE__ */ wrapCipher(
{ blockSize: 64, nonceLength: 24, tagLength: 16 },
(key: Uint8Array, nonce: Uint8Array): Cipher => {
const tagLength = 16;
abytes(key, 32);
abytes(nonce, 24);
return {
encrypt(plaintext: Uint8Array, output?: Uint8Array) {
abytes(plaintext);
// This is small optimization (calculate auth key with same call as encryption itself) makes it hard
// to separate tag calculation and encryption itself, since 32 byte is half-block of salsa (64 byte)
const clength = plaintext.length + 32;
Expand All @@ -145,15 +142,7 @@ export const xsalsa20poly1305 = /* @__PURE__ */ wrapCipher(
return output.subarray(tagLength);
},
decrypt(ciphertext: Uint8Array, output?: Uint8Array) {
abytes(ciphertext);
if (ciphertext.length < tagLength)
throw new Error('encrypted data should be at least 16 bytes');
const clength = ciphertext.length + 32; // 32 is authKey length
if (output) {
abytes(output, clength);
} else {
output = new Uint8Array(clength);
}
output = getDst(ciphertext.length + 32, output); // 32 is authKey length
// Create new ciphertext array:
// tmp part auth tag ciphertext
// [bytes 0..32] [bytes 32..48] [bytes 48..]
Expand All @@ -165,7 +154,7 @@ export const xsalsa20poly1305 = /* @__PURE__ */ wrapCipher(
const authKeyBuf = output.subarray(0, 32);
clean(authKeyBuf);
const authKey = xsalsa20(key, nonce, authKeyBuf, authKeyBuf);
const tag = poly1305(output.subarray(32 + tagLength), authKey); // alloc
const tag = poly1305(output.subarray(48), authKey); // alloc
if (!equalBytes(output.subarray(32, 48), tag)) throw new Error('invalid tag');
// NOTE: first 32 bytes skipped (used for authKey)
xsalsa20(key, nonce, output.subarray(16), output.subarray(16));
Expand Down
Loading

0 comments on commit 47116a6

Please sign in to comment.