-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7ed6cf3
commit c01c787
Showing
6 changed files
with
38 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
setup( | ||
name="bm25_pt", | ||
version="0.0.4", | ||
version="0.0.5", | ||
description="bm25 search algorithm in pytorch", | ||
author="Jack Morris", | ||
author_email="[email protected]", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import List | ||
|
||
import pickle | ||
import pytest | ||
import os | ||
import torch | ||
|
||
from bm25_pt import BM25 | ||
|
||
|
||
current_folder = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
@pytest.fixture | ||
def scrolls_corpus() -> List[str]: | ||
file_path = os.path.join(current_folder, "scrolls_test_corpus.p") | ||
return pickle.load(open(file_path, "rb")) | ||
|
||
@pytest.fixture | ||
def scrolls_queries() -> List[str]: | ||
file_path = os.path.join(current_folder, "scrolls_test_queries.p") | ||
return pickle.load(open(file_path, "rb")) | ||
|
||
def test_scrolls(scrolls_corpus, scrolls_queries): | ||
scrolls_corpus = scrolls_corpus[:1] | ||
bm25 = BM25() | ||
print("indexing") | ||
bm25.index(scrolls_corpus) | ||
print("scoring (slow)") | ||
doc_scores_slow = bm25.score_slow(scrolls_queries[0]) | ||
print("scoring (fast)") | ||
doc_scores = bm25.score(scrolls_queries[0]) | ||
|
||
torch.testing.assert_close(doc_scores, doc_scores_slow) |
Binary file not shown.
Binary file not shown.