Skip to content

Commit fae977c

Browse files
committed
refactor: Add ComponentGearMixin and ModalGearMixin for interaction management
1 parent 013d9f6 commit fae977c

File tree

8 files changed

+353
-63
lines changed

8 files changed

+353
-63
lines changed

discord/bot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
820820
await self.sync_commands()
821821
else:
822822
await self.sync_commands(check_guilds=[guild_id])
823-
return self._bot.dispatch("unknown_application_command", interaction)
823+
# return self._bot.dispatch("unknown_application_command", interaction)
824824

825825
if interaction.type is InteractionType.auto_complete:
826826
return self._bot.dispatch("application_command_auto_complete", interaction, command)
@@ -1160,7 +1160,7 @@ def __init__(self, description=None, *args, **options):
11601160
self._before_invoke = None
11611161
self._after_invoke = None
11621162

1163-
self._bot.add_listener(self.on_interaction, event=InteractionCreate)
1163+
# self._bot.add_listener(self.on_interaction, event=InteractionCreate)
11641164

11651165
async def on_connect(self):
11661166
if self.auto_sync_commands:

discord/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ async def connect(self, *, reconnect: bool = True) -> None:
648648
aiohttp.ClientError,
649649
asyncio.TimeoutError,
650650
) as exc:
651-
self.dispatch("disconnect")
651+
# self.dispatch("disconnect") # TODO: dispatch event
652652
if not reconnect:
653653
await self.close()
654654
if isinstance(exc, ConnectionClosed) and exc.code == 1000:

discord/events/gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class Resumed(Event):
5454
__event_name__: str = "RESUMED"
5555

5656
@classmethod
57-
async def __load__(cls, _data: Any, _state: ConnectionState) -> Self | None:
57+
async def __load__(cls, data: Any, state: ConnectionState) -> Self | None:
5858
return cls()
5959

6060

discord/events/interaction.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,20 @@
4343

4444
def _interaction_factory(payload: InteractionPayload) -> type[Interaction]:
4545
type: int = payload["type"]
46-
if type == InteractionType.application_command:
46+
if type == InteractionType.application_command.value:
4747
return ApplicationCommandInteraction
48-
if type == InteractionType.auto_complete:
48+
if type == InteractionType.auto_complete.value:
4949
return AutocompleteInteraction
50-
if type == InteractionType.component:
50+
if type == InteractionType.component.value:
5151
return ComponentInteraction
52-
if type == InteractionType.modal_submit:
52+
if type == InteractionType.modal_submit.value:
5353
return ModalInteraction
5454
return Interaction
5555

5656

5757
@lru_cache(maxsize=128)
58-
def _create_event_interaction_class(event_cls: type[Event], interaction_cls: type[Interaction]) -> type[Interaction]:
59-
class EventInteraction(interaction_cls, event_cls): # type: ignore
58+
def _create_event_interaction_class(interaction_cls: type[Interaction]) -> type[Interaction]:
59+
class EventInteraction(interaction_cls, Event): # type: ignore
6060
__slots__ = ()
6161

6262
@override
@@ -66,7 +66,12 @@ def __init__(self) -> None:
6666
@override
6767
@classmethod
6868
def event_type(self) -> type[Event]:
69-
return event_cls
69+
return InteractionCreate
70+
71+
@classmethod
72+
@override
73+
async def __load__(cls, data: InteractionPayload, state: ConnectionState) -> None:
74+
return None
7075

7176
return EventInteraction # type: ignore
7277

@@ -94,7 +99,7 @@ def __init__(self) -> None:
9499
async def __load__(cls, data: Any, state: ConnectionState) -> Self | None:
95100
factory = _interaction_factory(data)
96101
interaction = await factory._from_data(payload=data, state=state)
97-
interaction_event_cls = _create_event_interaction_class(Event, factory)
102+
interaction_event_cls = _create_event_interaction_class(factory)
98103
self = interaction_event_cls()
99104
self._populate_from_slots(interaction)
100105
return self

