From 36c1f0ee2c2bb5476dd146e7bbbc5ffb103c7c8e Mon Sep 17 00:00:00 2001 From: davfsa Date: Wed, 22 Jun 2022 01:59:54 +0200 Subject: [PATCH] Keep message reference when updating cache messages (#1192) - Deserialize `referenced_message` as the partial message it is --- changes/1192.bugfix.md | 2 ++ hikari/impl/cache.py | 7 ++++++- hikari/impl/entity_factory.py | 4 ++-- hikari/internal/cache.py | 3 --- hikari/messages.py | 9 ++++++--- tests/hikari/impl/test_entity_factory.py | 2 +- 6 files changed, 17 insertions(+), 10 deletions(-) create mode 100644 changes/1192.bugfix.md diff --git a/changes/1192.bugfix.md b/changes/1192.bugfix.md new file mode 100644 index 0000000000..f84e19a270 --- /dev/null +++ b/changes/1192.bugfix.md @@ -0,0 +1,2 @@ +Properly garbage collect message references in the cache + - Properly deserialize `PartialMessage.referenced_message` as a partial message diff --git a/hikari/impl/cache.py b/hikari/impl/cache.py index 5d071cb4df..6293788645 100644 --- a/hikari/impl/cache.py +++ b/hikari/impl/cache.py @@ -1563,7 +1563,12 @@ def _set_message( referenced_message: typing.Optional[cache_utility.RefCell[cache_utility.MessageData]] = None if message.referenced_message: - referenced_message = self._set_message(message.referenced_message) + reference_id = message.referenced_message.id + referenced_message = self._message_entries.get(reference_id) or self._referenced_messages.get(reference_id) + + if referenced_message: + # Since the message is partial, if we don't have it cached, there is nothing we can do about it + referenced_message.object.update(message.referenced_message) # Only increment ref counts if this wasn't previously cached. if message.id not in self._referenced_messages and message.id not in self._message_entries: diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index eefc296fff..042ca2e861 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -2534,9 +2534,9 @@ def deserialize_message( if "message_reference" in payload: message_reference = self._deserialize_message_reference(payload["message_reference"]) - referenced_message: typing.Optional[message_models.Message] = None + referenced_message: typing.Optional[message_models.PartialMessage] = None if referenced_message_payload := payload.get("referenced_message"): - referenced_message = self.deserialize_message(referenced_message_payload) + referenced_message = self.deserialize_partial_message(referenced_message_payload) application: typing.Optional[message_models.MessageApplication] = None if "application" in payload: diff --git a/hikari/internal/cache.py b/hikari/internal/cache.py index 7f1bae284e..9980034731 100644 --- a/hikari/internal/cache.py +++ b/hikari/internal/cache.py @@ -723,9 +723,6 @@ def build_from_entity( if not member and message.member: member = RefCell(MemberData.build_from_entity(message.member)) - if not referenced_message and message.referenced_message: - referenced_message = RefCell(MessageData.build_from_entity(message.referenced_message)) - interaction = ( MessageInteractionData.build_from_entity(message.interaction, user=interaction_user) if message.interaction diff --git a/hikari/messages.py b/hikari/messages.py index 923283e939..cc647e9865 100644 --- a/hikari/messages.py +++ b/hikari/messages.py @@ -895,7 +895,7 @@ class PartialMessage(snowflakes.Unique): This is a string used for validating a message was sent. """ - referenced_message: undefined.UndefinedNoneOr[Message] = attr.field(hash=False, eq=False, repr=False) + referenced_message: undefined.UndefinedNoneOr[PartialMessage] = attr.field(hash=False, eq=False, repr=False) """The message that was replied to. If `type` is `MessageType.REPLY` and `hikari.undefined.UNDEFINED`, Discord's @@ -1737,8 +1737,11 @@ class Message(PartialMessage): nonce: typing.Optional[str] = attr.field(hash=False, eq=False, repr=False) """The message nonce. This is a string used for validating a message was sent.""" - referenced_message: typing.Optional[Message] = attr.field(hash=False, eq=False, repr=False) - """The message that was replied to.""" + referenced_message: typing.Optional[PartialMessage] = attr.field(hash=False, eq=False, repr=False) + """The message that was replied to. + + If `type` is `MessageType.REPLY` and `builtins.None`, the message was deleted. + """ interaction: typing.Optional[MessageInteraction] = attr.field(hash=False, eq=False, repr=False) """Information about the interaction this message was created by.""" diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index c5fcfe0cd2..151dc77034 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -4854,7 +4854,7 @@ def test_deserialize_message( assert message.message_reference.guild_id == 278325129692446720 assert isinstance(message.message_reference, message_models.MessageReference) - assert message.referenced_message == entity_factory_impl.deserialize_message(referenced_message) + assert message.referenced_message == entity_factory_impl.deserialize_partial_message(referenced_message) assert message.flags == message_models.MessageFlag.IS_CROSSPOST # Sticker