diff --git a/src/utils.ts b/src/utils.ts index 7e8d906..88d7f8d 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -132,6 +132,17 @@ export function toBytes(data: Input): Uint8Array { return data; } +/** + * Checks if two U8A use same underlying buffer and overlaps (will corrupt and break if input and output same) + */ +export function overlapBytes(a: Uint8Array, b: Uint8Array): boolean { + return ( + a.buffer === b.buffer && // probably will fail with some obscure proxies, but this is best we can do + a.byteOffset < b.byteOffset + b.byteLength && // a starts before b end + b.byteOffset < a.byteOffset + a.byteLength // b starts before a end + ); +} + /** * Copies several Uint8Arrays into one. */ @@ -237,6 +248,14 @@ export const wrapCipher = , P extends CipherParams>( } const cipher = constructor(key, ...args); + const checkOutput = (fnLength: number, data: Uint8Array, output?: Uint8Array) => { + if (output !== undefined) { + if (fnLength !== 2) throw new Error('cipher output not supported'); + abytes(output); + if (overlapBytes(data, output)) + throw new Error('input and output use same buffer and overlap'); + } + }; // Create wrapped cipher with validation and single-use encryption let called = false; @@ -245,10 +264,12 @@ export const wrapCipher = , P extends CipherParams>( if (called) throw new Error('cannot encrypt() twice with same key + nonce'); called = true; abytes(data); + checkOutput(cipher.encrypt.length, data, output); return (cipher as CipherWithOutput).encrypt(data, output); }, decrypt(data: Uint8Array, output?: Uint8Array) { abytes(data); + checkOutput(cipher.decrypt.length, data, output); if (tagl && data.length < tagl) throw new Error('invalid ciphertext length: smaller than tagLength=' + tagl); return (cipher as CipherWithOutput).decrypt(data, output); @@ -273,11 +294,9 @@ export type XorStream = ( export function getOutput(expectedLength: number, out?: Uint8Array, onlyAligned = true) { if (out === undefined) return new Uint8Array(expectedLength); if (out.length !== expectedLength) - throw new Error( - 'invalid output length, expected at least ' + expectedLength + ', got: ' + out.length - ); + throw new Error('invalid output length, expected ' + expectedLength + ', got: ' + out.length); if (onlyAligned && !isAligned32(out)) throw new Error('invalid output, must be aligned'); - return out.subarray(0, expectedLength).fill(0); + return out.fill(0); } // Polyfill for Safari 14 diff --git a/test/basic.test.js b/test/basic.test.js index 72a73a4..c9368c7 100644 --- a/test/basic.test.js +++ b/test/basic.test.js @@ -133,6 +133,78 @@ describe('Basic', () => { } }); } + should(`${k} (re-use)`, () => { + const { fn, keyLen } = opts; + const msg = new Uint8Array(2 * opts.fn.blockSize).fill(12); + const key = randomBytes(keyLen); + const nonce = randomBytes(fn.nonceLength); + const AAD = randomBytes(64); + let cipher = fn(key, nonce, AAD); + // Not supported! + if (k.startsWith('micro')) return; + if (k.startsWith('gcm')) return; + if (k.startsWith('siv')) return; + if (k.startsWith('aeskw')) return; + // Wrapper changes length :( + if (cipher.encrypt.length === 2) { + // Tmp buffer + let outLen = msg.length; + if (fn.tagLength) outLen += fn.tagLength; + if (k === 'xsalsa20poly1305') outLen += 16; + if (k.includes('cbc') || k.includes('ecb')) outLen += 16; + // Expected result + cipher = fn(key, nonce, AAD); + const exp = cipher.encrypt(msg); + const out = new Uint8Array(outLen); + // First pass + cipher = fn(key, nonce, AAD); + const res = cipher.encrypt(msg, out); + deepStrictEqual(res, exp); + // check if res is output + deepStrictEqual(res, out.subarray(res.byteOffset, res.byteOffset + res.length)); + deepStrictEqual(res.buffer, out.buffer); // make sure that underlying array buffer is same + // Second pass + out.fill(42); + cipher = fn(key, nonce, AAD); + const res2 = cipher.encrypt(msg, out); + deepStrictEqual(res2, exp); + deepStrictEqual(res2, out.subarray(res2.byteOffset, res2.byteOffset + res2.length)); + deepStrictEqual(res2.buffer, out.buffer); // make sure that underlying array buffer is same + // Throws on same buffer: + cipher = fn(key, nonce, AAD); + out.set(msg); + const msg2 = out.subarray(0, msg.length); + throws(() => cipher.encrypt(msg2, out)); + } + if (cipher.decrypt.length === 2) { + // Expected result + cipher = fn(key, nonce, AAD); + const input = cipher.encrypt(msg); + // Tmp buffer + let outLen = msg.length; + if (k.endsWith('xsalsa20poly1305')) outLen += 32 + 16; + if (k.includes('cbc') || k.includes('ecb')) outLen += 16; + const out = new Uint8Array(outLen); + // First pass + const res = cipher.decrypt(input, out); + deepStrictEqual(res, msg); + deepStrictEqual(res, out.subarray(res.byteOffset, res.byteOffset + res.length)); + deepStrictEqual(res.buffer, out.buffer); // make sure that underlying array buffer is same + // Second pass + out.fill(42); + const res2 = cipher.decrypt(input, out); + deepStrictEqual(res2, msg); + deepStrictEqual(res2, out.subarray(res2.byteOffset, res2.byteOffset + res2.length)); + deepStrictEqual(res2.buffer, out.buffer); // make sure that underlying array buffer is same + // Throws on same buffer: + const tmp = new Uint8Array(Math.max(out.length, input.length)); + tmp.set(input); + const out2 = tmp.subarray(0, out.length); + const input2 = tmp.subarray(0, input.length); + throws(() => cipher.decrypt(input2, out2)); + } + }); + // Human tests ^, AI abomination v should('unaligned', () => { if (!['xsalsa20poly1305', 'xchacha20poly1305', 'chacha20poly1305'].includes(k)) return; diff --git a/test/utils.test.js b/test/utils.test.js index afad541..1ce83f1 100644 --- a/test/utils.test.js +++ b/test/utils.test.js @@ -2,7 +2,7 @@ const { deepStrictEqual, throws } = require('assert'); const fc = require('fast-check'); const { describe, should } = require('micro-should'); const { TYPE_TEST, unalign } = require('./utils.js'); -const { bytesToHex, concatBytes, hexToBytes } = require('../utils.js'); +const { bytesToHex, concatBytes, hexToBytes, overlapBytes } = require('../utils.js'); describe('utils', () => { const staticHexVectors = [ @@ -57,6 +57,41 @@ describe('utils', () => { }) ) ); + should('sameBytes', () => { + // Basic + const buffer = new ArrayBuffer(20); + const a = new Uint8Array(buffer, 0, 10); // Bytes 0-9 + const b = new Uint8Array(buffer, 5, 10); // Bytes 5-14 + const c = new Uint8Array(buffer, 10, 10); // Bytes 10-19 + const d = new Uint8Array(new ArrayBuffer(20), 0, 10); // Different buffer + deepStrictEqual(overlapBytes(a, b), true); + deepStrictEqual(overlapBytes(a, c), false); + deepStrictEqual(overlapBytes(b, c), true); + deepStrictEqual(overlapBytes(a, d), false); + // Scan + const res = []; + const main = new Uint8Array(8 + 4); // 2byte + first + 2byte + const first = main.subarray(2).subarray(0, 8); + for (let i = 0; i < main.length; i++) { + const second = main.subarray(i).subarray(0, 1); // one byte window + deepStrictEqual(second, new Uint8Array(1)); + res.push(overlapBytes(first, second)); + } + deepStrictEqual(res, [ + false, + false, + true, + true, + true, + true, + true, + true, + true, + true, + false, + false, + ]); + }); }); describe('utils etc', () => {