diff --git a/changes/1922.feature.md b/changes/1922.feature.md new file mode 100644 index 000000000..726d30bb4 --- /dev/null +++ b/changes/1922.feature.md @@ -0,0 +1 @@ +Add support for built-in polls diff --git a/hikari/__init__.py b/hikari/__init__.py index e821512f4..503c08e3f 100644 --- a/hikari/__init__.py +++ b/hikari/__init__.py @@ -85,6 +85,7 @@ from hikari.events.member_events import * from hikari.events.message_events import * from hikari.events.monetization_events import * +from hikari.events.poll_events import * from hikari.events.reaction_events import * from hikari.events.role_events import * from hikari.events.scheduled_events import * @@ -115,6 +116,7 @@ from hikari.messages import * from hikari.monetization import * from hikari.permissions import * +from hikari.polls import * from hikari.presences import * from hikari.scheduled_events import * from hikari.sessions import * diff --git a/hikari/__init__.pyi b/hikari/__init__.pyi index 4ea13a1bd..4f4c29d40 100644 --- a/hikari/__init__.pyi +++ b/hikari/__init__.pyi @@ -60,6 +60,7 @@ from hikari.events.lifetime_events import * from hikari.events.member_events import * from hikari.events.message_events import * from hikari.events.monetization_events import * +from hikari.events.poll_events import * from hikari.events.reaction_events import * from hikari.events.role_events import * from hikari.events.scheduled_events import * @@ -90,6 +91,7 @@ from hikari.locales import * from hikari.messages import * from hikari.monetization import * from hikari.permissions import * +from hikari.polls import * from hikari.presences import * from hikari.scheduled_events import * from hikari.sessions import * diff --git a/hikari/api/entity_factory.py b/hikari/api/entity_factory.py index 42b78f752..1bece870d 100644 --- a/hikari/api/entity_factory.py +++ b/hikari/api/entity_factory.py @@ -42,6 +42,7 @@ from hikari import invites as invite_models from hikari import messages as message_models from hikari import monetization as entitlement_models + from hikari import polls as poll_models from hikari import presences as presence_models from hikari import scheduled_events as scheduled_events_models from hikari import sessions as gateway_models @@ -1994,3 +1995,37 @@ def deserialize_stage_instance(self, payload: data_binding.JSONObject) -> stage_ hikari.stage_intances.StageInstance The deserialized stage instance object """ + + ############### + # POLL MODELS # + ############### + + @abc.abstractmethod + def deserialize_poll(self, payload: data_binding.JSONObject) -> poll_models.Poll: + """Parse a raw payload from Discord into a poll object. + + Parameters + ---------- + payload + The JSON payload to deserialize. + + Returns + ------- + hikari.polls.Poll + The deserialized poll object. + """ + + @abc.abstractmethod + def serialize_poll(self, poll: poll_models.PollBuilder) -> data_binding.JSONObject: + """Serialize a poll object to a json serializable dict. + + Parameters + ---------- + poll + The poll object to serialize. + + Returns + ------- + hikari.internal.data_binding.JSONObject + The serialized representation of the poll. + """ diff --git a/hikari/api/event_factory.py b/hikari/api/event_factory.py index d292e04c2..837d785d5 100644 --- a/hikari/api/event_factory.py +++ b/hikari/api/event_factory.py @@ -48,6 +48,7 @@ from hikari.events import member_events from hikari.events import message_events from hikari.events import monetization_events + from hikari.events import poll_events from hikari.events import reaction_events from hikari.events import role_events from hikari.events import scheduled_events @@ -1469,3 +1470,45 @@ def deserialize_stage_instance_delete_event( hikari.events.stage_events.StageInstanceDeleteEvent The parsed stage instance delete event object. """ + + ################ + # POLL EVENTS # + ################ + + @abc.abstractmethod + def deserialize_poll_vote_create_event( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> poll_events.PollVoteCreateEvent: + """Parse a raw payload from Discord into a poll vote create event object. + + Parameters + ---------- + shard + The shard that emitted this event. + payload + The dict payload to parse. + + Returns + ------- + hikari.events.poll_events.PollVoteCreateEvent + The parsed poll vote create event object. + """ + + @abc.abstractmethod + def deserialize_poll_vote_delete_event( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> poll_events.PollVoteDeleteEvent: + """Parse a raw payload from Discord into a poll vote delete event object. + + Parameters + ---------- + shard + The shard that emitted this event. + payload + The dict payload to parse. + + Returns + ------- + hikari.events.poll_events.PollVoteDeleteEvent + The parsed poll vote delete event object. + """ diff --git a/hikari/api/rest.py b/hikari/api/rest.py index 537f5ab41..397af3c15 100644 --- a/hikari/api/rest.py +++ b/hikari/api/rest.py @@ -49,6 +49,7 @@ from hikari import messages as messages_ from hikari import monetization from hikari import permissions as permissions_ + from hikari import polls from hikari import sessions from hikari import snowflakes from hikari import stage_instances @@ -1043,6 +1044,7 @@ async def create_message( components: undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] = undefined.UNDEFINED, embed: undefined.UndefinedOr[embeds_.Embed] = undefined.UNDEFINED, embeds: undefined.UndefinedOr[typing.Sequence[embeds_.Embed]] = undefined.UNDEFINED, + poll: undefined.UndefinedOr[polls.PollBuilder] = undefined.UNDEFINED, sticker: undefined.UndefinedOr[snowflakes.SnowflakeishOr[stickers_.PartialSticker]] = undefined.UNDEFINED, stickers: undefined.UndefinedOr[ snowflakes.SnowflakeishSequence[stickers_.PartialSticker] @@ -1120,6 +1122,8 @@ async def create_message( If provided, the message embed. embeds If provided, the message embeds. + poll + If provided, the poll to create. sticker If provided, the object or ID of a sticker to send on the message. @@ -8654,3 +8658,80 @@ async def delete_stage_instance(self, channel: snowflakes.SnowflakeishOr[channel hikari.errors.InternalServerError If an internal error occurs on Discord while handling the request. """ + + @abc.abstractmethod + async def fetch_poll_voters( + self, + channel: snowflakes.SnowflakeishOr[channels_.TextableChannel], + message: snowflakes.SnowflakeishOr[messages_.PartialMessage], + answer_id: int, + /, + *, + after: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED, + limit: undefined.UndefinedOr[int] = undefined.UNDEFINED, + ) -> typing.Sequence[users.User]: + """Fetch users that voted for a specific answer. + + Parameters + ---------- + channel + The channel the poll is in. + message + The message the poll is in. + answer_id + The answers id. + after + The votes to collect, after this user voted. + limit + The amount of votes to collect. Maximum 100, default 25 + + Returns + ------- + typing.Sequence[users.User] + An sequence of Users. + + Raises + ------ + hikari.errors.BadRequestError + If any of the fields that are passed have an invalid value. + hikari.errors.UnauthorizedError + If you are unauthorized to make the request (invalid/missing token). + hikari.errors.NotFoundError + If the entitlement was not found. + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.InternalServerError + If an internal error occurs on Discord while handling the request. + """ + + @abc.abstractmethod + async def end_poll( + self, + channel: snowflakes.SnowflakeishOr[channels_.TextableChannel], + message: snowflakes.SnowflakeishOr[messages_.PartialMessage], + /, + ) -> None: + """End a poll. + + Parameters + ---------- + channel + The channel the poll is in. + message + The message the poll is in. + + Raises + ------ + hikari.errors.BadRequestError + If any of the fields that are passed have an invalid value. + hikari.errors.UnauthorizedError + If you are unauthorized to make the request (invalid/missing token). + hikari.errors.NotFoundError + If the entitlement was not found. + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.InternalServerError + If an internal error occurs on Discord while handling the request. + """ diff --git a/hikari/events/poll_events.py b/hikari/events/poll_events.py new file mode 100644 index 000000000..be386e3bd --- /dev/null +++ b/hikari/events/poll_events.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Events related to polls.""" + +from __future__ import annotations + +__all__: typing.Sequence[str] = ("PollVoteCreateEvent", "PollVoteDeleteEvent") + +import typing + +import attrs + +from hikari import undefined +from hikari.events import shard_events +from hikari.internal import attrs_extensions + +if typing.TYPE_CHECKING: + from hikari import channels + from hikari import guilds + from hikari import snowflakes + from hikari import traits + from hikari import users + from hikari.api import shard as gateway_shard + + +@attrs_extensions.with_copy +@attrs.define(kw_only=True, weakref_slot=False) +class BasePollVoteEvent(shard_events.ShardEvent): + """Event base for any event that involves a user voting on a poll.""" + + app: traits.RESTAware = attrs.field(metadata={attrs_extensions.SKIP_DEEP_COPY: True}) + # <>. + + shard: gateway_shard.GatewayShard = attrs.field(metadata={attrs_extensions.SKIP_DEEP_COPY: True}) + # <>. + + user_id: snowflakes.Snowflake = attrs.field() + """ID of the user that added their vote to the poll.""" + + channel_id: snowflakes.Snowflake = attrs.field() + """ID of the channel that the poll is in.""" + + message_id: snowflakes.Snowflake = attrs.field() + """ID of the message that the poll is in.""" + + guild_id: undefined.UndefinedOr[snowflakes.Snowflake] = attrs.field() + """ID of the guild that the poll is in. + + This will be [hikari.undefined.UNDEFINED][] if the poll is in a DM channel. + """ + + answer_id: int = attrs.field() + """ID of the answer that the user voted for.""" + + def get_guild(self) -> typing.Optional[guilds.GatewayGuild]: + """Get the cached guild that this event relates to, if known. + + If not, return [`None`][]. + + Returns + ------- + typing.Optional[hikari.guilds.GatewayGuild] + The gateway guild this event relates to, if known. Otherwise, + this will return [`None`][]. + """ + if not isinstance(self.app, traits.CacheAware): + return None + + if isinstance(self.guild_id, undefined.UndefinedType): + return None + + return self.app.cache.get_available_guild(self.guild_id) or self.app.cache.get_unavailable_guild(self.guild_id) + + async def fetch_guild(self) -> typing.Optional[guilds.RESTGuild]: + """Perform an API call to fetch the guild that this event relates to. + + Returns + ------- + hikari.guilds.RESTGuild + The guild that this event occurred in. + + Raises + ------ + hikari.errors.UnauthorizedError + If you are unauthorized to make the request (invalid/missing token). + hikari.errors.ForbiddenError + If you are not part of the guild. + hikari.errors.NotFoundError + If the guild is not found. + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.InternalServerError + If an internal error occurs on Discord while handling the request. + """ + if isinstance(self.guild_id, undefined.UndefinedType): + return None + + return await self.app.rest.fetch_guild(self.guild_id) + + def get_channel(self) -> typing.Optional[channels.PermissibleGuildChannel]: + """Get the cached channel that this event relates to, if known. + + If not, return [`None`][]. + + Returns + ------- + typing.Optional[hikari.channels.GuildChannel] + The cached channel this event relates to. If not known, this + will return [`None`][] instead. + """ + if not isinstance(self.app, traits.CacheAware): + return None + + return self.app.cache.get_guild_channel(self.channel_id) + + async def fetch_channel(self) -> channels.GuildChannel: + """Perform an API call to fetch the details about this channel. + + !!! note + For [`hikari.events.channel_events.GuildChannelDeleteEvent`][] events, this will always raise + an exception, since the channel will have already been removed. + + Returns + ------- + hikari.channels.GuildChannel + A derivative of [`hikari.channels.GuildChannel`][]. The + actual type will vary depending on the type of channel this event + concerns. + + Raises + ------ + hikari.errors.UnauthorizedError + If you are unauthorized to make the request (invalid/missing token). + hikari.errors.ForbiddenError + If you are missing the [`hikari.permissions.Permissions.VIEW_CHANNEL`][] permission in the channel. + hikari.errors.NotFoundError + If the channel is not found. + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.InternalServerError + If an internal error occurs on Discord while handling the request. + """ + channel = await self.app.rest.fetch_channel(self.channel_id) + assert isinstance(channel, channels.GuildChannel) + return channel + + def get_user(self) -> typing.Optional[users.User]: + """Get the cached user that is typing, if known. + + Returns + ------- + typing.Optional[hikari.users.User] + The user, if known. + """ + if isinstance(self.app, traits.CacheAware): + return self.app.cache.get_user(self.user_id) + + return None + + async def fetch_user(self) -> users.User: + """Perform an API call to fetch an up-to-date image of this user. + + Returns + ------- + hikari.users.User + The user. + + Raises + ------ + hikari.errors.UnauthorizedError + If you are unauthorized to make the request (invalid/missing token). + hikari.errors.NotFoundError + If the user is not found. + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.InternalServerError + If an internal error occurs on Discord while handling the request. + """ + return await self.app.rest.fetch_user(self.user_id) + + +@attrs_extensions.with_copy +@attrs.define(kw_only=True, weakref_slot=False) +class PollVoteCreateEvent(BasePollVoteEvent): + """Event that is fired when a user add their vote to a poll. + + If the poll allows multiple selection, one event will be fired for each vote. + """ + + +@attrs_extensions.with_copy +@attrs.define(kw_only=True, weakref_slot=False) +class PollVoteDeleteEvent(BasePollVoteEvent): + """Event that is fired when a user remove their vote to a poll. + + If the poll allows multiple selection, one event will be fired for each vote. + """ diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index f119ac105..138511a7c 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -47,6 +47,7 @@ from hikari import messages as message_models from hikari import monetization as monetization_models from hikari import permissions as permission_models +from hikari import polls as poll_models from hikari import presences as presence_models from hikari import scheduled_events as scheduled_events_models from hikari import sessions as gateway_models @@ -3031,7 +3032,7 @@ def _deserialize_message_interaction(self, payload: data_binding.JSONObject) -> user=self.deserialize_user(payload["user"]), ) - def deserialize_partial_message( # noqa: C901 - Too complex + def deserialize_partial_message( # noqa: C901, CFQ001 - Too complex, Exceeds allowed length self, payload: data_binding.JSONObject ) -> message_models.PartialMessage: author: undefined.UndefinedOr[user_models.User] = undefined.UNDEFINED @@ -3069,6 +3070,10 @@ def deserialize_partial_message( # noqa: C901 - Too complex if "embeds" in payload: embeds = [self.deserialize_embed(embed) for embed in payload["embeds"]] + poll: undefined.UndefinedOr[poll_models.Poll] = undefined.UNDEFINED + if "poll" in payload: + poll = self.deserialize_poll(payload["poll"]) + reactions: undefined.UndefinedOr[list[message_models.Reaction]] = undefined.UNDEFINED if "reactions" in payload: reactions = [self._deserialize_message_reaction(reaction) for reaction in payload["reactions"]] @@ -3142,6 +3147,7 @@ def deserialize_partial_message( # noqa: C901 - Too complex is_tts=payload.get("tts", undefined.UNDEFINED), attachments=attachments, embeds=embeds, + poll=poll, reactions=reactions, is_pinned=payload.get("pinned", undefined.UNDEFINED), webhook_id=snowflakes.Snowflake(payload["webhook_id"]) if "webhook_id" in payload else undefined.UNDEFINED, @@ -3181,6 +3187,10 @@ def deserialize_message(self, payload: data_binding.JSONObject) -> message_model embeds = [self.deserialize_embed(embed) for embed in payload["embeds"]] + poll: undefined.UndefinedOr[poll_models.Poll] = undefined.UNDEFINED + if "polls" in payload: + poll = self.deserialize_poll(payload["poll"]) + if "reactions" in payload: reactions = [self._deserialize_message_reaction(reaction) for reaction in payload["reactions"]] else: @@ -3241,6 +3251,7 @@ def deserialize_message(self, payload: data_binding.JSONObject) -> message_model is_tts=payload["tts"], attachments=attachments, embeds=embeds, + poll=poll, reactions=reactions, is_pinned=payload["pinned"], webhook_id=snowflakes.Snowflake(payload["webhook_id"]) if "webhook_id" in payload else None, @@ -3774,3 +3785,67 @@ def deserialize_sku(self, payload: data_binding.JSONObject) -> monetization_mode slug=payload["slug"], flags=monetization_models.SKUFlags(payload["flags"]), ) + + ############### + # POLL MODELS # + ############### + def deserialize_poll(self, payload: data_binding.JSONObject) -> poll_models.Poll: + question = payload["question"]["text"] + expiry = time.iso8601_datetime_string_to_datetime(payload["expiry"]) + allow_multiselect = payload["allow_multiselect"] + layout_type = poll_models.PollLayoutType(payload["layout_type"]) + + answers: typing.MutableSequence[poll_models.PollAnswer] = [] + for answer_payload in payload["answers"]: + answer_id = answer_payload["answer_id"] + + emoji = answer_payload["poll_media"]["emoji"] + poll_media = poll_models.PollMedia( + text=answer_payload["poll_media"]["text"], emoji=self.deserialize_emoji(emoji) if emoji else None + ) + + answers.append(poll_models.PollAnswer(answer_id=answer_id, poll_media=poll_media)) + + results = None + if (result_payload := payload.get("result")) is not None: + is_finalized = result_payload["is_finalized"] + + answer_counts = tuple( + poll_models.PollAnswerCount( + answer_id=payload["answer_id"], count=payload["count"], me_voted=payload["me_voted"] + ) + for payload in result_payload["answer_counts"] + ) + results = poll_models.PollResult(is_finalized=is_finalized, answer_counts=answer_counts) + + return poll_models.Poll( + question=question, + answers=answers, + expiry=expiry, + allow_multiselect=allow_multiselect, + layout_type=layout_type, + results=results, + ) + + def _serialize_poll_media(self, poll_media: poll_models.PollMedia) -> data_binding.JSONObject: + serialised_poll_media: typing.MutableMapping[str, typing.Any] = {"text": poll_media.text} + + if isinstance(poll_media.emoji, emoji_models.UnicodeEmoji): + serialised_poll_media["emoji"] = {"name": poll_media.emoji.name} + elif isinstance(poll_media.emoji, emoji_models.CustomEmoji): + serialised_poll_media["emoji"] = {"name": poll_media.emoji.name, "id": poll_media.emoji.id} + + return serialised_poll_media + + def serialize_poll(self, poll: poll_models.PollBuilder) -> data_binding.JSONObject: + answers: typing.MutableSequence[typing.Any] = [] + for answer in poll.answers: + answers.append({"poll_media": self._serialize_poll_media(answer.poll_media)}) + + return { + "question": self._serialize_poll_media(poll.question), + "answers": answers, + "duration": poll.duration, + "allow_multiselect": poll.allow_multiselect, + "layout_type": poll.layout_type.value, + } diff --git a/hikari/impl/event_factory.py b/hikari/impl/event_factory.py index 515015d07..896e124d2 100644 --- a/hikari/impl/event_factory.py +++ b/hikari/impl/event_factory.py @@ -45,6 +45,7 @@ from hikari.events import member_events from hikari.events import message_events from hikari.events import monetization_events +from hikari.events import poll_events from hikari.events import reaction_events from hikari.events import role_events from hikari.events import scheduled_events @@ -979,3 +980,37 @@ def deserialize_stage_instance_delete_event( return stage_events.StageInstanceDeleteEvent( shard=shard, stage_instance=self._app.entity_factory.deserialize_stage_instance(payload) ) + + ################ + # POLL EVENTS # + ################ + + def deserialize_poll_vote_create_event( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> poll_events.PollVoteCreateEvent: + return poll_events.PollVoteCreateEvent( + app=self._app, + shard=shard, + user_id=snowflakes.Snowflake(payload["user_id"]), + channel_id=snowflakes.Snowflake(payload["channel_id"]), + message_id=snowflakes.Snowflake(payload["message_id"]), + guild_id=( + snowflakes.Snowflake(payload["guild_id"]) if payload.get("guild_id", None) else undefined.UNDEFINED + ), + answer_id=payload["answer_id"], + ) + + def deserialize_poll_vote_delete_event( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> poll_events.PollVoteDeleteEvent: + return poll_events.PollVoteDeleteEvent( + app=self._app, + shard=shard, + user_id=snowflakes.Snowflake(payload["user_id"]), + channel_id=snowflakes.Snowflake(payload["channel_id"]), + message_id=snowflakes.Snowflake(payload["message_id"]), + guild_id=( + snowflakes.Snowflake(payload["guild_id"]) if payload.get("guild_id", None) else undefined.UNDEFINED + ), + answer_id=payload["answer_id"], + ) diff --git a/hikari/impl/event_manager.py b/hikari/impl/event_manager.py index b9fed4e30..9403747e9 100644 --- a/hikari/impl/event_manager.py +++ b/hikari/impl/event_manager.py @@ -43,6 +43,7 @@ from hikari.events import member_events from hikari.events import message_events from hikari.events import monetization_events +from hikari.events import poll_events from hikari.events import reaction_events from hikari.events import role_events from hikari.events import scheduled_events @@ -906,3 +907,17 @@ async def on_stage_instance_delete( self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject ) -> None: await self.dispatch(self._event_factory.deserialize_stage_instance_delete_event(shard, payload)) + + @event_manager_base.filtered(poll_events.PollVoteCreateEvent) + async def on_message_poll_vote_add( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> None: + """See https://discord.com/developers/docs/topics/gateway-events#message-poll-vote-add for more info.""" + await self.dispatch(self._event_factory.deserialize_poll_vote_create_event(shard, payload)) + + @event_manager_base.filtered(poll_events.PollVoteDeleteEvent) + async def on_message_poll_vote_remove( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> None: + """See https://discord.com/developers/docs/topics/gateway-events#message-poll-vote-remove for more info.""" + await self.dispatch(self._event_factory.deserialize_poll_vote_delete_event(shard, payload)) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index fbceb44c9..c8829e965 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -60,6 +60,7 @@ from hikari import messages as messages_ from hikari import monetization from hikari import permissions as permissions_ +from hikari import polls from hikari import scheduled_events from hikari import snowflakes from hikari import stage_instances @@ -1387,6 +1388,7 @@ def _build_message_payload( # noqa: C901- Function too complex ] = undefined.UNDEFINED, embed: undefined.UndefinedNoneOr[embeds_.Embed] = undefined.UNDEFINED, embeds: undefined.UndefinedNoneOr[typing.Sequence[embeds_.Embed]] = undefined.UNDEFINED, + poll: undefined.UndefinedNoneOr[polls.PollBuilder] = undefined.UNDEFINED, sticker: undefined.UndefinedOr[snowflakes.SnowflakeishOr[stickers_.PartialSticker]] = undefined.UNDEFINED, stickers: undefined.UndefinedOr[ snowflakes.SnowflakeishSequence[stickers_.PartialSticker] @@ -1474,6 +1476,7 @@ def _build_message_payload( # noqa: C901- Function too complex body.put("flags", flags) body.put("embeds", serialized_embeds) body.put("components", serialized_components) + body.put("poll", poll, conversion=self._entity_factory.serialize_poll) body.put_snowflake_array("sticker_ids", (sticker,) if sticker else stickers) if not edit or not undefined.all_undefined(mentions_everyone, mentions_reply, user_mentions, role_mentions): @@ -1518,6 +1521,7 @@ async def create_message( components: undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] = undefined.UNDEFINED, embed: undefined.UndefinedOr[embeds_.Embed] = undefined.UNDEFINED, embeds: undefined.UndefinedOr[typing.Sequence[embeds_.Embed]] = undefined.UNDEFINED, + poll: undefined.UndefinedOr[polls.PollBuilder] = undefined.UNDEFINED, sticker: undefined.UndefinedOr[snowflakes.SnowflakeishOr[stickers_.PartialSticker]] = undefined.UNDEFINED, stickers: undefined.UndefinedOr[ snowflakes.SnowflakeishSequence[stickers_.PartialSticker] @@ -1544,6 +1548,7 @@ async def create_message( components=components, embed=embed, embeds=embeds, + poll=poll, sticker=sticker, stickers=stickers, tts=tts, @@ -4603,3 +4608,36 @@ async def edit_stage_instance( async def delete_stage_instance(self, channel: snowflakes.SnowflakeishOr[channels_.GuildStageChannel]) -> None: route = routes.DELETE_STAGE_INSTANCE.compile(channel=channel) await self._request(route) + + async def fetch_poll_voters( + self, + channel: snowflakes.SnowflakeishOr[channels_.TextableChannel], + message: snowflakes.SnowflakeishOr[messages_.PartialMessage], + answer_id: int, + /, + *, + after: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED, + limit: undefined.UndefinedOr[int] = undefined.UNDEFINED, + ) -> typing.Sequence[users.User]: + query = data_binding.StringMapBuilder() + + query.put("after", after) + query.put("limit", limit) + + route = routes.GET_POLL_ANSWER.compile(channel=channel, message=message, answer=answer_id) + + response = await self._request(route, query=query) + + assert isinstance(response, list) + + return [self._entity_factory.deserialize_user(payload) for payload in response] + + async def end_poll( + self, + channel: snowflakes.SnowflakeishOr[channels_.TextableChannel], + message: snowflakes.SnowflakeishOr[messages_.PartialMessage], + /, + ) -> None: + route = routes.POST_END_POLL.compile(channel=channel, message=message) + + await self._request(route) diff --git a/hikari/intents.py b/hikari/intents.py index fd9d7c4de..0f843fd4c 100644 --- a/hikari/intents.py +++ b/hikari/intents.py @@ -334,6 +334,20 @@ class Intents(enums.Flag): * `GUILD_SCHEDULED_EVENT_USER_REMOVE` """ + GUILD_MESSAGE_POLLS = 1 << 24 + """Subscribes to the events listed below. + + * `MESSAGE_POLL_VOTE_ADD` + * `MESSAGE_POLL_VOTE_REMOVE` + """ + + DIRECT_MESSAGE_POLLS = 1 << 25 + """Subscribes to the events listed below. + + * `MESSAGE_POLL_VOTE_ADD` + * `MESSAGE_POLL_VOTE_REMOVE` + """ + # Annoyingly, enums hide classmethods and staticmethods from __dir__ in # EnumMeta which means if I make methods to generate these, then stuff # will not be documented by pdoc. Alas, my dream of being smart with @@ -351,6 +365,7 @@ class Intents(enums.Flag): | GUILD_MESSAGE_TYPING | GUILD_MODERATION | GUILD_SCHEDULED_EVENTS + | GUILD_MESSAGE_POLLS ) """All unprivileged guild-related intents.""" @@ -373,12 +388,15 @@ class Intents(enums.Flag): use. """ - ALL_DMS = DM_MESSAGES | DM_MESSAGE_TYPING | DM_MESSAGE_REACTIONS + ALL_DMS = DM_MESSAGES | DM_MESSAGE_TYPING | DM_MESSAGE_REACTIONS | DIRECT_MESSAGE_POLLS """All direct message channel (non-guild bound) intents.""" ALL_MESSAGES = DM_MESSAGES | GUILD_MESSAGES """All message intents.""" + ALL_POLLS = GUILD_MESSAGE_POLLS | DIRECT_MESSAGE_POLLS + """All poll intents.""" + ALL_MESSAGE_REACTIONS = DM_MESSAGE_REACTIONS | GUILD_MESSAGE_REACTIONS """All message reaction intents.""" diff --git a/hikari/internal/cache.py b/hikari/internal/cache.py index b819d812b..95316b1c6 100644 --- a/hikari/internal/cache.py +++ b/hikari/internal/cache.py @@ -72,6 +72,7 @@ from hikari import applications from hikari import channels as channels_ from hikari import components as components_ + from hikari import polls as polls_ from hikari import traits from hikari import users as users_ from hikari.interactions import base_interactions @@ -722,6 +723,7 @@ class MessageData(BaseData[messages.Message]): attachments: tuple[messages.Attachment, ...] = attrs.field() embeds: tuple[embeds_.Embed, ...] = attrs.field() reactions: tuple[messages.Reaction, ...] = attrs.field() + poll: undefined.UndefinedOr[polls_.Poll] = attrs.field() is_pinned: bool = attrs.field() webhook_id: typing.Optional[snowflakes.Snowflake] = attrs.field() type: typing.Union[messages.MessageType, int] = attrs.field() @@ -790,6 +792,7 @@ def build_from_entity( mentions_everyone=message.mentions_everyone, attachments=tuple(map(copy.copy, message.attachments)), embeds=tuple(map(_copy_embed, message.embeds)), + poll=message.poll, reactions=tuple(map(copy.copy, message.reactions)), is_pinned=message.is_pinned, webhook_id=message.webhook_id, @@ -836,6 +839,7 @@ def build_entity(self, app: traits.RESTAware, /) -> messages.Message: mentions_everyone=self.mentions_everyone, attachments=tuple(map(copy.copy, self.attachments)), embeds=tuple(map(_copy_embed, self.embeds)), + poll=self.poll, reactions=tuple(map(copy.copy, self.reactions)), is_pinned=self.is_pinned, webhook_id=self.webhook_id, diff --git a/hikari/internal/routes.py b/hikari/internal/routes.py index 484c41fbb..114d218a8 100644 --- a/hikari/internal/routes.py +++ b/hikari/internal/routes.py @@ -345,6 +345,10 @@ def compile_to_file( PATCH_STAGE_INSTANCE: typing.Final[Route] = Route(PATCH, "/stage-instances/{channel}") DELETE_STAGE_INSTANCE: typing.Final[Route] = Route(DELETE, "/stage-instances/{channel}") +# Polls +GET_POLL_ANSWER: typing.Final[Route] = Route(GET, "/channels/{channel}/polls/{message}/answer/{answer}") +POST_END_POLL: typing.Final[Route] = Route(POST, "/channels/{channel}/polls/{message}/expire") + # Reactions GET_REACTIONS: typing.Final[Route] = Route(GET, "/channels/{channel}/messages/{message}/reactions/{emoji}") DELETE_ALL_REACTIONS: typing.Final[Route] = Route(DELETE, "/channels/{channel}/messages/{message}/reactions") diff --git a/hikari/messages.py b/hikari/messages.py index 31cc07434..23293e5e8 100644 --- a/hikari/messages.py +++ b/hikari/messages.py @@ -57,6 +57,7 @@ from hikari import channels as channels_ from hikari import embeds as embeds_ from hikari import emojis as emojis_ + from hikari import polls as polls_ from hikari import stickers as stickers_ from hikari import users as users_ from hikari.api import special_endpoints @@ -538,6 +539,9 @@ class PartialMessage(snowflakes.Unique): embeds: undefined.UndefinedOr[typing.Sequence[embeds_.Embed]] = attrs.field(hash=False, eq=False, repr=False) """The message embeds.""" + poll: undefined.UndefinedOr[polls_.Poll] = attrs.field(hash=False, eq=False, repr=False) + """The message poll.""" + reactions: undefined.UndefinedOr[typing.Sequence[Reaction]] = attrs.field(hash=False, eq=False, repr=False) """The message reactions.""" @@ -764,6 +768,7 @@ async def edit( ] = undefined.UNDEFINED, embed: undefined.UndefinedNoneOr[embeds_.Embed] = undefined.UNDEFINED, embeds: undefined.UndefinedNoneOr[typing.Sequence[embeds_.Embed]] = undefined.UNDEFINED, + poll: undefined.UndefinedOr[polls_.PollBuilder] = undefined.UNDEFINED, mentions_everyone: undefined.UndefinedOr[bool] = undefined.UNDEFINED, mentions_reply: undefined.UndefinedOr[bool] = undefined.UNDEFINED, user_mentions: undefined.UndefinedOr[ @@ -936,6 +941,7 @@ async def respond( components: undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] = undefined.UNDEFINED, embed: undefined.UndefinedOr[embeds_.Embed] = undefined.UNDEFINED, embeds: undefined.UndefinedOr[typing.Sequence[embeds_.Embed]] = undefined.UNDEFINED, + poll: undefined.UndefinedOr[polls_.PollBuilder] = undefined.UNDEFINED, sticker: undefined.UndefinedOr[snowflakes.SnowflakeishOr[stickers_.PartialSticker]] = undefined.UNDEFINED, stickers: undefined.UndefinedOr[ snowflakes.SnowflakeishSequence[stickers_.PartialSticker] @@ -1103,6 +1109,7 @@ async def respond( components=components, embed=embed, embeds=embeds, + poll=poll, sticker=sticker, stickers=stickers, tts=tts, diff --git a/hikari/polls.py b/hikari/polls.py new file mode 100644 index 000000000..4c8ceb62f --- /dev/null +++ b/hikari/polls.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Polls and poll-related objects.""" # TODO: Improve this docstring + +from __future__ import annotations + +__all__: typing.Sequence[str] = ( + "PollMedia", + "PollAnswer", + "PollResult", + "PollAnswerCount", + "PollLayoutType", + "PartialPoll", + "PollBuilder", + "Poll", +) + +import typing + +import attrs + +from hikari import emojis +from hikari import undefined +from hikari.internal import attrs_extensions +from hikari.internal import enums + +if typing.TYPE_CHECKING: + import datetime + + +def _ensure_optional_emoji(emoji: typing.Optional[typing.Union[str, emojis.Emoji]]) -> emojis.Emoji | None: + """Ensure the object is a [hikari.emojis.Emoji][].""" + if emoji is not None: + return emojis.Emoji.parse(emoji) if isinstance(emoji, str) else emoji + return None + + +@attrs_extensions.with_copy +@attrs.define(hash=False, kw_only=True, weakref_slot=False) +class PollMedia: + """Common object backing a poll's questions and answers.""" + + text: typing.Optional[str] = attrs.field(default=None, repr=True) + """The text of the element, or [`None`][] if not present.""" + + emoji: typing.Optional[emojis.Emoji] = attrs.field(default=None, repr=True) + """The emoji of the element, or [`None`][] if not present.""" + + +@attrs_extensions.with_copy +@attrs.define(hash=False, kw_only=True, weakref_slot=False) +class PollAnswer: + """Represents an answer to a poll.""" + + answer_id: int = attrs.field(repr=True) + """The ID that labels this answer.""" + + poll_media: PollMedia = attrs.field(repr=True) + """The [media][hikari.polls.PollMedia] associated with this answer.""" + + +@attrs_extensions.with_copy +@attrs.define(hash=False, kw_only=True, weakref_slot=False) +class PollResult: + """Represents a poll result.""" + + is_finalized: bool = attrs.field(repr=True) + """Whether the poll is finalized and the votes are precisely counted.""" + + answer_counts: typing.Sequence[PollAnswerCount] = attrs.field(repr=True) + """The counts for each answer.""" + + +@attrs_extensions.with_copy +@attrs.define(hash=False, kw_only=True, weakref_slot=False) +class PollAnswerCount: + """Represents the count of a poll answer.""" + + answer_id: int = attrs.field(repr=True) + """The ID of the answer.""" + + count: int = attrs.field(repr=True) + """The number of votes for this answer.""" + + me_voted: bool = attrs.field(repr=True) + """Whether the current user voted for this answer.""" + + +class PollLayoutType(int, enums.Enum): + """Layout of a poll.""" + + DEFAULT = 1 + """The default layout of a poll.""" + + +class PartialPoll: + """Base class for all poll objects.""" + + __slots__: typing.Sequence[str] = ("_question", "_answers", "_allow_multiselect", "_layout_type", "_counter") + + def __init__(self, question: str, allow_multiselect: bool, layout_type: typing.Union[int, PollLayoutType]): + self._question = PollMedia(text=question) # Only text is supported for question + self._allow_multiselect = allow_multiselect + self._layout_type = layout_type + + @property + def question(self) -> PollMedia: + """Returns the question of the poll.""" + return self._question + + @question.setter + def question(self, value: str) -> None: + self._question = PollMedia(text=value) + + @property + def allow_multiselect(self) -> bool: + """Returns whether the poll allows multiple answers.""" + return self._allow_multiselect + + @allow_multiselect.setter + def allow_multiselect(self, value: bool) -> None: + self._allow_multiselect = value + + @property + def layout_type(self) -> PollLayoutType: + """Returns the layout type of the poll.""" + return PollLayoutType(self._layout_type) + + @layout_type.setter + def layout_type(self, value: typing.Union[int, PollLayoutType]) -> None: + self._layout_type = value + + +class PollBuilder(PartialPoll): + """Poll Builder. + + Build a new poll to send as a message to discord. + + Parameters + ---------- + question + The question you wish to ask. + """ # TODO: Improve this docstring + + __slots__: typing.Sequence[str] = ("_duration",) + + def __init__( + self, + question: str, + duration: int, + allow_multiselect: bool, + layout_type: typing.Union[int, PollLayoutType] = PollLayoutType.DEFAULT, + ): + super().__init__(question=question, allow_multiselect=allow_multiselect, layout_type=layout_type) + self._duration = duration + + # Answer is required, but we want users to user add_answer() instead of + # providing at initialization. + # + # Considering that answer ID can be arbitrary, `list`-based approaches + # like that of hikari.embeds.Embed._fields, while feasible to implement, + # would decrease long-term maintainability. I'm opting to use a `dict` + # here to simplify the implementation with some performance trade-off + # due to hashmap overhead. + self._answers: typing.MutableSequence[PollAnswer] = [] + + @property + def duration(self) -> int: + """Returns the duration of the poll.""" + return self._duration + + @duration.setter + def duration(self, value: int) -> None: + self._duration = value + + @property + def answers(self) -> typing.Iterable[PollAnswer]: + """Returns the answers of the poll. + + !!! note + Use [`hikari.polls.PollBuilder.add_answer`][] to add a new answer, + [`hikari.polls.PollBuilder.edit_answer`][] to edit an existing answer, or + [`hikari.polls.PollBuilder.remove_answer`][] to remove an answer. + """ + return self._answers + + def add_answer(self, text: str, emoji: typing.Optional[emojis.Emoji]) -> PartialPoll: + """ + Add an answer to the poll. + + Parameters + ---------- + text + The text of the answer to add. + emoji + The emoji associated with the answer. + + Returns + ------- + PartialPoll + This poll. Allows for call chaining. + """ + self._answers.append( + PollAnswer(answer_id=-1, poll_media=PollMedia(text=text, emoji=_ensure_optional_emoji(emoji))) + ) + + return self + + def edit_answer( + self, + index: int, + *, + text: typing.Optional[str] = None, + emoji: undefined.UndefinedNoneOr[typing.Union[str, emojis.Emoji]] = undefined.UNDEFINED, + ) -> PartialPoll: + """ + Edit an answer in the poll. + + Parameters + ---------- + index + The index of the answer you want to edit. + text + The new text of the answer. + emoji + The new emoji associated with the answer. + + Returns + ------- + PartialPoll + This poll. Allows for call chaining. + """ + answer = self._answers[index] + if text: + answer.poll_media.text = text + if emoji is not undefined.UNDEFINED: + answer.poll_media.emoji = _ensure_optional_emoji(emoji) + + return self + + def remove_answer(self, answer_id: int) -> PartialPoll: + """ + Remove an answer from the poll. + + Parameters + ---------- + answer_id + The ID of the answer to remove. + + Returns + ------- + PartialPoll + This poll. Allows for call chaining. + + Raises + ------ + KeyError + Raised when the answer ID is not found in the poll. + """ + del self._answers[answer_id] + + return self + + +class Poll(PartialPoll): + """Represents an existing poll.""" + + __slots__: typing.Sequence[str] = ("_expiry", "_results") + + def __init__( + self, + question: str, + answers: typing.Sequence[PollAnswer], + allow_multiselect: bool, + expiry: datetime.datetime, + results: typing.Optional[PollResult], + layout_type: typing.Union[int, PollLayoutType] = PollLayoutType.DEFAULT, + ): + super().__init__(question=question, allow_multiselect=allow_multiselect, layout_type=layout_type) + self._answers = answers + self._expiry = expiry + self._results = results + + @property + def answers(self) -> typing.Iterable[PollAnswer]: + """Returns the answers of the poll.""" + return self._answers + + @property + def expiry(self) -> datetime.datetime: + """Returns whether the poll has expired.""" + return self._expiry + + @property + def results(self) -> typing.Optional[PollResult]: + """Returns the result of the poll. + + !!! note + According to Discord, their backend does not always return `results`, + this is meant to be interpreted as "unknown result" rather than "no + result". Please refer to the + [official documentation](https://discord.com/developers/docs/resources/poll#poll-results-object) + for more information. + """ + return self._results diff --git a/tests/hikari/impl/test_cache.py b/tests/hikari/impl/test_cache.py index 65b0461cb..b3fb8547f 100644 --- a/tests/hikari/impl/test_cache.py +++ b/tests/hikari/impl/test_cache.py @@ -30,6 +30,7 @@ from hikari import guilds from hikari import invites from hikari import messages +from hikari import polls from hikari import snowflakes from hikari import stickers from hikari import undefined @@ -2759,6 +2760,7 @@ def test__build_message(self, cache_impl): mock_attachment = mock.MagicMock(messages.Attachment) mock_embed_field = mock.MagicMock(embeds.EmbedField) mock_embed = mock.MagicMock(embeds.Embed, fields=(mock_embed_field,)) + mock_poll = mock.MagicMock(polls.Poll) mock_sticker = mock.MagicMock(stickers.PartialSticker) mock_reaction = mock.MagicMock(messages.Reaction) mock_activity = mock.MagicMock(messages.MessageActivity) @@ -2788,6 +2790,7 @@ def test__build_message(self, cache_impl): mentions_everyone=False, attachments=(mock_attachment,), embeds=(mock_embed,), + poll=mock_poll, reactions=(mock_reaction,), is_pinned=False, webhook_id=snowflakes.Snowflake(3123123), @@ -2877,6 +2880,7 @@ def test__build_message_with_null_fields(self, cache_impl): mentions_everyone=undefined.UNDEFINED, attachments=(), embeds=(), + poll=None, reactions=(), is_pinned=False, webhook_id=None, diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index d89ba0458..5666241dd 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -42,6 +42,7 @@ from hikari import messages as message_models from hikari import monetization as monetization_models from hikari import permissions as permission_models +from hikari import polls from hikari import presences as presence_models from hikari import scheduled_events as scheduled_event_models from hikari import sessions as gateway_models @@ -5652,6 +5653,7 @@ def message_payload( custom_emoji_payload, partial_application_payload, embed_payload, + poll_payload, referenced_message, action_row_payload, partial_sticker_payload, @@ -5678,6 +5680,7 @@ def message_payload( "mention_channels": [{"id": "456", "guild_id": "678", "type": 1, "name": "hikari-testing"}], "attachments": [attachment_payload], "embeds": [embed_payload], + "poll": poll_payload, "reactions": [{"emoji": custom_emoji_payload, "count": 100, "me": True}], "pinned": True, "webhook_id": "1234", @@ -7232,3 +7235,60 @@ def test_deserialize_stage_instance(self, entity_factory_impl, stage_instance_pa assert stage_instance.topic == "Testing Testing, 123" assert stage_instance.privacy_level == stage_instance_models.StageInstancePrivacyLevel.GUILD_ONLY assert stage_instance.discoverable_disabled is False + + ########### + # POLLS # + ########### + + @pytest.fixture + def poll_payload(self): + return { + "question": {"text": "fruit"}, + "answers": [ + {"answer_id": 1, "poll_media": {"text": "apple", "emoji": {"name": "🍏"}}}, + {"answer_id": 2, "poll_media": {"text": "banana", "emoji": {"name": "🍌"}}}, + {"answer_id": 3, "poll_media": {"text": "carrot", "emoji": {"name": "🥕"}}}, + ], + "expiry": "2021-02-01T18:03:20.888000+00:00", + "allow_multiselect": True, + "layout_type": 1, + } + + def test_deserialize_poll(self, entity_factory_impl, poll_payload): + poll = entity_factory_impl.deserialize_poll(poll_payload) + + assert poll.question.text == "fruit" + assert poll.question.emoji is None + assert len(poll.answers) == 3 + assert poll.answers[0].answer_id == 1 + assert poll.answers[0].poll_media.text == "apple" + assert poll.answers[0].poll_media.emoji == "🍏" + assert poll.answers[1].answer_id == 2 + assert poll.answers[1].poll_media.text == "banana" + assert poll.answers[1].poll_media.emoji == "🍌" + assert poll.answers[2].answer_id == 3 + assert poll.answers[2].poll_media.text == "carrot" + assert poll.answers[2].poll_media.emoji == "🥕" + + assert poll.expiry == datetime.datetime(2021, 2, 1, 18, 3, 20, 888000, tzinfo=datetime.timezone.utc) + + def test_serialize_poll(self, entity_factory_impl): + poll = polls.PollBuilder("fruit", 1, allow_multiselect=True, layout_type=polls.PollLayoutType.DEFAULT) + + poll.add_answer("apple", "🍏") + poll.add_answer("banana", "🍌") + poll.add_answer("carrot", "🥕") + + payload = entity_factory_impl.serialize_poll(poll) + + assert payload == { + "question": {"text": "fruit"}, + "answers": [ + {"poll_media": {"text": "apple", "emoji": {"name": "🍏"}}}, + {"poll_media": {"text": "banana", "emoji": {"name": "🍌"}}}, + {"poll_media": {"text": "carrot", "emoji": {"name": "🥕"}}}, + ], + "duration": 1, + "allow_multiselect": True, + "layout_type": 1, + } diff --git a/tests/hikari/impl/test_event_factory.py b/tests/hikari/impl/test_event_factory.py index e7ef4cf33..1fb766963 100644 --- a/tests/hikari/impl/test_event_factory.py +++ b/tests/hikari/impl/test_event_factory.py @@ -39,6 +39,7 @@ from hikari.events import member_events from hikari.events import message_events from hikari.events import monetization_events +from hikari.events import poll_events from hikari.events import reaction_events from hikari.events import role_events from hikari.events import scheduled_events @@ -1571,3 +1572,33 @@ def test_deserialize_stage_instance_delete_event(self, event_factory, mock_app, assert event.shard is mock_shard assert event.app is event.stage_instance.app assert event.stage_instance == mock_app.entity_factory.deserialize_stage_instance.return_value + + ########### + # POLLS # + ########### + + def test_deserialize_poll_vote_create_event(self, event_factory, mock_app, mock_shard): + payload = { + "user_id": "3847382", + "channel_id": "4598743", + "message_id": "458437954", + "guild_id": "3589273", + "answer_id": 1, + } + + event = event_factory.deserialize_poll_vote_create_event(mock_shard, payload) + + assert isinstance(event, poll_events.PollVoteCreateEvent) + + def test_deserialize_poll_vote_delete_event(self, event_factory, mock_app, mock_shard): + payload = { + "user_id": "3847382", + "channel_id": "4598743", + "message_id": "458437954", + "guild_id": "3589273", + "answer_id": 1, + } + + event = event_factory.deserialize_poll_vote_delete_event(mock_shard, payload) + + assert isinstance(event, poll_events.PollVoteDeleteEvent) diff --git a/tests/hikari/impl/test_event_manager.py b/tests/hikari/impl/test_event_manager.py index fa620ce0c..41662ca0b 100644 --- a/tests/hikari/impl/test_event_manager.py +++ b/tests/hikari/impl/test_event_manager.py @@ -1761,3 +1761,35 @@ async def test_on_stage_instance_delete( event_manager_impl.dispatch.assert_awaited_once_with( event_factory.deserialize_stage_instance_delete_event.return_value ) + + @pytest.mark.asyncio + async def test_on_message_poll_vote_create( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: mock.Mock, + event_factory: event_factory_.EventFactory, + ): + mock_payload = mock.Mock() + + await event_manager_impl.on_message_poll_vote_add(shard, mock_payload) + + event_factory.deserialize_poll_vote_create_event.assert_called_once_with(shard, mock_payload) + event_manager_impl.dispatch.assert_awaited_once_with( + event_factory.deserialize_poll_vote_create_event.return_value + ) + + @pytest.mark.asyncio + async def test_on_message_poll_vote_delete( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: mock.Mock, + event_factory: event_factory_.EventFactory, + ): + mock_payload = mock.Mock() + + await event_manager_impl.on_message_poll_vote_remove(shard, mock_payload) + + event_factory.deserialize_poll_vote_delete_event.assert_called_once_with(shard, mock_payload) + event_manager_impl.dispatch.assert_awaited_once_with( + event_factory.deserialize_poll_vote_delete_event.return_value + ) diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index b69d752ab..d204a15df 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -2520,13 +2520,14 @@ async def test_fetch_message(self, rest_client): rest_client._request.assert_awaited_once_with(expected_route) rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) - async def test_create_message_when_form(self, rest_client): + async def test_create_message_when_form(self, rest_client: rest.RESTClientImpl): attachment_obj = object() attachment_obj2 = object() component_obj = object() component_obj2 = object() embed_obj = object() embed_obj2 = object() + poll_obj = object() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") @@ -2543,6 +2544,7 @@ async def test_create_message_when_form(self, rest_client): components=[component_obj2], embed=embed_obj, embeds=[embed_obj2], + poll=poll_obj, sticker=54234, stickers=[564123, 431123], tts=True, @@ -2563,6 +2565,7 @@ async def test_create_message_when_form(self, rest_client): components=[component_obj2], embed=embed_obj, embeds=[embed_obj2], + poll=poll_obj, sticker=54234, stickers=[564123, 431123], tts=True, @@ -2622,6 +2625,7 @@ async def test_create_message_when_no_form(self, rest_client): components=[component_obj2], embed=embed_obj, embeds=[embed_obj2], + poll=undefined.UNDEFINED, sticker=543345, stickers=[123321, 6572345], tts=True, @@ -6653,3 +6657,31 @@ async def test_delete_stage_instance(self, rest_client): await rest_client.delete_stage_instance(channel=StubModel(7334)) rest_client._request.assert_called_once_with(expected_route) + + async def test_fetch_poll_voters(self, rest_client: rest.RESTClientImpl): + expected_route = routes.GET_POLL_ANSWER.compile( + channel=StubModel(45874392), message=StubModel(398475938475), answer=StubModel(4) + ) + + rest_client._request = mock.AsyncMock(return_value=[{"id": "1234"}]) + + with mock.patch.object( + rest_client._entity_factory, "deserialize_user", return_value=mock.Mock() + ) as patched_deserialize_user: + await rest_client.fetch_poll_voters( + StubModel(45874392), StubModel(398475938475), StubModel(4), after=StubModel(43587935), limit=6 + ) + + patched_deserialize_user.assert_called_once_with({"id": "1234"}) + + rest_client._request.assert_awaited_once_with(expected_route, query={"after": "43587935", "limit": "6"}) + + async def test_end_poll(self, rest_client: rest.RESTClientImpl): + expected_route = routes.POST_END_POLL.compile( + channel=StubModel(45874392), message=StubModel(398475938475), answer=StubModel(4) + ) + rest_client._request = mock.AsyncMock() + + await rest_client.end_poll(StubModel(45874392), StubModel(398475938475)) + + rest_client._request.assert_awaited_once_with(expected_route) diff --git a/tests/hikari/test_messages.py b/tests/hikari/test_messages.py index 74d528d8d..247f7c390 100644 --- a/tests/hikari/test_messages.py +++ b/tests/hikari/test_messages.py @@ -107,6 +107,7 @@ def message(): mentions_everyone=False, attachments=(), embeds=(), + poll=object(), reactions=(), is_pinned=True, webhook_id=None, @@ -175,6 +176,7 @@ async def test_edit(self, message): message.channel_id = 456 embed = object() embeds = [object(), object()] + poll = object() component = object() components = object(), object() attachment = object() @@ -183,6 +185,7 @@ async def test_edit(self, message): content="test content", embed=embed, embeds=embeds, + poll=poll, attachment=attachment, attachments=[attachment, attachment], component=component, @@ -216,6 +219,7 @@ async def test_respond(self, message): message.channel_id = 456 embed = object() embeds = [object(), object()] + poll = object() roles = [object()] attachment = object() attachments = [object()] @@ -226,6 +230,7 @@ async def test_respond(self, message): content="test content", embed=embed, embeds=embeds, + poll=poll, attachment=attachment, attachments=attachments, component=component, @@ -246,6 +251,7 @@ async def test_respond(self, message): content="test content", embed=embed, embeds=embeds, + poll=poll, attachment=attachment, attachments=attachments, component=component, @@ -272,6 +278,7 @@ async def test_respond_when_reply_is_True(self, message): content=undefined.UNDEFINED, embed=undefined.UNDEFINED, embeds=undefined.UNDEFINED, + poll=undefined.UNDEFINED, attachment=undefined.UNDEFINED, attachments=undefined.UNDEFINED, component=undefined.UNDEFINED, @@ -298,6 +305,7 @@ async def test_respond_when_reply_is_False(self, message): content=undefined.UNDEFINED, embed=undefined.UNDEFINED, embeds=undefined.UNDEFINED, + poll=undefined.UNDEFINED, attachment=undefined.UNDEFINED, attachments=undefined.UNDEFINED, component=undefined.UNDEFINED, diff --git a/tests/hikari/test_polls.py b/tests/hikari/test_polls.py new file mode 100644 index 000000000..ea5799118 --- /dev/null +++ b/tests/hikari/test_polls.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE.\ +from __future__ import annotations + +from hikari import polls + + +class TestPollBuilder: + def test_add_answer(self): + poll = polls.PollBuilder("question", 1, False) + + poll.add_answer("beanos", None) + + assert len(list(poll.answers)) == 1 + + assert list(poll.answers)[0] == polls.PollAnswer( + answer_id=-1, poll_media=polls.PollMedia(text="beanos", emoji=None) + ) + + def test_edit_answer(self): + poll = polls.PollBuilder("question", 1, False) + + poll.add_answer("beanos", None) + + assert len(list(poll.answers)) == 1 + + assert list(poll.answers)[0] == polls.PollAnswer( + answer_id=-1, poll_media=polls.PollMedia(text="beanos", emoji=None) + ) + + poll.edit_answer(0, emoji="🫘") + + assert list(poll.answers)[0] == polls.PollAnswer( + answer_id=-1, poll_media=polls.PollMedia(text="beanos", emoji="🫘") + ) + + def test_remove_answer(self): + poll = polls.PollBuilder("question", 1, False) + + poll.add_answer("beanos", None) + + assert len(list(poll.answers)) == 1 + + assert list(poll.answers)[0] == polls.PollAnswer( + answer_id=-1, poll_media=polls.PollMedia(text="beanos", emoji=None) + ) + + poll.remove_answer(0) + + assert len(list(poll.answers)) == 0