From f7c6f268a40a203daa44af981ff4f896e5dab702 Mon Sep 17 00:00:00 2001 From: cacosandon Date: Sat, 6 Apr 2024 16:27:50 -0300 Subject: [PATCH] refactor(django-channels@storage): move throttling methods to redis storage and close method to base style: fix linter wrong return type of save_snapshot --- .../storage/base_yroom_storage.py | 22 ++++++------------- .../storage/redis_yroom_storage.py | 21 ++++++++++++++++-- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/pycrdt_websocket/django_channels/storage/base_yroom_storage.py b/pycrdt_websocket/django_channels/storage/base_yroom_storage.py index 427b5af..7a4c272 100644 --- a/pycrdt_websocket/django_channels/storage/base_yroom_storage.py +++ b/pycrdt_websocket/django_channels/storage/base_yroom_storage.py @@ -1,4 +1,3 @@ -import time from abc import ABC, abstractmethod from typing import Optional @@ -55,10 +54,8 @@ async def save_snapshot(self): ``` """ - def __init__(self, room_name: str, save_throttle_interval: int | None) -> None: + def __init__(self, room_name: str) -> None: self.room_name = room_name - self.save_throttle_interval = save_throttle_interval - self.last_saved_at = time.time() @abstractmethod async def get_document(self) -> Doc: @@ -95,15 +92,10 @@ async def save_snapshot(self) -> None: """Saves the document encoded as update to the database.""" ... - async def throttled_save_snapshot(self) -> None: - """Saves the document encoded as update to the database, throttled.""" + async def close(self) -> None: + """Closes the storage connection. - if ( - not self.save_throttle_interval - or time.time() - self.last_saved_at <= self.save_throttle_interval - ): - return - - await self.save_snapshot() - - self.last_saved_at = time.time() + Useful for cleaning up resources like closing a database + connection or saving the document before exiting. + """ + pass diff --git a/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py b/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py index 51d1fb3..c0bd024 100644 --- a/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py +++ b/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py @@ -1,3 +1,4 @@ +import time from typing import Optional import redis.asyncio as redis @@ -14,7 +15,10 @@ class RedisYRoomStorage(BaseYRoomStorage): """ def __init__(self, room_name: str, save_throttle_interval: int | None = None) -> None: - super().__init__(room_name, save_throttle_interval) + super().__init__(room_name) + + self.save_throttle_interval = save_throttle_interval + self.last_saved_at = time.time() self.redis_key = f"document:{self.room_name}" self.redis = self._make_redis() @@ -64,9 +68,22 @@ async def update_document(self, update: bytes): async def load_snapshot(self) -> Optional[bytes]: return None - async def save_snapshot(self) -> Optional[bytes]: + async def save_snapshot(self) -> None: return None + async def throttled_save_snapshot(self) -> None: + """Saves the document encoded as update to the database, throttled.""" + + if ( + not self.save_throttle_interval + or time.time() - self.last_saved_at <= self.save_throttle_interval + ): + return + + await self.save_snapshot() + + self.last_saved_at = time.time() + async def close(self): await self.save_snapshot() await self.redis.close()