From 7f3f6097e74d4eddd3dde387a2ef46af5769b5d1 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 14 Sep 2023 08:43:50 -0700 Subject: [PATCH] Add mmr support to redis retriever (#10556) --- .../integrations/vectorstores/redis.ipynb | 46 ++++++++++++++++--- .../langchain/vectorstores/redis/base.py | 6 ++- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/docs/extras/integrations/vectorstores/redis.ipynb b/docs/extras/integrations/vectorstores/redis.ipynb index 4b13672fc5297..bb67201501210 100644 --- a/docs/extras/integrations/vectorstores/redis.ipynb +++ b/docs/extras/integrations/vectorstores/redis.ipynb @@ -158,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -242,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": { "tags": [] }, @@ -253,7 +253,7 @@ "rds = Redis.from_texts(\n", " texts,\n", " embeddings,\n", - " metadatas=metadats,\n", + " metadatas=metadata,\n", " redis_url=\"redis://localhost:6379\",\n", " index_name=\"users\"\n", ")" @@ -597,7 +597,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -607,7 +607,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -1110,6 +1110,38 @@ "retriever.get_relevant_documents(\"foo\")" ] }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = rds.as_retriever(search_type=\"mmr\", search_kwargs={\"fetch_k\": 20, \"k\": 4, \"lambda_mult\": 0.1})" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='foo', metadata={'id': 'doc:users:8f6b673b390647809d510112cde01a27', 'user': 'john', 'job': 'engineer', 'credit_score': 'high', 'age': '18'}),\n", + " Document(page_content='bar', metadata={'id': 'doc:users:93521560735d42328b48c9c6f6418d6a', 'user': 'tyler', 'job': 'engineer', 'credit_score': 'high', 'age': '100'}),\n", + " Document(page_content='foo', metadata={'id': 'doc:users:125ecd39d07845eabf1a699d44134a5b', 'user': 'nancy', 'job': 'doctor', 'credit_score': 'high', 'age': '94'}),\n", + " Document(page_content='foo', metadata={'id': 'doc:users:d6200ab3764c466082fde3eaab972a2a', 'user': 'derrick', 'job': 'doctor', 'credit_score': 'low', 'age': '45'})]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever.get_relevant_documents(\"foo\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1227,7 +1259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index 830a97d1a5ce0..dbca5c36b7e00 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -1425,6 +1425,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever): "similarity", "similarity_distance_threshold", "similarity_score_threshold", + "mmr", ] """Allowed search types.""" @@ -1438,7 +1439,6 @@ def _get_relevant_documents( ) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.similarity_search(query, **self.search_kwargs) - elif self.search_type == "similarity_distance_threshold": if self.search_kwargs["distance_threshold"] is None: raise ValueError( @@ -1454,6 +1454,10 @@ def _get_relevant_documents( ) ) docs = [doc for doc, _ in docs_and_similarities] + elif self.search_type == "mmr": + docs = self.vectorstore.max_marginal_relevance_search( + query, **self.search_kwargs + ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs