Skip to content

Commit

Permalink
Fix some cache bugs and ensure member is always returned on delete calls
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa committed Nov 7, 2021
1 parent 463a1c4 commit 5d59eb5
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 44 deletions.
48 changes: 22 additions & 26 deletions hikari/impl/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def _garbage_collect_member(
*,
decrement: typing.Optional[int] = None,
deleting: bool = False,
) -> typing.Optional[cache_utility.RefCell[cache_utility.MemberData]]:
) -> None:
if deleting:
member.object.has_been_deleted = True

Expand All @@ -808,10 +808,10 @@ def _garbage_collect_member(

user_id = member.object.user.object.id
if not guild_record.members or user_id not in guild_record.members:
return None
return

if not self._can_remove_member(member):
return None
return

del guild_record.members[user_id]
self._garbage_collect_user(member.object.user, decrement=1)
Expand All @@ -820,8 +820,6 @@ def _garbage_collect_member(
guild_record.members = None
self._remove_guild_record_if_empty(member.object.guild_id, guild_record)

return member

def clear_members(
self,
) -> cache.CacheView[snowflakes.Snowflake, cache.CacheView[snowflakes.Snowflake, guilds.Member]]:
Expand All @@ -843,9 +841,10 @@ def clear_members_for_guild(
return cache_utility.EmptyCacheView()

cached_members = guild_record.members.freeze()
members_gen = (self._garbage_collect_member(guild_record, m, deleting=True) for m in cached_members.values())
# _garbage_collect_member will only return the member data object if they could be removed, else None.
cached_members = {member.object.user.object.id: member for member in members_gen if member}

for m in cached_members.values():
self._garbage_collect_member(guild_record, m, deleting=True)

self._remove_guild_record_if_empty(guild_id, guild_record)
return cache_utility.CacheMappingView(cached_members, builder=self._build_member) # type: ignore[type-var]

Expand All @@ -868,13 +867,8 @@ def delete_member(
if not member_data:
return None

if not guild_record.members:
guild_record.members = None
self._remove_guild_record_if_empty(guild_id, guild_record)

# _garbage_collect_member will only return the member data object if they could be removed, else None.
garbage_collected = self._garbage_collect_member(guild_record, member_data, deleting=True)
return self._build_member(member_data) if garbage_collected else None
self._garbage_collect_member(guild_record, member_data, deleting=True)
return self._build_member(member_data)

def get_member(
self,
Expand Down Expand Up @@ -1295,12 +1289,13 @@ def clear_voice_states_for_channel(
if not guild_record or not guild_record.voice_states:
return cache_utility.EmptyCacheView()

cached_voice_states = {}
cached_voice_states = dict(
filter(lambda item: item[1].channel_id == channel_id, guild_record.voice_states.items())
)

for user_id, voice_state in guild_record.voice_states.items():
if voice_state.channel_id == channel_id:
cached_voice_states[user_id] = voice_state
self._garbage_collect_member(guild_record, voice_state.member, decrement=1)
for user_id, voice_state in cached_voice_states.items():
del guild_record.voice_states[user_id]
self._garbage_collect_member(guild_record, voice_state.member, decrement=1)

if not guild_record.voice_states:
guild_record.voice_states = None
Expand Down Expand Up @@ -1348,11 +1343,12 @@ def delete_voice_state(
if not voice_state_data:
return None

self._garbage_collect_member(guild_record, voice_state_data.member, decrement=1)

if not guild_record.voice_states:
guild_record.voice_states = None
self._remove_guild_record_if_empty(guild_id, guild_record)

self._garbage_collect_member(guild_record, voice_state_data.member, decrement=1)
self._remove_guild_record_if_empty(guild_id, guild_record)
return self._build_voice_state(voice_state_data)

def get_voice_state(
Expand Down Expand Up @@ -1459,12 +1455,13 @@ def _garbage_collect_message(
*,
decrement: typing.Optional[int] = None,
override_ref: bool = False,
) -> typing.Optional[cache_utility.RefCell[cache_utility.MessageData]]:
) -> bool:
# A bool is returned to inform whether the message was removed or not
if decrement is not None:
self._increment_ref_count(message, -decrement)

if not self._can_remove_message(message) or override_ref:
return None
return False

self._garbage_collect_user(message.object.author, decrement=1)

Expand All @@ -1485,7 +1482,7 @@ def _garbage_collect_message(
if message.object.id in self._referenced_messages:
del self._referenced_messages[message.object.id]

return message
return True

def _on_message_expire(self, message: cache_utility.RefCell[cache_utility.MessageData], /) -> None:
if not self._garbage_collect_message(message):
Expand Down Expand Up @@ -1520,7 +1517,6 @@ def delete_message(

if not self._garbage_collect_message(message_data):
self._referenced_messages[message_id] = message_data
return None

return self._build_message(message_data)

Expand Down
2 changes: 1 addition & 1 deletion hikari/voices.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class VoiceState:
session_id: str = attr.field(hash=True, repr=True)
"""The string ID of this voice state's session."""

requested_to_speak_at: typing.Optional[datetime.datetime] = attr.field(eq=False, hash=False, repr=True)
requested_to_speak_at: typing.Optional[datetime.datetime] = attr.field(eq=False, hash=False, repr=False)
"""When the user requested to speak in a stage channel.
Will be `builtins.None` if they have not requested to speak.
Expand Down
49 changes: 32 additions & 17 deletions tests/hikari/impl/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ def test_delete_member_for_unknown_member_cache(self, cache_impl):

assert result is None

def test_delete_member_for_known_member(self, cache_impl):
def test_delete_member(self, cache_impl):
mock_member = mock.Mock(guilds.Member)
mock_user = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(67876)))
mock_member_data = mock.Mock(
Expand All @@ -1672,21 +1672,6 @@ def test_delete_member_for_known_member(self, cache_impl):
cache_impl._garbage_collect_user.assert_called_once_with(mock_user, decrement=1)
cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(42123), guild_record)

def test_delete_member_for_known_hard_referenced_member(self, cache_impl):
mock_member = cache_utilities.RefCell(mock.Mock(has_been_deleted=False), ref_count=1)
cache_impl._guild_entries = collections.FreezableDict(
{
snowflakes.Snowflake(42123): cache_utilities.GuildRecord(
members=collections.FreezableDict({snowflakes.Snowflake(67876): mock_member})
)
}
)

result = cache_impl.delete_member(StubModel(42123), StubModel(67876))

assert result is None
assert mock_member.object.has_been_deleted is True

def test_get_member_for_unknown_member_cache(self, cache_impl):
cache_impl._guild_entries = collections.FreezableDict(
{snowflakes.Snowflake(1234213): cache_utilities.GuildRecord()}
Expand Down Expand Up @@ -2222,11 +2207,41 @@ def test_delete_voice_state(self, cache_impl):

assert result is mock_voice_state
cache_impl._garbage_collect_member.assert_called_once_with(guild_record, mock_member_data, decrement=1)
cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(43123), guild_record)
cache_impl._remove_guild_record_if_empty.assert_not_called()
assert cache_impl._guild_entries[snowflakes.Snowflake(43123)].voice_states == {
snowflakes.Snowflake(6541234): mock_other_voice_state_data
}

def test_delete_voice_state_when_no_voice_states_left(self, cache_impl):
mock_member_data = object()
mock_voice_state_data = mock.Mock(cache_utilities.VoiceStateData, member=mock_member_data)
mock_voice_state = mock.Mock(voices.VoiceState)
cache_impl._build_voice_state = mock.Mock(return_value=mock_voice_state)
guild_record = cache_utilities.GuildRecord(
voice_states=collections.FreezableDict({snowflakes.Snowflake(12354345): mock_voice_state_data}),
members=collections.FreezableDict(
{snowflakes.Snowflake(12354345): mock_member_data, snowflakes.Snowflake(9955959): object()}
),
)
cache_impl._user_entries = collections.FreezableDict(
{snowflakes.Snowflake(12354345): object(), snowflakes.Snowflake(9393): object()}
)
cache_impl._guild_entries = collections.FreezableDict(
{
snowflakes.Snowflake(65234): mock.Mock(cache_utilities.GuildRecord),
snowflakes.Snowflake(43123): guild_record,
}
)
cache_impl._remove_guild_record_if_empty = mock.Mock()
cache_impl._garbage_collect_member = mock.Mock()

result = cache_impl.delete_voice_state(StubModel(43123), StubModel(12354345))

assert result is mock_voice_state
cache_impl._garbage_collect_member.assert_called_once_with(guild_record, mock_member_data, decrement=1)
cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(43123), guild_record)
assert cache_impl._guild_entries[snowflakes.Snowflake(43123)].voice_states is None

def test_delete_voice_state_unknown_state(self, cache_impl):
mock_other_voice_state_data = mock.Mock(cache_utilities.VoiceStateData)
cache_impl._build_voice_state = mock.Mock()
Expand Down

0 comments on commit 5d59eb5

Please sign in to comment.