diff --git a/.vscode/launch.json b/.vscode/launch.json index 2c241927..5dfb5b91 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,8 @@ "8000", "--reload" ], - "jinja": true + "jinja": true, + "envFile": "${workspaceFolder}/.env" } ] } diff --git a/config.py b/config.py index 7b94a976..00fb6abb 100644 --- a/config.py +++ b/config.py @@ -76,7 +76,13 @@ def get_env_variable( logger = logging.getLogger() -debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true" +debug_mode = os.getenv("DEBUG_RAG_API", "False").lower() in ( + "true", + "1", + "yes", + "y", + "t", +) console_json = get_env_variable("CONSOLE_JSON", "False").lower() == "true" if debug_mode: diff --git a/main.py b/main.py index b2cbfa7f..747a3b10 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,9 @@ from langchain.schema import Document from contextlib import asynccontextmanager from dotenv import find_dotenv, load_dotenv +from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware +from fastapi.exceptions import RequestValidationError from langchain_core.runnables.config import run_in_executor from langchain.text_splitter import RecursiveCharacterTextSplitter from fastapi import ( @@ -84,7 +86,7 @@ async def lifespan(app: FastAPI): yield -app = FastAPI(lifespan=lifespan) +app = FastAPI(lifespan=lifespan, debug=debug_mode) app.add_middleware( CORSMiddleware, @@ -213,10 +215,17 @@ async def delete_documents(document_ids: List[str] = Body(...)): @app.post("/query") -async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request): - user_authorized = ( - "public" if not hasattr(request.state, "user") else request.state.user.get("id") - ) +async def query_embeddings_by_file_id( + body: QueryRequestBody, + request: Request, +): + if not hasattr(request.state, "user"): + user_authorized = body.entity_id if body.entity_id else "public" + else: + user_authorized = ( + body.entity_id if body.entity_id else request.state.user.get("id") + ) + authorized_documents = [] try: @@ -245,9 +254,24 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request): if doc_user_id is None or doc_user_id == user_authorized: authorized_documents = documents else: - logger.warn( - f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}" - ) + # If using entity_id and access denied, try again with user's actual ID + if body.entity_id and hasattr(request.state, "user"): + user_authorized = request.state.user.get("id") + if doc_user_id == user_authorized: + authorized_documents = documents + else: + if body.entity_id == doc_user_id: + logger.warning( + f"Entity ID {body.entity_id} matches document user_id but user {user_authorized} is not authorized" + ) + else: + logger.warning( + f"Access denied for both entity ID {body.entity_id} and user {user_authorized} to document with user_id {doc_user_id}" + ) + else: + logger.warning( + f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}" + ) return authorized_documents @@ -361,8 +385,9 @@ def get_loader(filename: str, file_content_type: str, filepath: str): @app.post("/local/embed") -async def embed_local_file(document: StoreDocument, request: Request): - +async def embed_local_file( + document: StoreDocument, request: Request, entity_id: str = None +): # Check if the file exists if not os.path.exists(document.filepath): raise HTTPException( @@ -371,9 +396,9 @@ async def embed_local_file(document: StoreDocument, request: Request): ) if not hasattr(request.state, "user"): - user_id = "public" + user_id = entity_id if entity_id else "public" else: - user_id = request.state.user.get("id") + user_id = entity_id if entity_id else request.state.user.get("id") try: loader, known_type = get_loader( @@ -410,15 +435,18 @@ async def embed_local_file(document: StoreDocument, request: Request): @app.post("/embed") async def embed_file( - request: Request, file_id: str = Form(...), file: UploadFile = File(...) + request: Request, + file_id: str = Form(...), + file: UploadFile = File(...), + entity_id: str = Form(None), ): response_status = True response_message = "File processed successfully." known_type = None if not hasattr(request.state, "user"): - user_id = "public" + user_id = entity_id if entity_id else "public" else: - user_id = request.state.user.get("id") + user_id = entity_id if entity_id else request.state.user.get("id") temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id) os.makedirs(temp_base_path, exist_ok=True) @@ -538,14 +566,17 @@ async def load_document_context(id: str): @app.post("/embed-upload") async def embed_file_upload( - request: Request, file_id: str = Form(...), uploaded_file: UploadFile = File(...) + request: Request, + file_id: str = Form(...), + uploaded_file: UploadFile = File(...), + entity_id: str = Form(None), ): temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename) if not hasattr(request.state, "user"): - user_id = "public" + user_id = entity_id if entity_id else "public" else: - user_id = request.state.user.get("id") + user_id = entity_id if entity_id else request.state.user.get("id") try: with open(temp_file_path, "wb") as temp_file: @@ -624,6 +655,22 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody): raise HTTPException(status_code=500, detail=str(e)) +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + body = await request.body() + logger.debug(f"Validation error occurred") + logger.debug(f"Raw request body: {body.decode()}") + logger.debug(f"Validation errors: {exc.errors()}") + return JSONResponse( + status_code=422, + content={ + "detail": exc.errors(), + "body": body.decode(), + "message": "Request validation failed", + }, + ) + + if debug_mode: app.include_router(router=pgvector_router) diff --git a/models.py b/models.py index b584c992..11b4bde9 100644 --- a/models.py +++ b/models.py @@ -26,9 +26,10 @@ class StoreDocument(BaseModel): class QueryRequestBody(BaseModel): - file_id: str query: str + file_id: str k: int = 4 + entity_id: Optional[str] = None class CleanupMethod(str, Enum):