diff --git a/.all-contributorsrc b/.all-contributorsrc
index c54b1e29e5..2ad013d704 100644
--- a/.all-contributorsrc
+++ b/.all-contributorsrc
@@ -1904,6 +1904,16 @@
"contributions": [
"doc"
]
+ },
+ {
+ "login": "olzhasar",
+ "name": "Olzhas Arystanov",
+ "avatar_url": "https://avatars.githubusercontent.com/u/12471703?v=4",
+ "profile": "https://olzhasar.com",
+ "contributions": [
+ "bug",
+ "doc"
+ ]
}
],
"contributorsPerLine": 7,
diff --git a/README.md b/README.md
index 9ef75fc762..257a6ccf67 100644
--- a/README.md
+++ b/README.md
@@ -582,6 +582,7 @@ see [the contribution guide](CONTRIBUTING.rst).
Mohammed Babelly 💻 |
Charles Duffy 💻 |
Evgeny Demchenko 📖 |
+ Olzhas Arystanov 🐛 📖 |
diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py
index 8e702ea1aa..e4a2df7825 100644
--- a/litestar/handlers/websocket_handlers/listener.py
+++ b/litestar/handlers/websocket_handlers/listener.py
@@ -335,10 +335,6 @@ class WebsocketListener(ABC):
"""A sequence of :class:`Guard <.types.Guard>` callables."""
middleware: list[Middleware] | None = None
"""A sequence of :class:`Middleware <.types.Middleware>`."""
- on_accept: AnyCallable | None = None
- """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been accepted. Can receive any dependencies"""
- on_disconnect: AnyCallable | None = None
- """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been disconnected. Can receive any dependencies"""
receive_mode: WebSocketMode = "text"
""":class:`WebSocket <.connection.WebSocket>` mode to receive data in, either ``text`` or ``binary``."""
send_mode: WebSocketMode = "text"
@@ -380,6 +376,9 @@ def __init__(self, owner: Router) -> None:
self._owner = owner
def to_handler(self) -> WebsocketListenerRouteHandler:
+ on_accept = self.on_accept if self.on_accept != WebsocketListener.on_accept else None
+ on_disconnect = self.on_disconnect if self.on_disconnect != WebsocketListener.on_disconnect else None
+
handler = WebsocketListenerRouteHandler(
dependencies=self.dependencies,
dto=self.dto,
@@ -389,8 +388,8 @@ def to_handler(self) -> WebsocketListenerRouteHandler:
send_mode=self.send_mode,
receive_mode=self.receive_mode,
name=self.name,
- on_accept=self.on_accept,
- on_disconnect=self.on_disconnect,
+ on_accept=on_accept,
+ on_disconnect=on_disconnect,
opt=self.opt,
path=self.path,
return_dto=self.return_dto,
@@ -402,6 +401,16 @@ def to_handler(self) -> WebsocketListenerRouteHandler:
handler.owner = self._owner
return handler
+ def on_accept(self, *args: Any, **kwargs: Any) -> Any:
+ """Called after a :class:`WebSocket <.connection.WebSocket>` connection
+ has been accepted. Can receive any dependencies
+ """
+
+ def on_disconnect(self, *args: Any, **kwargs: Any) -> Any:
+ """Called after a :class:`WebSocket <.connection.WebSocket>` connection
+ has been disconnected. Can receive any dependencies
+ """
+
@abstractmethod
def on_receive(self, *args: Any, **kwargs: Any) -> Any:
"""Called after data has been received from the WebSocket.
diff --git a/pyproject.toml b/pyproject.toml
index 06b34c0998..b145ccee9c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -464,6 +464,7 @@ known-first-party = ["litestar", "tests", "examples"]
"litestar/_openapi/schema_generation/schema.py" = ["C901"]
"litestar/exceptions/*.*" = ["N818"]
"litestar/handlers/**/*.*" = ["N801"]
+"litestar/handlers/websocket_handlers/listener.py" = ["B027"]
"litestar/params.py" = ["N802"]
"test_apps/**/*.*" = ["D", "TRY", "EM", "S", "PTH"]
"tests/**/*.*" = [
diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py b/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py
index 08c74690d4..f6afec0a2f 100644
--- a/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py
+++ b/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py
@@ -394,10 +394,10 @@ def some_dependency() -> str:
class Listener(WebsocketListener):
path = "/{name: str}"
- def on_accept(self, name: str, state: State, query: dict, some: str) -> None: # type: ignore[override]
+ def on_accept(self, name: str, state: State, query: dict, some: str) -> None: # pyright: ignore
on_accept_mock(name=name, state=state, query=query, some=some)
- def on_disconnect(self, name: str, state: State, query: dict, some: str) -> None: # type: ignore[override]
+ def on_disconnect(self, name: str, state: State, query: dict, some: str) -> None: # pyright: ignore
on_disconnect_mock(name=name, state=state, query=query, some=some)
def on_receive(self, data: bytes) -> None: # pyright: ignore