Skip to content

Commit

Permalink
Merge pull request #12 from Visecy/dev-data
Browse files Browse the repository at this point in the history
Update chatbot api
Fixed an issue where restarting would not occur in case of network failure
  • Loading branch information
Ovizro authored Mar 6, 2024
2 parents acbe3ef + a4d0903 commit 46425fa
Show file tree
Hide file tree
Showing 30 changed files with 1,076 additions and 380 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,6 @@ disabled.*
# Karuha data file
.bot/
config.json

# tinode
.tn-cli-cookie
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def hi(session: MessageSession, text: str) -> None:
total = text.split(' ', 1)
if len(total) == 1:
await session.send("Hello!")
return
name = total[1]
await session.send(f"Hello {name}!")
```
Expand Down
1 change: 1 addition & 0 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def hi(session: MessageSession, text: str) -> None:
total = text.split(' ', 1)
if len(total) == 1:
await session.send("Hello!")
return
name = total[1]
await session.send(f"Hello {name}!")
```
Expand Down
8 changes: 8 additions & 0 deletions examples/echo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from karuha import on_command, MessageSession


@on_command
async def echo(session: MessageSession, text: str):
argv = text.split(None, 1)
if len(argv) >= 1:
await session.send(argv[1])
1 change: 1 addition & 0 deletions examples/hi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ async def hi(session: MessageSession, text: str) -> None:
total = text.split(' ', 1)
if len(total) == 1:
await session.send("Hello!")
return
name = total[1]
await session.send(f"Hello {name}!")
34 changes: 25 additions & 9 deletions karuha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
from .config import get_config, load_config, init_config, save_config, Config
from .config import Server as ServerConfig, Bot as BotConfig
from .bot import Bot
from .event import on, on_event, Event
from .exception import KaruhaException
from .command import CommandCollection, AbstractCommand, AbstractCommandNameParser, BaseSession, MessageSession, get_collection, on_command
from .event import on, Event
from .command import CommandCollection, AbstractCommand, AbstractCommandParser, BaseSession, MessageSession, get_collection, on_command
from .event.message import reset_message_lock
from .text import Drafty, BaseText, PlainText, Message, TextChain
from .plugin_server import init_server
from .logger import logger


_bot_cache: Dict[str, Bot] = {}
_loop = None


def get_bot(name: str = "chatbot") -> Bot:
Expand All @@ -48,10 +48,20 @@ def add_bot(bot: Bot) -> None:
raise ValueError(f"bot {bot.name} has existed")


def get_all_bots() -> List[Bot]:
return list(_bot_cache.values())


def _get_running_loop() -> asyncio.AbstractEventLoop:
if _loop is None:
raise RuntimeError("no running loop")
return _loop


async def async_run() -> None:
global _loop
config = get_config()
loop = asyncio.get_running_loop()
reset_message_lock()
_loop = asyncio.get_running_loop()

for i in config.bots:
if i.name in _bot_cache:
Expand All @@ -61,16 +71,17 @@ async def async_run() -> None:

tasks: List[asyncio.Task] = []
for bot in _bot_cache.values():
tasks.append(loop.create_task(bot.async_run(config.server)))
logger.debug(f"run bot {bot.config}")
tasks.append(_loop.create_task(bot.async_run(config.server)))

if config.server.enable_plugin: # pragma: no cover
server = init_server(config.server.listen)
loop.call_soon(server.start)
_loop.call_soon(server.start)
else:
server = None

if config.log_level == "DEBUG":
loop.set_debug(True)
_loop.set_debug(True)

if not tasks: # pragma: no cover
logger.warning("no bot found")
Expand All @@ -83,6 +94,7 @@ async def async_run() -> None:
if server is not None: # pragma: no cover
logger.info("stop plugin server")
server.stop(None)
_loop = None


def run() -> None:
Expand All @@ -92,6 +104,9 @@ def run() -> None:
pass


from .plugin_server import init_server


__all__ = [
# bot
"add_bot",
Expand Down Expand Up @@ -119,12 +134,13 @@ def run() -> None:
# command
"CommandCollection",
"AbstractCommand",
"AbstractCommandNameParser",
"AbstractCommandParser",
"get_collection",
"BaseSession",
"MessageSession",
# decorator
"on",
"on_event",
"on_command",
# exception
"KaruhaException"
Expand Down
Loading

0 comments on commit 46425fa

Please sign in to comment.