Skip to content

Commit

Permalink
Prohibit inputs sharing buffers with outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmillr committed Nov 28, 2024
1 parent b4e769e commit 7e444bf
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 5 deletions.
27 changes: 23 additions & 4 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ export function toBytes(data: Input): Uint8Array {
return data;
}

/**
* Checks if two U8A use same underlying buffer and overlaps (will corrupt and break if input and output same)
*/
export function overlapBytes(a: Uint8Array, b: Uint8Array): boolean {
return (
a.buffer === b.buffer && // probably will fail with some obscure proxies, but this is best we can do
a.byteOffset < b.byteOffset + b.byteLength && // a starts before b end
b.byteOffset < a.byteOffset + a.byteLength // b starts before a end
);
}

/**
* Copies several Uint8Arrays into one.
*/
Expand Down Expand Up @@ -237,6 +248,14 @@ export const wrapCipher = <C extends CipherCons<any>, P extends CipherParams>(
}

const cipher = constructor(key, ...args);
const checkOutput = (fnLength: number, data: Uint8Array, output?: Uint8Array) => {
if (output !== undefined) {
if (fnLength !== 2) throw new Error('cipher output not supported');
abytes(output);
if (overlapBytes(data, output))
throw new Error('input and output use same buffer and overlap');
}
};

// Create wrapped cipher with validation and single-use encryption
let called = false;
Expand All @@ -245,10 +264,12 @@ export const wrapCipher = <C extends CipherCons<any>, P extends CipherParams>(
if (called) throw new Error('cannot encrypt() twice with same key + nonce');
called = true;
abytes(data);
checkOutput(cipher.encrypt.length, data, output);
return (cipher as CipherWithOutput).encrypt(data, output);
},
decrypt(data: Uint8Array, output?: Uint8Array) {
abytes(data);
checkOutput(cipher.decrypt.length, data, output);
if (tagl && data.length < tagl)
throw new Error('invalid ciphertext length: smaller than tagLength=' + tagl);
return (cipher as CipherWithOutput).decrypt(data, output);
Expand All @@ -273,11 +294,9 @@ export type XorStream = (
export function getOutput(expectedLength: number, out?: Uint8Array, onlyAligned = true) {
if (out === undefined) return new Uint8Array(expectedLength);
if (out.length !== expectedLength)
throw new Error(
'invalid output length, expected at least ' + expectedLength + ', got: ' + out.length
);
throw new Error('invalid output length, expected ' + expectedLength + ', got: ' + out.length);
if (onlyAligned && !isAligned32(out)) throw new Error('invalid output, must be aligned');
return out.subarray(0, expectedLength).fill(0);
return out.fill(0);
}

// Polyfill for Safari 14
Expand Down
72 changes: 72 additions & 0 deletions test/basic.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,78 @@ describe('Basic', () => {
}
});
}
should(`${k} (re-use)`, () => {
const { fn, keyLen } = opts;
const msg = new Uint8Array(2 * opts.fn.blockSize).fill(12);
const key = randomBytes(keyLen);
const nonce = randomBytes(fn.nonceLength);
const AAD = randomBytes(64);
let cipher = fn(key, nonce, AAD);
// Not supported!
if (k.startsWith('micro')) return;
if (k.startsWith('gcm')) return;
if (k.startsWith('siv')) return;
if (k.startsWith('aeskw')) return;
// Wrapper changes length :(
if (cipher.encrypt.length === 2) {
// Tmp buffer
let outLen = msg.length;
if (fn.tagLength) outLen += fn.tagLength;
if (k === 'xsalsa20poly1305') outLen += 16;
if (k.includes('cbc') || k.includes('ecb')) outLen += 16;
// Expected result
cipher = fn(key, nonce, AAD);
const exp = cipher.encrypt(msg);
const out = new Uint8Array(outLen);
// First pass
cipher = fn(key, nonce, AAD);
const res = cipher.encrypt(msg, out);
deepStrictEqual(res, exp);
// check if res is output
deepStrictEqual(res, out.subarray(res.byteOffset, res.byteOffset + res.length));
deepStrictEqual(res.buffer, out.buffer); // make sure that underlying array buffer is same
// Second pass
out.fill(42);
cipher = fn(key, nonce, AAD);
const res2 = cipher.encrypt(msg, out);
deepStrictEqual(res2, exp);
deepStrictEqual(res2, out.subarray(res2.byteOffset, res2.byteOffset + res2.length));
deepStrictEqual(res2.buffer, out.buffer); // make sure that underlying array buffer is same
// Throws on same buffer:
cipher = fn(key, nonce, AAD);
out.set(msg);
const msg2 = out.subarray(0, msg.length);
throws(() => cipher.encrypt(msg2, out));
}
if (cipher.decrypt.length === 2) {
// Expected result
cipher = fn(key, nonce, AAD);
const input = cipher.encrypt(msg);
// Tmp buffer
let outLen = msg.length;
if (k.endsWith('xsalsa20poly1305')) outLen += 32 + 16;
if (k.includes('cbc') || k.includes('ecb')) outLen += 16;
const out = new Uint8Array(outLen);
// First pass
const res = cipher.decrypt(input, out);
deepStrictEqual(res, msg);
deepStrictEqual(res, out.subarray(res.byteOffset, res.byteOffset + res.length));
deepStrictEqual(res.buffer, out.buffer); // make sure that underlying array buffer is same
// Second pass
out.fill(42);
const res2 = cipher.decrypt(input, out);
deepStrictEqual(res2, msg);
deepStrictEqual(res2, out.subarray(res2.byteOffset, res2.byteOffset + res2.length));
deepStrictEqual(res2.buffer, out.buffer); // make sure that underlying array buffer is same
// Throws on same buffer:
const tmp = new Uint8Array(Math.max(out.length, input.length));
tmp.set(input);
const out2 = tmp.subarray(0, out.length);
const input2 = tmp.subarray(0, input.length);
throws(() => cipher.decrypt(input2, out2));
}
});
// Human tests ^, AI abomination v

