Skip to content

Commit

Permalink
Merge pull request #167 from nekokatt/task/161-typing-event-user
Browse files Browse the repository at this point in the history
Implemented `user` member on typing events.
  • Loading branch information
Nekokatt authored Sep 12, 2020
2 parents 0f0d2b9 + 2e64418 commit 5a12196
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 55 deletions.
106 changes: 86 additions & 20 deletions hikari/events/typing_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

from hikari import channels
from hikari import intents
from hikari import users
from hikari.api import special_endpoints
from hikari.events import base_events
from hikari.events import shard_events
from hikari.utilities import attr_extensions
Expand All @@ -45,7 +47,6 @@
from hikari import guilds
from hikari import snowflakes
from hikari import traits
from hikari import users
from hikari.api import shard as gateway_shard


Expand Down Expand Up @@ -97,6 +98,18 @@ def channel(self) -> typing.Optional[channels.TextChannel]:
The channel, if known.
"""

@property
@abc.abstractmethod
def user(self) -> typing.Optional[users.User]:
"""Get the cached user that is typing, if known.
Returns
-------
typing.Optional[hikari.users.User]
The user, if known.
"""

@abc.abstractmethod
async def fetch_channel(self) -> channels.TextChannel:
"""Perform an API call to fetch an up-to-date image of this channel.
Expand All @@ -105,10 +118,8 @@ async def fetch_channel(self) -> channels.TextChannel:
hikari.channels.TextChannel
The channel.
"""
channel = await self.app.rest.fetch_channel(self.channel_id)
assert isinstance(channel, channels.TextChannel)
return channel

@abc.abstractmethod
async def fetch_user(self) -> users.User:
"""Perform an API call to fetch an up-to-date image of this user.
Expand All @@ -117,7 +128,17 @@ async def fetch_user(self) -> users.User:
hikari.users.User
The user.
"""
return await self.app.rest.fetch_user(self.user_id)

def trigger_typing(self) -> special_endpoints.TypingIndicator:
"""Return a typing indicator for this channel that can be awaited.
Returns
-------
hikari.api.special_endpoints.TypingIndicator
A typing indicator context manager and awaitable to trigger typing
in a channel with.
"""
return self.app.rest.trigger_typing(self.channel_id)


@base_events.requires_intents(intents.Intents.GUILD_MESSAGE_TYPING)
Expand All @@ -135,9 +156,6 @@ class GuildTypingEvent(TypingEvent):
channel_id: snowflakes.Snowflake = attr.ib()
# <<inherited docstring from TypingEvent>>.

user_id: snowflakes.Snowflake = attr.ib(repr=True)
# <<inherited docstring from TypingEvent>>.

timestamp: datetime.datetime = attr.ib(repr=False)
# <<inherited docstring from TypingEvent>>.

Expand All @@ -150,19 +168,32 @@ class GuildTypingEvent(TypingEvent):
The ID of the guild that relates to this event.
"""

member: guilds.Member = attr.ib(repr=False)
user: guilds.Member = attr.ib(repr=False)
"""Member object of the user who triggered this typing event.
Unlike on `PrivateTypingEvent` instances, Discord will always send
this field in any payload.
Returns
-------
hikari.guilds.Member
Member of the user who triggered this typing event.
"""

@property
def channel(self) -> typing.Optional[channels.GuildTextChannel]:
# <<inherited docstring from TypingEvent>>.
return typing.cast("channels.GuildTextChannel", self.app.cache.get_guild_channel(self.channel_id))
def channel(self) -> typing.Union[channels.GuildTextChannel, channels.GuildNewsChannel]:
"""Get the cached channel object this typing event occurred in.
Returns
-------
typing.Union[hikari.channels.GuildTextChannel, hikari.channels.GuildNewsChannel]
The channel.
"""
channel = self.app.cache.get_guild_channel(self.channel_id)
assert isinstance(
channel, (channels.GuildTextChannel, channels.GuildNewsChannel)
), f"expected GuildTextChannel or GuildNewsChannel from cache, got {channel}"
return channel

