Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into fix-model-produced-invalid-content
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje authored Aug 28, 2024
2 parents f7caf75 + 4f9383a commit 8321c81
Show file tree
Hide file tree
Showing 39 changed files with 2,485 additions and 1,027 deletions.
3 changes: 3 additions & 0 deletions samples/apps/autogen-studio/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ autogenstudio/web/workdir/*
autogenstudio/web/ui/*
autogenstudio/web/skills/user/*
.release.sh
.nightly.sh

notebooks/work_dir/*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
81 changes: 5 additions & 76 deletions samples/apps/autogen-studio/autogenstudio/chatmanager.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import asyncio
import json
import os
import time
from datetime import datetime
from queue import Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import websockets
from fastapi import WebSocket, WebSocketDisconnect

from .datamodel import Message, SocketMessage, Workflow
from .utils import (
extract_successful_code_blocks,
get_modified_files,
summarize_chat_history,
)
from .datamodel import Message
from .workflowmanager import WorkflowManager


Expand Down Expand Up @@ -82,76 +75,12 @@ def chat(
connection_id=connection_id,
)

workflow = Workflow.model_validate(workflow)

message_text = message.content.strip()
result_message: Message = workflow_manager.run(message=f"{message_text}", clear_history=False, history=history)

start_time = time.time()
workflow_manager.run(message=f"{message_text}", clear_history=False)
end_time = time.time()

metadata = {
"messages": workflow_manager.agent_history,
"summary_method": workflow.summary_method,
"time": end_time - start_time,
"files": get_modified_files(start_time, end_time, source_dir=work_dir),
}

output = self._generate_output(message_text, workflow_manager, workflow)

output_message = Message(
user_id=message.user_id,
role="assistant",
content=output,
meta=json.dumps(metadata),
session_id=message.session_id,
)

return output_message

def _generate_output(
self,
message_text: str,
workflow_manager: WorkflowManager,
workflow: Workflow,
) -> str:
"""
Generates the output response based on the workflow configuration and agent history.
:param message_text: The text of the incoming message.
:param flow: An instance of `WorkflowManager`.
:param flow_config: An instance of `AgentWorkFlowConfig`.
:return: The output response as a string.
"""

output = ""
if workflow.summary_method == "last":
successful_code_blocks = extract_successful_code_blocks(workflow_manager.agent_history)
last_message = (
workflow_manager.agent_history[-1]["message"]["content"] if workflow_manager.agent_history else ""
)
successful_code_blocks = "\n\n".join(successful_code_blocks)
output = (last_message + "\n" + successful_code_blocks) if successful_code_blocks else last_message
elif workflow.summary_method == "llm":
client = workflow_manager.receiver.client
status_message = SocketMessage(
type="agent_status",
data={
"status": "summarizing",
"message": "Summarizing agent dialogue",
},
connection_id=workflow_manager.connection_id,
)
self.send(status_message.dict())
output = summarize_chat_history(
task=message_text,
messages=workflow_manager.agent_history,
client=client,
)

elif workflow.summary_method == "none":
output = ""
return output
result_message.user_id = message.user_id
result_message.session_id = message.session_id
return result_message


class WebSocketConnectionManager:
Expand Down
35 changes: 34 additions & 1 deletion samples/apps/autogen-studio/autogenstudio/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def ui(
port: int = 8081,
workers: int = 1,
reload: Annotated[bool, typer.Option("--reload")] = False,
docs: bool = False,
docs: bool = True,
appdir: str = None,
database_uri: Optional[str] = None,
):
Expand Down Expand Up @@ -48,6 +48,39 @@ def ui(
)


@app.command()
def serve(
workflow: str = "",
host: str = "127.0.0.1",
port: int = 8084,
workers: int = 1,
docs: bool = False,
):
"""
Serve an API Endpoint based on an AutoGen Studio workflow json file.
Args:
workflow (str): Path to the workflow json file.
host (str, optional): Host to run the UI on. Defaults to 127.0.0.1 (localhost).
port (int, optional): Port to run the UI on. Defaults to 8081.
workers (int, optional): Number of workers to run the UI with. Defaults to 1.
reload (bool, optional): Whether to reload the UI on code changes. Defaults to False.
docs (bool, optional): Whether to generate API docs. Defaults to False.
"""

os.environ["AUTOGENSTUDIO_API_DOCS"] = str(docs)
os.environ["AUTOGENSTUDIO_WORKFLOW_FILE"] = workflow

uvicorn.run(
"autogenstudio.web.serve:app",
host=host,
port=port,
workers=workers,
reload=False,
)


@app.command()
def version():
"""
Expand Down
55 changes: 37 additions & 18 deletions samples/apps/autogen-studio/autogenstudio/database/dbmanager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from datetime import datetime
from typing import Optional

Expand All @@ -15,30 +16,39 @@
Skill,
Workflow,
WorkflowAgentLink,
WorkflowAgentType,
)
from .utils import init_db_samples

valid_link_types = ["agent_model", "agent_skill", "agent_agent", "workflow_agent"]


class WorkflowAgentMap(SQLModel):
agent: Agent
link: WorkflowAgentLink


class DBManager:
"""A class to manage database operations"""

_init_lock = threading.Lock() # Class-level lock

def __init__(self, engine_uri: str):
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
self.engine = create_engine(engine_uri, connect_args=connection_args)
# run_migration(engine_uri=engine_uri)

def create_db_and_tables(self):
"""Create a new database and tables"""
try:
SQLModel.metadata.create_all(self.engine)
with self._init_lock: # Use the lock
try:
init_db_samples(self)
SQLModel.metadata.create_all(self.engine)
try:
init_db_samples(self)
except Exception as e:
logger.info("Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while creating database tables:" + str(e))
logger.info("Error while creating database tables:" + str(e))

def upsert(self, model: SQLModel):
"""Create a new entity"""
Expand All @@ -62,7 +72,7 @@ def upsert(self, model: SQLModel):
session.refresh(model)
except Exception as e:
session.rollback()
logger.error("Error while upserting %s", e)
logger.error("Error while updating " + str(model_class.__name__) + ": " + str(e))
status = False

response = Response(
Expand Down Expand Up @@ -115,7 +125,7 @@ def get_items(
session.rollback()
status = False
status_message = f"Error while fetching {model_class.__name__}"
logger.error("Error while getting %s: %s", model_class.__name__, e)
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))

response: Response = Response(
message=status_message,
Expand Down Expand Up @@ -157,16 +167,16 @@ def delete(self, model_class: SQLModel, filters: dict = None):
status_message = f"{model_class.__name__} Deleted Successfully"
else:
print(f"Row with filters {filters} not found")
logger.info("Row with filters %s not found", filters)
logger.info("Row with filters + filters + not found")
status_message = "Row not found"
except exc.IntegrityError as e:
session.rollback()
logger.error("Integrity ... Error while deleting: %s", e)
logger.error("Integrity ... Error while deleting: " + str(e))
status_message = f"The {model_class.__name__} is linked to another entity and cannot be deleted."
status = False
except Exception as e:
session.rollback()
logger.error("Error while deleting: %s", e)
logger.error("Error while deleting: " + str(e))
status_message = f"Error while deleting: {e}"
status = False
response = Response(
Expand All @@ -182,6 +192,7 @@ def get_linked_entities(
primary_id: int,
return_json: bool = False,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
):
"""
Get all entities linked to the primary entity.
Expand Down Expand Up @@ -217,19 +228,21 @@ def get_linked_entities(
linked_entities = agent.agents
elif link_type == "workflow_agent":
linked_entities = session.exec(
select(Agent)
.join(WorkflowAgentLink)
select(WorkflowAgentLink, Agent)
.join(Agent, WorkflowAgentLink.agent_id == Agent.id)
.where(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_type == agent_type,
)
).all()

linked_entities = [WorkflowAgentMap(agent=agent, link=link) for link, agent in linked_entities]
linked_entities = sorted(linked_entities, key=lambda x: x.link.sequence_id) # type: ignore
except Exception as e:
logger.error("Error while getting linked entities: %s", e)
logger.error("Error while getting linked entities: " + str(e))
status_message = f"Error while getting linked entities: {e}"
status = False
if return_json:
linked_entities = [self._model_to_dict(row) for row in linked_entities]
linked_entities = [row.model_dump() for row in linked_entities]

response = Response(
message=status_message,
Expand All @@ -245,6 +258,7 @@ def link(
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
) -> Response:
"""
Link two entities together.
Expand Down Expand Up @@ -357,6 +371,7 @@ def link(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()
if existing_link:
Expand All @@ -373,6 +388,7 @@ def link(
workflow_id=primary_id,
agent_id=secondary_id,
agent_type=agent_type,
sequence_id=sequence_id,
)
session.add(workflow_agent_link)
# add and commit the link
Expand All @@ -385,7 +401,7 @@ def link(

except Exception as e:
session.rollback()
logger.error("Error while linking: %s", e)
logger.error("Error while linking: " + str(e))
status = False
status_message = f"Error while linking due to an exception: {e}"

Expand All @@ -402,6 +418,7 @@ def unlink(
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = 0,
) -> Response:
"""
Unlink two entities.
Expand All @@ -417,6 +434,7 @@ def unlink(
"""
status = True
status_message = ""
print("primary", primary_id, "secondary", secondary_id, "sequence", sequence_id, "agent_type", agent_type)

if link_type not in valid_link_types:
status = False
Expand Down Expand Up @@ -452,6 +470,7 @@ def unlink(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()

Expand All @@ -465,7 +484,7 @@ def unlink(

except Exception as e:
session.rollback()
logger.error("Error while unlinking: %s", e)
logger.error("Error while unlinking: " + str(e))
status = False
status_message = f"Error while unlinking due to an exception: {e}"

Expand Down
Loading

0 comments on commit 8321c81

Please sign in to comment.