Skip to content

Commit

Permalink
refactor(yroom_storage): make base class abstract and optional thrott…
Browse files Browse the repository at this point in the history
…ling save
  • Loading branch information
cacosandon committed Apr 6, 2024
1 parent 889db49 commit bfd11c4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 41 deletions.
1 change: 0 additions & 1 deletion pycrdt_websocket/django_channels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .yjs_consumer import YjsConsumer as YjsConsumer
from .yroom_storage import BaseYRoomStorage as BaseYRoomStorage
from .yroom_storage import RedisYRoomStorage as RedisYRoomStorage
80 changes: 40 additions & 40 deletions pycrdt_websocket/django_channels/yroom_storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import time
from abc import ABC, abstractmethod
from typing import Optional

import redis.asyncio as redis
from pycrdt import Doc


class BaseYRoomStorage:
class BaseYRoomStorage(ABC):
"""Base class for YRoom storage.
This class is responsible for storing, retrieving, updating and persisting the Ypy document.
Each Django Channels Consumer should have its own YRoomStorage instance, although all consumers
Expand Down Expand Up @@ -55,66 +56,59 @@ async def save_snapshot(self):
```
"""

def __init__(self, room_name: str) -> None:
def __init__(self, room_name: str, save_throttle_interval: int | None) -> None:
self.room_name = room_name

self.save_throttle_interval = save_throttle_interval
self.last_saved_at = time.time()
self.save_throttle_interval = 5

@abstractmethod
async def get_document(self) -> Doc:
"""Gets the document from the storage.
Ideally it should be retrieved first from temporary storage (e.g. Redis) and then from
persistent storage (e.g. a database).
Returns:
The document with the latest changes.
"""
...

raise NotImplementedError

async def update_document(self, update: bytes):
@abstractmethod
async def update_document(self, update: bytes) -> None:
"""Updates the document in the storage.
Updates could be received by Yjs client (e.g. from a WebSocket) or from the server
(e.g. from a Django Celery job).
Args:
update: The update to apply to the document.
"""
...

raise NotImplementedError

@abstractmethod
async def load_snapshot(self) -> Optional[bytes]:
"""Gets the document from the database. Override this method to
"""Gets the document encoded as update from the database. Override this method to
implement a persistent storage.
Defaults to None.
Returns:
The latest document snapshot.
"""
return None
...

@abstractmethod
async def save_snapshot(self) -> None:
"""Saves a snapshot of the document to the storage.
If you need to persist the document to a database, you should do it here.
Default implementation does nothing.
"""

pass
"""Saves the document encoded as update to the database."""
...

async def throttled_save_snapshot(self) -> None:
"""Saves a snapshot of the document to the storage, debouncing the calls."""
"""Saves the document encoded as update to the database, throttled."""

if time.time() - self.last_saved_at <= self.save_throttle_interval:
if (
not self.save_throttle_interval
or time.time() - self.last_saved_at <= self.save_throttle_interval
):
return

await self.save_snapshot()

self.last_saved_at = time.time()

async def close(self):
"""Closes the storage.
Default implementation does nothing.
"""

pass


class RedisYRoomStorage(BaseYRoomStorage):
"""A YRoom storage that uses Redis as main storage, without
Expand All @@ -123,17 +117,11 @@ class RedisYRoomStorage(BaseYRoomStorage):
room_name: The name of the room.
"""

def __init__(self, room_name: str) -> None:
super().__init__(room_name)
def __init__(self, room_name: str, save_throttle_interval: int | None = None) -> None:
super().__init__(room_name, save_throttle_interval)

self.redis_key = f"document:{self.room_name}"
self.redis = self.make_redis()

def make_redis(self):
"""Makes a Redis client.
Defaults to a local client"""

return redis.Redis(host="localhost", port=6379, db=0)
self.redis = self._make_redis()

async def get_document(self) -> Doc:
snapshot = await self.redis.get(self.redis_key)
Expand All @@ -153,7 +141,7 @@ async def update_document(self, update: bytes):

try:
current_document = await self.get_document()
updated_snapshot = self._apply_update_to_snapshot(current_document, update)
updated_snapshot = self._apply_update_to_document(current_document, update)

async with self.redis.pipeline() as pipe:
while True:
Expand All @@ -165,9 +153,9 @@ async def update_document(self, update: bytes):

break
except redis.WatchError:
current_snapshot = await self.get_document()
updated_snapshot = self._apply_update_to_snapshot(
current_snapshot,
current_document = await self.get_document()
updated_snapshot = self._apply_update_to_document(
current_document,
update,
)

Expand All @@ -177,11 +165,23 @@ async def update_document(self, update: bytes):

await self.throttled_save_snapshot()

async def load_snapshot(self) -> Optional[bytes]:
return None

async def save_snapshot(self) -> Optional[bytes]:
return None

async def close(self):
await self.save_snapshot()
await self.redis.close()

def _apply_update_to_snapshot(self, document: Doc, update: bytes) -> bytes:
def _apply_update_to_document(self, document: Doc, update: bytes) -> bytes:
document.apply_update(update)

return document.get_update()

def _make_redis(self):
"""Makes a Redis client.
Defaults to a local client"""

return redis.Redis(host="localhost", port=6379, db=0)

0 comments on commit bfd11c4

Please sign in to comment.