discord/gears/components.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
"""
2+
The MIT License (MIT)
3+
4+
Copyright (c) 2021-present Pycord Development
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a
7+
copy of this software and associated documentation files (the "Software"),
8+
to deal in the Software without restriction, including without limitation
9+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
10+
and/or sell copies of the Software, and to permit persons to whom the
11+
Software is furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in
14+
all copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17+
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22+
DEALINGS IN THE SOFTWARE.
23+
"""
24+
25+
from abc import ABC
26+
from functools import wraps
27+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, ParamSpec, Protocol, TypeAlias, TypeVar, Unpack
28+
29+
from ..events import InteractionCreate
30+
from ..interactions import ComponentInteraction, ModalInteraction
31+
from ..utils import MISSING, Undefined
32+
from ..utils.private import hybridmethod, maybe_awaitable
33+
from .base import GearBase
34+
35+
ComponentPredicate: TypeAlias = Callable[[str], bool | Awaitable[bool]]
36+
37+
38+
class ComponentListener(Protocol):
39+
async def __call__(self, interaction: ComponentInteraction[Any]) -> Any: ...
40+
41+
42+
CL_t = TypeVar("CL_t", bound=ComponentListener)
43+
44+
45+
class ModalListener(Protocol):
46+
async def __call__(self, interaction: ModalInteraction[Unpack[tuple[Any, ...]]]) -> Any: ...
47+
48+
49+
ML_t = TypeVar("ML_t", bound=ModalListener)
50+
51+
T = TypeVar("T", bound="ComponentListener | ModalListener")
52+
MG_t = TypeVar("MG_t", bound="ModalGearMixin")
53+
54+
55+
def _unwrap_predicate(
56+
maybe_predicate: Callable[[str], bool | Awaitable[bool]] | str,
57+
) -> Callable[[str], bool | Awaitable[bool]]:
58+
return lambda x: x == maybe_predicate if isinstance(maybe_predicate, str) else maybe_predicate
59+
60+
61+
P = ParamSpec("P")
62+
R = TypeVar("R")
63+
64+
65+
def _listener_factory(
66+
listener: Callable[P, Awaitable[R]],
67+
interaction_type: type[ModalInteraction | ComponentInteraction],
68+
predicate: ComponentPredicate,
69+
) -> Callable[P, Coroutine[Any, Any, R | None]]:
70+
@wraps(listener)
71+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None:
72+
# Assume last positional arg is the interaction
73+
if args:
74+
interaction: Any = args[-1]
75+
if isinstance(interaction, interaction_type) and await maybe_awaitable(predicate, interaction.custom_id):
76+
return await listener(*args, **kwargs)
77+
return None
78+
79+
return wrapper
80+
81+
82+
CG_t = TypeVar("CG_t", bound="ComponentGearMixin")
83+
84+
85+
class ComponentGearMixin(GearBase, ABC):
86+
"""A mixin that provides component handling for a :class:`discord.Gear`.
87+
88+
This mixin is used to handle components such as buttons, select menus, and other interactive elements.
89+
"""
90+
91+
def add_component_listener(
92+
self, predicate: Callable[[str], bool | Awaitable[bool]] | str, listener: ComponentListener
93+
) -> Callable[[InteractionCreate], Awaitable[None]]:
94+
"""Registers a component interaction listener.
95+
96+
This method can be used to register a function that will be called
97+
when a component interaction occurs that matches the provided predicate.
98+
99+
.. versionadded:: 3.0
100+
101+
Parameters
102+
----------
103+
predicate:
104+
A (potentially async) function that takes a string (the component's custom ID) and returns a boolean indicating whether the
105+
function should be called for that component. Alternatively, a string can be provided, which will match
106+
the component's custom ID exactly.
107+
108+
listener:
109+
The interaction callback to call when a component interaction occurs that matches the predicate.
110+
111+
Returns
112+
-------
113+
Callable[[InteractionCreate], Awaitable[None]]
114+
The registered listener. Use this to unregister the listener.
115+
"""
116+
actual_predicate: Callable[[str], bool | Awaitable[bool]] = _unwrap_predicate(predicate)
117+
actual_listener = _listener_factory(listener, ComponentInteraction, actual_predicate)
118+
self.add_listener(actual_listener, event=InteractionCreate)
119+
return actual_listener
120+
121+
if TYPE_CHECKING:
122+
123+
@classmethod
124+
def listen_component(
125+
cls: type[CG_t],
126+
predicate: Callable[[str], bool | Awaitable[bool]] | str,
127+
) -> Callable[
128+
[Callable[[ComponentListener], Awaitable[None]] | Callable[[Any, ComponentListener], Awaitable[None]]],
129+
Callable[[InteractionCreate], Awaitable[None]],
130+
]:
131+
"""A shortcut decorator that registers a component interaction listener.
132+
133+
This decorator can be used to register a function that will be called
134+
when a component interaction occurs that matches the provided predicate.
135+
136+
.. versionadded:: 3.0
137+
138+
Parameters
139+
----------
140+
predicate:
141+
A (potentially async) function that takes a string (the component's custom ID) and returns a boolean indicating whether the
142+
function should be called for that component. Alternatively, a string can be provided, which will match
143+
the component's custom ID exactly.
144+
"""
145+
...
146+
else:
147+
# Instance function listeners (but not bound to an instance)
148+
@hybridmethod
149+
def listen_component(
150+
cls: type[CG_t], # noqa: N805
151+
predicate: Callable[[str], bool | Awaitable[bool]] | str,
152+
) -> Callable[
153+
[Callable[[Any, ComponentInteraction[Any]], Awaitable[None]]],
154+
Callable[[Any, ComponentInteraction[Any]], Awaitable[None]],
155+
]:
156+
def decorator(
157+
func: Callable[[Any, ComponentInteraction[Any]], Awaitable[None]],
158+
) -> Callable[[Any, ComponentInteraction[Any]], Awaitable[None]]:
159+
actual_predicate: Callable[[str], bool | Awaitable[bool]] = _unwrap_predicate(predicate)
160+
161+
actual_listener = _listener_factory(func, ComponentInteraction, actual_predicate)
162+
163+
# Use parent's listen to register for InteractionCreate
164+
return cls.listen(InteractionCreate)(actual_listener)
165+
166+
return decorator
167+
168+
# Bare listeners (everything else)
169+
@listen_component.instancemethod
170+
def listen_component(
171+
self,
172+
predicate: Callable[[str], bool | Awaitable[bool]] | str,
173+
) -> Callable[[ComponentListener], Callable[[InteractionCreate], Awaitable[None]]]:
174+
def decorator(
175+
func: ComponentListener,
176+
) -> Callable[[InteractionCreate], Awaitable[None]]:
177+
return self.add_component_listener(predicate, func)
178+
179+
return decorator
180+
181+
182+
MG_t = TypeVar("MG_t", bound="ModalGearMixin")
183+
184+
185+
class ModalGearMixin(GearBase, ABC):
186+
"""A mixin that provides modal handling for a :class:`discord.Gear`.
187+
188+
This mixin is used to handle modals interactions.
189+
"""
190+
191+
def add_modal_listener(
192+
self, predicate: Callable[[str], bool | Awaitable[bool]] | str, listener: ModalListener
193+
) -> Callable[[InteractionCreate], Awaitable[None]]:
194+
"""Registers a modal interaction listener.
195+
196+
This method can be used to register a function that will be called
197+
when a modal interaction occurs that matches the provided predicate.
198+
199+
.. versionadded:: 3.0
200+
201+
Parameters
202+
----------
203+
predicate:
204+
A (potentially async) function that takes a string (the modal's custom ID) and returns a boolean indicating whether the
205+
function should be called for that modal. Alternatively, a string can be provided, which will match
206+
the modal's custom ID exactly.
207+
208+
listener:
209+
The interaction callback to call when a modal interaction occurs that matches the predicate.
210+
211+
Returns
212+
-------
213+
Callable[[InteractionCreate], Awaitable[None]]
214+
The registered listener. Use this to unregister the listener.
215+
"""
216+
actual_predicate: Callable[[str], bool | Awaitable[bool]] = _unwrap_predicate(predicate)
217+
actual_listener = _listener_factory(listener, ModalInteraction, actual_predicate)
218+
self.add_listener(actual_listener, event=InteractionCreate)
219+
return actual_listener
220+
221+
if TYPE_CHECKING:
222+
223+
@classmethod
224+
def listen_modal(
225+
cls: type[MG_t],
226+
predicate: Callable[[str], bool | Awaitable[bool]] | str,
227+
) -> Callable[
228+
[Callable[[ModalListener], Awaitable[None]] | Callable[[Any, ModalListener], Awaitable[None]]],
229+
Callable[[InteractionCreate], Awaitable[None]],
230+
]:
231+
"""A shortcut decorator that registers a modal interaction listener.
232+
233+
This decorator can be used to register a function that will be called
234+
when a modal interaction occurs that matches the provided predicate.
235+
236+
.. versionadded:: 3.0
237+
238+
Parameters
239+
----------
240+
predicate:
241+
A (potentially async) function that takes a string (the modal's custom ID) and returns a boolean indicating whether the
242+
function should be called for that modal. Alternatively, a string can be provided, which will match
243+
the modal's custom ID exactly.
244+
"""
245+
...
246+
else:
247+
# Instance function listeners (but not bound to an instance)
248+
@hybridmethod
249+
def listen_modal(
250+
cls: type[MG_t], # noqa: N805
251+
predicate: Callable[[str], bool | Awaitable[bool]] | str,
252+
) -> Callable[
253+
[Callable[[Any, ModalInteraction[Unpack[tuple[Any, ...]]]], Awaitable[None]]],
254+
Callable[[Any, ModalInteraction[Unpack[tuple[Any, ...]]]], Awaitable[None]],
255+
]:
256+
def decorator(
257+
func: Callable[[Any, ModalInteraction[Unpack[tuple[Any, ...]]]], Awaitable[None]],
258+
) -> Callable[[Any, ModalInteraction[Unpack[tuple[Any, ...]]]], Awaitable[None]]:
259+
actual_predicate: Callable[[str], bool | Awaitable[bool]] = _unwrap_predicate(predicate)
260+
261+
actual_listener = _listener_factory(func, ModalInteraction, actual_predicate)
262+
263+
# Use parent's listen to register for InteractionCreate
264+
return cls.listen(InteractionCreate)(actual_listener)
265+
266+
return decorator
267+
268+
# Bare listeners (everything else)
269+
@listen_modal.instancemethod
270+
def listen_modal(
271+
self,
272+
predicate: Callable[[str], bool | Awaitable[bool]] | str,
273+
) -> Callable[[ModalListener], Callable[[InteractionCreate], Awaitable[None]]]:
274+
def decorator(
275+
func: ModalListener,
276+
) -> Callable[[InteractionCreate], Awaitable[None]]:
277+
return self.add_modal_listener(predicate, func)
278+
279+
return decorator

discord/gears/gear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ..utils.annotations import get_annotations
4141
from ..utils.private import hybridmethod
4242
from .base import GearBase
43+
from .components import ComponentGearMixin, ModalGearMixin
4344

4445
_T = TypeVar("_T", bound="Gear")
4546
E = TypeVar("E", bound="Event", covariant=True)
@@ -60,7 +61,7 @@ class StaticAttributedEventCallback(AttributedEventCallback, Protocol):
6061
EventCallback: TypeAlias = Callable[[E], Awaitable[None]]
6162

6263

63-
class Gear(GearBase):
64+
class Gear(ModalGearMixin, GearBase):
6465
"""A gear is a modular component that can listen to and handle events.
6566
6667
You can subclass this class to create your own gears and attach them to your bot or other gears.
@@ -109,6 +110,8 @@ def __init__(self) -> None:
109110
self.add_listener(cast("EventCallback[Event]", callback), event=event, once=once)
110111
setattr(self, name, callback)
111112

113+
super().__init__()
114+
112115
def _handle_event(self, event: Event) -> Collection[Awaitable[Any]]:
113116
tasks: list[Awaitable[None]] = []
114117

0 commit comments

Comments
 (0)