Skip to content

Commit

Permalink
feat(🧮): Add invert on Matrix4 (Shopify#2791)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcandillon authored Dec 4, 2024
1 parent 31cd2f6 commit 42f81f0
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 0 deletions.
93 changes: 93 additions & 0 deletions packages/skia/src/renderer/__tests__/e2e/Matrix4.spec.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import {
toMatrix3,
Matrix4,
mapPoint3d,
invert4,
scale,
rotateZ,
} from "../../../skia/types";

const ckPerspective = (d: number) => [
Expand Down Expand Up @@ -272,4 +275,94 @@ describe("Matrix4", () => {
const result = mapPoint3d(translationMatrix, point);
expect(result).toEqual(expectedResult);
});

const almostEqual = (
a: number[] | Matrix4 | readonly [number, number, number],
b: number[] | Matrix4 | readonly [number, number, number],
epsilon = 1e-10
) => {
expect(a.length).toBe(b.length);
a.forEach((val, idx) => {
expect(Math.abs(val - b[idx])).toBeLessThan(epsilon);
});
};

const matrixEqual = (a: number[] | Matrix4, b: number[] | Matrix4) => {
expect(a.length).toBe(b.length);
a.forEach((val, idx) => {
// Object.is will distinguish -0 from 0, so we use === instead
const equal = val === b[idx] || (val === 0 && b[idx] === 0);
expect(equal).toBe(true);
});
};

it("should return the identity matrix when inverting the identity matrix", () => {
const identityMatrix = Matrix4();
const result = invert4(identityMatrix);
matrixEqual(result, identityMatrix);
});

it("should correctly invert a translation matrix", () => {
const translationMatrix = translate(100, -50, 25);
const inverse = invert4(translationMatrix);
// Inverse of translation(x,y,z) should be translation(-x,-y,-z)
const expectedInverse = translate(-100, 50, -25);
matrixEqual(inverse, expectedInverse);
});

it("should correctly invert a scale matrix", () => {
const scaleMatrix = scale(2, 4, 8);
const inverse = invert4(scaleMatrix);
// Inverse of scale(x,y,z) should be scale(1/x,1/y,1/z)
const expectedInverse = scale(1 / 2, 1 / 4, 1 / 8);
matrixEqual(inverse, expectedInverse);
});

it("should return identity matrix for non-invertible matrix", () => {
// A matrix of all zeros is not invertible
const nonInvertibleMatrix = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
] as Matrix4;
const result = invert4(nonInvertibleMatrix);
matrixEqual(result, Matrix4());
});

it("multiplying a matrix by its inverse should give the identity matrix", () => {
// Create a complex transformation matrix
const complexMatrix = processTransform3d([
{ scale: 2 },
{ rotateZ: Math.PI / 4 },
{ translate: [100, 100, 50] as const },
]);

const inverse = invert4(complexMatrix);
const result = multiply4(complexMatrix, inverse);

// Due to floating point arithmetic, we use almostEqual instead of exact equality
almostEqual(result, Matrix4());
});

it("should correctly transform points when using inverse matrix", () => {
const transformMatrix = translate(100, 100, 100);
const inverse = invert4(transformMatrix);

const point = [200, 0, 300] as const;
const transformedPoint = mapPoint3d(inverse, point);
const expectedPoint = [100, -100, 200] as const;

// Using almostEqual for floating point comparison
almostEqual(transformedPoint, expectedPoint);
});

it("should maintain inverse relationship for rotations", () => {
const rotationMatrix = rotateZ(Math.PI / 3); // 60 degrees rotation
const inverse = invert4(rotationMatrix);
const point = [100, 100, 0] as const;

// Transform point forward then backward should give original point
const transformed = mapPoint3d(rotationMatrix, point);
const backTransformed = mapPoint3d(inverse, transformed);

almostEqual(backTransformed, point);
});
});
101 changes: 101 additions & 0 deletions packages/skia/src/skia/types/Matrix4.ts
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,104 @@ export const convertToAffineMatrix = (m4: Matrix4) => {
// Returning the 6-element affine transformation matrix
return [a, b, c, d, tx, ty];
};

/**
* Calculates the determinant of a 3x3 matrix
* @worklet
*/
const det3x3 = (
a00: number,
a01: number,
a02: number,
a10: number,
a11: number,
a12: number,
a20: number,
a21: number,
a22: number
): number => {
"worklet";
return (
a00 * (a11 * a22 - a12 * a21) +
a01 * (a12 * a20 - a10 * a22) +
a02 * (a10 * a21 - a11 * a20)
);
};

/**
* Inverts a 4x4 matrix
* @worklet
* @returns The inverted matrix, or the identity matrix if the input is not invertible
*/
export const invert4 = (m: Matrix4): Matrix4 => {
"worklet";

const a00 = m[0],
a01 = m[1],
a02 = m[2],
a03 = m[3];
const a10 = m[4],
a11 = m[5],
a12 = m[6],
a13 = m[7];
const a20 = m[8],
a21 = m[9],
a22 = m[10],
a23 = m[11];
const a30 = m[12],
a31 = m[13],
a32 = m[14],
a33 = m[15];

// Calculate cofactors
const b00 = det3x3(a11, a12, a13, a21, a22, a23, a31, a32, a33);
const b01 = -det3x3(a10, a12, a13, a20, a22, a23, a30, a32, a33);
const b02 = det3x3(a10, a11, a13, a20, a21, a23, a30, a31, a33);
const b03 = -det3x3(a10, a11, a12, a20, a21, a22, a30, a31, a32);

const b10 = -det3x3(a01, a02, a03, a21, a22, a23, a31, a32, a33);
const b11 = det3x3(a00, a02, a03, a20, a22, a23, a30, a32, a33);
const b12 = -det3x3(a00, a01, a03, a20, a21, a23, a30, a31, a33);
const b13 = det3x3(a00, a01, a02, a20, a21, a22, a30, a31, a32);

const b20 = det3x3(a01, a02, a03, a11, a12, a13, a31, a32, a33);
const b21 = -det3x3(a00, a02, a03, a10, a12, a13, a30, a32, a33);
const b22 = det3x3(a00, a01, a03, a10, a11, a13, a30, a31, a33);
const b23 = -det3x3(a00, a01, a02, a10, a11, a12, a30, a31, a32);

const b30 = -det3x3(a01, a02, a03, a11, a12, a13, a21, a22, a23);
const b31 = det3x3(a00, a02, a03, a10, a12, a13, a20, a22, a23);
const b32 = -det3x3(a00, a01, a03, a10, a11, a13, a20, a21, a23);
const b33 = det3x3(a00, a01, a02, a10, a11, a12, a20, a21, a22);

// Calculate determinant
const det = a00 * b00 + a01 * b01 + a02 * b02 + a03 * b03;

// Check if matrix is invertible
if (Math.abs(det) < 1e-8) {
// Return identity matrix if not invertible
return Matrix4();
}

const invDet = 1.0 / det;

// Calculate inverse matrix
return [
b00 * invDet,
b10 * invDet,
b20 * invDet,
b30 * invDet,
b01 * invDet,
b11 * invDet,
b21 * invDet,
b31 * invDet,
b02 * invDet,
b12 * invDet,
b22 * invDet,
b32 * invDet,
b03 * invDet,
b13 * invDet,
b23 * invDet,
b33 * invDet,
] as Matrix4;
};

0 comments on commit 42f81f0

Please sign in to comment.