Skip to content

Commit

Permalink
Refactor cache service and fix async issues (#1512)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogabrielluiz committed Mar 10, 2024
1 parent 23fe373 commit 67bccdc
Show file tree
Hide file tree
Showing 24 changed files with 465 additions and 170 deletions.
4 changes: 2 additions & 2 deletions src/backend/langflow/api/utils.py
Expand Up @@ -197,7 +197,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"


def build_and_cache_graph(
async def build_and_cache_graph(
flow_id: str,
session: Session,
chat_service: "ChatService",
Expand All @@ -212,7 +212,7 @@ def build_and_cache_graph(
graph = other_graph
else:
graph = graph.update(other_graph)
chat_service.set_cache(flow_id, graph)
await chat_service.set_cache(flow_id, graph)
return graph


Expand Down
19 changes: 11 additions & 8 deletions src/backend/langflow/api/v1/chat.py
Expand Up @@ -58,9 +58,9 @@ async def get_vertices(
try:
# First, we need to check if the flow_id is in the cache
graph = None
if cache := chat_service.get_cache(flow_id):
if cache := await chat_service.get_cache(flow_id):
graph = cache.get("result")
graph = build_and_cache_graph(flow_id, session, chat_service, graph)
graph = await build_and_cache_graph(flow_id, session, chat_service, graph)
if stop_component_id or start_component_id:
try:
vertices = graph.sort_vertices(stop_component_id, start_component_id)
Expand Down Expand Up @@ -98,11 +98,11 @@ async def build_vertex(
next_vertices_ids = []
try:
start_time = time.perf_counter()
cache = chat_service.get_cache(flow_id)
cache = await chat_service.get_cache(flow_id)
if not cache:
# If there's no cache
logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}")
graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
graph = await build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
else:
graph = cache.get("result")
result_data_response = ResultDataResponse(results={})
Expand All @@ -121,8 +121,11 @@ async def build_vertex(
artifacts = vertex.artifacts
else:
raise ValueError(f"No result found for vertex {vertex_id}")
next_vertices_ids = vertex.successors_ids
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
async with chat_service._cache_locks[flow_id] as lock:
graph.remove_from_predecessors(vertex_id)
next_vertices_ids = vertex.successors_ids
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)

result_data_response = ResultDataResponse(**result_dict.model_dump())

Expand All @@ -134,7 +137,7 @@ async def build_vertex(
artifacts = {}
# If there's an error building the vertex
# we need to clear the cache
chat_service.clear_cache(flow_id)
await chat_service.clear_cache(flow_id)

# Log the vertex build
if not vertex.will_stream:
Expand All @@ -157,7 +160,7 @@ async def build_vertex(
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
chat_service.set_cache(flow_id, graph)
await chat_service.set_cache(flow_id, graph)

# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex
Expand Down
3 changes: 2 additions & 1 deletion src/backend/langflow/api/v1/login.py
@@ -1,5 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session

from langflow.api.v1.schemas import Token
from langflow.services.auth.utils import (
authenticate_user,
Expand All @@ -8,7 +10,6 @@
create_user_tokens,
)
from langflow.services.deps import get_session, get_settings_service
from sqlmodel import Session

router = APIRouter(tags=["Login"])

Expand Down
3 changes: 2 additions & 1 deletion src/backend/langflow/base/data/utils.py
Expand Up @@ -104,7 +104,8 @@ def parse_text_file_to_record(file_path: str, silent_errors: bool) -> Optional[R
elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
text = yaml.safe_load(text)
elif file_path.endswith(".xml"):
text = ET.fromstring(text)
xml_element = ET.fromstring(text)
text = ET.tostring(xml_element, encoding="unicode")
except Exception as e:
if not silent_errors:
raise ValueError(f"Error loading file {file_path}: {e}") from e
Expand Down
1 change: 1 addition & 0 deletions src/backend/langflow/base/io/chat.py
Expand Up @@ -26,6 +26,7 @@ def build_config(self):
"session_id": {
"display_name": "Session ID",
"info": "If provided, the message will be stored in the memory.",
"advanced": True,
},
"return_record": {
"display_name": "Return Record",
Expand Down
70 changes: 70 additions & 0 deletions src/backend/langflow/components/agents/ReActAgent.py
@@ -0,0 +1,70 @@
# from typing import Dict, List

# import dspy

# from langflow import CustomComponent
# from langflow.field_typing import Text


# class ReActAgentComponent(CustomComponent):
# display_name = "ReAct Agent"
# description = "A component to create a ReAct Agent."
# icon = "user-secret"

# def build_config(self):
# return {
# "input_value": {
# "display_name": "Input",
# "input_types": ["Text"],
# "info": "The input value for the ReAct Agent.",
# },
# "instructions": {
# "display_name": "Instructions",
# "info": "The Prompt.",
# },
# "inputs": {
# "display_name": "Inputs",
# "info": "The Name and Description of the Input Fields.",
# },
# "outputs": {
# "display_name": "Outputs",
# "info": "The Name and Description of the Output Fields.",
# },
# }

# def build(
# self,
# input_value: List[dict],
# instructions: Text,
# inputs: List[dict],
# outputs: List[Dict],
# ) -> Text:
# # inputs is a list of dictionaries where the key is the name of the input
# # and the value is the description of the input
# input_fields = (
# {}
# ) # dict[str, FieldInfo] InputField and OutputField are subclasses of pydantic.Field
# for input_dict in inputs:
# for name, description in input_dict.items():
# prefix = name if ":" in name else f"{name}:"
# input_fields[name] = dspy.InputField(
# prefix=prefix, description=description
# )

# output_fields = {} # dict[str, FieldInfo]
# for output_dict in outputs:
# for name, description in output_dict.items():
# prefix = name if ":" in name else f"{name}:"
# output_fields[name] = dspy.OutputField(
# prefix=prefix, description=description
# )

# signature = dspy.make_signature(inputs, instructions=instructions)
# agent = dspy.ReAct(
# signature=signature,
# )
# inputs_dict = {}
# for input_dict in input_value:
# inputs_dict.update(input_dict)

# result = agent(inputs_dict)
2 changes: 1 addition & 1 deletion src/backend/langflow/components/data/APIRequest.py
Expand Up @@ -106,7 +106,7 @@ async def build(
bodies = [body.data]
if len(urls) != len(bodies):
# add bodies with None
bodies += [None] * (len(urls) - len(bodies))
bodies += [None] * (len(urls) - len(bodies)) # type: ignore
async with httpx.AsyncClient() as client:
results = await asyncio.gather(
*[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)]
Expand Down
2 changes: 0 additions & 2 deletions src/backend/langflow/components/experimental/__init__.py
Expand Up @@ -3,12 +3,10 @@
from .GetNotified import GetNotifiedComponent
from .ListFlows import ListFlowsComponent
from .MergeRecords import MergeRecordsComponent
from .MessageHistory import MessageHistoryComponent
from .Notify import NotifyComponent
from .RunFlow import RunFlowComponent
from .RunnableExecutor import RunnableExecComponent
from .SQLExecutor import SQLExecutorComponent
from .TextToRecord import TextToRecordComponent

__all__ = [
"ClearMessageHistoryComponent",
Expand Down
4 changes: 4 additions & 0 deletions src/backend/langflow/components/helpers/__init__.py
@@ -1,13 +1,17 @@
from .CustomComponent import Component
from .DocumentToRecord import DocumentToRecordComponent
from .IDGenerator import UUIDGeneratorComponent
from .MessageHistory import MessageHistoryComponent
from .PythonFunction import PythonFunctionComponent
from .RecordsAsText import RecordsAsTextComponent
from .TextToRecord import TextToRecordComponent

__all__ = [
"Component",
"DocumentToRecordComponent",
"UUIDGeneratorComponent",
"PythonFunctionComponent",
"RecordsAsTextComponent",
"TextToRecordComponent",
"MessageHistoryComponent",
]
4 changes: 3 additions & 1 deletion src/backend/langflow/components/models/CohereModel.py
@@ -1,4 +1,5 @@
from langchain_community.chat_models.cohere import ChatCohere
from pydantic.v1 import SecretStr

from langflow.components.models.base.model import LCModelComponent
from langflow.field_typing import Text
Expand Down Expand Up @@ -44,8 +45,9 @@ def build(
temperature: float = 0.75,
stream: bool = False,
) -> Text:
api_key = SecretStr(cohere_api_key)
output = ChatCohere( # type: ignore
cohere_api_key=cohere_api_key,
cohere_api_key=api_key,
temperature=temperature,
)
return self.get_result(output=output, stream=stream, input_value=input_value)
6 changes: 4 additions & 2 deletions src/backend/langflow/graph/edge/base.py
Expand Up @@ -122,7 +122,9 @@ async def honor(self, source: "Vertex", target: "Vertex") -> None:
return

if not source._built:
await source.build()
# The system should be read-only, so we should not be building vertices
# that are not already built.
raise ValueError(f"Source vertex {source.id} is not built.")

if self.matched_type == "Text":
self.result = source._built_result
Expand All @@ -132,7 +134,7 @@ async def honor(self, source: "Vertex", target: "Vertex") -> None:
target.params[self.target_param] = self.result
self.is_fulfilled = True

async def get_result(self, source: "Vertex", target: "Vertex"):
async def get_result_from_source(self, source: "Vertex", target: "Vertex"):
# Fulfill the contract if it has not been fulfilled.
if not self.is_fulfilled:
await self.honor(source, target)
Expand Down
33 changes: 32 additions & 1 deletion src/backend/langflow/graph/graph/base.py
Expand Up @@ -240,6 +240,7 @@ def metadata(self):

def build_graph_maps(self):
self.predecessor_map, self.successor_map = self.build_adjacency_maps()

self.in_degree_map = self.build_in_degree()
self.parent_child_map = self.build_parent_child_map()

Expand Down Expand Up @@ -295,6 +296,15 @@ def build_adjacency_maps(self):
successor_map[edge.source_id].append(edge.target_id)
return predecessor_map, successor_map

def build_run_map(self):
run_map = defaultdict(list)
# The run map gets the predecessor_map and maps the info like this:
# {vertex_id: every id that contains the vertex_id in the predecessor_map}
for vertex_id, predecessors in self.predecessor_map.items():
for predecessor in predecessors:
run_map[predecessor].append(vertex_id)
return run_map

@classmethod
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph":
"""
Expand Down Expand Up @@ -939,16 +949,37 @@ def sort_vertices(
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
self.run_map, self.run_predecessors = (
self.build_run_map(),
self.predecessor_map.copy(),
)

# Return just the first layer
return first_layer

def vertex_has_no_more_predecessors(self, vertex_id: str) -> bool:
"""Returns whether a vertex has no more predecessors."""
return not self.run_predecessors.get(vertex_id)

def should_run_vertex(self, vertex_id: str) -> bool:
"""Returns whether a component should be run."""
should_run = vertex_id in self.vertices_to_run
# the self.run_map is a map of vertex_id to a list of predecessors
# each time a vertex is run, we remove it from the list of predecessors
# if a vertex has no more predecessors, it should be run
should_run = vertex_id in self.vertices_to_run and self.vertex_has_no_more_predecessors(vertex_id)

if should_run:
self.vertices_to_run.remove(vertex_id)
# remove the vertex from the run_map
self.remove_from_predecessors(vertex_id)
return should_run

def remove_from_predecessors(self, vertex_id: str):
predecessors = self.run_map.get(vertex_id, [])
for predecessor in predecessors:
if vertex_id in self.run_predecessors[predecessor]:
self.run_predecessors[predecessor].remove(vertex_id)

def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""

Expand Down

0 comments on commit 67bccdc

Please sign in to comment.