-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
chat.py
263 lines (234 loc) · 10.2 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import time
import uuid
from typing import TYPE_CHECKING, Annotated, Optional
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from langflow.api.utils import (
build_and_cache_graph,
format_elapsed_time,
format_exception_message,
get_next_runnable_vertices,
get_top_level_vertices,
)
from langflow.api.v1.schemas import (
InputValueRequest,
ResultDataResponse,
StreamData,
VertexBuildResponse,
VerticesOrderResponse,
)
from langflow.services.auth.utils import get_current_active_user
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_session_service
from langflow.services.monitor.utils import log_vertex_build
if TYPE_CHECKING:
from langflow.graph.vertex.types import ChatVertex
from langflow.services.session.service import SessionService
router = APIRouter(tags=["Chat"])
async def try_running_celery_task(vertex, user_id):
# Try running the task in celery
# and set the task_id to the local vertex
# if it fails, run the task locally
try:
from langflow.worker import build_vertex
task = build_vertex.delay(vertex)
vertex.task_id = task.id
except Exception as exc:
logger.debug(f"Error running task in celery: {exc}")
vertex.task_id = None
await vertex.build(user_id=user_id)
return vertex
@router.get("/build/{flow_id}/vertices", response_model=VerticesOrderResponse)
async def get_vertices(
flow_id: str,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
session=Depends(get_session),
):
"""Check the flow_id is in the flow_data_store."""
try:
# First, we need to check if the flow_id is in the cache
graph = None
if cache := await chat_service.get_cache(flow_id):
graph = cache.get("result")
graph = await build_and_cache_graph(flow_id, session, chat_service, graph)
if stop_component_id or start_component_id:
try:
vertices = graph.sort_vertices(stop_component_id, start_component_id)
except Exception as exc:
logger.error(exc)
vertices = graph.sort_vertices()
else:
vertices = graph.sort_vertices()
# Now vertices is a list of lists
# We need to get the id of each vertex
# and return the same structure but only with the ids
run_id = uuid.uuid4()
graph.set_run_id(run_id)
return VerticesOrderResponse(ids=vertices, run_id=run_id)
except Exception as exc:
logger.error(f"Error checking build status: {exc}")
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post("/build/{flow_id}/vertices/{vertex_id}")
async def build_vertex(
flow_id: str,
vertex_id: str,
background_tasks: BackgroundTasks,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
):
"""Build a vertex instead of the entire graph."""
start_time = time.perf_counter()
next_runnable_vertices = []
top_level_vertices = []
try:
start_time = time.perf_counter()
cache = await chat_service.get_cache(flow_id)
if not cache:
# If there's no cache
logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}")
graph = await build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
else:
graph = cache.get("result")
result_data_response = ResultDataResponse(results={})
duration = ""
vertex = graph.get_vertex(vertex_id)
try:
if not vertex.frozen or not vertex._built:
inputs_dict = inputs.model_dump() if inputs else {}
await vertex.build(user_id=current_user.id, inputs=inputs_dict)
if vertex.result is not None:
params = vertex._built_object_repr()
valid = True
result_dict = vertex.result
artifacts = vertex.artifacts
else:
raise ValueError(f"No result found for vertex {vertex_id}")
next_runnable_vertices = await get_next_runnable_vertices(graph, vertex, vertex_id, chat_service, flow_id)
top_level_vertices = get_top_level_vertices(graph, next_runnable_vertices)
result_data_response = ResultDataResponse(**result_dict.model_dump())
except Exception as exc:
logger.exception(f"Error building vertex: {exc}")
params = format_exception_message(exc)
valid = False
result_data_response = ResultDataResponse(results={})
artifacts = {}
# If there's an error building the vertex
# we need to clear the cache
await chat_service.clear_cache(flow_id)
# Log the vertex build
if not vertex.will_stream:
background_tasks.add_task(
log_vertex_build,
flow_id=flow_id,
vertex_id=vertex_id,
valid=valid,
params=params,
data=result_data_response,
artifacts=artifacts,
)
timedelta = time.perf_counter() - start_time
duration = format_elapsed_time(timedelta)
result_data_response.duration = duration
result_data_response.timedelta = timedelta
vertex.add_build_time(timedelta)
inactivated_vertices = None
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
await chat_service.set_cache(flow_id, graph)
# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex
# if it is in next_vertices_ids, we need to remove other
# vertices from next_vertices_ids
if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices:
next_runnable_vertices = [graph.stop_vertex]
build_response = VertexBuildResponse(
inactivated_vertices=inactivated_vertices,
next_vertices_ids=next_runnable_vertices,
top_level_vertices=top_level_vertices,
valid=valid,
params=params,
id=vertex.id,
data=result_data_response,
)
return build_response
except Exception as exc:
logger.error(f"Error building vertex: {exc}")
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
# Now onto an endpoint that is an SSE endpoint
# it will receive a component_id and a flow_id
#
@router.get("/build/{flow_id}/{vertex_id}/stream", response_class=StreamingResponse)
async def build_vertex_stream(
flow_id: str,
vertex_id: str,
session_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
session_service: "SessionService" = Depends(get_session_service),
):
"""Build a vertex instead of the entire graph."""
try:
async def stream_vertex():
try:
if not session_id:
cache = await chat_service.get_cache(flow_id)
if not cache:
# If there's no cache
raise ValueError(f"No cache found for {flow_id}.")
else:
graph = cache.get("result")
else:
session_data = await session_service.load_session(session_id, flow_id=flow_id)
graph, artifacts = session_data if session_data else (None, None)
if not graph:
raise ValueError(f"No graph found for {flow_id}.")
vertex: "ChatVertex" = graph.get_vertex(vertex_id)
if not hasattr(vertex, "stream"):
raise ValueError(f"Vertex {vertex_id} does not support streaming")
if isinstance(vertex._built_result, str) and vertex._built_result:
stream_data = StreamData(
event="message",
data={"message": f"Streaming vertex {vertex_id}"},
)
yield str(stream_data)
stream_data = StreamData(
event="message",
data={"chunk": vertex._built_result},
)
yield str(stream_data)
elif not vertex.frozen or not vertex._built:
logger.debug(f"Streaming vertex {vertex_id}")
stream_data = StreamData(
event="message",
data={"message": f"Streaming vertex {vertex_id}"},
)
yield str(stream_data)
async for chunk in vertex.stream():
stream_data = StreamData(
event="message",
data={"chunk": chunk},
)
yield str(stream_data)
elif vertex.result is not None:
stream_data = StreamData(
event="message",
data={"chunk": vertex._built_result},
)
yield str(stream_data)
else:
raise ValueError(f"No result found for vertex {vertex_id}")
except Exception as exc:
logger.exception(f"Error building vertex: {exc}")
yield str(StreamData(event="error", data={"error": str(exc)}))
finally:
logger.debug("Closing stream")
yield str(StreamData(event="close", data={"message": "Stream closed"}))
return StreamingResponse(stream_vertex(), media_type="text/event-stream")
except Exception as exc:
raise HTTPException(status_code=500, detail="Error building vertex") from exc