forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Harrison/elastic search (langchain-ai#2419)
- Loading branch information
Showing
3 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
164 changes: 164 additions & 0 deletions
164
docs/modules/indexes/retrievers/examples/elastic_search_bm25.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ab66dd43", | ||
"metadata": {}, | ||
"source": [ | ||
"# ElasticSearch BM25\n", | ||
"\n", | ||
"This notebook goes over how to use a retriever that under the hood uses ElasticSearcha and BM25.\n", | ||
"\n", | ||
"For more information on the details of BM25 see [this blog post](https://www.elastic.co/blog/practical-bm25-part-2-the-bm25-algorithm-and-its-variables)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "393ac030", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.retrievers import ElasticSearchBM25Retriever" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "aaf80e7f", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create New Retriever" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"id": "bcb3c8c2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"elasticsearch_url=\"http://localhost:9200\"\n", | ||
"retriever = ElasticSearchBM25Retriever.create(elasticsearch_url, \"langchain-index-3\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"id": "b605284d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Alternatively, you can load an existing index\n", | ||
"# import elasticsearch\n", | ||
"# elasticsearch_url=\"http://localhost:9200\"\n", | ||
"# retriever = ElasticSearchBM25Retriever(elasticsearch.Elasticsearch(elasticsearch_url), \"langchain-index\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1c518c42", | ||
"metadata": {}, | ||
"source": [ | ||
"## Add texts (if necessary)\n", | ||
"\n", | ||
"We can optionally add texts to the retriever (if they aren't already in there)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "98b1c017", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"['386c76c9-4355-4c12-aaeb-7b80054caf93',\n", | ||
" 'fffd279c-a0c9-4158-a904-6e242c517c99',\n", | ||
" '7f5528a3-18d0-43b0-894d-f6770a002219',\n", | ||
" 'e2ef5e32-d5bd-44e2-b045-cfc5a8e0a0a1',\n", | ||
" 'cce8ba48-e473-4235-bca2-2c8d65e73ccf']" | ||
] | ||
}, | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"retriever.add_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "08437fa2", | ||
"metadata": {}, | ||
"source": [ | ||
"## Use Retriever\n", | ||
"\n", | ||
"We can now use the retriever!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"id": "c0455218", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"result = retriever.get_relevant_documents(\"foo\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"id": "7dfa5c29", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[Document(page_content='foo', metadata={}),\n", | ||
" Document(page_content='foo bar', metadata={})]" | ||
] | ||
}, | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"result" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "74bd9256", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
"""Wrapper around Elasticsearch vector database.""" | ||
from __future__ import annotations | ||
|
||
import uuid | ||
from typing import Any, Iterable, List | ||
|
||
from langchain.docstore.document import Document | ||
from langchain.schema import BaseRetriever | ||
|
||
|
||
class ElasticSearchBM25Retriever(BaseRetriever): | ||
"""Wrapper around Elasticsearch using BM25 as a retrieval method. | ||
To connect to an Elasticsearch instance that requires login credentials, | ||
including Elastic Cloud, use the Elasticsearch URL format | ||
https://username:password@es_host:9243. For example, to connect to Elastic | ||
Cloud, create the Elasticsearch URL with the required authentication details and | ||
pass it to the ElasticVectorSearch constructor as the named parameter | ||
elasticsearch_url. | ||
You can obtain your Elastic Cloud URL and login credentials by logging in to the | ||
Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and | ||
navigating to the "Deployments" page. | ||
To obtain your Elastic Cloud password for the default "elastic" user: | ||
1. Log in to the Elastic Cloud console at https://cloud.elastic.co | ||
2. Go to "Security" > "Users" | ||
3. Locate the "elastic" user and click "Edit" | ||
4. Click "Reset password" | ||
5. Follow the prompts to reset the password | ||
The format for Elastic Cloud URLs is | ||
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243. | ||
""" | ||
|
||
def __init__(self, client: Any, index_name: str): | ||
self.client = client | ||
self.index_name = index_name | ||
|
||
@classmethod | ||
def create( | ||
cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75 | ||
) -> ElasticSearchBM25Retriever: | ||
from elasticsearch import Elasticsearch | ||
|
||
# Create an Elasticsearch client instance | ||
es = Elasticsearch(elasticsearch_url) | ||
|
||
# Define the index settings and mappings | ||
index_settings = { | ||
"settings": { | ||
"analysis": {"analyzer": {"default": {"type": "standard"}}}, | ||
"similarity": { | ||
"custom_bm25": { | ||
"type": "BM25", | ||
"k1": k1, | ||
"b": b, | ||
} | ||
}, | ||
}, | ||
"mappings": { | ||
"properties": { | ||
"content": { | ||
"type": "text", | ||
"similarity": "custom_bm25", # Use the custom BM25 similarity | ||
} | ||
} | ||
}, | ||
} | ||
|
||
# Create the index with the specified settings and mappings | ||
es.indices.create(index=index_name, body=index_settings) | ||
return cls(es, index_name) | ||
|
||
def add_texts( | ||
self, | ||
texts: Iterable[str], | ||
refresh_indices: bool = True, | ||
) -> List[str]: | ||
"""Run more texts through the embeddings and add to the retriver. | ||
Args: | ||
texts: Iterable of strings to add to the retriever. | ||
refresh_indices: bool to refresh ElasticSearch indices | ||
Returns: | ||
List of ids from adding the texts into the retriever. | ||
""" | ||
try: | ||
from elasticsearch.helpers import bulk | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import elasticsearch python package. " | ||
"Please install it with `pip install elasticsearch`." | ||
) | ||
requests = [] | ||
ids = [] | ||
for i, text in enumerate(texts): | ||
_id = str(uuid.uuid4()) | ||
request = { | ||
"_op_type": "index", | ||
"_index": self.index_name, | ||
"content": text, | ||
"_id": _id, | ||
} | ||
ids.append(_id) | ||
requests.append(request) | ||
bulk(self.client, requests) | ||
|
||
if refresh_indices: | ||
self.client.indices.refresh(index=self.index_name) | ||
return ids | ||
|
||
def get_relevant_documents(self, query: str) -> List[Document]: | ||
query_dict = {"query": {"match": {"content": query}}} | ||
res = self.client.search(index=self.index_name, body=query_dict) | ||
|
||
docs = [] | ||
for r in res["hits"]["hits"]: | ||
docs.append(Document(page_content=r["_source"]["content"])) | ||
return docs | ||
|
||
async def aget_relevant_documents(self, query: str) -> List[Document]: | ||
raise NotImplementedError |