Skip to content

Commit

Permalink
Cassandra Vector Store, add metadata filtering + improvements (langch…
Browse files Browse the repository at this point in the history
…ain-ai#9280)

This PR addresses a few minor issues with the Cassandra vector store
implementation and extends the store to support Metadata search.

Thanks to the latest cassIO library (>=0.1.0), metadata filtering is
available in the store.

Further,
- the "relevance" score is prevented from being flipped in the [0,1]
interval, thus ensuring that 1 corresponds to the closest vector (this
is related to how the underlying cassIO class returns the cosine
difference);
- bumped the cassIO package version both in the notebooks and the
pyproject.toml;
- adjusted the textfile location for the vector-store example after the
reshuffling of the Langchain repo dir structure;
- added demonstration of metadata filtering in the Cassandra vector
store notebook;
- better docstring for the Cassandra vector store class;
- fixed test flakiness and removed offending out-of-place escape chars
from a test module docstring;

To my knowledge all relevant tests pass and mypy+black+ruff don't
complain. (mypy gives unrelated errors in other modules, which clearly
don't depend on the content of this PR).

Thank you!
Stefano

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
hemidactylus and baskaryan authored Sep 13, 2023
1 parent 49694f6 commit 415d38a
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.7\""
"!pip install \"cassio>=0.1.0\""
]
},
{
Expand Down Expand Up @@ -155,7 +155,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
55 changes: 51 additions & 4 deletions docs/extras/integrations/vectorstores/cassandra.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.7\""
"!pip install \"cassio>=0.1.0\""
]
},
{
Expand Down Expand Up @@ -152,7 +152,9 @@
"source": [
"from langchain.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
"SOURCE_FILE_NAME = \"../../modules/state_of_the_union.txt\"\n",
"\n",
"loader = TextLoader(SOURCE_FILE_NAME)\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
Expand Down Expand Up @@ -197,7 +199,7 @@
"# table_name=table_name,\n",
"# )\n",
"\n",
"# docsearch_preexisting.similarity_search(query, k=2)"
"# docs = docsearch_preexisting.similarity_search(query, k=2)"
]
},
{
Expand Down Expand Up @@ -253,6 +255,51 @@
"for i, doc in enumerate(found_docs):\n",
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
]
},
{
"cell_type": "markdown",
"id": "da791c5f",
"metadata": {},
"source": [
"### Metadata filtering\n",
"\n",
"You can specify filtering on metadata when running searches in the vector store. By default, when inserting documents, the only metadata is the `\"source\"` (but you can customize the metadata at insertion time).\n",
"\n",
"Since only one files was inserted, this is just a demonstration of how filters are passed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93f132fa",
"metadata": {},
"outputs": [],
"source": [
"filter = {\"source\": SOURCE_FILE_NAME}\n",
"filtered_docs = docsearch.similarity_search(query, filter=filter, k=5)\n",
"print(f\"{len(filtered_docs)} documents retrieved.\")\n",
"print(f\"{filtered_docs[0].page_content[:64]} ...\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b413ec4",
"metadata": {},
"outputs": [],
"source": [
"filter = {\"source\": \"nonexisting_file.txt\"}\n",
"filtered_docs2 = docsearch.similarity_search(query, filter=filter)\n",
"print(f\"{len(filtered_docs2)} documents retrieved.\")"
]
},
{
"cell_type": "markdown",
"id": "a0fea764",
"metadata": {},
"source": [
"Please visit the [cassIO documentation](https://cassio.org/frameworks/langchain/about/) for more on using vector stores with Langchain."
]
}
],
"metadata": {
Expand All @@ -271,7 +318,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
74 changes: 60 additions & 14 deletions libs/langchain/langchain/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@

import typing
import uuid
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

import numpy as np

Expand All @@ -18,11 +29,12 @@


class Cassandra(VectorStore):
"""`Cassandra` vector store.
"""Wrapper around Apache Cassandra(R) for vector-store workloads.
It based on the Cassandra vector-store capabilities, based on cassIO.
There is no notion of a default table name, since each embedding
function implies its own vector dimension, which is part of the schema.
To use it, you need a recent installation of the `cassio` library
and a Cassandra cluster / Astra DB instance supporting vector capabilities.
Visit the cassio.org website for extensive quickstarts and code examples.
Example:
.. code-block:: python
Expand All @@ -31,12 +43,20 @@ class Cassandra(VectorStore):
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
session = ...
keyspace = 'my_keyspace'
vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive')
session = ... # create your Cassandra session object
keyspace = 'my_keyspace' # the keyspace should exist already
table_name = 'my_vector_store'
vectorstore = Cassandra(embeddings, session, keyspace, table_name)
"""

_embedding_dimension: int | None
_embedding_dimension: Union[int, None]

@staticmethod
def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
if filter_dict is None:
return {}
else:
return filter_dict

def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
Expand Down Expand Up @@ -81,8 +101,18 @@ def __init__(
def embeddings(self) -> Embeddings:
return self.embedding

@staticmethod
def _dont_flip_the_cos_score(distance: float) -> float:
# the identity
return distance

def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._cosine_relevance_score_fn
"""
The underlying VectorTable already returns a "score proper",
i.e. one in [0, 1] where higher means more *similar*,
so here the final score transformation is not reversing the interval:
"""
return self._dont_flip_the_cos_score

def delete_collection(self) -> None:
"""
Expand Down Expand Up @@ -172,22 +202,24 @@ def similarity_search_with_score_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector.
No support for `filter` query (on metadata) along with vector search.
Args:
embedding (str): Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
search_metadata = self._filter_to_metadata(filter)
#
hits = self.table.search(
embedding_vector=embedding,
top_k=k,
metric="cos",
metric_threshold=None,
metadata=search_metadata,
)
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
# (1=most relevant), as required by this class' contract.
Expand All @@ -207,23 +239,24 @@ def similarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float, str]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)

# id-unaware search facilities
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector.
No support for `filter` query (on metadata) along with vector search.
Args:
embedding (str): Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
Expand All @@ -235,44 +268,51 @@ def similarity_search_with_score_by_vector(
for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
]

def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)

def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
return [
doc
for doc, _ in self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
)
]

def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)

def max_marginal_relevance_search_by_vector(
Expand All @@ -281,6 +321,7 @@ def max_marginal_relevance_search_by_vector(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Expand All @@ -296,11 +337,14 @@ def max_marginal_relevance_search_by_vector(
Returns:
List of Documents selected by maximal marginal relevance.
"""
search_metadata = self._filter_to_metadata(filter)

prefetchHits = self.table.search(
embedding_vector=embedding,
top_k=fetch_k,
metric="cos",
metric_threshold=None,
metadata=search_metadata,
)
# let the mmr utility pick the *indices* in the above array
mmrChosenIndices = maximal_marginal_relevance(
Expand Down Expand Up @@ -328,6 +372,7 @@ def max_marginal_relevance_search(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Expand All @@ -350,6 +395,7 @@ def max_marginal_relevance_search(
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test Cassandra functionality."""
import time
from typing import List, Optional, Type

from cassandra.cluster import Cluster
Expand Down Expand Up @@ -61,9 +62,9 @@ def test_cassandra_with_score() -> None:
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
Document(page_content="baz", metadata={"page": 2}),
Document(page_content="foo", metadata={"page": "0.0"}),
Document(page_content="bar", metadata={"page": "1.0"}),
Document(page_content="baz", metadata={"page": "2.0"}),
]
assert scores[0] > scores[1] > scores[2]

Expand All @@ -76,10 +77,10 @@ def test_cassandra_max_marginal_relevance_search() -> None:
______ v2
/ \
/ \ v1
/ | v1
v3 | . | query
\ / v0
\______/ (N.B. very crude drawing)
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==3 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order).
Expand All @@ -94,8 +95,8 @@ def test_cassandra_max_marginal_relevance_search() -> None:
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
}
assert output_set == {
("+0.25", 2),
("-0.124", 0),
("+0.25", "2.0"),
("-0.124", "0.0"),
}


Expand Down Expand Up @@ -150,6 +151,7 @@ def test_cassandra_delete() -> None:
assert len(output) == 1

docsearch.clear()
time.sleep(0.3)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0

Expand Down

0 comments on commit 415d38a

Please sign in to comment.