Skip to content

Commit

Permalink
Harrison/relevancy score (langchain-ai#3907)
Browse files Browse the repository at this point in the history
Co-authored-by: Ryan Grippeling <[email protected]>
Co-authored-by: Ryan <[email protected]>
Co-authored-by: Zander Chase <[email protected]>
  • Loading branch information
4 people authored May 2, 2023
1 parent c582f2e commit 13269fb
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.vs/
.vscode/
.idea/
# Byte-compiled / optimized / DLL files
Expand Down
7 changes: 4 additions & 3 deletions langchain/retrievers/time_weighted_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
)
results = {}
for fetched_doc, relevance in docs_and_scores:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
if "buffer_idx" in fetched_doc.metadata:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results

def get_relevant_documents(self, query: str) -> List[Document]:
Expand Down
26 changes: 26 additions & 0 deletions langchain/vectorstores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def _redis_prefix(index_name: str) -> str:
return f"doc:{index_name}"


def _default_relevance_score(val: float) -> float:
return 1 - val


class Redis(VectorStore):
"""Wrapper around Redis vector database.
Expand Down Expand Up @@ -108,6 +112,9 @@ def __init__(
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
relevance_score_fn: Optional[
Callable[[float], float]
] = _default_relevance_score,
**kwargs: Any,
):
"""Initialize with necessary components."""
Expand All @@ -133,6 +140,7 @@ def __init__(
self.content_key = content_key
self.metadata_key = metadata_key
self.vector_key = vector_key
self.relevance_score_fn = relevance_score_fn

def _create_index(self, dim: int = 1536) -> None:
try:
Expand Down Expand Up @@ -328,6 +336,24 @@ def similarity_search_with_score(

return docs

def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
0 is dissimilar, 1 is most similar.
"""
if self.relevance_score_fn is None:
raise ValueError(
"relevance_score_fn must be provided to"
" Weaviate constructor to normalize scores"
)
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]

@classmethod
def from_texts(
cls: Type[Redis],
Expand Down
68 changes: 65 additions & 3 deletions langchain/vectorstores/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Wrapper around weaviate vector database."""
from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional, Type
import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -58,6 +59,10 @@ def _create_weaviate_client(**kwargs: Any) -> Any:
return client


def _default_score_normalizer(val: float) -> float:
return 1 - 1 / (1 + np.exp(val))


class Weaviate(VectorStore):
"""Wrapper around Weaviate vector database.
Expand All @@ -80,6 +85,9 @@ def __init__(
text_key: str,
embedding: Optional[Embeddings] = None,
attributes: Optional[List[str]] = None,
relevance_score_fn: Optional[
Callable[[float], float]
] = _default_score_normalizer,
):
"""Initialize with Weaviate client."""
try:
Expand All @@ -98,6 +106,7 @@ def __init__(
self._embedding = embedding
self._text_key = text_key
self._query_attrs = [self._text_key]
self._relevance_score_fn = relevance_score_fn
if attributes is not None:
self._query_attrs.extend(attributes)

Expand All @@ -110,6 +119,11 @@ def add_texts(
"""Upload texts with metadata (properties) to Weaviate."""
from weaviate.util import get_valid_uuid

def json_serializable(value: Any) -> Any:
if isinstance(value, datetime.datetime):
return value.isoformat()
return value

with self._client.batch as batch:
ids = []
for i, doc in enumerate(texts):
Expand All @@ -118,7 +132,7 @@ def add_texts(
}
if metadatas is not None:
for key in metadatas[i].keys():
data_properties[key] = metadatas[i][key]
data_properties[key] = json_serializable(metadatas[i][key])

_id = get_valid_uuid(uuid4())

Expand Down Expand Up @@ -267,9 +281,57 @@ def max_marginal_relevance_search_by_vector(
payload[idx].pop("_additional")
meta = payload[idx]
docs.append(Document(page_content=text, metadata=meta))

return docs

def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
content: Dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
result = (
query_obj.with_near_text(content)
.with_limit(k)
.with_additional("vector")
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")

docs_and_scores = []
if self._embedding is None:
raise ValueError(
"_embedding cannot be None for similarity_search_with_score"
)
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
score = np.dot(
res["_additional"]["vector"], self._embedding.embed_query(query)
)
docs_and_scores.append((Document(page_content=text, metadata=res), score))
return docs_and_scores

def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
0 is dissimilar, 1 is most similar.
"""
if self._relevance_score_fn is None:
raise ValueError(
"relevance_score_fn must be provided to"
" Weaviate constructor to normalize scores"
)
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [
(doc, self._relevance_score_fn(score)) for doc, score in docs_and_scores
]

@classmethod
def from_texts(
cls: Type[Weaviate],
Expand Down

0 comments on commit 13269fb

Please sign in to comment.