Skip to content

Commit

Permalink
format and checks pass
Browse files Browse the repository at this point in the history
  • Loading branch information
husseinmozannar committed Nov 13, 2024
1 parent 77018ee commit 5aa89da
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import aiofiles
import PIL.Image
from PIL import Image
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import ChatMessage, MultiModalMessage, TextMessage
Expand All @@ -38,14 +37,17 @@
SystemMessage,
UserMessage,
)

from PIL import Image
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright

from ._events import WebSurferEvent
from ._playwright_controller import PlaywrightController
from ._prompts import WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE, WEB_SURFER_TOOL_PROMPT
from ._set_of_mark import add_set_of_mark
from ._tool_definitions import (
TOOL_CLICK,
TOOL_HISTORY_BACK,
TOOL_HOVER,
TOOL_PAGE_DOWN,
TOOL_PAGE_UP,
TOOL_READ_PAGE_AND_ANSWER,
Expand All @@ -56,12 +58,9 @@
TOOL_TYPE,
TOOL_VISIT_URL,
TOOL_WEB_SEARCH,
TOOL_HOVER,
)
from ._types import InteractiveRegion, UserContent
from ._utils import message_content_to_str
from ._playwright_controller import PlaywrightController
from ._prompts import WEB_SURFER_TOOL_PROMPT, WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE

# Viewport dimensions
VIEWPORT_HEIGHT = 900
Expand Down Expand Up @@ -625,12 +624,10 @@ async def __generate_reply(self, cancellation_token: CancellationToken) -> Tuple

# Add the multimodal message and make the request
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
print(text_prompt)
response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
) # , "parallel_tool_calls": False})
message = response.content
print(response)
self._last_download = None

if isinstance(message, str):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import base64
import os
import random
import asyncio
from typing import Any, Dict, Optional, Tuple, Union, cast, Callable
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast

from playwright._impl._errors import Error as PlaywrightError
from playwright._impl._errors import TimeoutError
from playwright.async_api import Download, Page

from ._types import (
InteractiveRegion,
VisualViewport,
Expand All @@ -19,8 +21,8 @@ def __init__(
self,
animate_actions: bool = False,
downloads_folder: Optional[str] = None,
viewport_width: int = None,
viewport_height: int = None,
viewport_width: int = 1440,
viewport_height: int = 900,
_download_handler: Optional[Callable[[Download], None]] = None,
to_resize_viewport: bool = True,
) -> None:
Expand All @@ -40,7 +42,7 @@ def __init__(
self._download_handler = _download_handler
self.to_resize_viewport = to_resize_viewport
self._page_script: str = ""
self.last_cursor_position: Tuple[int, int] = (0, 0)
self.last_cursor_position: Tuple[float, float] = (0.0, 0.0)

# Read page_script
with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt") as fh:
Expand Down Expand Up @@ -97,8 +99,8 @@ async def get_page_metadata(self, page: Page) -> Dict[str, Any]:

async def on_new_page(self, page: Page) -> None:
assert page is not None
page.on("download", self._download_handler)
if self.to_resize_viewport:
page.on("download", self._download_handler) # type: ignore
if self.to_resize_viewport and self.viewport_width and self.viewport_height:
await page.set_viewport_size({"width": self.viewport_width, "height": self.viewport_height})
await self.sleep(page, 0.2)
await page.add_init_script(path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"))
Expand Down Expand Up @@ -148,7 +150,9 @@ async def page_up(self, page: Page) -> None:
assert page is not None
await page.evaluate(f"window.scrollBy(0, -{self.viewport_height-50});")

async def gradual_cursor_animation(self, page: Page, start_x: int, start_y: int, end_x: int, end_y: int) -> None:
async def gradual_cursor_animation(
self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float
) -> None:
# animation helper
steps = 20
for step in range(steps):
Expand Down Expand Up @@ -209,11 +213,11 @@ async def remove_cursor_box(self, page: Page, identifier: str) -> None:
}})();
""")

async def click_id(self, page: Page, identifier: str) -> None:
async def click_id(self, page: Page, identifier: str) -> Page | None:
"""
Returns new page if a new page is opened, otherwise None.
"""
new_page = None
new_page: Page | None = None
assert page is not None
target = page.locator(f"[__elementId='{identifier}']")

Expand Down Expand Up @@ -258,7 +262,7 @@ async def click_id(self, page: Page, identifier: str) -> None:
await self.on_new_page(new_page)
except TimeoutError:
pass
return new_page
return new_page # type: ignore

async def hover_id(self, page: Page, identifier: str) -> None:
"""
Expand Down Expand Up @@ -286,6 +290,7 @@ async def hover_id(self, page: Page, identifier: str) -> None:
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
await asyncio.sleep(0.1)
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)

await self.remove_cursor_box(page, identifier)
else:
Expand Down Expand Up @@ -357,13 +362,17 @@ async def get_webpage_text(self, page: Page, n_lines: int = 100) -> str:
return: text in the first n_lines of the page
"""
assert page is not None
text_in_viewport = await page.evaluate("""() => {
return document.body.innerText;
}""")
text_in_viewport = "\n".join(text_in_viewport.split("\n")[:n_lines])
# remove empty lines
text_in_viewport = "\n".join([line for line in text_in_viewport.split("\n") if line.strip()])
return text_in_viewport
try:
text_in_viewport = await page.evaluate("""() => {
return document.body.innerText;
}""")
text_in_viewport = "\n".join(text_in_viewport.split("\n")[:n_lines])
# remove empty lines
text_in_viewport = "\n".join([line for line in text_in_viewport.split("\n") if line.strip()])
assert isinstance(text_in_viewport, str)
return text_in_viewport
except Exception:
return ""

async def get_page_markdown(self, page: Page) -> str:
# TODO: replace with mdconvert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
You are a helpful assistant that can summarize long documents to answer question.
"""


def WEB_SURFER_QA_PROMPT(title: str, question: str | None = None) -> str:
base_prompt = f"We are visiting the webpage '{title}'. Its full-text content are pasted below, along with a screenshot of the page's current viewport."
if question is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict

from autogen_core.components.tools._base import ParametersSchema, ToolSchema


Expand All @@ -13,7 +14,10 @@ def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema:
),
)

REASONING_TOOL_PROMPT = "A short description of the action to be performed and reason for doing so, do not mention the user."

REASONING_TOOL_PROMPT = (
"A short description of the action to be performed and reason for doing so, do not mention the user."
)

TOOL_VISIT_URL: ToolSchema = _load_tool(
{
Expand Down

This file was deleted.

0 comments on commit 5aa89da

Please sign in to comment.