@property
def guild(self) -> typing.Optional[guilds.GatewayGuild]:
Expand All @@ -177,10 +208,24 @@ def guild(self) -> typing.Optional[guilds.GatewayGuild]:
"""
return self.app.cache.get_available_guild(self.guild_id) or self.app.cache.get_unavailable_guild(self.guild_id)

if typing.TYPE_CHECKING:
@property
def user_id(self) -> snowflakes.Snowflake:
# <<inherited docstring from TypingEvent>>.
return self.user.id

async def fetch_channel(self) -> channels.GuildTextChannel:
...
async def fetch_channel(self) -> typing.Union[channels.GuildTextChannel, channels.GuildNewsChannel]:
"""Perform an API call to fetch an up-to-date image of this channel.
Returns
-------
typing.Union[hikari.channels.GuildTextChannel, hikari.channels.GuildNewsChannel]
The channel.
"""
channel = await self.app.rest.fetch_channel(self.channel_id)
assert isinstance(
channel, (channels.GuildTextChannel, channels.GuildNewsChannel)
), f"expected GuildTextChannel or GuildNewsChannel from API, got {channel}"
return channel

async def fetch_guild(self) -> guilds.Guild:
"""Perform an API call to fetch an up-to-date image of this guild.
Expand All @@ -202,7 +247,7 @@ async def fetch_guild_preview(self) -> guilds.GuildPreview:
"""
return await self.app.rest.fetch_guild_preview(self.guild_id)

async def fetch_member(self) -> guilds.Member:
async def fetch_user(self) -> guilds.Member:
"""Perform an API call to fetch an up-to-date image of this member.
Returns
Expand Down Expand Up @@ -232,7 +277,6 @@ class PrivateTypingEvent(TypingEvent):
# <<inherited docstring from TypingEvent>>.

timestamp: datetime.datetime = attr.ib(repr=False)

# <<inherited docstring from TypingEvent>>.

