From e0226cc51699cadc46d9bed1d3f62e628fb2a74c Mon Sep 17 00:00:00 2001 From: Ovizro Date: Thu, 14 Nov 2024 20:38:58 +0800 Subject: [PATCH] More low level API support (#21) - Add support for sending and receiving attachments in messages - Implement file upload and download functionality - Add AgentBot class for running bots on behalf of a specific user - Fix the issue where Bot.send_message did not handle extra parameters correctly. - Implement new bot methods for account management, message deletion, and topic querying - Update __init__.py to include new remove_bot function - Bump version to 0.2.2 --- examples/exec.py | 48 --- examples/jobs.py | 158 +++++++++ karuha/__init__.py | 15 +- karuha/bot.py | 666 ++++++++++++++++++++++++++++++++++++-- karuha/config.py | 28 +- karuha/data/sub.py | 2 +- karuha/event/bot.py | 20 +- karuha/runner.py | 134 +------- karuha/session.py | 270 +++++++++++++--- karuha/text/textchain.py | 20 +- karuha/utils/gathering.py | 115 +++++++ karuha/version.py | 2 +- setup.py | 10 +- tests/client/__init__.py | 0 tests/client/test_bot.py | 69 ++++ tests/utils.py | 68 +++- 16 files changed, 1356 insertions(+), 269 deletions(-) create mode 100644 examples/jobs.py create mode 100644 karuha/utils/gathering.py create mode 100644 tests/client/__init__.py create mode 100644 tests/client/test_bot.py diff --git a/examples/exec.py b/examples/exec.py index 9f71f4e..d6ed292 100644 --- a/examples/exec.py +++ b/examples/exec.py @@ -88,51 +88,3 @@ async def wait(self) -> None: def check_for_exit(self) -> None: if self.pipe_closed and self.exited and self.exit_future: self.exit_future.set_result(True) - - -@on_command -async def run(session: MessageSession, name: str, user_id: str, argv: List[str]) -> None: - user = await session.get_user(user_id) - if not user.staff: - await session.finish("Permission denied") - parser = ArgumentParser(session, name) - parser.add_argument("-c", "--cwd", help="working directory") - parser.add_argument("-e", "--env", action="append", help="environment variable") - parser.add_argument("command", nargs="*", help="command to run") - ns = parser.parse_args(argv) - if not ns.command: - await session.finish("No command specified") - - session.bot.logger.info(f"run: {ns.command}") - loop = asyncio.get_running_loop() - transport, protocol = await loop.subprocess_exec( - DateProtocol, - *ns.command, - cwd=ns.cwd, - env=dict(os.environ, **dict((e.split("=", 1) for e in ns.env or ()))), - stdin=None, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE - ) - - wait_task = asyncio.create_task(protocol.wait()) - while not wait_task.done(): - done, _ = await asyncio.wait( - (wait_task, protocol.output.get()), - return_when=asyncio.FIRST_COMPLETED - ) - if wait_task in done: - done.remove(wait_task) - if not done: - break - data: bytes = done.pop().result() # type: ignore - await session.send(data.decode()) - - while not protocol.output.empty(): - data = protocol.output.get_nowait() - await session.send(data.decode()) - - code = transport.get_returncode() - transport.close() - if code is not None: - await session.send(f"Process exited with code {code}") diff --git a/examples/jobs.py b/examples/jobs.py new file mode 100644 index 0000000..0595cad --- /dev/null +++ b/examples/jobs.py @@ -0,0 +1,158 @@ +import asyncio +import os +from io import StringIO +from traceback import format_exc +from typing import List, Optional + +from psutil import Process + +from karuha import MessageSession, PlainText, on_command +from karuha.text import Italic +from karuha.utils.argparse import ArgumentParser + + +class DateProtocol(asyncio.SubprocessProtocol): + def __init__(self, exit_future: Optional[asyncio.Future] = None) -> None: + self.exit_future = exit_future + self.output = asyncio.Queue() + self.pipe_closed = False + self.exited = False + + def pipe_connection_lost(self, fd: int, exc: Optional[Exception]) -> None: + self.pipe_closed = True + self.check_for_exit() + + def pipe_data_received(self, fd: int, data: bytes) -> None: + self.output.put_nowait(data) + + def process_exited(self) -> None: + self.exited = True + # process_exited() method can be called before + # pipe_connection_lost() method: wait until both methods are + # called. + self.check_for_exit() + + async def wait(self) -> None: + if self.pipe_closed and self.exited: + return + if self.exit_future is None: + self.exit_future = asyncio.Future() + await self.exit_future + + def check_for_exit(self) -> None: + if self.pipe_closed and self.exited and self.exit_future: + self.exit_future.set_result(True) + + +_jobs: List[asyncio.SubprocessTransport] = [] + + +@on_command +async def run(session: MessageSession, name: str, user_id: str, argv: List[str]) -> None: + user = await session.get_user(user_id) + if not user.staff: + await session.finish("Permission denied") + parser = ArgumentParser(session, name) + parser.add_argument("-c", "--cwd", help="working directory") + parser.add_argument("-e", "--env", action="append", help="environment variable") + # parser.add_argument("-s", "--shell", action="store_true", help="shell mode") + parser.add_argument("command", nargs="+", help="command to run") + ns = parser.parse_args(argv) + if not ns.command: + await session.finish("No command specified") + + session.bot.logger.info(f"run: {ns.command}") + loop = asyncio.get_running_loop() + try: + transport, protocol = await loop.subprocess_exec( + DateProtocol, + *ns.command, + cwd=ns.cwd, + env=dict(os.environ, **dict((e.split("=", 1) for e in ns.env or ()))), + # shell=ns.shell, + stdin=None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + except OSError: + await session.finish(format_exc()) + + _jobs.append(transport) + wait_task = asyncio.create_task(protocol.wait()) + while not wait_task.done(): + done, _ = await asyncio.wait( + (wait_task, protocol.output.get()), + return_when=asyncio.FIRST_COMPLETED + ) + if wait_task in done: + done.remove(wait_task) + if not done: + break + data: bytes = done.pop().result() # type: ignore + if text := data.decode().rstrip(): + await session.send(text) + + while not protocol.output.empty(): + data = protocol.output.get_nowait() + if text := data.decode().rstrip(): + await session.send(text) + + _jobs.remove(transport) + code = transport.get_returncode() + transport.close() + if code: + await session.send( + Italic( + content=PlainText(f"Process exited with code {code}") + ) + ) + + +@on_command +async def kill(session: MessageSession, name: str, user_id: str, argv: List[str]) -> None: + user = await session.get_user(user_id) + if not user.staff: + await session.finish("Permission denied") + parser = ArgumentParser(session, name) + parser.add_argument("tid", type=int, help="process id", nargs="?") + parser.add_argument("-s", "--signal", type=int, help="signal to send", default=15) + ns = parser.parse_args(argv) + if ns.tid is None: + # kill all subprocesses + for transport in _jobs: + transport.send_signal(ns.signal) + await session.send("All subprocesses killed") + else: + try: + transport = _jobs[ns.tid] + except IndexError: + await session.send("Invalid process id") + else: + transport.send_signal(ns.signal) + await session.send(f"Killed process {ns.tid}") + + +@on_command +async def jobs(session: MessageSession, name: str, user_id: str, argv: List[str]) -> None: + user = await session.get_user(user_id) + if not user.staff: + await session.finish("Permission denied") + parser = ArgumentParser(session, name) + parser.add_argument("-t", "--tid", action="store_true", help="list tid only") + parser.add_argument("-r", action="store_true", help="restrict output to running jobs") + parser.add_argument("-s", action="store_true", help="restrict output to stopped jobs") + ns = parser.parse_args(argv) + if ns.tid: + await session.finish('\n'.join(str(i) for i in range(len(_jobs)))) + ss = StringIO() + for i, transport in enumerate(_jobs): + pid = transport.get_pid() + process = Process(pid) + status = process.status() + if ns.r and status != "running": + continue + if ns.s and status == "running": + continue + ss.write(f"[{i}] {status} {' '.join(process.cmdline())}\n") + if text := ss.getvalue(): + await session.send(text) diff --git a/karuha/__init__.py b/karuha/__init__.py index 3e08091..5f386fc 100644 --- a/karuha/__init__.py +++ b/karuha/__init__.py @@ -9,31 +9,38 @@ from .version import __version__ -from .config import get_config, load_config, init_config, save_config, Config +from .config import get_config, load_config, init_config, save_config, reset_config, Config from .config import Server as ServerConfig, Bot as BotConfig -from .bot import Bot +from .bot import Bot, BotState from .exception import KaruhaException from .event import on, on_event, Event from .text import Drafty, BaseText, PlainText, Message, TextChain from .command import CommandCollection, AbstractCommand, AbstractCommandParser, BaseSession, MessageSession, CommandSession, get_collection, on_command, rule, on_rule from .data import get_user, get_topic, try_get_user, try_get_topic -from .runner import get_bot, add_bot, try_add_bot, get_all_bots, async_run, run +from .runner import try_get_bot, get_bot, add_bot, try_add_bot, get_all_bots, remove_bot, cancel_all_bots, async_run, run, reset __all__ = [ - # bot + # runner "add_bot", "try_add_bot", "get_bot", + "try_get_bot", "get_all_bots", + "remove_bot", + "cancel_all_bots", "async_run", "run", + "reset", + # bot "Bot", + "BotState", # config "get_config", "init_config", "load_config", "save_config", + "reset_config", "Config", "BotConfig", "ServerConfig", diff --git a/karuha/bot.py b/karuha/bot.py index e29c92c..db66027 100644 --- a/karuha/bot.py +++ b/karuha/bot.py @@ -1,9 +1,10 @@ import asyncio import base64 +from io import IOBase, TextIOBase +import os import platform import sys from asyncio.queues import Queue -from base64 import b64decode from collections import defaultdict from contextlib import asynccontextmanager, contextmanager from datetime import datetime, timezone @@ -14,6 +15,7 @@ import grpc from aiofiles import open as aio_open +from aiohttp import ClientError, ClientSession, FormData from google.protobuf.message import Message from grpc import aio as grpc_aio from pydantic import GetCoreSchemaHandler, TypeAdapter @@ -51,9 +53,12 @@ class Bot(object): initialize_event_callback: Callable[[Self], Any] finalize_event_callback: Callable[[Self], Coroutine] - server_event_callbacks: Dict[str, List[Callable[[Self, Message], Any]]] = ( - defaultdict(list) - ) + server_event_callbacks: Dict[ + str, + List[ + Callable[[Self, Message], Any] + ] + ] = defaultdict(list) client_event_callbacks: Dict[ str, List[ @@ -131,7 +136,10 @@ def __init__( async def hello(self, /, lang: str = "EN") -> Tuple[str, Dict[str, Any]]: """ - send a hello message to the server and get the server id + Handshake message client uses to inform the server of its version and user agent. + This message must be the first that the client sends to the server. + Server responds with a {ctrl} which contains server build build, wire protocol version ver, + session ID sid in case of long polling, as well as server constraints, all in ctrl.params. :param lang: the language of the chatbot :type lang: str @@ -164,9 +172,75 @@ async def hello(self, /, lang: str = "EN") -> Tuple[str, Dict[str, Any]]: self.logger.info(f"server: {build} {ver}") return tid, decode_mapping(ctrl.params) + async def account( + self, + user_id: str, + scheme: Optional[str] = None, + secret: Optional[bytes] = None, + *, + state: str = "ok", + do_login: bool = True, + desc: Optional[pb.SetDesc] = None, + tags: Iterable[str] = (), + cred: Iterable[pb.ClientCred] = (), + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + """ + Message {acc} creates users or updates tags or authentication credentials scheme and secret of exiting users. + To create a new user set user to the string new optionally followed by any character sequence, e.g. newr15gsr. + Either authenticated or anonymous session can send an {acc} message to create a new user. + To update authentication data or validate a credential of the current user leave user unset. + + The {acc} message cannot be used to modify desc or cred of an existing user. + Update user's me topic instead. + + :param user_id: the user id + :type user_id: str + :param scheme: the authentication scheme + :type scheme: Optional[str] + :param secret: the authentication secret + :type secret: Optional[bytes] + :param state: the account state + :type state: str + :param do_login: whether to login after updating + :type do_login:bool + :param desc: the account description + :type desc: Optional[pb.SetDesc] + :param tags: the account tags + :type tags: Iterable[str] + :param cred: the account credentials + :type cred: Iterable[pb.ClientCred] + :param extra: the extra data + :type extra: Optional[pb.ClientExtra] + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + tid = self._get_tid() + ctrl = await self.send_message( + tid, + acc=pb.ClientAcc( + id=tid, + user_id=user_id, + scheme=scheme, + secret=secret, + state=state, + login=do_login, + desc=desc, + tags=tags, + cred=cred, + ), + extra=extra + ) + assert isinstance(ctrl, pb.ServerCtrl) + if ctrl.code < 200 or ctrl.code >= 400: # pragma: no cover + err_text = f"fail to update account: {ctrl.text}" + self.logger.error(err_text) + raise KaruhaBotError(err_text, bot=self, code=ctrl.code) + return tid, decode_mapping(ctrl.params) + async def login(self) -> Tuple[str, Dict[str, Any]]: """ - login to the server and get the user id + Login is used to authenticate the current session. :return: tid and params :rtype: Tuple[str, Dict[str, Any]] @@ -225,7 +299,12 @@ async def subscribe( extra: Optional[pb.ClientExtra] = None ) -> Tuple[str, Dict[str, Any]]: """ - subscribe to a topic + The {sub} packet serves the following functions: + + - creating a new topic + - subscribing user to an existing topic + - attaching session to a previously subscribed topic + - fetching topic data :param topic: topic to subscribe :type topic: str @@ -285,7 +364,13 @@ async def subscribe( async def leave(self, /, topic: str, *, extra: Optional[pb.ClientExtra] = None) -> Tuple[str, Dict[str, Any]]: """ - leave a topic + This is a counterpart to {sub} message. It also serves two functions: + + - leaving the topic without unsubscribing (unsub=false) + - unsubscribing (unsub=true) + + Server responds to {leave} with a {ctrl} packet. Leaving without unsubscribing affects just the current session. + Leaving with unsubscribing will affect all user's sessions. :param topic: topic to leave :type topic: str @@ -327,7 +412,7 @@ async def publish( extra: Optional[pb.ClientExtra] = None ) -> Tuple[str, Dict[str, Any]]: """ - publish message to a topic + The message is used to distribute content to topic subscribers. :param topic: topic to publish :type topic: str @@ -340,10 +425,7 @@ async def publish( :return: tid and params :rtype: Tuple[str, Dict[str, Any]] """ - if head is None: - head = {} - else: - head = encode_mapping(head) + head = {} if head is None else encode_mapping(head) if "auto" not in head: head["auto"] = b"true" tid = self._get_tid() @@ -374,7 +456,23 @@ async def get( *, desc: Optional[pb.GetOpts] = None, extra: Optional[pb.ClientExtra] = None, - ) -> Tuple[str, Optional[pb.ServerMeta]]: ... + ) -> Tuple[str, Optional[pb.ServerMeta]]: + """ + Query topic for description. + + NOTE: only one of `what` can be specified at a time + + :param topic: topic to get + :type topic: str + :param what: fields to get + :type what: Literal["desc", "sub", "data", "tags"], optional + :param desc: description query options + :type desc: Optional[pb.GetOpts] + :param extra: extra data + :type extra: Optional[pb.ClientExtra] + :return: tid and meta + :rtype: Tuple[str, Optional[pb.ServerMeta]] + """ @overload async def get( @@ -385,7 +483,23 @@ async def get( *, sub: Optional[pb.GetOpts] = None, extra: Optional[pb.ClientExtra] = None, - ) -> Tuple[str, Optional[pb.ServerMeta]]: ... + ) -> Tuple[str, Optional[pb.ServerMeta]]: + """ + Query topic for subscriptions. + + NOTE: only one of `what` can be specified at a time + + :param topic: topic to get + :type topic: str + :param what: fields to get + :type what: Literal["desc", "sub", "data", "tags"], optional + :param sub: subscriptions query options + :type sub: Optional[pb.GetOpts] + :param extra: extra data + :type extra: Optional[pb.ClientExtra] + :return: tid and meta + :rtype: Tuple[str, Optional[pb.ServerMeta]] + """ @overload async def get( @@ -396,28 +510,73 @@ async def get( *, data: Optional[pb.GetOpts] = None, extra: Optional[pb.ClientExtra] = None, - ) -> Tuple[str, Optional[pb.ServerMeta]]: ... + ) -> Tuple[str, Optional[pb.ServerMeta]]: + """ + Query topic for data. + + NOTE: only one of `what` can be specified at a time + + :param topic: topic to get + :type topic: str + :param what: fields to get + :type what: Literal["desc", "sub", "data", "tags"], optional + :param data: data query options + :type data: Optional[pb.GetOpts] + :param extra: extra data + :type extra: Optional[pb.ClientExtra] + :return: tid and meta + :rtype: Tuple[str, Optional[pb.ServerMeta]] + """ @overload async def get( self, /, topic: str, - what: Literal["tags"], + what: Optional[Literal["tags"]] = None, *, extra: Optional[pb.ClientExtra] = None, - ) -> Tuple[str, Optional[pb.ServerMeta]]: ... + ) -> Tuple[str, Optional[pb.ServerMeta]]: + """ + Query topic for tags. + + NOTE: only one of `what` can be specified at a time + + :param topic: topic to get + :type topic: str + :param what: fields to get + :type what: Literal["desc", "sub", "data", "tags"], optional + :param extra: extra data + :type extra: Optional[pb.ClientExtra] + :return: tid and meta + :rtype: Tuple[str, Optional[pb.ServerMeta]] + """ @overload async def get( self, /, topic: str, - what: Literal["cred"], + what: Optional[Literal["cred"]] = None, *, extra: Optional[pb.ClientExtra] = None, - ) -> Tuple[str, Optional[pb.ServerMeta]]: ... + ) -> Tuple[str, Optional[pb.ServerMeta]]: + """ + Query topic for credentials. + + NOTE: only one of `what` can be specified at a time + :param topic: topic to get + :type topic: str + :param what: fields to get + :type what: Literal["cred"], optional + :param extra: extra data + :type extra: Optional[pb.ClientExtra] + :return: tid and meta + :rtype: Tuple[str, Optional[pb.ServerMeta]] + """ + + @overload async def get( self, /, @@ -430,7 +589,9 @@ async def get( extra: Optional[pb.ClientExtra] = None ) -> Tuple[str, Optional[pb.ServerMeta]]: """ - get data from a topic + Query topic for metadata, such as description or a list of subscribers, or query message history. + The requester must be subscribed and attached to the topic to receive the full response. + Some limited desc and sub information is available without being attached. NOTE: only one of `what` can be specified at a time @@ -449,6 +610,18 @@ async def get( :return: tid and meta :rtype: Tuple[str, Optional[pb.ServerMeta]] """ + + async def get( + self, + /, + topic: str, + what: Optional[Literal["desc", "sub", "data", "tags", "cred"]] = None, + *, + desc: Optional[pb.GetOpts] = None, + sub: Optional[pb.GetOpts] = None, + data: Optional[pb.GetOpts] = None, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Optional[pb.ServerMeta]]: tid = self._get_tid() if what is None: if desc is not None: @@ -542,7 +715,9 @@ async def set( extra: Optional[pb.ClientExtra] = None ) -> Tuple[str, Dict[str, Any]]: """ - set data to a topic + Update topic metadata, delete messages or topic. + The requester is generally expected to be subscribed and attached to the topic. + Only desc.private and requester's sub.mode can be updated without attaching first. :param topic: topic to set :type topic: str @@ -582,9 +757,343 @@ async def set( raise KaruhaBotError(err_text, bot=self, code=ctrl.code) return tid, decode_mapping(ctrl.params) + @overload + async def delete( + self, + what: Literal["msg"], + *, + topic: str, + del_seq: Iterable[pb.SeqRange] = (), + hard: bool = False, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + """ + User can soft-delete hard=false (default) or hard-delete hard=true messages. + Soft-deleting messages hides them from the requesting user but does not delete them from storage. + An R permission is required to soft-delete messages. + Hard-deleting messages deletes message content from storage (head, content) leaving a message stub. + It affects all users. A D permission is needed to hard-delete messages. + Messages can be deleted in bulk by specifying one or more message ID ranges in delseq parameter. + Each delete operation is assigned a unique delete ID. + The greatest delete ID is reported back in the clear of the {meta} message. + + :param what: delete type, defaults to "msg" + :type what: Literal["msg"] + :param topic: topic to delete + :type topic: str + :param del_seq: message ID ranges to delete, defaults to () + :type del_seq: Iterable[pb.SeqRange], optional + :param hard: hard delete, defaults to False + :type hard: bool, optional + :param extra: extra data, defaults to None + :type extra: Optional[pb.ClientExtra], optional + :raises KaruhaBotError: fail to delete messages + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + + @overload + async def delete( + self, + what: Literal["topic"], + *, + topic: str, + hard: bool = False, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + """ + Deleting a topic deletes the topic including all subscriptions, and all messages. + Only the owner can delete a topic. + + :param what: delete type, defaults to "topic" + :type what: Literal["topic"] + :param topic: topic to delete + :type topic: str + :param hard: hard delete, defaults to False + :type hard: bool, optional + :param extra: extra data, defaults to None + :type extra: Optional[pb.ClientExtra], optional + :raises KaruhaBotError: fail to delete topic + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + + @overload + async def delete( + self, + what: Literal["sub"], + *, + topic: str, + user_id: str, + hard: bool = False, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + """ + Deleting a subscription removes specified user from topic subscribers. + It requires an A permission. A user cannot delete own subscription. + A {leave} should be used instead. If the subscription is soft-deleted (default), + it's marked as deleted without actually deleting a record from storage. + + :param what: delete type, defaults to "sub" + :type what: Literal["sub"] + :param topic: topic to delete + :type topic: str + :param user_id: user ID to delete + :type user_id: str + :param hard: hard delete, defaults to False + :type hard: bool, optional + :param extra: extra data, defaults to None + :type extra: Optional[pb.ClientExtra], optional + :raises KaruhaBotError: fail to delete topic + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + + @overload + async def delete( + self, + what: Literal["user"], + *, + user_id: str, + hard: bool = False, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + """ + Deleting a user is a very heavy operation. Use caution. + + :param what: delete type, defaults to "user" + :type what: Literal["user"] + :param user_id: user ID to delete + :type user_id: str + :param hard: hard delete, defaults to False + :type hard: bool, optional + :param extra: extra data, defaults to None + :type extra: Optional[pb.ClientExtra], optional + :raises KaruhaBotError: fail to delete user + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + + @overload + async def delete( + self, + what: Literal["cred"], + *, + cred: pb.ClientCred, + hard: bool = False, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + """ + Delete credential. + Validated credentials and those with no attempts at validation are hard-deleted. + Credentials with failed attempts at validation are soft-deleted which prevents their reuse by the same user. + + :param what: delete type, defaults to "cred" + :type what: Literal["cred"] + :param cred: credential to delete + :type cred: pb.ClientCred + :param hard: hard delete, defaults to False + :type hard: bool, optional + :param extra: extra data, defaults to None + :type extra: Optional[pb.ClientExtra], optional + :raises KaruhaBotError: failto delete credential + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + + @overload + async def delete( + self, + what: Literal["msg", "topic", "sub", "user", "cred"], + *, + topic: Optional[str] = None, + del_seq: Iterable[pb.SeqRange] = (), + user_id: Optional[str] = None, + cred: Optional[pb.ClientCred] = None, + hard: bool = False, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + """ + Delete messages, subscriptions, topics, users. + + :param what: delete type + :type what: Literal["msg", "topic", "sub", "user", "cred"] + :param topic: topic to delete, defaults to None + :type topic: Optional[str], optional + :param del_seq: message ID ranges to delete, defaults to () + :type del_seq: Iterable[pb.SeqRange], optional + :param user_id: user ID to delete, defaults to None + :type user_id: Optional[str], optional + :param cred: credential to delete, defaults to None + :type cred: Optional[pb.ClientCred], optional + :param hard: hard delete, defaults to False + :type hard: bool, optional + :param extra: extra data, defaults to None + :type extra: Optional[pb.ClientExtra], optional + :raises KaruhaBotError: fail to delete messages + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + + async def delete( + self, + what: Literal["msg", "topic", "sub", "user", "cred"], + *, + topic: Optional[str] = None, + del_seq: Iterable[pb.SeqRange] = (), + user_id: Optional[str] = None, + cred: Optional[pb.ClientCred] = None, + hard: bool = False, + extra: Optional[pb.ClientExtra] = None + ) -> Tuple[str, Dict[str, Any]]: + tid = self._get_tid() + ctrl = await self.send_message( + tid, extra=extra, **{"del": pb.ClientDel( + id=tid, what=getattr(pb.ClientDel.What, what.upper()), + topic=topic, + del_seq=del_seq, + user_id=user_id, + cred=cred, + hard=hard + )} + ) + assert isinstance(ctrl, pb.ServerCtrl) + if ctrl.code < 200 or ctrl.code >= 400: # pragma: no cover + err_text = f"fail to delete: {ctrl.text}" + self.logger.error(err_text) + raise KaruhaBotError(err_text, bot=self, code=ctrl.code) + return tid, decode_mapping(ctrl.params) + + async def note_kp(self, /, topic: str) -> None: + """key press, i.e. a typing notification. + The client should use it to indicate that the user is composing a new message. + + :param topic: topic to note + :type topic: str + """ + await self.send_message(note=pb.ClientNote(topic=topic, what=pb.KP)) + + async def note_recv(self, /, topic: str, seq: int) -> None: + """mark a text as received + a {data} message is received by the client software but may not yet seen by user. + + :param topic: topic to mark + :type topic: str + :param seq: sequence id + :type seq: int + """ + await self.send_message(note=pb.ClientNote(topic=topic, what=pb.RECV, seq_id=seq)) + async def note_read(self, /, topic: str, seq: int) -> None: + """mark a text as read + a {data} message is seen (read) by the user. It implies recv as well. + + :param topic: topic to mark + :type topic: str + :param seq: sequence id + :type seq: int + """ await self.send_message(note=pb.ClientNote(topic=topic, what=pb.READ, seq_id=seq)) + async def upload( + self, + path: Union[str, os.PathLike, IOBase] + ) -> Tuple[str, Dict[str, Any]]: + """ + upload a file + + :param path: file path + :type path: Union[str, os.PathLike] + :raises KaruhaBotError: fail to upload file + :return: tid and params + :rtype: Tuple[str, Dict[str, Any]] + """ + tid = self._get_tid() + + try: + async with await self._get_http_session() as session: + self.logger.debug(f"upload request: {tid=} {path=} {session.headers=}") + while True: + url = "/v0/file/u/" + data = FormData() + data.add_field("id", tid) + if isinstance(path, IOBase): + f = path + _path = None + else: + f = open(path, "rb") + _path = os.path.basename(path) + with f: + data.add_field("file", f, filename=_path) + async with session.post(url, data=data) as resp: + ret = await resp.text() + self.logger.debug(f"upload response: {ret}") + ctrl = from_json(ret)["ctrl"] + params = ctrl["params"] + code = ctrl["code"] + if code != 307: + break + url = params["url"] + # If 307 Temporary Redirect is returned, the client must retry the upload at the provided URL. + self.logger.info(f"upload redirected to {url}") + except OSError as e: + raise KaruhaBotError(f"fail to read file {path}", bot=self) from e + except ClientError as e: + err_text = f"fail to upload file {path}: {e}" + self.logger.error(err_text, exc_info=True) + raise KaruhaBotError(err_text, bot=self) from e + assert ctrl["id"] == tid, "tid mismatch" + if code < 200 or code >= 400: # pragma: no cover + err_text = f"fail to upload file {path}: {ctrl['text']}" + self.logger.error(err_text) + raise KaruhaBotError(err_text, bot=self, code=code) + self.logger.info(f"upload file {path}") + return tid, params + + async def download( + self, + url: str, + path: Union[str, os.PathLike, IOBase] + ) -> None: + """ + download a file + + :param url: file url + :type url: str + :param path: file path to save + :type path: Union[str, os.PathLike] + :raises KaruhaBotError: fail to download file + """ + tid = self._get_tid() + try: + async with await self._get_http_session() as session: + self.logger.debug(f"download request: {tid=} {path=} {session.headers=}") + size = 0 + if isinstance(path, IOBase): + async with session.get(url, params={"id": tid}) as resp: + resp.raise_for_status() + async for chunk in resp.content.iter_any(): + if isinstance(path, TextIOBase): + path.write(chunk.decode()) + else: + path.write(chunk) + size += len(chunk) + else: + async with aio_open(path, "wb") as f: + async with session.get(url, params={"id": tid}) as resp: + resp.raise_for_status() + async for chunk in resp.content.iter_any(): + await f.write(chunk) + size = await f.tell() + self.logger.debug(f"download length: {size}") + except OSError as e: + raise KaruhaBotError(f"fail to write file {path}", bot=self) from e + except ClientError as e: + err_text = f"fail to download file {path}: {e}" + self.logger.error(err_text, exc_info=True) + raise KaruhaBotError(err_text, bot=self) from e + self.logger.info(f"download file {path}") + @overload async def send_message( self, @@ -614,7 +1123,7 @@ async def send_message( ) -> Optional[Message]: """set messages to Tinode server - :param wait_tid: if set, it willl wait until a response message with the same tid is received, defaults to None + :param wait_tid: if set, it will wait until a response message with the same tid is received, defaults to None :type wait_tid: Optional[str], optional :return: message which has the same tid :rtype: Optional[Message] @@ -622,7 +1131,7 @@ async def send_message( if self.state != BotState.running: raise KaruhaBotError("bot is not running", bot=self) - client_msg = pb.ClientMsg(**kwds) # type: ignore + client_msg = pb.ClientMsg(**kwds, extra=extra) # type: ignore ret = None if wait_tid is None: await self.queue.put(client_msg) @@ -735,6 +1244,32 @@ def _get_tid(self) -> str: self._tid_counter += 1 return tid + async def _get_http_session(self) -> ClientSession: + if self.server is None: + raise ValueError("server not specified") + web_host = self.server.web_host + + try: + schema, secret = self.config.schema_, self.config.secret + if self.token is not None and self.token_expires > datetime.now(timezone.utc): + schema, secret = "token", self.token + elif schema == "cookie": + schema, secret_bytes = await read_auth_cookie(secret) + secret = base64.b64encode(secret_bytes).decode() + else: + secret = base64.b64encode(secret.encode()).decode() + except Exception as e: # pragma: no cover + err_text = f"fail to read auth secret: {e}" + self.logger.error(err_text) + raise KaruhaBotError(err_text, bot=self) from e + + headers = { + 'X-Tinode-APIKey': self.server.api_key, + "X-Tinode-Auth": f"{schema.title()} {secret}", + "User-Agent": f"KaruhaBot {APP_VERSION}/{LIB_VERSION}", + } + return ClientSession(web_host, headers=headers) + @contextmanager def _wait_reply(self, tid: Optional[str] = None) -> Generator[asyncio.Future, None, None]: tid = tid or self._get_tid() @@ -848,6 +1383,83 @@ def __repr__(self) -> str: return f"" +class AgentBot(Bot): + """ + the bot that runs on the `extra.on_behalf_of` agent + """ + __slots__ = ["on_behalf_of", "login_user_id"] + + def __init__(self, *args: Any, on_behalf_of: str, **kwds: Any) -> None: + super().__init__(*args, **kwds) + self.on_behalf_of = on_behalf_of + + @classmethod + def from_bot(cls, bot: Bot, /, on_behalf_of: str, name: Optional[str] = None) -> Self: + config = bot.config.model_copy() + if name is None: + config.name = f"{config.name}_agent_{on_behalf_of}" + else: + config.name = name + return cls(config, bot.server, bot.logger.level, on_behalf_of=on_behalf_of) + + @overload + async def send_message( + self, + wait_tid: str, + /, + *, + extra: Optional[pb.ClientExtra] = None, + **kwds: Optional[Message], + ) -> Message: ... + + @overload + async def send_message( + self, + wait_tid: None = None, + /, + *, + extra: Optional[pb.ClientExtra] = None, + **kwds: Optional[Message], + ) -> None: ... + + async def send_message( + self, + wait_tid: Optional[str] = None, + /, *, + extra: Optional[pb.ClientExtra] = None, + **kwds: Optional[Message] + ) -> Optional[Message]: + """set messages to Tinode server + + :param wait_tid: if set, it willl wait until a response message with the same tid is received, defaults to None + :type wait_tid: Optional[str], optional + :param extra: extra fields, defaults to None + :type extra: Optional[pb.ClientExtra], optional + :return: message which has the same tid + :rtype: Optional[Message] + """ + exclude_msg = {"hi", "login"} + keys = set(kwds) + if exclude_keys := keys & exclude_msg: + if len(exclude_keys) != len(keys): + raise KaruhaBotError("cannot mix message types", bot=self) + elif extra is None: + extra = pb.ClientExtra(on_behalf_of=self.on_behalf_of) + elif not extra.on_behalf_of: + extra.on_behalf_of = self.on_behalf_of + elif extra.on_behalf_of != self.on_behalf_of: + raise KaruhaBotError(f"on_behalf_of mismatch: {extra.on_behalf_of} != {self.on_behalf_of}", bot=self) + return await super().send_message(wait_tid, extra=extra, **kwds) + + @property + def user_id(self) -> str: + return self.on_behalf_of + + @user_id.setter + def user_id(self, val: str) -> None: + self.login_user_id = val + + def get_stream(channel: grpc_aio.Channel, /) -> grpc_aio.StreamStreamMultiCallable: # pragma: no cover return channel.stream_stream( '/pbx.Node/MessageLoop', @@ -856,16 +1468,16 @@ def get_stream(channel: grpc_aio.Channel, /) -> grpc_aio.StreamStreamMultiCallab ) -async def read_auth_cookie(cookie_file_name) -> Union[Tuple[str, bytes], Tuple[None, None]]: +async def read_auth_cookie(cookie_file_name: Union[str, os.PathLike]) -> Tuple[str, bytes]: """Read authentication token from a file""" async with aio_open(cookie_file_name, 'r') as cookie: params = from_json(await cookie.read()) schema = params.get("schema") secret = params.get('secret') if schema is None or secret is None: - return None, None + raise ValueError("invalid cookie file") if schema == 'token': - secret = b64decode(secret) + secret = base64.b64decode(secret) else: secret = secret.encode('utf-8') return schema, secret diff --git a/karuha/config.py b/karuha/config.py index 365ff87..affb206 100644 --- a/karuha/config.py +++ b/karuha/config.py @@ -1,6 +1,6 @@ from pathlib import Path from typing import Iterable, Literal, Optional, Tuple, Union -from pydantic import AnyUrl, BaseModel, Field, PrivateAttr, ValidationError, field_validator +from pydantic import AnyUrl, BaseModel, Field, HttpUrl, PrivateAttr, ValidationError, field_validator from typing_extensions import Annotated from .logger import logger, Level @@ -8,7 +8,7 @@ class Server(BaseModel): host: Annotated[str, AnyUrl] = "localhost:16060" - web_host: Annotated[str, AnyUrl] = "localhost:6060" + web_host: Annotated[str, HttpUrl] = "http://localhost:6060" api_key: str = "AQEAAAABAAD_rAp4DJh05a1HAwFT3A6K" ssl: bool = False ssl_host: Optional[str] = None @@ -22,7 +22,9 @@ class Bot(BaseModel): name: str = "chatbot" schema_: Literal["basic", "token", "cookie"] = Field(alias="schema") secret: str + auto_login: bool = True auto_subscribe_new_user: bool = False + file_size_threshold: int = 1024 * 1024 class Config(BaseModel): @@ -74,7 +76,7 @@ def load_config( with open(path, "r", encoding=encoding) as f: config = Config.model_validate_json(f.read()) except OSError: - logger.warn(f"failed to load file '{path}'", exc_info=True) + logger.warning(f"failed to load file '{path}'", exc_info=True) config = Config(_path=path) # type: ignore if auto_create: config.save(path, encoding=encoding, ignore_error=True) @@ -90,16 +92,19 @@ def load_config( def init_config( - server: Union[dict, Server] = Server(), + server: Union[dict, Server, Config] = Server(), bots: Optional[Iterable[Union[dict, Bot]]] = None, log_level: Level = "INFO" ) -> Config: global _config - _config = Config( - server=server, # type: ignore - bots=bots or (), # type: ignore - log_level=log_level - ) + if isinstance(server, Config): + _config = server + else: + _config = Config( + server=server, # type: ignore + bots=bots or (), # type: ignore + log_level=log_level + ) return _config @@ -107,3 +112,8 @@ def save_config() -> Path: config = get_config() config.save() return config._path + + +def reset_config() -> None: + global _config + _config = None diff --git a/karuha/data/sub.py b/karuha/data/sub.py index 29c591b..e9a4f1e 100644 --- a/karuha/data/sub.py +++ b/karuha/data/sub.py @@ -30,7 +30,7 @@ async def ensure_sub(bot: Bot, topic: str) -> bool: def reset_sub(bot: Bot) -> None: - del _subscriptions[bot.uid] + _subscriptions[bot.uid].clear() @on(SubscribeEvent) diff --git a/karuha/event/bot.py b/karuha/event/bot.py index 7e59132..259e0b0 100644 --- a/karuha/event/bot.py +++ b/karuha/event/bot.py @@ -39,26 +39,30 @@ class BotInitEvent(BotEvent): async def __default_handler__(self) -> None: bot = self.bot try: - await self.bot.hello() + await bot.hello() except Exception: - self.bot.logger.error("failed to connect to server, restarting") - self.bot.restart() + bot.logger.error("failed to connect to server, restarting") + bot.restart() + return + + if not bot.config.auto_login: + bot.logger.info("auto login is disabled, skipping") return retry = bot.server.retry if bot.server is not None else 0 for i in range(retry): try: - await self.bot.login() + await bot.login() except (asyncio.TimeoutError, KaruhaBotError): - self.bot.logger.warning(f"login failed, retrying {i+1} times") + bot.logger.warning(f"login failed, retrying {i+1} times") else: break else: try: - await self.bot.login() + await bot.login() except (asyncio.TimeoutError, KaruhaBotError): - self.bot.logger.error("login failed, cancel the bot") - self.bot.cancel() + bot.logger.error("login failed, cancel the bot") + bot.cancel() class BotReadyEvent(BotEvent): diff --git a/karuha/runner.py b/karuha/runner.py index 8f68a4b..264d56b 100644 --- a/karuha/runner.py +++ b/karuha/runner.py @@ -1,129 +1,17 @@ import asyncio import contextlib import signal -from typing import Dict, List, MutableSequence, Optional, Awaitable +from typing import Dict, List, Optional -from .config import get_config +from .config import get_config, reset_config from .bot import Bot, BotState from .event.sys import SystemStartEvent, SystemStopEvent from .logger import logger +from .utils.gathering import DynamicGatheringFuture _bot_cache: Dict[str, Bot] = {} -_gathering_future: Optional["DynamicGatheringFuture"] = None - - -class DynamicGatheringFuture(asyncio.Future): - """ - A dynamic version of `asyncio.tasks._GatheringFuture`. - - It allows to add new tasks to the gathering future. - """ - - __slots__ = ["children", "nfinished", "_cancel_requested"] - - def __init__(self, children: MutableSequence[asyncio.Future], *, loop=None): - super().__init__(loop=loop) - self.children = children - self.nfinished = 0 - self._cancel_requested = False - done_futs = [] - - for child in children: - if child.done(): - done_futs.append(child) - else: - child.add_done_callback(self._done_callback) - - for child in done_futs: - self._done_callback(child) - - def add_task(self, fut: asyncio.Future) -> None: - if self.done(): # pragma: no cover - raise RuntimeError("cannot add child to cancelled parent") - fut.add_done_callback(self._done_callback) - self.children.append(fut) - - def add_coroutine(self, coro: Awaitable) -> None: - fut = asyncio.ensure_future(coro) - if fut is not coro: - # 'coro' was not a Future, therefore, 'fut' is a new - # Future created specifically for 'coro'. Since the caller - # can't control it, disable the "destroy pending task" - # warning. - fut._log_destroy_pending = False # type: ignore[attr-defined] - self.add_task(fut) - - def cancel(self, msg=None) -> bool: # pragma: no cover - if self.done(): - return False - ret = False - for child in self.children: - cancelled = child.cancel(msg=msg) if msg is not None else child.cancel() # type: ignore - if cancelled: - ret = True - if ret: - # If any child tasks were actually cancelled, we should - # propagate the cancellation request regardless of - # *return_exceptions* argument. See issue 32684. - self._cancel_requested = True - return ret - - def _done_callback(self, fut: asyncio.Future) -> None: - self.nfinished += 1 - - if self.done(): # pragma: no cover - if not fut.cancelled(): - # Mark exception retrieved. - fut.exception() - return - - if fut.cancelled(): - # Check if 'fut' is cancelled first, as - # 'fut.exception()' will *raise* a CancelledError - # instead of returning it. - try: - exc = fut._make_cancelled_error() # type: ignore - except AttributeError: - exc = asyncio.CancelledError() - self.set_exception(exc) - return - else: - exc = fut.exception() - if exc is not None: # pragma: no cover - self.set_exception(exc) - return - - if self.nfinished == len(self.children): - # All futures are done; create a list of results - # and set it to the 'outer' future. - results = [] - - for fut in self.children: - if fut.cancelled(): # pragma: no cover - # Check if 'fut' is cancelled first, as - # 'fut.exception()' will *raise* a CancelledError - # instead of returning it. - res = asyncio.CancelledError( - getattr(fut, "_cancel_message", '') or '' - ) - else: - res = fut.exception() - if res is None: - res = fut.result() - results.append(res) - - if self._cancel_requested: # pragma: no cover - # If gather is being cancelled we must propagate the - # cancellation regardless of *return_exceptions* argument. - # See issue 32684. - try: - exc = fut._make_cancelled_error() # type: ignore - except AttributeError: - exc = asyncio.CancelledError() - self.set_exception(exc) - else: - self.set_result(results) +_gathering_future: Optional[DynamicGatheringFuture] = None def try_get_bot(name: str = "chatbot") -> Optional[Bot]: @@ -178,6 +66,10 @@ def get_all_bots() -> List[Bot]: return list(_bot_cache.values()) +def cancel_all_bots() -> bool: + return False if _gathering_future is None else _gathering_future.cancel() + + def _get_running_loop() -> asyncio.AbstractEventLoop: if _gathering_future is None: # pragma: no cover raise RuntimeError("no running loop") @@ -185,8 +77,8 @@ def _get_running_loop() -> asyncio.AbstractEventLoop: def _handle_sigterm() -> None: # pragma: no cover - for bot in _bot_cache.values(): - bot.cancel() + if _gathering_future is not None: + _gathering_future.cancel() async def async_run() -> None: @@ -237,4 +129,10 @@ def run() -> None: # pragma: no cover asyncio.run(async_run()) +def reset() -> None: + global _bot_cache + _bot_cache = {} + reset_config() + + from .plugin_server import init_server diff --git a/karuha/session.py b/karuha/session.py index e516e1f..bfc6793 100644 --- a/karuha/session.py +++ b/karuha/session.py @@ -1,26 +1,55 @@ import asyncio +import mimetypes import os import re -from typing import Any, Dict, List, NoReturn, Optional, Union, overload +import weakref +from functools import partialmethod +from typing import (Any, Dict, Iterable, List, NoReturn, Optional, Union, + overload) + +from aiofiles.ospath import getsize from tinode_grpc import pb from typing_extensions import Self import karuha + from .bot import Bot from .exception import KaruhaRuntimeError class BaseSession(object): + """Represents a session for interacting with a bot on a specific topic. + + This class manages the communication between the bot and the user, + allowing for sending messages, handling attachments, + and managing subscriptions to topics. + """ __slots__ = ["bot", "topic", "_task", "_closed"] def __init__(self, /, bot: Bot, topic: str) -> None: + """Initializes the BaseSession with a bot and a topic. + + :param bot: the bot instance to interact with + :type bot: Bot + :param topic: the topic to send messages to + :type topic: str + """ self.bot = bot self.topic = topic self._closed = False self._task = None - + def bind_task(self, task: Optional[asyncio.Task] = None) -> Self: + """Binds an asyncio task to the session. + + :param task: the task to bind; If None, the current task is used, defaults to None + :type task: Optional[asyncio.Task], optional + :return: the current session instance + :rtype: Self + """ self._task = task or asyncio.current_task() + if self._task is not None: + self._task.add_done_callback(lambda _: self.close()) return self async def send( @@ -30,8 +59,25 @@ async def send( head: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, topic: Optional[str] = None, - replace: Optional[int] = None + replace: Optional[int] = None, + attachments: Optional[Iterable[str]] = None ) -> Optional[int]: + """Send a message to the specified topic. + + :param text: the text or Drafty model to send + :type text: Union[str, dict, "Drafty", "BaseText"] + :param head: additional metadata to include in the message, defaults to None + :type head: Optional[Dict[str, Any]], optional + :param timeout: the timeout in seconds for sending the message, defaults to None + :type timeout: Optional[float], optional + :param topic: the topic to send the message to, defaults to None + :type topic: Optional[str], optional + :param replace: the message ID to replace, defaults to None + :type replace: Optional[int], optional + :param attachments: the list of attachment URLs to include in the message, defaults to None + :type attachments: Optional[Iterable[str]], optional + :return: the message ID if successful, None otherwise + :rtype: Optional[int]""" topic = topic or self.topic await self.subscribe(topic) if replace is not None: @@ -46,42 +92,66 @@ async def send( head = head or {} head["mime"] = "text/x-drafty" _, params = await asyncio.wait_for( - self.bot.publish(topic, text, head=head), - timeout + self.bot.publish( + topic, text, head=head, + extra=pb.ClientExtra(attachments=attachments) if attachments else None + ), + timeout, ) return params.get("seq") send_text = send - async def send_file( + async def send_attachment( self, path: Union[str, os.PathLike], /, *, name: Optional[str] = None, mime: Optional[str] = None, + attachment_cls_name: str = "File", + force_upload: bool = False, **kwds: Any ) -> Optional[int]: + """Send an attachment to the specified topic. + + :param path: the path to the file to send + :type path: Union[str, os.PathLike] + :param name: the name of the file, defaults to None + :type name: Optional[str], optional + :param mime: the MIME type of the file, defaults to None + :type mime: Optional[str], optional + :param attachment_cls_name: the name of the attachment class to use, defaults to "File" + :type attachment_cls_name: str, optional + :param force_upload: force upload even if the file size is below the threshold, defaults to False + :type force_upload: bool, optional + :return: the message ID if successful, None otherwise + :rtype: Optional[int] + """ + self._ensure_status() + attachment_cls = getattr(textchain, attachment_cls_name, None) + if attachment_cls is None: + raise KaruhaRuntimeError("unknown attachment type name") + elif not issubclass(attachment_cls, _Attachment): + raise KaruhaRuntimeError("unknown attachment type") + size = await getsize(path) + if force_upload or size < self.bot.config.file_size_threshold: + return await self.send( + await attachment_cls.from_file( + path, name=name, mime=mime + ), + **kwds + ) + _, upload_params = await self.bot.upload(path) + url = upload_params["url"] + mime = mime or mimetypes.guess_type(path)[0] return await self.send( - await File.from_file( - path, name=name, mime=mime - ), + attachment_cls.from_url(url, name=name, mime=mime), + attachments=[url], **kwds ) - async def send_image( - self, - path: Union[str, os.PathLike], - /, *, - name: Optional[str] = None, - mime: Optional[str] = None, - **kwds: Any - ) -> Optional[int]: - return await self.send( - await Image.from_file( - path, name=name, mime=mime - ), - **kwds - ) + send_file = partialmethod(send_attachment, attachment_cls_name="File") + send_image = partialmethod(send_attachment, attachment_cls_name="Image") async def wait_reply( self, @@ -90,6 +160,19 @@ async def wait_reply( pattern: Optional[re.Pattern] = None, priority: float = 1.2 ) -> "Message": + """Wait for a reply from the specified topic. + + :param topic: the topic to wait for, defaults to None + :type topic: Optional[str], optional + :param user_id: the user ID to wait for, defaults to None + :type user_id: Optional[str], optional + :param pattern: the pattern to match, defaults to None + :type pattern: Optional[re.Pattern], optional + :param priority: the priorityof the message, defaults to 1.2 + :type priority: float, optional + :return: the message + :rtype: "Message" + """ self._ensure_status() loop = asyncio.get_running_loop() dispatcher = SessionDispatcher( @@ -109,6 +192,18 @@ async def send_form( topic: Optional[str] = None, **kwds: Any, ) -> int: + """Send a form to the specified topic. + + :param title: the title of the form + :type title: Union[str, "BaseText"] + :param button: the buttons to include in the form + :type button: Union[str, "Button"], optional + :param topic: the topic to send the form to, defaults to None + :type topic: Optional[str], optional + :return: the message ID if successful, None otherwise + :rtype: int + """ + self._ensure_status() chain = TextChain( Bold(content=PlainText(title)) if isinstance(title, str) else title ) @@ -133,7 +228,7 @@ async def send_form( raise KaruhaRuntimeError("failed to fetch message id") loop = asyncio.get_running_loop() - dispatcher = ButtonReplyDispatcher( + dispatcher = _ButtonReplyDispatcher( self, loop.create_future(), seq_id=seq_id, @@ -144,15 +239,28 @@ async def send_form( resp = await dispatcher.wait() except: # noqa: E722 # pragma: no cover # The dispatcher will automatically deactivate after receiving a message, - # so you only need to actively deactivate it when an exception occurs + # so we only need to actively deactivate it when an exception occurs dispatcher.deactivate() raise return pred_resp.index(resp) async def confirm(self, title: Union[str, "BaseText"], **kwds: Any) -> bool: + """A convenience method to send a form and wait for a reply. + + :param title: the title of the form + :type title: Union[str, "BaseText"] + :return: True if the user selects "Yes", False otherwise + :rtype: bool + """ return not await self.send_form(title, "Yes", "No", **kwds) async def finish(self, text: Union[str, dict, "Drafty", "BaseText"], /, **kwds: Any) -> NoReturn: + """Finish the session and send a message to the user. + + :param text: the message to send + :type text: Union[str, dict, "Drafty", "BaseText"] + :raises KaruhaRuntimeError: if the session is not active + """ await self.send(text, **kwds) self.cancel() @@ -164,6 +272,15 @@ async def subscribe( get: Union[pb.GetQuery, str, None] = "desc sub", **kwds: Any ) -> None: + """Subscribe to the specified topic. + + :param topic: the topic to subscribe to, defaults to None + :type topic: Optional[str], optional + :param force: whether to force the subscription, defaults to False + :type force: bool, optional + :param get: the query to use, defaults to "desc sub + :type get: Union[pb.GetQuery, str, None], optional + """ self._ensure_status() topic = topic or self.topic if karuha.data.has_sub(self.bot, topic) and not force: @@ -171,15 +288,42 @@ async def subscribe( await self.bot.subscribe(topic, get=get, **kwds) async def leave(self, topic: Optional[str] = None, *, force: bool = False, **kwds: Any) -> None: + """Leave the specified topic. + + :param topic: the topic to leave, defaults to None + :type topic: Optional[str], optional + :param force: whether to force the leave, defaults to False + :type force: bool, optional + """ self._ensure_status() topic = topic or self.topic if karuha.data.has_sub(self.bot, topic) or force: await self.bot.leave(topic, **kwds) - + async def get_user(self, user_id: str, *, ensure_user: bool = False) -> "karuha.data.BaseUser": + """Get the user data from the specified user ID. + + :param user_id: the user ID to get the data from + :type user_id: str + :param ensure_user: whether to ensure that the user exists, defaults to False + :type ensure_user: bool, optional + :return: the user data + :rtype: "karuha.data.BaseUser" + """ + self._ensure_status() return await karuha.data.get_user(self.bot, user_id, ensure_user=ensure_user) async def get_topic(self, topic: Optional[str] = None, *, ensure_topic: bool = False) -> "karuha.data.BaseTopic": + """Get the topic data from the specified topic ID. + + :param topic: the topic ID to get the data from, defaults to None + :type topic: Optional[str], optional + :param ensure_topic: whether to ensure that the topic exists, defaults to False + :type ensure_topic: bool, optional + :return: the topic data + :rtype: "karuha.data.BaseTopic" + """ + self._ensure_status() return await karuha.data.get_topic(self.bot, topic or self.topic, ensure_topic=ensure_topic) @overload @@ -188,7 +332,16 @@ async def get_data( topic: Optional[str] = None, *, seq_id: int, - ) -> "Message": ... + ) -> "Message": + """Get the message data from the specified message ID. + + :param topic: the topic ID to get the data from, defaults to None + :type topic: Optional[str], optional + :param seq_id: the message ID to get the data from + :type seq_id: int + :return: the message data + :rtype: "Message" + """ @overload async def get_data( @@ -197,7 +350,18 @@ async def get_data( *, low: Optional[int] = None, hi: Optional[int] = None, - ) -> List["Message"]: ... + ) -> List["Message"]: + """Get the message data from the specified range. + + :param topic: the topic ID to get the data from, defaults to None + :type topic: Optional[str], optional + :param low: the lower bound of the range, defaults to None + :type low: Optional[int], optional + :param hi: the upper bound of the range, defaults to None + :type hi: Optional[int], optional + :return: the message data + :rtype: List["Message"] + """ @overload async def get_data( @@ -207,7 +371,20 @@ async def get_data( seq_id: Optional[int] = None, low: Optional[int] = None, hi: Optional[int] = None, - ) -> Union["Message", List["Message"]]: ... + ) -> Union["Message", List["Message"]]: + """Get the message data from the specified range. + + :param topic: the topic ID to get the data from, defaults to None + :type topic: Optional[str], optional + :param seq_id: the message ID to get the data from, defaults to None + :type seq_id: Optional[int], optional + :param low: the lower bound of the range, defaults to None + :type low: Optional[int], optional + :param hi: the upper bound of the range, defaults to None + :type hi: Optional[int], optional + :return: the message data + :rtype: Union["Message", List["Message"]] + """ async def get_data( self, @@ -224,34 +401,50 @@ async def get_data( ) def close(self) -> None: + """Close the session.""" self._closed = True def cancel(self) -> NoReturn: + """Cancel the session. + + :raises asyncio.CancelledError: cancels the session + """ self.close() + if ( + self._task is not None + and self._task is not asyncio.current_task() + and not self._task.done() + ): + self._task.cancel() raise asyncio.CancelledError - def _ensure_status(self) -> None: - if self._closed: - raise KaruhaRuntimeError("session is closed") - @property def closed(self) -> bool: + """Whether the session is closed.""" return self._closed async def __aenter__(self) -> Self: + """Enter the session.""" + self.bind_task() await self.subscribe() return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit the session.""" self.close() + def _ensure_status(self) -> None: + if self._closed: + raise KaruhaRuntimeError("session is closed") + +from .event.message import MessageDispatcher, get_message_lock +from .text import textchain from .text.drafty import Drafty -from .text.textchain import (BaseText, Bold, Button, File, Form, Image, - NewLine, PlainText, TextChain) from .text.message import Message +from .text.textchain import (BaseText, Bold, Button, Form, NewLine, PlainText, + TextChain, _Attachment) from .utils.dispatcher import FutureDispatcher -from .event.message import MessageDispatcher, get_message_lock class SessionDispatcher(MessageDispatcher, FutureDispatcher[Message]): @@ -286,7 +479,7 @@ def match(self, message: Message, /) -> float: return self.priority -class ButtonReplyDispatcher(SessionDispatcher): +class _ButtonReplyDispatcher(SessionDispatcher): __slots__ = ["seq_id", "_cache"] def __init__( @@ -319,9 +512,10 @@ def match(self, message: Message) -> float: resp = value.get("resp") if value.get("seq") == self.seq_id: self._cache[id(message)] = resp + weakref.finalize(message, self._cache.pop, id(message), None) return self.priority return 0 def run(self, message: Message) -> None: - resp = self._cache.get((id(message))) + resp = self._cache[id(message)] self.future.set_result(resp) diff --git a/karuha/text/textchain.py b/karuha/text/textchain.py index 3cfed30..24e7ce1 100644 --- a/karuha/text/textchain.py +++ b/karuha/text/textchain.py @@ -466,7 +466,7 @@ def from_bytes( ref: Optional[str] = None, **kwds: Any ) -> Self: - return cls( # type: ignore + return cls( mime=mime or "text/plain", name=name, ref=ref, @@ -475,6 +475,22 @@ def from_bytes( **kwds ) + @classmethod + def from_url( + cls, + url: str, + *, + mime: Optional[str] = None, + name: Optional[str] = None, + **kwds: Any + ) -> Self: + return cls( + mime=mime or "text/plain", + name=name, + ref=url, + **kwds + ) + @classmethod async def from_file( cls, @@ -493,7 +509,7 @@ async def from_file( ref=ref, **kwds ) - + async def save(self, path: Union[str, os.PathLike, None] = None) -> None: path = path or self.name if path is None: diff --git a/karuha/utils/gathering.py b/karuha/utils/gathering.py new file mode 100644 index 0000000..a5ad01b --- /dev/null +++ b/karuha/utils/gathering.py @@ -0,0 +1,115 @@ +import asyncio +from typing import Awaitable, MutableSequence + + +class DynamicGatheringFuture(asyncio.Future): + """ + A dynamic version of `asyncio.tasks._GatheringFuture`. + + It allows to add new tasks to the gathering future. + """ + + __slots__ = ["children", "nfinished", "_cancel_requested"] + + def __init__(self, children: MutableSequence[asyncio.Future], *, loop=None): + super().__init__(loop=loop) + self.children = children + self.nfinished = 0 + self._cancel_requested = False + done_futs = [] + + for child in children: + if child.done(): + done_futs.append(child) + else: + child.add_done_callback(self._done_callback) + + for child in done_futs: + self._done_callback(child) + + def add_task(self, fut: asyncio.Future) -> None: + if self.done(): # pragma: no cover + raise RuntimeError("cannot add child to cancelled parent") + fut.add_done_callback(self._done_callback) + self.children.append(fut) + + def add_coroutine(self, coro: Awaitable) -> None: + fut = asyncio.ensure_future(coro) + if fut is not coro: + # 'coro' was not a Future, therefore, 'fut' is a new + # Future created specifically for 'coro'. Since the caller + # can't control it, disable the "destroy pending task" + # warning. + fut._log_destroy_pending = False # type: ignore[attr-defined] + self.add_task(fut) + + def cancel(self, msg=None) -> bool: # pragma: no cover + if self.done(): + return False + ret = False + for child in self.children: + cancelled = child.cancel(msg=msg) if msg is not None else child.cancel() # type: ignore + if cancelled: + ret = True + if ret: + # If any child tasks were actually cancelled, we should + # propagate the cancellation request regardless of + # *return_exceptions* argument. See issue 32684. + self._cancel_requested = True + return ret + + def _done_callback(self, fut: asyncio.Future) -> None: + self.nfinished += 1 + + if self.done(): # pragma: no cover + if not fut.cancelled(): + # Mark exception retrieved. + fut.exception() + return + + if fut.cancelled(): + # Check if 'fut' is cancelled first, as + # 'fut.exception()' will *raise* a CancelledError + # instead of returning it. + try: + exc = fut._make_cancelled_error() # type: ignore + except AttributeError: + exc = asyncio.CancelledError() + self.set_exception(exc) + return + else: + exc = fut.exception() + if exc is not None: # pragma: no cover + self.set_exception(exc) + return + + if self.nfinished == len(self.children): + # All futures are done; create a list of results + # and set it to the 'outer' future. + results = [] + + for fut in self.children: + if fut.cancelled(): # pragma: no cover + # Check if 'fut' is cancelled first, as + # 'fut.exception()' will *raise* a CancelledError + # instead of returning it. + res = asyncio.CancelledError( + getattr(fut, "_cancel_message", '') or '' + ) + else: + res = fut.exception() + if res is None: + res = fut.result() + results.append(res) + + if self._cancel_requested: # pragma: no cover + # If gather is being cancelled we must propagate the + # cancellation regardless of *return_exceptions* argument. + # See issue 32684. + try: + exc = fut._make_cancelled_error() # type: ignore + except AttributeError: + exc = asyncio.CancelledError() + self.set_exception(exc) + else: + self.set_result(results) diff --git a/karuha/version.py b/karuha/version.py index 0d0bca5..c656e58 100644 --- a/karuha/version.py +++ b/karuha/version.py @@ -1,4 +1,4 @@ from importlib.metadata import distribution -APP_VERSION = __version__ = "0.2.1" +APP_VERSION = __version__ = "0.2.2" LIB_VERSION = distribution("tinode_grpc").version diff --git a/setup.py b/setup.py index 32f572d..6b5f0ba 100644 --- a/setup.py +++ b/setup.py @@ -51,16 +51,17 @@ packages=find_packages(), python_requires=">=3.8", install_requires=[ - "typing_extensions>=4.0", + "typing_extensions>=4.9", "grpcio>=1.40.0", "tinode-grpc>=0.20.0b3", "pydantic>2.0", - "aiofiles" + "aiohttp>=3.7", + "aiofiles>=23.1" ], extras_require={ - "image": ["pillow"], + "image": ["pillow>=10.0"], "data": ["greenback"], - "all": ["pillow", "greenback"], + "all": ["pillow>=10.0", "greenback"], }, classifiers=[ @@ -72,6 +73,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Communications :: Chat", ] diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/client/test_bot.py b/tests/client/test_bot.py new file mode 100644 index 0000000..20eb516 --- /dev/null +++ b/tests/client/test_bot.py @@ -0,0 +1,69 @@ +from functools import partial +from io import StringIO + +from tinode_grpc import pb + +from karuha import BaseSession, MessageSession, add_bot, remove_bot, on_rule +from karuha.bot import AgentBot +from karuha.event.bot import BotReadyEvent +from ..utils import AsyncBotClientTestCase + + +class TestBotClient(AsyncBotClientTestCase): + async def test_hi(self) -> None: + await self.bot.hello() + await self.bot.hello(lang="CN") + + async def test_upload(self) -> None: + f = StringIO("Hello world!") + _, params = await self.bot.upload(f) + self.assertTrue("url" in params) + + f = StringIO() + await self.bot.download(params["url"], f) + self.assertEqual(f.getvalue(), "Hello world!") + + async def test_account(self) -> None: + _, params = await self.bot.account( + "newkr_test", + "basic", + b"test:test123456", + cred=[pb.ClientCred(method="email", value="test@example.com")], + do_login=False, + ) + uid = params["user"] + await self.bot.delete("user", user_id=uid, hard=True) + + async def test_on_behalf_of(self) -> None: + _, params = await self.bot.account( + "newkr_test", + "basic", + b"test:test123456", + cred=[pb.ClientCred(method="email", value="test@example.com")], + do_login=False, + ) + uid = params["user"] + try: + agent_bot = AgentBot.from_bot(self.bot, on_behalf_of=uid) + add_bot(agent_bot) + self.addCleanup(partial(remove_bot, agent_bot)) + + with self.catchEvent(BotReadyEvent) as catcher: + ev = await catcher.catch_event() + self.assertIs(ev.bot, agent_bot) + self.assertEqual(agent_bot.user_id, uid) + self.assertEqual(agent_bot.login_user_id, self.bot.user_id) + + @on_rule(user_id=uid, bot=self.bot) + async def _handler(session: MessageSession, text: str) -> None: + self.assertEqual(text, "test") + await session.send("test_reply") + + self.addCleanup(_handler.deactivate) + + async with BaseSession(agent_bot, self.bot.user_id) as session: + await session.send("test") + reply = await session.wait_reply() + self.assertEqual(reply.plain_text, "test_reply") + finally: + await self.bot.delete("user", user_id=uid, hard=True) diff --git a/tests/utils.py b/tests/utils.py index 8e22d18..d67dbb2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,21 +2,21 @@ import json from time import time from types import coroutine -from typing import Any, Awaitable, Dict, Generator, Optional -from unittest import IsolatedAsyncioTestCase +from typing import Any, Awaitable, ClassVar, Dict, Generator, Optional, TypeVar +from unittest import IsolatedAsyncioTestCase, SkipTest from grpc import ChannelConnectivity from grpc import aio as grpc_aio from tinode_grpc import pb from typing_extensions import Self -from karuha import async_run, try_add_bot +from karuha import Config, async_run, get_bot, try_add_bot, cancel_all_bots, reset from karuha.bot import Bot, BotState from karuha.command.collection import new_collection from karuha.command.command import CommandMessage, FunctionCommand from karuha.config import Server as ServerConfig -from karuha.config import init_config -from karuha.store import T +from karuha.config import init_config, load_config +from karuha.event.bot import BotReadyEvent from karuha.text.message import Message from karuha.utils.event_catcher import T_Event from karuha.utils.event_catcher import EventCatcher as _EventCatcher @@ -24,6 +24,8 @@ TEST_TIMEOUT = 3 +T = TypeVar("T") + @coroutine def run_forever() -> Generator[None, None, None]: @@ -196,7 +198,12 @@ async def catch_event(self, timeout: float = TEST_TIMEOUT) -> T_Event: class AsyncBotTestCase(IsolatedAsyncioTestCase): bot = bot_mock - config = init_config(log_level="DEBUG") + config: ClassVar[Config] + + @classmethod + def setUpClass(cls) -> None: + reset() + cls.config = init_config(log_level="DEBUG") async def asyncSetUp(self) -> None: self.assertEqual(self.bot.state, BotState.stopped) @@ -205,10 +212,12 @@ async def asyncSetUp(self) -> None: await self.bot.wait_init() async def asyncTearDown(self) -> None: - loop = asyncio.get_running_loop() self.assertEqual(self.bot.state, BotState.running) - self.bot.cancel() - await loop.shutdown_asyncgens() + cancel_all_bots() + try: + await self.wait_for(self._main_task) + except asyncio.CancelledError: + pass catchEvent = EventCatcher @@ -240,6 +249,47 @@ async def wait_for(self, future: Awaitable[T], /, timeout: Optional[float] = TES return await asyncio.wait_for(future, timeout) +class AsyncBotClientTestCase(IsolatedAsyncioTestCase): + config_path = "config.json" + bot_name = "chatbot" + auto_login: ClassVar[bool] = True + bot: ClassVar[Bot] + + __unittest_skip__ = False + __unittest_skip_why__ = None + + @classmethod + def setUpClass(cls) -> None: + try: + cls.config = load_config(cls.config_path, auto_create=False) + cls.bot = get_bot(cls.bot_name) + except Exception: + cls.__unittest_skip__ = True + cls.__unittest_skip_why__ = "not bot config found" + raise SkipTest(cls.__unittest_skip_why__) from None + cls.bot.config.auto_login = cls.auto_login + + async def asyncSetUp(self) -> None: + self.assertEqual(self.bot.state, BotState.stopped) + try_add_bot(self.bot) + self._main_task = asyncio.create_task(async_run()) + with EventCatcher(BotReadyEvent) as catcher: + await catcher.catch_event() + + async def asyncTearDown(self) -> None: + self.assertEqual(self.bot.state, BotState.running) + cancel_all_bots() + try: + await self.wait_for(self._main_task) + except asyncio.CancelledError: + pass + + catchEvent = EventCatcher + + async def wait_for(self, future: Awaitable[T], /, timeout: Optional[float] = TEST_TIMEOUT) -> T: + return await asyncio.wait_for(future, timeout) + + def new_test_message(content: bytes = b"\"test\"") -> Message: return Message.new( bot_mock, "test", "user", 1, {}, content