Skip to content

Commit

Permalink
feat:core:#95
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed May 15, 2023
1 parent 05408a9 commit 2fd1fb5
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 68 deletions.
73 changes: 73 additions & 0 deletions embedbase/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SearchRequest,
AddRequest,
UpdateRequest,
MultiDatasetSearchRequest,
)
from embedbase.utils import embedbase_ascii, get_user_id
from embedbase.settings import Settings
Expand Down Expand Up @@ -420,6 +421,75 @@ async def semantic_search(
"hash": match.hash,
"embedding": match.embedding,
"metadata": match.metadata,
"dataset_id": match.dataset_id,
}
)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"query": query, "similarities": similarities},
)


async def multi_dataset_semantic_search(
self,
request: Request,
request_body: MultiDatasetSearchRequest,
):
"""
Run semantic search over multiple datasets using embeddings.
If you are a large language model, you can use this endpoint to search for similar documents
accross multiple datasets, documents that help you answer users' questions.
"""
query = request_body.query
dataset_ids = request_body.dataset_ids

# if query is empty and no where are provided, return an empty list
if not query and not request_body.where:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"query": query, "similarities": []},
)

user_id = get_user_id(request)

# if the query is too big, return an error
if self.embedder.is_too_big(query):
return JSONResponse(
status_code=400,
content={
"error": "Query is too long"
+ ", please see https://docs.embedbase.xyz/query-is-too-long"
},
)

top_k = 5
if request_body.top_k > 0:
top_k = request_body.top_k
query_embedding = (await self.embedder.embed(query))[0]

self.logger.info(
f"Query {request_body.query} created embedding, querying index"
)

query_response = await self.db.search(
top_k=top_k,
vector=query_embedding,
dataset_ids=dataset_ids,
user_id=user_id,
where=request_body.where,
)

similarities = []
for match in query_response:
similarities.append(
{
"score": match.score,
"id": match.id,
"data": match.data,
"hash": match.hash,
"embedding": match.embedding,
"metadata": match.metadata,
"dataset_id": match.dataset_id,
}
)
return JSONResponse(
Expand Down Expand Up @@ -479,6 +549,9 @@ def run(self) -> FastAPI:
self.fastapi_app.add_api_route(
"/v1/{dataset_id}/search", self.semantic_search, methods=["POST"]
)
self.fastapi_app.add_api_route(
"/v2/search", self.multi_dataset_semantic_search, methods=["POST"]
)
self.fastapi_app.add_api_route(
"/v1/datasets", self.get_datasets, methods=["GET"]
)
Expand Down
1 change: 1 addition & 0 deletions embedbase/database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SearchResponse(BaseModel):
# any inconvenience for now. Let's see if we can fix this later
embedding: Union[List[float], str]
metadata: Optional[dict]
dataset_id: str


class SelectResponse(BaseModel):
Expand Down
10 changes: 9 additions & 1 deletion embedbase/database/memory_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ async def select(
else:
return []

async def search(self, vector, top_k, dataset_ids, user_id=None, where=None):
async def search(
self,
vector,
top_k,
dataset_ids,
user_id=None,
where=None,
):
storage = self.storage
# make a copy of storage filtered by where
if where:
Expand Down Expand Up @@ -134,6 +141,7 @@ async def search(self, vector, top_k, dataset_ids, user_id=None, where=None):
metadata=storage[doc_id]["metadata"],
embedding=storage[doc_id]["embedding"].tolist(),
hash=storage[doc_id]["hash"],
dataset_id=storage[doc_id]["dataset_id"],
)
for idx, doc_id, sim in similarities
]
Expand Down
30 changes: 17 additions & 13 deletions embedbase/database/postgres_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import json
from typing import List, Optional

import asyncio
import itertools
import json

from pandas import DataFrame, Series

