Skip to content

Commit

Permalink
Threads cache (#1384)
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa authored Dec 4, 2022
1 parent b638e0c commit 20cee68
Show file tree
Hide file tree
Showing 8 changed files with 650 additions and 108 deletions.
1 change: 1 addition & 0 deletions changes/1384.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Threads cache.
169 changes: 169 additions & 0 deletions hikari/api/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,76 @@ def get_guild_channels_view_for_guild(
cache for the specified guild.
"""

@abc.abstractmethod
def get_thread(
self, thread: snowflakes.SnowflakeishOr[channels.PartialChannel], /
) -> typing.Optional[channels.GuildThreadChannel]:
"""Get a thread channel from the cache.
Parameters
----------
thread : hikari.snowflakes.SnowflakeishOr[hikari.channels.PartialChannel]
Object or ID of the thread to get from the cache.
Returns
-------
typing.Optional[hikari.channels.GuildThreadChannel]
The object of the thread that was found in the cache
or `builtins.None`.
"""

@abc.abstractmethod
def get_threads_view(self) -> CacheView[snowflakes.Snowflake, channels.GuildThreadChannel]:
"""Get a view of the thread channels in the cache.
Returns
-------
CacheView[hikari.snowflakes.Snowflake, hikari.channels.GuildThreadChannel]
A view of channel IDs to objects of the thread channels found in the
cache.
"""

@abc.abstractmethod
def get_threads_view_for_channel(
self,
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
channel: snowflakes.SnowflakeishOr[channels.PartialChannel],
/,
) -> CacheView[snowflakes.Snowflake, channels.GuildThreadChannel]:
"""Get a view of the thread channels in the cache for a specific guild.
Parameters
----------
guild : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialGuild]
Object or ID of the guild to get the cached thread channels for.
channel : hikari.snowflakes.SnowflakeishOr[hikari.channels.PartialChannel]
Object or ID of the channel to get the cached thread channels for.
Returns
-------
CacheView[hikari.snowflakes.Snowflake, hikari.channels.GuildThreadChannel]
A view of channel IDs to objects of the thread channels found in the
cache for the specified channel.
"""

@abc.abstractmethod
def get_threads_view_for_guild(
self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], /
) -> CacheView[snowflakes.Snowflake, channels.GuildThreadChannel]:
"""Get a view of the thread channels in the cache for a specific guild.
Parameters
----------
guild : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialGuild]
Object or ID of the guild to get the cached thread channels for.
Returns
-------
CacheView[hikari.snowflakes.Snowflake, hikari.channels.GuildThreadChannel]
A view of channel IDs to objects of the thread channels found in the
cache for the specified guild.
"""

@abc.abstractmethod
def get_invite(self, code: typing.Union[invites.InviteCode, str], /) -> typing.Optional[invites.InviteWithMetadata]:
"""Get an invite object from the cache.
Expand Down Expand Up @@ -1072,6 +1142,105 @@ def update_guild_channel(
(else `builtins.None`).
""" # noqa: E501 - Line too long

@abc.abstractmethod
def clear_threads(self) -> CacheView[snowflakes.Snowflake, channels.GuildThreadChannel]:
"""Remove all thread channels from the cache.
Returns
-------
CacheView[hikari.snowflakes.Snowflake, hikari.channels.GuildThreadChannel]
A view of channel IDs to objects of the thread channels that were
removed from the cache.
"""

@abc.abstractmethod
def clear_threads_for_channel(
self,
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
channel: snowflakes.SnowflakeishOr[channels.PartialChannel],
/,
) -> CacheView[snowflakes.Snowflake, channels.GuildThreadChannel]:
"""Remove thread channels from the cache for a specific channel.
Parameters
----------
guild : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialGuild]
Object or ID of the guild to remove cached threads for.
channel : hikari.snowflakes.SnowflakeishOr[hikari.channels.PartialChannel]
Object or ID of the channel to remove cached threads for.
Returns
-------
CacheView[hikari.snowflakes.Snowflake, hikari.channels.GuildThreadChannel]
A view of channel IDs to objects of the thread channels that were
removed from the cache.
"""

@abc.abstractmethod
def clear_threads_for_guild(
self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], /
) -> CacheView[snowflakes.Snowflake, channels.GuildThreadChannel]:
"""Remove thread channels from the cache for a specific guild.
Parameters
----------
guild : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialGuild]
Object or ID of the guild to remove cached threads for.
Returns
-------
CacheView[hikari.snowflakes.Snowflake, hikari.channels.GuildThreadChannel]
A view of channel IDs to objects of the thread channels that were
removed from the cache.
"""

@abc.abstractmethod
def delete_thread(
self, thread: snowflakes.SnowflakeishOr[channels.PartialChannel], /
) -> typing.Optional[channels.GuildThreadChannel]:
"""Remove a thread channel from the cache.
Parameters
----------
thread : hikari.snowflakes.SnowflakeishOr[hikari.channels.PartialChannel]
Object or ID of the thread to remove from the cache.
Returns
-------
typing.Optional[hikari.channels.GuildThreadChannel]
The object of the thread that was removed from the cache if
found, else `builtins.None`.
"""

@abc.abstractmethod
def set_thread(self, channel: channels.GuildThreadChannel, /) -> None:
"""Add a thread channel to the cache.
Parameters
----------
channel : hikari.channels.GuildThreadChannel
The thread channel based object to add to the cache.
"""

@abc.abstractmethod
def update_thread(
self, thread: channels.GuildThreadChannel, /
) -> typing.Tuple[typing.Optional[channels.GuildThreadChannel], typing.Optional[channels.GuildThreadChannel]]:
"""Update a thread channel in the cache.
Parameters
----------
thread : hikari.channels.GuildThreadChannel
The object of the thread channel to update in the cache.
Returns
-------
typing.Tuple[typing.Optional[hikari.channels.GuildThreadChannel], typing.Optional[hikari.channels.GuildThreadChannel]]
A tuple of the old cached thread channel if found (else `builtins.None`)
and the new cached thread channel if it could be cached
(else `builtins.None`).
"""

@abc.abstractmethod
def clear_invites(self) -> CacheView[str, invites.InviteWithMetadata]:
"""Remove all the invite objects from the cache.
Expand Down
4 changes: 4 additions & 0 deletions hikari/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class CacheComponents(enums.Flag):
GUILD_STICKERS = 1 << 11
"""Enables the guild stickers cache."""

