From bfd11c489ef794f26faa8504360683784c9b5830 Mon Sep 17 00:00:00 2001 From: cacosandon Date: Sat, 6 Apr 2024 16:01:06 -0300 Subject: [PATCH] refactor(yroom_storage): make base class abstract and optional throttling save --- pycrdt_websocket/django_channels/__init__.py | 1 - .../django_channels/yroom_storage.py | 80 +++++++++---------- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/pycrdt_websocket/django_channels/__init__.py b/pycrdt_websocket/django_channels/__init__.py index 04b431a..0f860bd 100644 --- a/pycrdt_websocket/django_channels/__init__.py +++ b/pycrdt_websocket/django_channels/__init__.py @@ -1,3 +1,2 @@ from .yjs_consumer import YjsConsumer as YjsConsumer from .yroom_storage import BaseYRoomStorage as BaseYRoomStorage -from .yroom_storage import RedisYRoomStorage as RedisYRoomStorage diff --git a/pycrdt_websocket/django_channels/yroom_storage.py b/pycrdt_websocket/django_channels/yroom_storage.py index fe05f29..f41f08c 100644 --- a/pycrdt_websocket/django_channels/yroom_storage.py +++ b/pycrdt_websocket/django_channels/yroom_storage.py @@ -1,11 +1,12 @@ import time +from abc import ABC, abstractmethod from typing import Optional import redis.asyncio as redis from pycrdt import Doc -class BaseYRoomStorage: +class BaseYRoomStorage(ABC): """Base class for YRoom storage. This class is responsible for storing, retrieving, updating and persisting the Ypy document. Each Django Channels Consumer should have its own YRoomStorage instance, although all consumers @@ -55,12 +56,12 @@ async def save_snapshot(self): ``` """ - def __init__(self, room_name: str) -> None: + def __init__(self, room_name: str, save_throttle_interval: int | None) -> None: self.room_name = room_name - + self.save_throttle_interval = save_throttle_interval self.last_saved_at = time.time() - self.save_throttle_interval = 5 + @abstractmethod async def get_document(self) -> Doc: """Gets the document from the storage. Ideally it should be retrieved first from temporary storage (e.g. Redis) and then from @@ -68,53 +69,46 @@ async def get_document(self) -> Doc: Returns: The document with the latest changes. """ + ... - raise NotImplementedError - - async def update_document(self, update: bytes): + @abstractmethod + async def update_document(self, update: bytes) -> None: """Updates the document in the storage. Updates could be received by Yjs client (e.g. from a WebSocket) or from the server (e.g. from a Django Celery job). Args: update: The update to apply to the document. """ + ... - raise NotImplementedError - + @abstractmethod async def load_snapshot(self) -> Optional[bytes]: - """Gets the document from the database. Override this method to + """Gets the document encoded as update from the database. Override this method to implement a persistent storage. Defaults to None. Returns: The latest document snapshot. """ - return None + ... + @abstractmethod async def save_snapshot(self) -> None: - """Saves a snapshot of the document to the storage. - If you need to persist the document to a database, you should do it here. - Default implementation does nothing. - """ - - pass + """Saves the document encoded as update to the database.""" + ... async def throttled_save_snapshot(self) -> None: - """Saves a snapshot of the document to the storage, debouncing the calls.""" + """Saves the document encoded as update to the database, throttled.""" - if time.time() - self.last_saved_at <= self.save_throttle_interval: + if ( + not self.save_throttle_interval + or time.time() - self.last_saved_at <= self.save_throttle_interval + ): return await self.save_snapshot() self.last_saved_at = time.time() - async def close(self): - """Closes the storage. - Default implementation does nothing. - """ - - pass - class RedisYRoomStorage(BaseYRoomStorage): """A YRoom storage that uses Redis as main storage, without @@ -123,17 +117,11 @@ class RedisYRoomStorage(BaseYRoomStorage): room_name: The name of the room. """ - def __init__(self, room_name: str) -> None: - super().__init__(room_name) + def __init__(self, room_name: str, save_throttle_interval: int | None = None) -> None: + super().__init__(room_name, save_throttle_interval) self.redis_key = f"document:{self.room_name}" - self.redis = self.make_redis() - - def make_redis(self): - """Makes a Redis client. - Defaults to a local client""" - - return redis.Redis(host="localhost", port=6379, db=0) + self.redis = self._make_redis() async def get_document(self) -> Doc: snapshot = await self.redis.get(self.redis_key) @@ -153,7 +141,7 @@ async def update_document(self, update: bytes): try: current_document = await self.get_document() - updated_snapshot = self._apply_update_to_snapshot(current_document, update) + updated_snapshot = self._apply_update_to_document(current_document, update) async with self.redis.pipeline() as pipe: while True: @@ -165,9 +153,9 @@ async def update_document(self, update: bytes): break except redis.WatchError: - current_snapshot = await self.get_document() - updated_snapshot = self._apply_update_to_snapshot( - current_snapshot, + current_document = await self.get_document() + updated_snapshot = self._apply_update_to_document( + current_document, update, ) @@ -177,11 +165,23 @@ async def update_document(self, update: bytes): await self.throttled_save_snapshot() + async def load_snapshot(self) -> Optional[bytes]: + return None + + async def save_snapshot(self) -> Optional[bytes]: + return None + async def close(self): await self.save_snapshot() await self.redis.close() - def _apply_update_to_snapshot(self, document: Doc, update: bytes) -> bytes: + def _apply_update_to_document(self, document: Doc, update: bytes) -> bytes: document.apply_update(update) return document.get_update() + + def _make_redis(self): + """Makes a Redis client. + Defaults to a local client""" + + return redis.Redis(host="localhost", port=6379, db=0)