should('unaligned', () => {
if (!['xsalsa20poly1305', 'xchacha20poly1305', 'chacha20poly1305'].includes(k)) return;
Expand Down
37 changes: 36 additions & 1 deletion test/utils.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const { deepStrictEqual, throws } = require('assert');
const fc = require('fast-check');
const { describe, should } = require('micro-should');
const { TYPE_TEST, unalign } = require('./utils.js');
const { bytesToHex, concatBytes, hexToBytes } = require('../utils.js');
const { bytesToHex, concatBytes, hexToBytes, overlapBytes } = require('../utils.js');

describe('utils', () => {
const staticHexVectors = [
Expand Down Expand Up @@ -57,6 +57,41 @@ describe('utils', () => {
})
)
);
should('sameBytes', () => {
// Basic
const buffer = new ArrayBuffer(20);
const a = new Uint8Array(buffer, 0, 10); // Bytes 0-9
const b = new Uint8Array(buffer, 5, 10); // Bytes 5-14
const c = new Uint8Array(buffer, 10, 10); // Bytes 10-19
const d = new Uint8Array(new ArrayBuffer(20), 0, 10); // Different buffer
deepStrictEqual(overlapBytes(a, b), true);
deepStrictEqual(overlapBytes(a, c), false);
deepStrictEqual(overlapBytes(b, c), true);
deepStrictEqual(overlapBytes(a, d), false);
// Scan
const res = [];
const main = new Uint8Array(8 + 4); // 2byte + first + 2byte
const first = main.subarray(2).subarray(0, 8);
for (let i = 0; i < main.length; i++) {
const second = main.subarray(i).subarray(0, 1); // one byte window
deepStrictEqual(second, new Uint8Array(1));
res.push(overlapBytes(first, second));
}
deepStrictEqual(res, [
false,
false,
true,
true,
true,
true,
true,
true,
true,
true,
false,
false,
]);
});
});

describe('utils etc', () => {
Expand Down

0 comments on commit 7e444bf

Please sign in to comment.