From 67bccdc753dff25816c26f2bfdeef82b9c2266a2 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 9 Mar 2024 22:55:56 -0300 Subject: [PATCH] Refactor cache service and fix async issues (#1512) --- src/backend/langflow/api/utils.py | 4 +- src/backend/langflow/api/v1/chat.py | 19 +- src/backend/langflow/api/v1/login.py | 3 +- src/backend/langflow/base/data/utils.py | 3 +- src/backend/langflow/base/io/chat.py | 1 + .../langflow/components/agents/ReActAgent.py | 70 +++++++ .../langflow/components/data/APIRequest.py | 2 +- .../components/experimental/__init__.py | 2 - .../langflow/components/helpers/__init__.py | 4 + .../langflow/components/models/CohereModel.py | 4 +- src/backend/langflow/graph/edge/base.py | 6 +- src/backend/langflow/graph/graph/base.py | 33 +++- src/backend/langflow/graph/vertex/base.py | 186 ++++++++++-------- .../langflow/interface/initialize/loading.py | 22 ++- src/backend/langflow/schema/schema.py | 2 +- .../langflow/services/cache/__init__.py | 14 +- src/backend/langflow/services/cache/base.py | 80 +++++++- .../langflow/services/cache/factory.py | 13 +- .../langflow/services/cache/service.py | 115 +++++++++-- src/backend/langflow/services/chat/service.py | 20 +- .../langflow/services/settings/base.py | 2 +- .../BotMessageSquare/BotMessageSquare.jsx | 2 +- tests/test_cache.py | 6 +- tests/test_data_components.py | 22 +-- 24 files changed, 465 insertions(+), 170 deletions(-) create mode 100644 src/backend/langflow/components/agents/ReActAgent.py diff --git a/src/backend/langflow/api/utils.py b/src/backend/langflow/api/utils.py index f51e0fd5f0..1a53c74e39 100644 --- a/src/backend/langflow/api/utils.py +++ b/src/backend/langflow/api/utils.py @@ -197,7 +197,7 @@ def format_elapsed_time(elapsed_time: float) -> str: return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}" -def build_and_cache_graph( +async def build_and_cache_graph( flow_id: str, session: Session, chat_service: "ChatService", @@ -212,7 +212,7 @@ def build_and_cache_graph( graph = other_graph else: graph = graph.update(other_graph) - chat_service.set_cache(flow_id, graph) + await chat_service.set_cache(flow_id, graph) return graph diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 77a99b0355..e19605f280 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -58,9 +58,9 @@ async def get_vertices( try: # First, we need to check if the flow_id is in the cache graph = None - if cache := chat_service.get_cache(flow_id): + if cache := await chat_service.get_cache(flow_id): graph = cache.get("result") - graph = build_and_cache_graph(flow_id, session, chat_service, graph) + graph = await build_and_cache_graph(flow_id, session, chat_service, graph) if stop_component_id or start_component_id: try: vertices = graph.sort_vertices(stop_component_id, start_component_id) @@ -98,11 +98,11 @@ async def build_vertex( next_vertices_ids = [] try: start_time = time.perf_counter() - cache = chat_service.get_cache(flow_id) + cache = await chat_service.get_cache(flow_id) if not cache: # If there's no cache logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}") - graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service) + graph = await build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service) else: graph = cache.get("result") result_data_response = ResultDataResponse(results={}) @@ -121,8 +121,11 @@ async def build_vertex( artifacts = vertex.artifacts else: raise ValueError(f"No result found for vertex {vertex_id}") - next_vertices_ids = vertex.successors_ids - next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)] + async with chat_service._cache_locks[flow_id] as lock: + graph.remove_from_predecessors(vertex_id) + next_vertices_ids = vertex.successors_ids + next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)] + await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock) result_data_response = ResultDataResponse(**result_dict.model_dump()) @@ -134,7 +137,7 @@ async def build_vertex( artifacts = {} # If there's an error building the vertex # we need to clear the cache - chat_service.clear_cache(flow_id) + await chat_service.clear_cache(flow_id) # Log the vertex build if not vertex.will_stream: @@ -157,7 +160,7 @@ async def build_vertex( inactivated_vertices = list(graph.inactivated_vertices) graph.reset_inactivated_vertices() graph.reset_activated_vertices() - chat_service.set_cache(flow_id, graph) + await chat_service.set_cache(flow_id, graph) # graph.stop_vertex tells us if the user asked # to stop the build of the graph at a certain vertex diff --git a/src/backend/langflow/api/v1/login.py b/src/backend/langflow/api/v1/login.py index ecc0fea6a1..4f44e79bcf 100644 --- a/src/backend/langflow/api/v1/login.py +++ b/src/backend/langflow/api/v1/login.py @@ -1,5 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.security import OAuth2PasswordRequestForm +from sqlmodel import Session + from langflow.api.v1.schemas import Token from langflow.services.auth.utils import ( authenticate_user, @@ -8,7 +10,6 @@ create_user_tokens, ) from langflow.services.deps import get_session, get_settings_service -from sqlmodel import Session router = APIRouter(tags=["Login"]) diff --git a/src/backend/langflow/base/data/utils.py b/src/backend/langflow/base/data/utils.py index 4210bc38af..450f6e04b3 100644 --- a/src/backend/langflow/base/data/utils.py +++ b/src/backend/langflow/base/data/utils.py @@ -104,7 +104,8 @@ def parse_text_file_to_record(file_path: str, silent_errors: bool) -> Optional[R elif file_path.endswith(".yaml") or file_path.endswith(".yml"): text = yaml.safe_load(text) elif file_path.endswith(".xml"): - text = ET.fromstring(text) + xml_element = ET.fromstring(text) + text = ET.tostring(xml_element, encoding="unicode") except Exception as e: if not silent_errors: raise ValueError(f"Error loading file {file_path}: {e}") from e diff --git a/src/backend/langflow/base/io/chat.py b/src/backend/langflow/base/io/chat.py index 2dafd7bc22..3411f603ef 100644 --- a/src/backend/langflow/base/io/chat.py +++ b/src/backend/langflow/base/io/chat.py @@ -26,6 +26,7 @@ def build_config(self): "session_id": { "display_name": "Session ID", "info": "If provided, the message will be stored in the memory.", + "advanced": True, }, "return_record": { "display_name": "Return Record", diff --git a/src/backend/langflow/components/agents/ReActAgent.py b/src/backend/langflow/components/agents/ReActAgent.py new file mode 100644 index 0000000000..a9a813ada1 --- /dev/null +++ b/src/backend/langflow/components/agents/ReActAgent.py @@ -0,0 +1,70 @@ +# from typing import Dict, List + +# import dspy + +# from langflow import CustomComponent +# from langflow.field_typing import Text + + +# class ReActAgentComponent(CustomComponent): +# display_name = "ReAct Agent" +# description = "A component to create a ReAct Agent." +# icon = "user-secret" + +# def build_config(self): +# return { +# "input_value": { +# "display_name": "Input", +# "input_types": ["Text"], +# "info": "The input value for the ReAct Agent.", +# }, +# "instructions": { +# "display_name": "Instructions", +# "info": "The Prompt.", +# }, +# "inputs": { +# "display_name": "Inputs", +# "info": "The Name and Description of the Input Fields.", +# }, +# "outputs": { +# "display_name": "Outputs", +# "info": "The Name and Description of the Output Fields.", +# }, +# } + +# def build( +# self, +# input_value: List[dict], +# instructions: Text, +# inputs: List[dict], +# outputs: List[Dict], +# ) -> Text: +# # inputs is a list of dictionaries where the key is the name of the input +# # and the value is the description of the input +# input_fields = ( +# {} +# ) # dict[str, FieldInfo] InputField and OutputField are subclasses of pydantic.Field +# for input_dict in inputs: +# for name, description in input_dict.items(): +# prefix = name if ":" in name else f"{name}:" +# input_fields[name] = dspy.InputField( +# prefix=prefix, description=description +# ) + +# output_fields = {} # dict[str, FieldInfo] +# for output_dict in outputs: +# for name, description in output_dict.items(): +# prefix = name if ":" in name else f"{name}:" +# output_fields[name] = dspy.OutputField( +# prefix=prefix, description=description +# ) + +# signature = dspy.make_signature(inputs, instructions=instructions) +# agent = dspy.ReAct( +# signature=signature, +# ) +# inputs_dict = {} +# for input_dict in input_value: +# inputs_dict.update(input_dict) + +# result = agent(inputs_dict) diff --git a/src/backend/langflow/components/data/APIRequest.py b/src/backend/langflow/components/data/APIRequest.py index 6199d541bd..75dbe69a50 100644 --- a/src/backend/langflow/components/data/APIRequest.py +++ b/src/backend/langflow/components/data/APIRequest.py @@ -106,7 +106,7 @@ async def build( bodies = [body.data] if len(urls) != len(bodies): # add bodies with None - bodies += [None] * (len(urls) - len(bodies)) + bodies += [None] * (len(urls) - len(bodies)) # type: ignore async with httpx.AsyncClient() as client: results = await asyncio.gather( *[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)] diff --git a/src/backend/langflow/components/experimental/__init__.py b/src/backend/langflow/components/experimental/__init__.py index 5771f4353e..412e9075af 100644 --- a/src/backend/langflow/components/experimental/__init__.py +++ b/src/backend/langflow/components/experimental/__init__.py @@ -3,12 +3,10 @@ from .GetNotified import GetNotifiedComponent from .ListFlows import ListFlowsComponent from .MergeRecords import MergeRecordsComponent -from .MessageHistory import MessageHistoryComponent from .Notify import NotifyComponent from .RunFlow import RunFlowComponent from .RunnableExecutor import RunnableExecComponent from .SQLExecutor import SQLExecutorComponent -from .TextToRecord import TextToRecordComponent __all__ = [ "ClearMessageHistoryComponent", diff --git a/src/backend/langflow/components/helpers/__init__.py b/src/backend/langflow/components/helpers/__init__.py index f659a04380..7cb7f4a4fb 100644 --- a/src/backend/langflow/components/helpers/__init__.py +++ b/src/backend/langflow/components/helpers/__init__.py @@ -1,8 +1,10 @@ from .CustomComponent import Component from .DocumentToRecord import DocumentToRecordComponent from .IDGenerator import UUIDGeneratorComponent +from .MessageHistory import MessageHistoryComponent from .PythonFunction import PythonFunctionComponent from .RecordsAsText import RecordsAsTextComponent +from .TextToRecord import TextToRecordComponent __all__ = [ "Component", @@ -10,4 +12,6 @@ "UUIDGeneratorComponent", "PythonFunctionComponent", "RecordsAsTextComponent", + "TextToRecordComponent", + "MessageHistoryComponent", ] diff --git a/src/backend/langflow/components/models/CohereModel.py b/src/backend/langflow/components/models/CohereModel.py index 1cab1bec7a..55d22ca167 100644 --- a/src/backend/langflow/components/models/CohereModel.py +++ b/src/backend/langflow/components/models/CohereModel.py @@ -1,4 +1,5 @@ from langchain_community.chat_models.cohere import ChatCohere +from pydantic.v1 import SecretStr from langflow.components.models.base.model import LCModelComponent from langflow.field_typing import Text @@ -44,8 +45,9 @@ def build( temperature: float = 0.75, stream: bool = False, ) -> Text: + api_key = SecretStr(cohere_api_key) output = ChatCohere( # type: ignore - cohere_api_key=cohere_api_key, + cohere_api_key=api_key, temperature=temperature, ) return self.get_result(output=output, stream=stream, input_value=input_value) diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 0b4091ab52..e6583cc078 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -122,7 +122,9 @@ async def honor(self, source: "Vertex", target: "Vertex") -> None: return if not source._built: - await source.build() + # The system should be read-only, so we should not be building vertices + # that are not already built. + raise ValueError(f"Source vertex {source.id} is not built.") if self.matched_type == "Text": self.result = source._built_result @@ -132,7 +134,7 @@ async def honor(self, source: "Vertex", target: "Vertex") -> None: target.params[self.target_param] = self.result self.is_fulfilled = True - async def get_result(self, source: "Vertex", target: "Vertex"): + async def get_result_from_source(self, source: "Vertex", target: "Vertex"): # Fulfill the contract if it has not been fulfilled. if not self.is_fulfilled: await self.honor(source, target) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 98b44fa42c..fa4a0e5ebc 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -240,6 +240,7 @@ def metadata(self): def build_graph_maps(self): self.predecessor_map, self.successor_map = self.build_adjacency_maps() + self.in_degree_map = self.build_in_degree() self.parent_child_map = self.build_parent_child_map() @@ -295,6 +296,15 @@ def build_adjacency_maps(self): successor_map[edge.source_id].append(edge.target_id) return predecessor_map, successor_map + def build_run_map(self): + run_map = defaultdict(list) + # The run map gets the predecessor_map and maps the info like this: + # {vertex_id: every id that contains the vertex_id in the predecessor_map} + for vertex_id, predecessors in self.predecessor_map.items(): + for predecessor in predecessors: + run_map[predecessor].append(vertex_id) + return run_map + @classmethod def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph": """ @@ -939,16 +949,37 @@ def sort_vertices( # save the only the rest self.vertices_layers = vertices_layers[1:] self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)} + self.run_map, self.run_predecessors = ( + self.build_run_map(), + self.predecessor_map.copy(), + ) + # Return just the first layer return first_layer + def vertex_has_no_more_predecessors(self, vertex_id: str) -> bool: + """Returns whether a vertex has no more predecessors.""" + return not self.run_predecessors.get(vertex_id) + def should_run_vertex(self, vertex_id: str) -> bool: """Returns whether a component should be run.""" - should_run = vertex_id in self.vertices_to_run + # the self.run_map is a map of vertex_id to a list of predecessors + # each time a vertex is run, we remove it from the list of predecessors + # if a vertex has no more predecessors, it should be run + should_run = vertex_id in self.vertices_to_run and self.vertex_has_no_more_predecessors(vertex_id) + if should_run: self.vertices_to_run.remove(vertex_id) + # remove the vertex from the run_map + self.remove_from_predecessors(vertex_id) return should_run + def remove_from_predecessors(self, vertex_id: str): + predecessors = self.run_map.get(vertex_id, []) + for predecessor in predecessors: + if vertex_id in self.run_predecessors[predecessor]: + self.run_predecessors[predecessor].remove(vertex_id) + def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: """Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first.""" diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 06d761533a..9738b65f5e 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -1,4 +1,5 @@ import ast +import asyncio import inspect import types from enum import Enum @@ -7,7 +8,6 @@ Any, AsyncIterator, Callable, - Coroutine, Dict, Iterator, List, @@ -56,6 +56,7 @@ def __init__( ) -> None: # is_external means that the Vertex send or receives data from # an external source (e.g the chat) + self._lock = asyncio.Lock() self.will_stream = False self.updated_raw_params = False self.id: str = data["id"] @@ -171,6 +172,7 @@ def __getstate__(self): } def __setstate__(self, state): + self._lock = asyncio.Lock() self._data = state["_data"] self.params = state["params"] self.base_type = state["base_type"] @@ -383,7 +385,7 @@ async def _build(self, user_id=None): Initiate the build process. """ logger.debug(f"Building {self.display_name}") - await self._build_each_node_in_params_dict(user_id) + await self._build_each_vertex_in_params_dict(user_id) await self._get_and_instantiate_class(user_id) self._validate_built_object() @@ -452,105 +454,123 @@ async def _run( result = await generate_result(self._built_object, inputs, self.has_external_output, session_id) self._built_result = result - async def _build_each_node_in_params_dict(self, user_id=None): + async def _build_each_vertex_in_params_dict(self, user_id=None): """ - Iterates over each node in the params dictionary and builds it. + Iterates over each vertex in the params dictionary and builds it. """ for key, value in self._raw_params.items(): - if self._is_node(value): + if self._is_vertex(value): if value == self: del self.params[key] continue - await self._build_node_and_update_params(key, value, user_id) - elif isinstance(value, list) and self._is_list_of_nodes(value): - await self._build_list_of_nodes_and_update_params(key, value, user_id) + await self._build_vertex_and_update_params( + key, + value, + ) + elif isinstance(value, list) and self._is_list_of_vertices(value): + await self._build_list_of_vertices_and_update_params(key, value) elif isinstance(value, dict): - await self._build_dict_and_update_params(key, value, user_id) + await self._build_dict_and_update_params( + key, + value, + ) elif key not in self.params or self.updated_raw_params: self.params[key] = value - async def _build_dict_and_update_params(self, key, nodes_dict: Dict[str, "Vertex"], user_id=None): + async def _build_dict_and_update_params( + self, + key, + vertices_dict: Dict[str, "Vertex"], + ): """ - Iterates over a dictionary of nodes, builds each and updates the params dictionary. + Iterates over a dictionary of vertices, builds each and updates the params dictionary. """ - for sub_key, value in nodes_dict.items(): - if not self._is_node(value): + for sub_key, value in vertices_dict.items(): + if not self._is_vertex(value): self.params[key][sub_key] = value else: - built = await value.get_result(requester=self, user_id=user_id) - self.params[key][sub_key] = built + result = await value.get_result() + self.params[key][sub_key] = result - def _is_node(self, value): + def _is_vertex(self, value): """ Checks if the provided value is an instance of Vertex. """ return isinstance(value, Vertex) - def _is_list_of_nodes(self, value): + def _is_list_of_vertices(self, value): """ Checks if the provided value is a list of Vertex instances. """ - return all(self._is_node(node) for node in value) + return all(self._is_vertex(vertex) for vertex in value) - async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any: - # PLEASE REVIEW THIS IF STATEMENT - # Check if the Vertex was built already - if self._built: - return self._built_object if not self.use_result else self._built_result + async def get_result( + self, + ) -> Any: + """ + Retrieves the result of the vertex. - if self.is_task and self.task_id is not None: - task = self.get_task() + This is a read-only method so it raises an error if the vertex has not been built yet. - result = task.get(timeout=timeout) - if isinstance(result, Coroutine): - result = await result - if result is not None: # If result is ready - self._update_built_object_and_artifacts(result) - return self._built_object - else: - # Handle the case when the result is not ready (retry, throw exception, etc.) - pass + Returns: + The result of the vertex. + """ + async with self._lock: + return await self._get_result() + + async def _get_result(self) -> Any: + """ + Retrieves the result of the built component. + + If the component has not been built yet, a ValueError is raised. - # If there's no task_id, build the vertex locally - await self.build(requester=requester, user_id=user_id) - return self._built_object + Returns: + The built result if use_result is True, else the built object. + """ + if not self._built: + raise ValueError(f"Component {self.display_name} has not been built yet") + return self._built_result if self.use_result else self._built_object - async def _build_node_and_update_params(self, key, node: "Vertex", user_id=None): + async def _build_vertex_and_update_params(self, key, vertex: "Vertex"): """ - Builds a given node and updates the params dictionary accordingly. + Builds a given vertex and updates the params dictionary accordingly. """ - result = await node.get_result(requester=self, user_id=user_id) + result = await vertex.get_result() self._handle_func(key, result) if isinstance(result, list): self._extend_params_list_with_result(key, result) self.params[key] = result - async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None): + async def _build_list_of_vertices_and_update_params( + self, + key, + vertices: List["Vertex"], + ): """ - Iterates over a list of nodes, builds each and updates the params dictionary. + Iterates over a list of vertices, builds each and updates the params dictionary. """ self.params[key] = [] - for node in nodes: - built = await node.get_result(requester=self, user_id=user_id) + for vertex in vertices: + result = await vertex.get_result() # Weird check to see if the params[key] is a list # because sometimes it is a Record and breaks the code if not isinstance(self.params[key], list): self.params[key] = [self.params[key]] - if isinstance(built, list): - self.params[key].extend(built) + if isinstance(result, list): + self.params[key].extend(result) else: try: - if self.params[key] == built: + if self.params[key] == result: continue - self.params[key].append(built) + self.params[key].append(result) except AttributeError as e: logger.exception(e) raise ValueError( - f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {built}" - f"Error building node {self.display_name}: {str(e)}" + f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}" + f"Error building vertex {self.display_name}: {str(e)}" ) from e def _handle_func(self, key, result): @@ -580,12 +600,9 @@ async def _get_and_instantiate_class(self, user_id=None): Gets the class from a dictionary and instantiates it with the params. """ if self.base_type is None: - raise ValueError(f"Base type for node {self.display_name} not found") + raise ValueError(f"Base type for vertex {self.display_name} not found") try: result = await loading.instantiate_class( - node_type=self.vertex_type, - base_type=self.base_type, - params=self.params, user_id=user_id, vertex=self, ) @@ -593,7 +610,7 @@ async def _get_and_instantiate_class(self, user_id=None): except Exception as exc: logger.exception(exc) - raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc + raise ValueError(f"Error building vertex {self.display_name}: {str(exc)}") from exc def _update_built_object_and_artifacts(self, result): """ @@ -647,33 +664,34 @@ async def build( requester: Optional["Vertex"] = None, **kwargs, ) -> Any: - if self.state == VertexStates.INACTIVE: - # If the vertex is inactive, return None - self.build_inactive() - return - - if self.frozen and self._built: - return self.get_requester_result(requester) - elif self._built and requester is not None: - # This means that the vertex has already been built - # and we are just getting the result for the requester - return await self.get_requester_result(requester) - self._reset() - - if self._is_chat_input() and inputs: - inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")} - self.update_raw_params(inputs, overwrite=True) - - # Run steps - for step in self.steps: - if step not in self.steps_ran: - if inspect.iscoroutinefunction(step): - await step(user_id=user_id, **kwargs) - else: - step(user_id=user_id, **kwargs) - self.steps_ran.append(step) + async with self._lock: + if self.state == VertexStates.INACTIVE: + # If the vertex is inactive, return None + self.build_inactive() + return + + if self.frozen and self._built: + return self.get_requester_result(requester) + elif self._built and requester is not None: + # This means that the vertex has already been built + # and we are just getting the result for the requester + return await self.get_requester_result(requester) + self._reset() + + if self._is_chat_input() and inputs: + inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")} + self.update_raw_params(inputs, overwrite=True) + + # Run steps + for step in self.steps: + if step not in self.steps_ran: + if inspect.iscoroutinefunction(step): + await step(user_id=user_id, **kwargs) + else: + step(user_id=user_id, **kwargs) + self.steps_ran.append(step) - self._finalize_build() + self._finalize_build() return await self.get_requester_result(requester) @@ -686,7 +704,11 @@ async def get_requester_result(self, requester: Optional["Vertex"]): # Get the requester edge requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None) # Return the result of the requester edge - return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester) + return ( + None + if requester_edge is None + else await requester_edge.get_result_from_source(source=self, target=requester) + ) def add_edge(self, edge: "ContractEdge") -> None: if edge not in self.edges: diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 554fd99d7d..581f853d5c 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -1,6 +1,6 @@ import inspect import json -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Type import orjson from langchain.agents import agent as agent_module @@ -40,27 +40,29 @@ async def instantiate_class( - node_type: str, - base_type: str, - params: Dict, + vertex: "Vertex", user_id=None, - vertex: Optional["Vertex"] = None, ) -> Any: """Instantiate class from module type and key, and params""" + vertex_type = vertex.vertex_type + base_type = vertex.base_type + params = vertex.params params = convert_params_to_sets(params) params = convert_kwargs(params) - if node_type in CUSTOM_NODES: - if custom_node := CUSTOM_NODES.get(node_type): + if vertex_type in CUSTOM_NODES: + if custom_node := CUSTOM_NODES.get(vertex_type): if hasattr(custom_node, "initialize"): return custom_node.initialize(**params) return custom_node(**params) - logger.debug(f"Instantiating {node_type} of type {base_type}") - class_object = import_by_type(_type=base_type, name=node_type) + logger.debug(f"Instantiating {vertex_type} of type {base_type}") + if not base_type: + raise ValueError("No base type provided for vertex") + class_object = import_by_type(_type=base_type, name=vertex_type) return await instantiate_based_on_type( class_object=class_object, base_type=base_type, - node_type=node_type, + node_type=vertex_type, params=params, user_id=user_id, vertex=vertex, diff --git a/src/backend/langflow/schema/schema.py b/src/backend/langflow/schema/schema.py index 5ee4be748f..4079f90ac1 100644 --- a/src/backend/langflow/schema/schema.py +++ b/src/backend/langflow/schema/schema.py @@ -16,7 +16,7 @@ class Record(BaseModel): _default_value: str = "" @model_validator(mode="before") - def validate_data(values): + def validate_data(cls, values): if not values.get("data"): values["data"] = {} # Any other keyword should be added to the data dictionary diff --git a/src/backend/langflow/services/cache/__init__.py b/src/backend/langflow/services/cache/__init__.py index bf3a7c5eec..fa6c3ae517 100644 --- a/src/backend/langflow/services/cache/__init__.py +++ b/src/backend/langflow/services/cache/__init__.py @@ -1,9 +1,17 @@ -from . import factory, service -from langflow.services.cache.service import InMemoryCache +from langflow.services.cache.service import ( + AsyncInMemoryCache, + BaseCacheService, + RedisCache, + ThreadingInMemoryCache, +) +from . import factory, service __all__ = [ "factory", "service", - "InMemoryCache", + "ThreadingInMemoryCache", + "AsyncInMemoryCache", + "BaseCacheService", + "RedisCache", ] diff --git a/src/backend/langflow/services/cache/base.py b/src/backend/langflow/services/cache/base.py index 3b34e12f6f..e2d36b73ca 100644 --- a/src/backend/langflow/services/cache/base.py +++ b/src/backend/langflow/services/cache/base.py @@ -1,4 +1,7 @@ import abc +import asyncio +import threading +from typing import Optional from langflow.services.base import Service @@ -11,7 +14,7 @@ class BaseCacheService(Service): name = "cache_service" @abc.abstractmethod - def get(self, key): + def get(self, key, lock: Optional[threading.Lock] = None): """ Retrieve an item from the cache. @@ -23,7 +26,7 @@ def get(self, key): """ @abc.abstractmethod - def set(self, key, value): + def set(self, key, value, lock: Optional[threading.Lock] = None): """ Add an item to the cache. @@ -33,7 +36,7 @@ def set(self, key, value): """ @abc.abstractmethod - def upsert(self, key, value): + def upsert(self, key, value, lock: Optional[threading.Lock] = None): """ Add an item to the cache if it doesn't exist, or update it if it does. @@ -43,7 +46,7 @@ def upsert(self, key, value): """ @abc.abstractmethod - def delete(self, key): + def delete(self, key, lock: Optional[threading.Lock] = None): """ Remove an item from the cache. @@ -52,7 +55,7 @@ def delete(self, key): """ @abc.abstractmethod - def clear(self): + def clear(self, lock: Optional[threading.Lock] = None): """ Clear all items from the cache. """ @@ -96,3 +99,70 @@ def __delitem__(self, key): Args: key: The key of the item to remove. """ + + +class AsyncBaseCacheService(Service): + """ + Abstract base class for a async cache. + """ + + name = "cache_service" + + @abc.abstractmethod + async def get(self, key, lock: Optional[asyncio.Lock] = None): + """ + Retrieve an item from the cache. + + Args: + key: The key of the item to retrieve. + + Returns: + The value associated with the key, or None if the key is not found. + """ + + @abc.abstractmethod + async def set(self, key, value, lock: Optional[asyncio.Lock] = None): + """ + Add an item to the cache. + + Args: + key: The key of the item. + value: The value to cache. + """ + + @abc.abstractmethod + async def upsert(self, key, value, lock: Optional[asyncio.Lock] = None): + """ + Add an item to the cache if it doesn't exist, or update it if it does. + + Args: + key: The key of the item. + value: The value to cache. + """ + + @abc.abstractmethod + async def delete(self, key, lock: Optional[asyncio.Lock] = None): + """ + Remove an item from the cache. + + Args: + key: The key of the item to remove. + """ + + @abc.abstractmethod + async def clear(self, lock: Optional[asyncio.Lock] = None): + """ + Clear all items from the cache. + """ + + @abc.abstractmethod + def __contains__(self, key): + """ + Check if the key is in the cache. + + Args: + key: The key of the item to check. + + Returns: + True if the key is in the cache, False otherwise. + """ diff --git a/src/backend/langflow/services/cache/factory.py b/src/backend/langflow/services/cache/factory.py index 145f4e6533..48c518b311 100644 --- a/src/backend/langflow/services/cache/factory.py +++ b/src/backend/langflow/services/cache/factory.py @@ -1,6 +1,11 @@ from typing import TYPE_CHECKING -from langflow.services.cache.service import BaseCacheService, InMemoryCache, RedisCache +from langflow.services.cache.service import ( + AsyncInMemoryCache, + BaseCacheService, + RedisCache, + ThreadingInMemoryCache, +) from langflow.services.factory import ServiceFactory from langflow.utils.logger import logger @@ -29,7 +34,9 @@ def create(self, settings_service: "SettingsService"): logger.debug("Redis cache is connected") return redis_cache logger.warning("Redis cache is not connected, falling back to in-memory cache") - return InMemoryCache() + return ThreadingInMemoryCache() elif settings_service.settings.CACHE_TYPE == "memory": - return InMemoryCache() + return ThreadingInMemoryCache() + elif settings_service.settings.CACHE_TYPE == "async": + return AsyncInMemoryCache() diff --git a/src/backend/langflow/services/cache/service.py b/src/backend/langflow/services/cache/service.py index ced3458514..d86c89336d 100644 --- a/src/backend/langflow/services/cache/service.py +++ b/src/backend/langflow/services/cache/service.py @@ -1,16 +1,17 @@ +import asyncio import pickle import threading import time from collections import OrderedDict +from typing import Optional from loguru import logger from langflow.services.base import Service -from langflow.services.cache.base import BaseCacheService +from langflow.services.cache.base import AsyncBaseCacheService, BaseCacheService -class InMemoryCache(BaseCacheService, Service): - +class ThreadingInMemoryCache(BaseCacheService, Service): """ A simple in-memory cache using an OrderedDict. @@ -49,7 +50,7 @@ def __init__(self, max_size=None, expiration_time=60 * 60): self.max_size = max_size self.expiration_time = expiration_time - def get(self, key): + def get(self, key, lock: Optional[threading.Lock] = None): """ Retrieve an item from the cache. @@ -59,7 +60,7 @@ def get(self, key): Returns: The value associated with the key, or None if the key is not found or the item has expired. """ - with self._lock: + with lock or self._lock: return self._get_without_lock(key) def _get_without_lock(self, key): @@ -80,7 +81,7 @@ def _get_without_lock(self, key): self.delete(key) return None - def set(self, key, value, pickle=False): + def set(self, key, value, lock: Optional[threading.Lock] = None): """ Add an item to the cache. @@ -90,7 +91,7 @@ def set(self, key, value, pickle=False): key: The key of the item. value: The value to cache. """ - with self._lock: + with lock or self._lock: if key in self._cache: # Remove existing key before re-inserting to update order self.delete(key) @@ -98,12 +99,10 @@ def set(self, key, value, pickle=False): # Remove least recently used item self._cache.popitem(last=False) # pickle locally to mimic Redis - if pickle: - value = pickle.dumps(value) self._cache[key] = {"value": value, "time": time.time()} - def upsert(self, key, value): + def upsert(self, key, value, lock: Optional[threading.Lock] = None): """ Inserts or updates a value in the cache. If the existing value and the new value are both dictionaries, they are merged. @@ -112,7 +111,7 @@ def upsert(self, key, value): key: The key of the item. value: The value to insert or update. """ - with self._lock: + with lock or self._lock: existing_value = self._get_without_lock(key) if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict): existing_value.update(value) @@ -120,7 +119,7 @@ def upsert(self, key, value): self.set(key, value) - def get_or_set(self, key, value): + def get_or_set(self, key, value, lock: Optional[threading.Lock] = None): """ Retrieve an item from the cache. If the item does not exist, set it with the provided value. @@ -132,27 +131,27 @@ def get_or_set(self, key, value): Returns: The cached value associated with the key. """ - with self._lock: + with lock or self._lock: if key in self._cache: return self.get(key) self.set(key, value) return value - def delete(self, key): + def delete(self, key, lock: Optional[threading.Lock] = None): """ Remove an item from the cache. Args: key: The key of the item to remove. """ - with self._lock: + with lock or self._lock: self._cache.pop(key, None) - def clear(self): + def clear(self, lock: Optional[threading.Lock] = None): """ Clear all items from the cache. """ - with self._lock: + with lock or self._lock: self._cache.clear() def __contains__(self, key): @@ -323,3 +322,85 @@ def __delitem__(self, key): def __repr__(self): """Return a string representation of the RedisCache instance.""" return f"RedisCache(expiration_time={self.expiration_time})" + + +class AsyncInMemoryCache(AsyncBaseCacheService, Service): + def __init__(self, max_size=None, expiration_time=3600): + self.cache = OrderedDict() + + self.lock = asyncio.Lock() + self.max_size = max_size + self.expiration_time = expiration_time + + async def get(self, key, lock: Optional[asyncio.Lock] = None): + if not lock: + async with self.lock: + return await self._get(key) + else: + return await self._get(key) + + async def _get(self, key): + item = self.cache.get(key, None) + if item and (time.time() - item["time"] < self.expiration_time): + self.cache.move_to_end(key) + return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"] + if item: + await self.delete(key) + return None + + async def set(self, key, value, lock: Optional[asyncio.Lock] = None): + if not lock: + async with self.lock: + await self._set( + key, + value, + ) + else: + await self._set( + key, + value, + ) + + async def _set(self, key, value): + if self.max_size and len(self.cache) >= self.max_size: + self.cache.popitem(last=False) + self.cache[key] = {"value": value, "time": time.time()} + self.cache.move_to_end(key) + + async def delete(self, key, lock: Optional[asyncio.Lock] = None): + if not lock: + async with self.lock: + await self._delete(key) + else: + await self._delete(key) + + async def _delete(self, key): + if key in self.cache: + del self.cache[key] + + async def clear(self, lock: Optional[asyncio.Lock] = None): + if not lock: + async with self.lock: + await self._clear() + else: + await self._clear() + + async def _clear(self): + self.cache.clear() + + async def upsert(self, key, value, lock: Optional[asyncio.Lock] = None): + if not lock: + async with self.lock: + await self._upsert(key, value) + else: + await self._upsert(key, value) + + async def _upsert(self, key, value): + existing_value = await self.get(key) + if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict): + existing_value.update(value) + value = existing_value + await self.set(key, value) + + def __contains__(self, key): + return key in self.cache diff --git a/src/backend/langflow/services/chat/service.py b/src/backend/langflow/services/chat/service.py index 323b3d6f34..4df75d437b 100644 --- a/src/backend/langflow/services/chat/service.py +++ b/src/backend/langflow/services/chat/service.py @@ -1,4 +1,6 @@ -from typing import Any +import asyncio +from collections import defaultdict +from typing import Any, Optional from langflow.services.base import Service from langflow.services.deps import get_cache_service @@ -8,30 +10,30 @@ class ChatService(Service): name = "chat_service" def __init__(self): + self._cache_locks = defaultdict(asyncio.Lock) self.cache_service = get_cache_service() - def set_cache(self, client_id: str, data: Any) -> bool: + async def set_cache(self, flow_id: str, data: Any, lock: Optional[asyncio.Lock] = None) -> bool: """ Set the cache for a client. """ # client_id is the flow id but that already exists in the cache # so we need to change it to something else - result_dict = { "result": data, "type": type(data), } - self.cache_service.upsert(client_id, result_dict) - return client_id in self.cache_service + await self.cache_service.upsert(flow_id, result_dict, lock=lock or self._cache_locks[flow_id]) + return flow_id in self.cache_service - def get_cache(self, client_id: str) -> Any: + async def get_cache(self, flow_id: str, lock: Optional[asyncio.Lock] = None) -> Any: """ Get the cache for a client. """ - return self.cache_service.get(client_id) + return await self.cache_service.get(flow_id, lock=lock or self._cache_locks[flow_id]) - def clear_cache(self, client_id: str): + async def clear_cache(self, flow_id: str, lock: Optional[asyncio.Lock] = None): """ Clear the cache for a client. """ - self.cache_service.delete(client_id) + self.cache_service.delete(flow_id, lock=lock or self._cache_locks[flow_id]) diff --git a/src/backend/langflow/services/settings/base.py b/src/backend/langflow/services/settings/base.py index f2e78b8cb7..9c7ab21c4d 100644 --- a/src/backend/langflow/services/settings/base.py +++ b/src/backend/langflow/services/settings/base.py @@ -38,7 +38,7 @@ class Settings(BaseSettings): DEV: bool = False DATABASE_URL: Optional[str] = None - CACHE_TYPE: str = "memory" + CACHE_TYPE: str = "async" REMOVE_API_KEYS: bool = False COMPONENTS_PATH: List[str] = [] LANGCHAIN_CACHE: str = "InMemoryCache" diff --git a/src/frontend/src/icons/BotMessageSquare/BotMessageSquare.jsx b/src/frontend/src/icons/BotMessageSquare/BotMessageSquare.jsx index 64c8fbb0fa..ea16e06823 100644 --- a/src/frontend/src/icons/BotMessageSquare/BotMessageSquare.jsx +++ b/src/frontend/src/icons/BotMessageSquare/BotMessageSquare.jsx @@ -9,7 +9,7 @@ const SvgBotMessageSquare = (props) => ( stroke-width="2" stroke-linecap="round" stroke-linejoin="round" - class="lucide lucide-bot-message-square" + className="lucide lucide-bot-message-square" {...props} > diff --git a/tests/test_cache.py b/tests/test_cache.py index 7368846730..6d4911efdf 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,8 +1,9 @@ import json -from langflow.graph import Graph import pytest +from langflow.graph import Graph + def get_graph(_type="basic"): """Get a graph from a json file""" @@ -38,7 +39,8 @@ def langchain_objects_are_equal(obj1, obj2): # Test build_graph -def test_build_graph(client, basic_data_graph): +@pytest.mark.asyncio +async def test_build_graph(client, basic_data_graph): graph = Graph.from_payload(basic_data_graph) assert graph is not None assert len(graph.vertices) == len(basic_data_graph["nodes"]) diff --git a/tests/test_data_components.py b/tests/test_data_components.py index bc46b263ef..ca92ee1904 100644 --- a/tests/test_data_components.py +++ b/tests/test_data_components.py @@ -28,9 +28,7 @@ async def test_successful_get_request(api_request): respx.get(url).mock(return_value=Response(200, json=mock_response)) # Making the request - result = await api_request.make_request( - client=httpx.AsyncClient(), method=method, url=url - ) + result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url) # Assertions assert result.data["status_code"] == 200 @@ -46,9 +44,7 @@ async def test_failed_request(api_request): respx.get(url).mock(return_value=Response(404)) # Making the request - result = await api_request.make_request( - client=httpx.AsyncClient(), method=method, url=url - ) + result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url) # Assertions assert result.data["status_code"] == 404 @@ -60,14 +56,10 @@ async def test_timeout(api_request): # Mocking a timeout url = "https://example.com/api/timeout" method = "GET" - respx.get(url).mock( - side_effect=httpx.TimeoutException(message="Timeout", request=None) - ) + respx.get(url).mock(side_effect=httpx.TimeoutException(message="Timeout", request=None)) # Making the request - result = await api_request.make_request( - client=httpx.AsyncClient(), method=method, url=url, timeout=1 - ) + result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url, timeout=1) # Assertions assert result.data["status_code"] == 408 @@ -106,7 +98,6 @@ def test_directory_component_build_with_multithreading( # Arrange directory_component = data.DirectoryComponent() path = os.path.dirname(os.path.abspath(__file__)) - types = ["py"] depth = 1 max_concurrency = 2 load_hidden = False @@ -123,7 +114,6 @@ def test_directory_component_build_with_multithreading( # Act directory_component.build( path, - types, depth, max_concurrency, load_hidden, @@ -134,9 +124,7 @@ def test_directory_component_build_with_multithreading( # Assert mock_resolve_path.assert_called_once_with(path) - mock_retrieve_file_paths.assert_called_once_with( - path, types, load_hidden, recursive, depth - ) + mock_retrieve_file_paths.assert_called_once_with(path, load_hidden, recursive, depth) mock_parallel_load_records.assert_called_once_with( mock_retrieve_file_paths.return_value, silent_errors, max_concurrency )