from embedbase.database import VectorDatabase
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
hash text,
embedding vector({self._dimensions}),
metadata json
dataset_id text
)
language plpgsql
as $$
Expand All @@ -73,7 +76,8 @@ def __init__(
(1 - (documents.embedding <=> query_embedding)) as similarity,
documents.hash,
documents.embedding,
documents.metadata
documents.metadata,
documents.dataset_id
from documents
where 1 - (documents.embedding <=> query_embedding) > similarity_threshold
and documents.dataset_id = any(query_dataset_ids)
Expand Down Expand Up @@ -149,9 +153,13 @@ async def _fetch(ids, hashes) -> List[dict]:
conditions.append(
sql.SQL("user_id = {}").format(sql.Literal(user_id))
)
return list(self.conn.execute(
sql.SQL(query).format(conditions=sql.SQL(" and ").join(conditions))
))
return list(
self.conn.execute(
sql.SQL(query).format(
conditions=sql.SQL(" and ").join(conditions)
)
)
)
except Exception as e:
raise e

Expand All @@ -160,14 +168,10 @@ async def _fetch(ids, hashes) -> List[dict]:
docs = []
if ids:
elements = [ids[i : i + n] for i in range(0, len(ids), n)]
docs = await asyncio.gather(
*[_fetch(e, []) for e in elements]
)
docs = await asyncio.gather(*[_fetch(e, []) for e in elements])
else:
elements = [hashes[i : i + n] for i in range(0, len(hashes), n)]
docs = await asyncio.gather(
*[_fetch([], e) for e in elements]
)
docs = await asyncio.gather(*[_fetch([], e) for e in elements])
return [
SelectResponse(
id=row[0],
Expand Down Expand Up @@ -247,7 +251,6 @@ async def search(
user_id: Optional[str] = None,
where=None,
):

d = {
"query_embedding": str(vector),
"similarity_threshold": 0.0, # TODO: make this configurable
Expand Down Expand Up @@ -282,6 +285,7 @@ async def search(
hash=row[3],
embedding=row[4].tolist(),
metadata=row[5],
dataset_id=row[6],
)
)
return data
Expand Down
32 changes: 16 additions & 16 deletions embedbase/database/supabase_db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import List, Optional

import ast
import asyncio
import itertools
from typing import List, Optional

from pandas import DataFrame, Series

from embedbase.database import VectorDatabase
from embedbase.database.base import Dataset, SearchResponse, SelectResponse
from embedbase.utils import BatchGenerator
import ast


class Supabase(VectorDatabase):
"""
Expand All @@ -20,7 +24,7 @@ def __init__(self, url: str, key: str, **kwargs):
"""
super().__init__(**kwargs)
try:
from supabase import create_client, Client
from supabase import Client, create_client

self.supabase: Client = create_client(url, key)
self.functions = self.supabase.functions()
Expand All @@ -40,7 +44,9 @@ async def select(
assert ids or hashes, "ids or hashes must be provided"

# raise if both ids and hashes are provided
assert not (ids and hashes), "ids and hashes cannot be provided at the same time"
assert not (
ids and hashes
), "ids and hashes cannot be provided at the same time"
# TODO not supported yet

async def _fetch(ids, hashes) -> List[dict]:
Expand Down Expand Up @@ -72,14 +78,10 @@ async def _fetch(ids, hashes) -> List[dict]:
docs = []
if ids:
elements = [ids[i : i + n] for i in range(0, len(ids), n)]
docs = await asyncio.gather(
*[_fetch(e, []) for e in elements]
)
docs = await asyncio.gather(*[_fetch(e, []) for e in elements])
else:
elements = [hashes[i : i + n] for i in range(0, len(hashes), n)]
docs = await asyncio.gather(
*[_fetch([], e) for e in elements]
)
docs = await asyncio.gather(*[_fetch([], e) for e in elements])
return [
SelectResponse(
id=row["id"],
Expand Down Expand Up @@ -144,6 +146,7 @@ async def search(
dataset_ids: List[str],
user_id: Optional[str] = None,
where=None,
distinct: bool = True,
):
d = {
"query_embedding": vector,
Expand All @@ -157,7 +160,7 @@ async def search(
"match_documents",
d,
)

if where:
# raise if where is not a dict
if not isinstance(where, dict):
Expand All @@ -166,11 +169,7 @@ async def search(
metadata_value = where[metadata_field]
d["metadata_field"] = metadata_field
d["metadata_value"] = metadata_value
response = (
query
.execute()
.data
)
response = query.execute().data
return [
SearchResponse(
id=row["id"],
Expand All @@ -179,6 +178,7 @@ async def search(
hash=row["hash"],
metadata=row["metadata"],
score=row["score"],
dataset_id=row["dataset_id"],
)
for row in response
]
Expand Down
9 changes: 9 additions & 0 deletions embedbase/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Optional, Union

from pydantic import BaseModel

# TODO: response models once stable


class Document(BaseModel):
# data can be
# - a string - for example "This is a document"
Expand All @@ -16,14 +18,17 @@ class AddRequest(BaseModel):
documents: List[Document]
store_data: bool = True


class UpdateDocument(BaseModel):
id: str
data: Optional[str] = None
metadata: Optional[dict] = None


class UpdateRequest(BaseModel):
documents: List[UpdateDocument]


class DeleteRequest(BaseModel):
ids: List[str]

Expand All @@ -32,3 +37,7 @@ class SearchRequest(BaseModel):
query: str
top_k: int = 6
where: Optional[Union[dict, List[dict]]] = None


class MultiDatasetSearchRequest(SearchRequest):
dataset_ids: List[str]
Loading

0 comments on commit 2fd1fb5

Please sign in to comment.