Skip to content

Commit

Permalink
Embedding Projector:knn for non-normalized vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
alicialics committed Apr 8, 2023
1 parent e5d1771 commit 3693f34
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 46 deletions.
4 changes: 2 additions & 2 deletions tensorboard/plugins/projector/vz_projector/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,12 @@ export class DataSet {
} else {
const knnGpuEnabled = (await util.hasWebGLSupport()) && !IS_FIREFOX;
const result = await (knnGpuEnabled
? knn.findKNNGPUCosDistNorm(data, nNeighbors, (d) => d.vector)
? knn.findKNNGPUCosDist(data, nNeighbors, (d) => d.vector)
: knn.findKNN(
data,
nNeighbors,
(d) => d.vector,
(a, b) => vector.cosDistNorm(a, b)
(a, b) => vector.cosDist(a, b)
));
this.nearest = result;
return Promise.resolve(result);
Expand Down
15 changes: 9 additions & 6 deletions tensorboard/plugins/projector/vz_projector/knn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const KNN_GPU_MSG_ID = 'knn-gpu';
* @param k Number of nearest neighbors to find.
* @param accessor A method that returns the vector, given the data point.
*/
export function findKNNGPUCosDistNorm<T>(
export function findKNNGPUCosDist<T>(
dataPoints: T[],
k: number,
accessor: (dataPoint: T) => Float32Array
Expand All @@ -61,6 +61,10 @@ export function findKNNGPUCosDistNorm<T>(
// pair of points, which we sort using KMin data structure to obtain the
// K nearest neighbors for each point.
const nearest: NearestEntry[][] = new Array(N);
const dpNorm: number[] = new Array(N);
for (let i = 0; i < N; i++) {
dpNorm[i] = Math.sqrt(vector.norm2(accessor(dataPoints[i])));
}
let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE);
const actualPieceSize = Math.floor(N / numPieces);
const modulo = N % actualPieceSize;
Expand Down Expand Up @@ -120,10 +124,9 @@ export function findKNNGPUCosDistNorm<T>(
// Access i * N's row at `j` column.
// Reach row has N entries and j-th index has cosine distance
// between iReal vs. j-th vectors.
const cosDist = partial[i * N + j];
if (cosDist >= 0) {
kMin.add(cosDist, {index: j, dist: cosDist});
}
const cosDist =
1 - (1 - partial[i * N + j]) / (dpNorm[i] * dpNorm[j]);
kMin.add(cosDist, {index: j, dist: cosDist});
}
nearest[iReal] = kMin.getMinKItems();
}
Expand Down Expand Up @@ -151,7 +154,7 @@ export function findKNNGPUCosDistNorm<T>(
(error) => {
// GPU failed. Reverting back to CPU.
logging.setModalMessage(null!, KNN_GPU_MSG_ID);
let distFunc = (a, b, limit) => vector.cosDistNorm(a, b);
let distFunc = (a, b, limit) => vector.cosDist(a, b);
findKNN(dataPoints, k, accessor, distFunc).then((nearest) => {
resolve(nearest);
});
Expand Down
65 changes: 27 additions & 38 deletions tensorboard/plugins/projector/vz_projector/knn_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import {findKNN, findKNNGPUCosDistNorm, NearestEntry, TEST_ONLY} from './knn';
import {cosDistNorm, unit} from './vector';
import {findKNN, findKNNGPUCosDist, NearestEntry, TEST_ONLY} from './knn';
import {cosDist} from './vector';

describe('projector knn test', () => {
function getIndices(nearest: NearestEntry[][]): number[][] {
Expand All @@ -22,22 +22,16 @@ describe('projector knn test', () => {
});
}

function unitVector(vector: Float32Array): Float32Array {
// `unit` method replaces the vector in-place.
unit(vector);
return vector;
}

describe('#findKNNGPUCosDistNorm', () => {
describe('#findKNNGPUCosDist', () => {
it('finds n-nearest neighbor for each item', async () => {
const values = await findKNNGPUCosDistNorm(
const values = await findKNNGPUCosDist(
[
{a: unitVector(new Float32Array([1, 2, 0]))},
{a: unitVector(new Float32Array([1, 1, 3]))},
{a: unitVector(new Float32Array([100, 30, 0]))},
{a: unitVector(new Float32Array([95, 23, 3]))},
{a: unitVector(new Float32Array([100, 10, 0]))},
{a: unitVector(new Float32Array([95, 23, 100]))},
{a: new Float32Array([1, 2, 0])},
{a: new Float32Array([1, 1, 3])},
{a: new Float32Array([100, 30, 0])},
{a: new Float32Array([95, 23, 3])},
{a: new Float32Array([100, 10, 0])},
{a: new Float32Array([95, 23, 100])},
],
4,
(data) => data.a
Expand All @@ -54,11 +48,8 @@ describe('projector knn test', () => {
});

it('returns less than N when number of item is lower', async () => {
const values = await findKNNGPUCosDistNorm(
[
unitVector(new Float32Array([1, 2, 0])),
unitVector(new Float32Array([1, 1, 3])),
],
const values = await findKNNGPUCosDist(
[new Float32Array([1, 2, 0]), new Float32Array([1, 1, 3])],
4,
(a) => a
);
Expand All @@ -68,10 +59,8 @@ describe('projector knn test', () => {

it('splits a large data into one that would fit into GPU memory', async () => {
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
const data = new Array(size).fill(
unitVector(new Float32Array([1, 1, 1]))
);
const values = await findKNNGPUCosDistNorm(data, 1, (a) => a);
const data = new Array(size).fill(new Float32Array([1, 1, 1]));
const values = await findKNNGPUCosDist(data, 1, (a) => a);

expect(getIndices(values)).toEqual([
// Since distance to the diagonal entries (distance to self is 0) is
Expand All @@ -84,25 +73,25 @@ describe('projector knn test', () => {
});

describe('#findKNN', () => {
// Covered by equality tests below (#findKNNGPUCosDistNorm == #findKNN).
// Covered by equality tests below (#findKNNGPUCosDist == #findKNN).
});

describe('#findKNNGPUCosDistNorm and #findKNN', () => {
describe('#findKNNGPUCosDist and #findKNN', () => {
it('returns same value when dist metrics are cosine', async () => {
const data = [
unitVector(new Float32Array([1, 2, 0])),
unitVector(new Float32Array([1, 1, 3])),
unitVector(new Float32Array([100, 30, 0])),
unitVector(new Float32Array([95, 23, 3])),
unitVector(new Float32Array([100, 10, 0])),
unitVector(new Float32Array([95, 23, 100])),
new Float32Array([1, 2, 0]),
new Float32Array([1, 1, 3]),
new Float32Array([100, 30, 0]),
new Float32Array([95, 23, 3]),
new Float32Array([100, 10, 0]),
new Float32Array([95, 23, 100]),
];
const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a);
const findKnnGpuCosVal = await findKNNGPUCosDist(data, 2, (a) => a);
const findKnnVal = await findKNN(
data,
2,
(a) => a,
(a, b, limit) => cosDistNorm(a, b)
(a, b, limit) => cosDist(a, b)
);

// Floating point precision makes it hard to test. Just assert indices.
Expand All @@ -112,15 +101,15 @@ describe('projector knn test', () => {
it('splits a large data without the result being wrong', async () => {
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
const data = Array.from(new Array(size)).map((_, index) => {
return unitVector(new Float32Array([index + 1, index + 1]));
return new Float32Array([index + 1, index + 1]);
});

const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a);
const findKnnGpuCosVal = await findKNNGPUCosDist(data, 2, (a) => a);
const findKnnVal = await findKNN(
data,
2,
(a) => a,
(a, b, limit) => cosDistNorm(a, b)
(a, b, limit) => cosDist(a, b)
);

expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal));
Expand Down

0 comments on commit 3693f34

Please sign in to comment.