GUILD_THREADS = 1 << 12
"""Enabled the guild threads cache."""

ALL = (
GUILDS
| GUILD_CHANNELS
Expand All @@ -89,6 +92,7 @@ class CacheComponents(enums.Flag):
| ME
| DM_CHANNEL_IDS
| GUILD_STICKERS
| GUILD_THREADS
)
"""Fully enables the cache."""

Expand Down
149 changes: 149 additions & 0 deletions hikari/impl/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CacheImpl(cache.MutableCache):
"_dm_channel_entries",
"_emoji_entries",
"_guild_channel_entries",
"_guild_thread_entries",
"_guild_entries",
"_intents",
"_invite_entries",
Expand All @@ -88,6 +89,7 @@ class CacheImpl(cache.MutableCache):
_emoji_entries: collections.ExtendedMutableMapping[snowflakes.Snowflake, cache_utility.KnownCustomEmojiData]
_dm_channel_entries: collections.ExtendedMutableMapping[snowflakes.Snowflake, snowflakes.Snowflake]
_guild_channel_entries: collections.ExtendedMutableMapping[snowflakes.Snowflake, channels_.PermissibleGuildChannel]
_guild_thread_entries: collections.ExtendedMutableMapping[snowflakes.Snowflake, channels_.GuildThreadChannel]
_guild_entries: collections.ExtendedMutableMapping[snowflakes.Snowflake, cache_utility.GuildRecord]
_invite_entries: collections.ExtendedMutableMapping[str, cache_utility.InviteData]
_role_entries: collections.ExtendedMutableMapping[snowflakes.Snowflake, guilds.Role]
Expand Down Expand Up @@ -118,6 +120,7 @@ def _create_cache(self) -> None:
self._dm_channel_entries = collections.LimitedCapacityCacheMap(limit=self._settings.max_dm_channel_ids)
self._emoji_entries = collections.FreezableDict()
self._guild_channel_entries = collections.FreezableDict()
self._guild_thread_entries = collections.FreezableDict()
self._guild_entries = collections.FreezableDict()
self._invite_entries = collections.FreezableDict()
self._role_entries = collections.FreezableDict()
Expand Down Expand Up @@ -568,6 +571,149 @@ def update_guild(
self.set_guild(guild)
return cached_guild, self.get_guild(guild.id)

def clear_threads(self) -> cache.CacheView[snowflakes.Snowflake, channels_.GuildThreadChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return cache_utility.EmptyCacheView()

cached_threads = self._guild_thread_entries
self._guild_thread_entries = collections.FreezableDict()

for guild_id, guild_record in self._guild_entries.freeze().items():
if guild_record.threads:
guild_record.threads = None
self._remove_guild_record_if_empty(guild_id, guild_record)

return cache_utility.CacheMappingView(cached_threads)

def clear_threads_for_guild(
self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], /
) -> cache.CacheView[snowflakes.Snowflake, channels_.GuildThreadChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return cache_utility.EmptyCacheView()

guild_id = snowflakes.Snowflake(guild)
guild_record = self._guild_entries.get(guild_id)
if not guild_record or not guild_record.threads:
return cache_utility.EmptyCacheView()

cached_threads = {sf: self._guild_thread_entries.pop(sf) for sf in guild_record.threads}
guild_record.threads = None
self._remove_guild_record_if_empty(guild_id, guild_record)
return cache_utility.CacheMappingView(cached_threads)

def clear_threads_for_channel(
self,
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
channel: snowflakes.SnowflakeishOr[channels_.PartialChannel],
/,
) -> cache.CacheView[snowflakes.Snowflake, channels_.GuildThreadChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return cache_utility.EmptyCacheView()

channel_id = snowflakes.Snowflake(channel)
guild = snowflakes.Snowflake(guild)
guild_record = self._guild_entries.get(guild)
if not guild_record or not guild_record.threads:
return cache_utility.EmptyCacheView()

threads: typing.Dict[snowflakes.Snowflake, channels_.GuildThreadChannel] = {}
for thread in map(self._guild_thread_entries.__getitem__, tuple(guild_record.threads)):
if thread.parent_id == channel_id:
del self._guild_thread_entries[thread.id]
guild_record.threads.remove(thread.id)

if not guild_record.threads:
guild_record.threads = None
self._remove_guild_record_if_empty(guild, guild_record)

return cache_utility.CacheMappingView(threads)

def delete_thread(
self, thread: snowflakes.SnowflakeishOr[channels_.PartialChannel], /
) -> typing.Optional[channels_.GuildThreadChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return None

thread_id = snowflakes.Snowflake(thread)
thread = self._guild_thread_entries.pop(thread_id, None)

if not thread:
return None

guild_record = self._guild_entries.get(thread.guild_id)
if guild_record and guild_record.threads:
guild_record.threads.remove(thread_id)
if not guild_record.threads:
guild_record.threads = None
self._remove_guild_record_if_empty(thread.guild_id, guild_record)

return thread

def get_thread(
self, thread: snowflakes.SnowflakeishOr[channels_.PartialChannel], /
) -> typing.Optional[channels_.GuildThreadChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return None

thread = self._guild_thread_entries.get(snowflakes.Snowflake(thread))
return copy.copy(thread) if thread else None

def get_threads_view(self) -> cache.CacheView[snowflakes.Snowflake, channels_.GuildThreadChannel]:
return cache_utility.CacheMappingView(self._guild_thread_entries.freeze())

def get_threads_view_for_guild(
self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], /
) -> cache.CacheView[snowflakes.Snowflake, channels_.GuildThreadChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return cache_utility.EmptyCacheView()

guild_record = self._guild_entries.get(snowflakes.Snowflake(guild))
if not guild_record or not guild_record.threads:
return cache_utility.EmptyCacheView()

return cache_utility.CacheMappingView(
{sf: self._guild_thread_entries[sf] for sf in guild_record.threads},
)

def get_threads_view_for_channel(
self,
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
channel: snowflakes.SnowflakeishOr[channels_.PartialChannel],
/,
) -> cache.CacheView[snowflakes.Snowflake, channels_.GuildThreadChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return cache_utility.EmptyCacheView()

record = self._guild_entries.get(snowflakes.Snowflake(guild))
if not record or not record.threads:
return cache_utility.EmptyCacheView()

threads = map(self._guild_thread_entries.__getitem__, record.threads)
channel = snowflakes.Snowflake(channel)
return cache_utility.CacheMappingView({thread.id: thread for thread in threads if thread.parent_id == channel})

def set_thread(self, thread: channels_.GuildThreadChannel, /) -> None:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return

self._guild_thread_entries[thread.id] = copy.copy(thread)
guild_record = self._get_or_create_guild_record(thread.guild_id)

if guild_record.threads is None:
guild_record.threads = collections.SnowflakeSet()

guild_record.threads.add(thread.id)

def update_thread(
self, thread: channels_.GuildThreadChannel, /
) -> typing.Tuple[typing.Optional[channels_.GuildThreadChannel], typing.Optional[channels_.GuildThreadChannel]]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_THREADS):
return None, None

cached_thread = self.get_thread(thread.id)
self.set_thread(thread)
return cached_thread, self.get_thread(thread.id)

def clear_guild_channels(self) -> cache.CacheView[snowflakes.Snowflake, channels_.PermissibleGuildChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_CHANNELS):
return cache_utility.EmptyCacheView()
Expand Down Expand Up @@ -629,6 +775,9 @@ def get_guild_channel(
return cache_utility.copy_guild_channel(channel) if channel else None

def get_guild_channels_view(self) -> cache.CacheView[snowflakes.Snowflake, channels_.PermissibleGuildChannel]:
if not self._is_cache_enabled_for(config_api.CacheComponents.GUILD_CHANNELS):
return cache_utility.EmptyCacheView()

return cache_utility.CacheMappingView(
self._guild_channel_entries.freeze(), builder=cache_utility.copy_guild_channel # type: ignore[type-var]
)
Expand Down
Loading

0 comments on commit 20cee68

Please sign in to comment.