Skip to content

Commit

Permalink
enable batched IDF properly
Browse files Browse the repository at this point in the history
  • Loading branch information
jxmorris12 committed Feb 29, 2024
1 parent 7ed6cf3 commit c01c787
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 25 deletions.
26 changes: 3 additions & 23 deletions bm25_pt/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import functools
import math
import scipy
import torch
import transformers
import tqdm
Expand All @@ -18,23 +17,6 @@ def documents_to_bags(docs: torch.Tensor, vocab_size: int) -> torch.sparse.Tenso
return torch.sparse_coo_tensor(idxs, vals, size=(num_docs, vocab_size)).coalesce()


def torch_sparse_to_scipy(t: torch.sparse.Tensor):
indices = t.coalesce().indices()
values = t.coalesce().values()
size = t.size()
coo_matrix = scipy.sparse.coo_matrix((values.numpy(), (indices[0].numpy(), indices[1].numpy())), shape=size)
return coo_matrix.tocsr()


def sparse_divide(A: torch.sparse.Tensor, B: torch.sparse.Tensor) -> torch.sparse.Tensor:
"""Have to do sparse division on CPU in scipy."""
device = A.device
A = torch_sparse_to_scipy(A)
B = torch_sparse_to_scipy(B)
r = (A / B)
return torch.sparse_coo_tensor(r.nonzero(), r.data, r.shape, device=device)


class TokenizedBM25:
k1: float
b: float
Expand Down Expand Up @@ -99,9 +81,7 @@ def _score_pair_slow(self, query: torch.Tensor, document_bag: torch.sparse.Tenso
den = occurrences + (self.k1 *
(1 - self.b + self.b * (this_document_length / self._average_document_length)))
word_score = self.compute_IDF(word) * num / den
print("word:", word, "num:", num, "den:", den, "idf:", self.compute_IDF(word))
score += word_score
print("\t total:", score)
return score

def score_slow(self, query: torch.Tensor) -> torch.Tensor:
Expand All @@ -112,16 +92,16 @@ def score(self, query: torch.Tensor) -> torch.Tensor:
return self.score_batch(query[None]).flatten()

def _score_batch(self, queries: torch.Tensor) -> torch.Tensor:
# TODO: Batch idf computation, this shouldn't be too slow though since it's cached.
# TODO: Change all dense computations to sparse
num_queries, seq_length = queries.shape
queries_bag = self.docs_to_bags(queries)

num = (self._corpus * (self.k1 + 1))
normalized_lengths = (self.k1 * (1 - self.b + self.b * (self._corpus_lengths[:, None] / self._average_document_length)))
den = normalized_lengths.repeat((1, self._corpus.shape[1])) + self._corpus
score = (self._IDF[None, :] * sparse_divide(num, den)).sum()
scores = (self._IDF[None, :] * (num.to_dense() / den))

bm25_scores = queries_bag @ scores.T
bm25_scores = queries_bag.float().to_dense() @ scores.T

return bm25_scores

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="bm25_pt",
version="0.0.4",
version="0.0.5",
description="bm25 search algorithm in pytorch",
author="Jack Morris",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion test/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_scores_equal_batch():
# check term counts
count_of_the = " ".join(corpus + [""]).count("the ")
token_of_the = bm25.tokenizer.encode("the", add_special_tokens=False, return_tensors='pt').item()
assert count_of_the == bm25._corpus.sum(0)[1996]
assert count_of_the == bm25._corpus.sum(0)[token_of_the]


def test_scores_equal_gpu():
Expand Down
33 changes: 33 additions & 0 deletions test/real_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import List

import pickle
import pytest
import os
import torch

from bm25_pt import BM25


current_folder = os.path.dirname(os.path.abspath(__file__))

@pytest.fixture
def scrolls_corpus() -> List[str]:
file_path = os.path.join(current_folder, "scrolls_test_corpus.p")
return pickle.load(open(file_path, "rb"))

@pytest.fixture
def scrolls_queries() -> List[str]:
file_path = os.path.join(current_folder, "scrolls_test_queries.p")
return pickle.load(open(file_path, "rb"))

def test_scrolls(scrolls_corpus, scrolls_queries):
scrolls_corpus = scrolls_corpus[:1]
bm25 = BM25()
print("indexing")
bm25.index(scrolls_corpus)
print("scoring (slow)")
doc_scores_slow = bm25.score_slow(scrolls_queries[0])
print("scoring (fast)")
doc_scores = bm25.score(scrolls_queries[0])

torch.testing.assert_close(doc_scores, doc_scores_slow)
Binary file added test/scrolls_test_corpus.p
Binary file not shown.
Binary file added test/scrolls_test_queries.p
Binary file not shown.

0 comments on commit c01c787

Please sign in to comment.