@property
Expand All @@ -246,7 +290,29 @@ def channel(self) -> typing.Optional[channels.DMChannel]:
"""
return self.app.cache.get_dm(self.user_id)

if typing.TYPE_CHECKING:
@property
def user(self) -> typing.Optional[users.User]:
# <<inherited docstring from TypingEvent>>.
return self.app.cache.get_user(self.user_id)

async def fetch_channel(self) -> channels.DMChannel:
...
async def fetch_channel(self) -> channels.DMChannel:
"""Perform an API call to fetch an up-to-date image of this channel.
Returns
-------
hikari.channels.DMChannel
The channel.
"""
channel = await self.app.rest.fetch_channel(self.channel_id)
assert isinstance(channel, channels.DMChannel), f"expected DMChannel from API, got {channel}"
return channel

async def fetch_user(self) -> users.User:
"""Perform an API call to fetch an up-to-date image of the user.
Returns
-------
hikari.users.User
The user.
"""
return await self.app.rest.fetch_user(self.user_id)
3 changes: 1 addition & 2 deletions hikari/impl/event_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ def deserialize_typing_start_event(
shard=shard,
channel_id=channel_id,
guild_id=guild_id,
user_id=user_id,
timestamp=timestamp,
member=member,
user=member,
)

return typing_events.PrivateTypingEvent(
Expand Down
109 changes: 76 additions & 33 deletions tests/hikari/events/test_typing_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,49 @@
from hikari import channels
from hikari import users
from hikari.events import typing_events
from tests.hikari import hikari_test_helpers


@pytest.mark.asyncio
class TestTypingEvent:
@pytest.fixture()
def event(self):
class StubEvent(typing_events.TypingEvent):
channel_id = 123
user_id = 456
timestamp = None
shard = None
app = mock.Mock(rest=mock.AsyncMock())
channel = object()
guild = object()

return StubEvent()

async def test_fetch_channel(self, event):
mock_channel = mock.Mock(spec_set=channels.TextChannel)
event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel)
assert await event.fetch_channel() is mock_channel

event.app.rest.fetch_channel.assert_awaited_once_with(123)

async def test_fetch_user(self, event):
mock_user = mock.Mock(spec_set=users.User)
event.app.rest.fetch_user = mock.AsyncMock(return_value=mock_user)
cls = hikari_test_helpers.mock_class_namespace(
typing_events.TypingEvent,
channel_id=123,
user_id=456,
timestamp=object(),
shard=object(),
channel=object(),
)

assert await event.fetch_user() is mock_user
return cls()

event.app.rest.fetch_user.assert_awaited_once_with(456)
async def test_trigger_typing(self, event):
event.app.rest.trigger_typing = mock.Mock()
result = event.trigger_typing()
event.app.rest.trigger_typing.assert_called_once_with(123)
assert result is event.app.rest.trigger_typing.return_value


@pytest.mark.asyncio
class TestGuildTypingEvent:
@pytest.fixture()
def event(self):
return typing_events.GuildTypingEvent(
app=mock.AsyncMock(cache=mock.Mock()),
shard=None,
cls = hikari_test_helpers.mock_class_namespace(typing_events.GuildTypingEvent)

return cls(
channel_id=123,
user_id=456,
timestamp=object(),
shard=object(),
app=mock.Mock(rest=mock.AsyncMock()),
guild_id=789,
timestamp=None,
member=None,
user=mock.Mock(id=456),
)

def test_channel(self, event):
@pytest.mark.parametrize("guild_channel_impl", [channels.GuildNewsChannel, channels.GuildTextChannel])
async def test_channel(self, event, guild_channel_impl):
event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl))
result = event.channel

assert result is event.app.cache.get_guild_channel.return_value
Expand All @@ -93,10 +88,16 @@ def test_guild_when_unavailable(self, event):
event.app.cache.get_unavailable_guild.assert_called_once_with(789)
event.app.cache.get_available_guild.assert_called_once_with(789)

async def test_fetch_channel(self, event):
await event.fetch_member()
def test_user_id(self, event):
assert event.user_id == event.user.id
assert event.user_id == 456

event.app.rest.fetch_member.assert_awaited_once_with(789, 456)
@pytest.mark.parametrize("guild_channel_impl", [channels.GuildNewsChannel, channels.GuildTextChannel])
async def test_fetch_channel(self, event, guild_channel_impl):
event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=guild_channel_impl))
await event.fetch_channel()

event.app.rest.fetch_channel.assert_awaited_once_with(123)

async def test_fetch_guild(self, event):
await event.fetch_guild()
Expand All @@ -107,3 +108,45 @@ async def test_fetch_guild_preview(self, event):
await event.fetch_guild_preview()

event.app.rest.fetch_guild_preview.assert_awaited_once_with(789)

async def test_fetch_user(self, event):
await event.fetch_user()

event.app.rest.fetch_member.assert_awaited_once_with(789, 456)


@pytest.mark.asyncio
class TestPrivateTypingEvent:
@pytest.fixture()
def event(self):
cls = hikari_test_helpers.mock_class_namespace(typing_events.PrivateTypingEvent)

return cls(
channel_id=123,
timestamp=object(),
shard=object(),
app=mock.Mock(rest=mock.AsyncMock()),
user_id=456,
)

async def test_channel(self, event):
event.app.cache.get_dm = mock.Mock(return_value=mock.Mock(spec_set=channels.DMChannel))
result = event.channel
assert result is event.app.cache.get_dm.return_value
event.app.cache.get_dm.assert_called_once_with(456)

def test_user(self, event):
event.app.cache.get_user = mock.Mock(return_value=mock.Mock(spec_set=users.User))

assert event.user is event.app.cache.get_user.return_value

async def test_fetch_channel(self, event):
event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=channels.DMChannel))
await event.fetch_channel()

event.app.rest.fetch_channel.assert_awaited_once_with(123)

async def test_fetch_user(self, event):
await event.fetch_user()

event.app.rest.fetch_user.assert_awaited_once_with(456)

0 comments on commit 5a12196

Please sign in to comment.