Skip to content

Commit

Permalink
Fix some cache bugs and ensure member is always returned on delete calls
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa committed Nov 7, 2021
1 parent 463a1c4 commit 1b4d0f5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
48 changes: 22 additions & 26 deletions hikari/impl/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def _garbage_collect_member(
*,
decrement: typing.Optional[int] = None,
deleting: bool = False,
) -> typing.Optional[cache_utility.RefCell[cache_utility.MemberData]]:
) -> None:
if deleting:
member.object.has_been_deleted = True

Expand All @@ -808,10 +808,10 @@ def _garbage_collect_member(

user_id = member.object.user.object.id
if not guild_record.members or user_id not in guild_record.members:
return None
return

if not self._can_remove_member(member):
return None
return

del guild_record.members[user_id]
self._garbage_collect_user(member.object.user, decrement=1)
Expand All @@ -820,8 +820,6 @@ def _garbage_collect_member(
guild_record.members = None
self._remove_guild_record_if_empty(member.object.guild_id, guild_record)

return member

def clear_members(
self,
) -> cache.CacheView[snowflakes.Snowflake, cache.CacheView[snowflakes.Snowflake, guilds.Member]]:
Expand All @@ -843,9 +841,10 @@ def clear_members_for_guild(
return cache_utility.EmptyCacheView()

cached_members = guild_record.members.freeze()
members_gen = (self._garbage_collect_member(guild_record, m, deleting=True) for m in cached_members.values())
# _garbage_collect_member will only return the member data object if they could be removed, else None.
cached_members = {member.object.user.object.id: member for member in members_gen if member}

for m in cached_members.values():
self._garbage_collect_member(guild_record, m, deleting=True)

self._remove_guild_record_if_empty(guild_id, guild_record)
return cache_utility.CacheMappingView(cached_members, builder=self._build_member) # type: ignore[type-var]

Expand All @@ -868,13 +867,8 @@ def delete_member(
if not member_data:
return None

if not guild_record.members:
guild_record.members = None
self._remove_guild_record_if_empty(guild_id, guild_record)

# _garbage_collect_member will only return the member data object if they could be removed, else None.
garbage_collected = self._garbage_collect_member(guild_record, member_data, deleting=True)
return self._build_member(member_data) if garbage_collected else None
self._garbage_collect_member(guild_record, member_data, deleting=True)
return self._build_member(member_data)

def get_member(
self,
Expand Down Expand Up @@ -1295,12 +1289,13 @@ def clear_voice_states_for_channel(
if not guild_record or not guild_record.voice_states:
return cache_utility.EmptyCacheView()

cached_voice_states = {}
cached_voice_states = dict(
filter(lambda item: item[1].channel_id == channel_id, guild_record.voice_states.items())
)

for user_id, voice_state in guild_record.voice_states.items():
if voice_state.channel_id == channel_id:
cached_voice_states[user_id] = voice_state
self._garbage_collect_member(guild_record, voice_state.member, decrement=1)
for user_id, voice_state in cached_voice_states.items():
del guild_record.voice_states[user_id]
self._garbage_collect_member(guild_record, voice_state.member, decrement=1)

if not guild_record.voice_states:
guild_record.voice_states = None
Expand Down Expand Up @@ -1348,11 +1343,12 @@ def delete_voice_state(
if not voice_state_data:
return None

self._garbage_collect_member(guild_record, voice_state_data.member, decrement=1)

if not guild_record.voice_states:
guild_record.voice_states = None
self._remove_guild_record_if_empty(guild_id, guild_record)

self._garbage_collect_member(guild_record, voice_state_data.member, decrement=1)
self._remove_guild_record_if_empty(guild_id, guild_record)
return self._build_voice_state(voice_state_data)

def get_voice_state(
Expand Down Expand Up @@ -1459,12 +1455,13 @@ def _garbage_collect_message(
*,
decrement: typing.Optional[int] = None,
override_ref: bool = False,
) -> typing.Optional[cache_utility.RefCell[cache_utility.MessageData]]:
) -> bool:
# A bool is returned to inform whether the message was removed or not
if decrement is not None:
self._increment_ref_count(message, -decrement)

if not self._can_remove_message(message) or override_ref:
return None
return False

self._garbage_collect_user(message.object.author, decrement=1)

Expand All @@ -1485,7 +1482,7 @@ def _garbage_collect_message(
if message.object.id in self._referenced_messages:
del self._referenced_messages[message.object.id]

return message
return True

def _on_message_expire(self, message: cache_utility.RefCell[cache_utility.MessageData], /) -> None:
if not self._garbage_collect_message(message):
Expand Down Expand Up @@ -1520,7 +1517,6 @@ def delete_message(

if not self._garbage_collect_message(message_data):
self._referenced_messages[message_id] = message_data
return None

return self._build_message(message_data)

Expand Down
2 changes: 1 addition & 1 deletion hikari/voices.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class VoiceState:
session_id: str = attr.field(hash=True, repr=True)
"""The string ID of this voice state's session."""

requested_to_speak_at: typing.Optional[datetime.datetime] = attr.field(eq=False, hash=False, repr=True)
requested_to_speak_at: typing.Optional[datetime.datetime] = attr.field(eq=False, hash=False, repr=False)
"""When the user requested to speak in a stage channel.
Will be `builtins.None` if they have not requested to speak.
Expand Down

0 comments on commit 1b4d0f5

Please sign in to comment.