-
Notifications
You must be signed in to change notification settings - Fork 0
/
transformer_api.py
117 lines (83 loc) · 2.99 KB
/
transformer_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import asyncio
from fastapi import FastAPI, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from typing import List
from fastapi.middleware.cors import CORSMiddleware
import transformer
import llama_based_retrieval
import uuid
from repository import db_connector
from fastapi import Form
app = FastAPI()
transformer_model = transformer.Transformer()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PDFResponse(BaseModel):
texts: List[List[str]]
class ErrorResponse(BaseModel):
error: str
class Question(BaseModel):
question: str
uniqueId: str
class Organization(BaseModel):
organizationName: str
class Country(BaseModel):
countryName: str
CUSTOM_QUERY = "Send me a summary of the file. In your summary, make sure not to mention the file location nor the data name, also to have 10 bullet points. Each bullet point should be on a new row. Try to incorporate few key points from all the text. Do it step by step:"
@app.post("/upload_pdf")
def upload_pdf(
files: list[UploadFile] = [], file_extension_list: List[str] = Form(...)
):
unique_id = str(uuid.uuid4())
if not files:
return JSONResponse(content={"error": "No files provided"}, status_code=400)
try:
parsed_pdf_list = transformer_model.parse_files(
files, unique_id, file_extension_list
)
llama_based_retrieval.create_index(parsed_pdf_list[2], unique_id)
auto_summarization = llama_based_retrieval.auto_summarization(unique_id)
parsed_pdf_list[0]["autoSummary"] = str(auto_summarization)
return parsed_pdf_list[0]
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/question")
def ask_question_to_ebuddy(question: Question):
async def astreamer(generator):
try:
for i in generator:
yield (i)
await asyncio.sleep(0.1)
except asyncio.CancelledError as e:
print("cancelled")
try:
# answer = llama_based_retrieval.ask_question(
# question.question, question.uniqueId
# )
# print(answer)
# return JSONResponse(content=answer, status_code=200)
return StreamingResponse(
astreamer(
llama_based_retrieval.ask_question(
question.question, question.uniqueId
).response_gen
),
media_type="text/event-stream",
)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/organization")
def get_organization_details(organizationName: Organization):
print(organizationName.organizationName)
organizationID = db_connector.addOrganization(organizationName.organizationName)
return organizationID
@app.post("/country")
def send_user_data():
return ""