diff --git a/hikari/events/message_events.py b/hikari/events/message_events.py index cc7458e179..6c9ce8a0b1 100644 --- a/hikari/events/message_events.py +++ b/hikari/events/message_events.py @@ -46,18 +46,17 @@ import attr +from hikari import channels +from hikari import guilds from hikari import intents from hikari import snowflakes +from hikari import undefined from hikari import users from hikari.events import base_events from hikari.events import shard_events from hikari.utilities import attr_extensions if typing.TYPE_CHECKING: - # Do NOT remove the users import here. It **is** required, even if PyCharm - # tries to assure you otherwise. - from hikari import channels - from hikari import guilds from hikari import messages from hikari import traits from hikari.api import shard as gateway_shard @@ -131,9 +130,24 @@ def guild_id(self) -> snowflakes.Snowflake: """ @property - def channel(self) -> typing.Optional[channels.GuildTextChannel]: - # <>. - return typing.cast("channels.GuildTextChannel", self.app.cache.get_guild_channel(self.channel_id)) + def channel(self) -> typing.Union[None, channels.GuildTextChannel, channels.GuildNewsChannel]: + """Channel that the message was sent in, if known. + + Returns + ------- + typing.Union[builtins.None, hikari.channels.GuildTextChannel, hikari.channels.GuildNewsChannel] + The channel the message was sent in, or `builtins.None` if not + known/cached. + + This otherwise will always be a `hikari.channels.GuildTextChannel` + if it is a normal message, or `hikari.channels.GuildNewsChannel` if + sent in an announcement channel. + """ + channel = self.app.cache.get_guild_channel(self.channel_id) + assert channel is None or isinstance( + channel, (channels.GuildTextChannel, channels.GuildNewsChannel) + ), f"expected cached channel to be None or a GuildTextChannel/GuildNewsChannel, not {channel}" + return channel @property def guild(self) -> typing.Optional[guilds.GatewayGuild]: @@ -157,6 +171,11 @@ def guild(self) -> typing.Optional[guilds.GatewayGuild]: class MessageCreateEvent(MessageEvent, abc.ABC): """Event base for any message creation event.""" + @property + def message_id(self) -> snowflakes.Snowflake: + # <>. + return self.message.id + @property @abc.abstractmethod def message(self) -> messages.Message: @@ -168,11 +187,6 @@ def message(self) -> messages.Message: The message object that was sent with this event. """ - @property - def message_id(self) -> snowflakes.Snowflake: - # <>. - return self.message.id - @property def channel_id(self) -> snowflakes.Snowflake: # <>. @@ -189,6 +203,17 @@ def author_id(self) -> snowflakes.Snowflake: """ return self.message.author.id + @property + @abc.abstractmethod + def author(self) -> users.User: + """User that sent the message. + + Returns + ------- + hikari.users.User + The user that sent the message. + """ + @base_events.requires_intents(intents.Intents.GUILD_MESSAGES, intents.Intents.PRIVATE_MESSAGES) @attr.s(kw_only=True, slots=True, weakref_slot=False) @@ -219,6 +244,23 @@ def message_id(self) -> snowflakes.Snowflake: # <>. return self.message.id + @property + def channel_id(self) -> snowflakes.Snowflake: + # <>. + return self.message.channel_id + + @property + @abc.abstractmethod + def channel(self) -> typing.Optional[channels.TextChannel]: + """Channel that the message was sent in, if known. + + Returns + ------- + typing.Optional[hikari.channels.TextChannel] + The text channel that the message was sent in, if known and cached, + otherwise, `builtins.None`. + """ + @property def author_id(self) -> snowflakes.Snowflake: """ID of the author that triggered this event. @@ -230,13 +272,23 @@ def author_id(self) -> snowflakes.Snowflake: """ # Looks like `author` is always present in this event variant. author = self.message.author - assert isinstance(author, users.PartialUser) + assert isinstance(author, users.User), "message.author was expected to be present" return author.id @property - def channel_id(self) -> snowflakes.Snowflake: - # <>. - return self.message.channel_id + @abc.abstractmethod + def author(self) -> typing.Optional[users.User]: + """User that sent the message. + + Returns + ------- + typing.Optional[hikari.users.User] + The user that sent the message, if known and cached, otherwise + `builtins.None`. + """ + author = self.message.author + assert isinstance(author, users.User), "message.author was expected to be present" + return author @base_events.requires_intents(intents.Intents.GUILD_MESSAGES, intents.Intents.PRIVATE_MESSAGES) @@ -278,6 +330,18 @@ def channel_id(self) -> snowflakes.Snowflake: # <>. return self.message.channel_id + @property + @abc.abstractmethod + def channel(self) -> typing.Optional[channels.TextChannel]: + """Channel that the message was sent in, if known. + + Returns + ------- + typing.Optional[hikari.channels.TextChannel] + The text channel that the message was sent in, if known and cached, + otherwise, `builtins.None`. + """ + @base_events.requires_intents(intents.Intents.GUILD_MESSAGES) @attr_extensions.with_copy @@ -292,14 +356,27 @@ class GuildMessageCreateEvent(GuildMessageEvent, MessageCreateEvent): # <>. message: messages.Message = attr.ib() - # <>. @property def guild_id(self) -> snowflakes.Snowflake: # <>. # Always present in this event. - return typing.cast("snowflakes.Snowflake", self.message.guild_id) + guild_id = self.message.guild_id + assert isinstance(guild_id, snowflakes.Snowflake) + return guild_id + + @property + def author(self) -> guilds.Member: + """Member that sent the message. + + Returns + ------- + hikari.guilds.Member + The member that sent the message. This is a specialised + implementation of `hikari.users.User`. + """ + return typing.cast(guilds.Member, self.message.member) @base_events.requires_intents(intents.Intents.PRIVATE_MESSAGES) @@ -315,14 +392,25 @@ class PrivateMessageCreateEvent(PrivateMessageEvent, MessageCreateEvent): # <>. message: messages.Message = attr.ib() - # <>. @property def channel(self) -> typing.Optional[channels.DMChannel]: - # <>. + """Channel that the message was sent in, if known. + + Returns + ------- + typing.Optional[hikari.channels.DMChannel] + The DM channel that the message was sent in, if known and cached, + otherwise, `builtins.None`. + """ return self.app.cache.get_dm(self.author_id) + @property + def author(self) -> users.User: + # <>. + return self.message.author + @base_events.requires_intents(intents.Intents.GUILD_MESSAGES) @attr_extensions.with_copy @@ -337,7 +425,6 @@ class GuildMessageUpdateEvent(GuildMessageEvent, MessageUpdateEvent): # <>. message: messages.PartialMessage = attr.ib() - # <>. @property @@ -348,6 +435,21 @@ def guild_id(self) -> snowflakes.Snowflake: assert isinstance(guild_id, snowflakes.Snowflake) return guild_id + @property + def author(self) -> typing.Union[guilds.Member, users.User]: + # <>. + member = self.message.member + if member is not undefined.UNDEFINED and member is not None: + return member + member = self.app.cache.get_member(self.guild_id, self.author_id) + if member is not None: + return member + + # This should always be present. + author = self.message.author + assert isinstance(author, users.User), "expected author to be present" + return author + @base_events.requires_intents(intents.Intents.PRIVATE_MESSAGES) @attr_extensions.with_copy @@ -362,7 +464,6 @@ class PrivateMessageUpdateEvent(PrivateMessageEvent, MessageUpdateEvent): # <>. message: messages.PartialMessage = attr.ib() - # <>. @property @@ -370,6 +471,13 @@ def channel(self) -> typing.Optional[channels.DMChannel]: # <>. return self.app.cache.get_dm(self.author_id) + @property + def author(self) -> typing.Optional[users.User]: + # Always present on an update event. + author = self.message.author + assert isinstance(author, users.User), "expected author to be present on PartialMessage" + return author + @base_events.requires_intents(intents.Intents.GUILD_MESSAGES) @attr_extensions.with_copy @@ -384,14 +492,15 @@ class GuildMessageDeleteEvent(GuildMessageEvent, MessageDeleteEvent): # <>. message: messages.PartialMessage = attr.ib() - # <>. @property def guild_id(self) -> snowflakes.Snowflake: # <>. # Always present in this event. - return typing.cast("snowflakes.Snowflake", self.message.guild_id) + guild_id = self.message.guild_id + assert isinstance(guild_id, snowflakes.Snowflake), f"expected guild_id to be snowflake, not {guild_id}" + return guild_id @attr_extensions.with_copy @@ -407,7 +516,6 @@ class PrivateMessageDeleteEvent(PrivateMessageEvent, MessageDeleteEvent): # <>. message: messages.PartialMessage = attr.ib() - # <>. @property @@ -473,9 +581,24 @@ class GuildMessageBulkDeleteEvent(MessageBulkDeleteEvent): """ @property - def channel(self) -> typing.Optional[channels.GuildTextChannel]: - # <>. - return typing.cast("channels.GuildTextChannel", self.app.cache.get_guild_channel(self.channel_id)) + def channel(self) -> typing.Union[None, channels.GuildTextChannel, channels.GuildNewsChannel]: + """Get the cached channel the messages were sent in, if known. + + Returns + ------- + typing.Union[builtins.None, hikari.channels.GuildTextChannel, hikari.channels.GuildNewsChannel] + The channel the messages were sent in, or `builtins.None` if not + known/cached. + + This otherwise will always be a `hikari.channels.GuildTextChannel` + if it is a normal message, or `hikari.channels.GuildNewsChannel` if + sent in an announcement channel. + """ + channel = self.app.cache.get_guild_channel(self.channel_id) + assert channel is None or isinstance( + channel, (channels.GuildTextChannel, channels.GuildNewsChannel) + ), f"expected cached channel to be None or a GuildTextChannel/GuildNewsChannel, not {channel}" + return channel @property def guild(self) -> typing.Optional[guilds.GatewayGuild]: diff --git a/tests/hikari/events/test_message_events.py b/tests/hikari/events/test_message_events.py index a189a2d4be..9ccb980561 100644 --- a/tests/hikari/events/test_message_events.py +++ b/tests/hikari/events/test_message_events.py @@ -22,6 +22,7 @@ import mock import pytest +from hikari import channels from hikari import messages from hikari import snowflakes from hikari import users @@ -39,9 +40,11 @@ def event(self): ) return cls() - def test_channel(self, event): - result = event.channel + @pytest.mark.parametrize("guild_channel_impl", [channels.GuildTextChannel, channels.GuildNewsChannel]) + def test_channel(self, event, guild_channel_impl): + event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) + result = event.channel assert result is event.app.cache.get_guild_channel.return_value event.app.cache.get_guild_channel.assert_called_once_with(54123123123) @@ -64,13 +67,21 @@ def test_guild_when_unavailable(self, event): class TestMessageCreateEvent: @pytest.fixture() def event(self): - class MessageCreateEvent(message_events.MessageCreateEvent): - app = None - message = mock.Mock(messages.Message, guild_id=snowflakes.Snowflake(998866)) - shard = object() - channel = object() + cls = hikari_test_helpers.mock_class_namespace( + message_events.MessageCreateEvent, + app=object(), + message=mock.Mock( + spec_set=messages.Message, + author=mock.Mock( + spec_set=users.User, + ), + ), + shard=object(), + channel=object(), + author=object(), + ) - return MessageCreateEvent() + return cls() def test_message_id_property(self, event): event.message.id = 123 @@ -88,18 +99,21 @@ def test_author_id_property(self, event): class TestMessageUpdateEvent: @pytest.fixture() def event(self): - class MessageUpdateEvent(message_events.MessageUpdateEvent): - app = None - message = mock.Mock( + cls = hikari_test_helpers.mock_class_namespace( + message_events.MessageUpdateEvent, + app=object(), + message=mock.Mock( spec_set=messages.Message, author=mock.Mock( - spec_set=users.PartialUser, + spec_set=users.User, ), - ) - shard = object() - channel = object() + ), + shard=object(), + channel=object(), + author=object(), + ) - return MessageUpdateEvent() + return cls() def test_message_id_property(self, event): event.message.id = snowflakes.Snowflake(123) @@ -184,7 +198,9 @@ def event(self): message_ids=None, ) - def test_channel(self, event): + @pytest.mark.parametrize("guild_channel_impl", [channels.GuildTextChannel, channels.GuildNewsChannel]) + def test_channel(self, event, guild_channel_impl): + event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) result = event.channel assert result is event.app.cache.get_guild_channel.return_value