Skip to content

Commit a787500

Browse files
authored
refactor: allow stacking extension decorators (#1146)
* refactor: allow stacking for extension decorators * fix: command first then option * revert: remove random change * ref: oop * refactor: it's is not random lol
1 parent b2d3571 commit a787500

File tree

1 file changed

+51
-43
lines changed

1 file changed

+51
-43
lines changed

interactions/client/bot.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,38 +1856,40 @@ def __new__(cls, client: Client, *args, **kwargs) -> "Extension":
18561856
for name, func in getmembers(self, predicate=iscoroutinefunction):
18571857
# TODO we can make these all share the same list, might make it easier to load/unload
18581858
if hasattr(func, "__listener_name__"): # set by extension_listener
1859-
func = client.event(
1860-
func, name=func.__listener_name__
1861-
) # capture the return value for friendlier ext-ing
1859+
all_listener_names: List[str] = func.__listener_name__
1860+
for listener_name in all_listener_names:
1861+
func = client.event(
1862+
func, name=listener_name
1863+
) # capture the return value for friendlier ext-ing
18621864

1863-
listeners = self._listeners.get(func.__listener_name__, [])
1864-
listeners.append(func)
1865-
self._listeners[func.__listener_name__] = listeners
1865+
listeners = self._listeners.get(listener_name, [])
1866+
listeners.append(func)
1867+
self._listeners[listener_name] = listeners
18661868

18671869
if hasattr(func, "__component_data__"):
1868-
args, kwargs = func.__component_data__
1869-
func = client.component(*args, **kwargs)(func)
1870-
1871-
component = kwargs.get("component") or args[0]
1872-
comp_name = (
1873-
_component(component).custom_id
1874-
if isinstance(component, (Button, SelectMenu))
1875-
else component
1876-
)
1877-
comp_name = f"component_{comp_name}"
1870+
all_component_data: List[Tuple[tuple, dict]] = func.__component_data__
1871+
for args, kwargs in all_component_data:
1872+
func = client.component(*args, **kwargs)(func)
1873+
1874+
component = kwargs.get("component") or args[0]
1875+
comp_name = (
1876+
_component(component).custom_id
1877+
if isinstance(component, (Button, SelectMenu))
1878+
else component
1879+
)
1880+
comp_name = f"component_{comp_name}"
18781881

1879-
listeners = self._listeners.get(comp_name, [])
1880-
listeners.append(func)
1881-
self._listeners[comp_name] = listeners
1882+
listeners = self._listeners.get(comp_name, [])
1883+
listeners.append(func)
1884+
self._listeners[comp_name] = listeners
18821885

18831886
if hasattr(func, "__autocomplete_data__"):
18841887
all_args_kwargs = func.__autocomplete_data__
1885-
for _ in all_args_kwargs:
1886-
args, kwargs = _[0], _[1]
1888+
for args, kwargs in all_args_kwargs:
18871889
func = client.autocomplete(*args, **kwargs)(func)
18881890

1889-
name = kwargs.get("name") or args[0]
1890-
_command = kwargs.get("command") or args[1]
1891+
_command = kwargs.get("command") or args[0]
1892+
name = kwargs.get("name") or args[1]
18911893

18921894
_command: Union[Snowflake, int] = (
18931895
_command.id if isinstance(_command, ApplicationCommand) else _command
@@ -1900,16 +1902,17 @@ def __new__(cls, client: Client, *args, **kwargs) -> "Extension":
19001902
self._listeners[auto_name] = listeners
19011903

19021904
if hasattr(func, "__modal_data__"):
1903-
args, kwargs = func.__modal_data__
1904-
func = client.modal(*args, **kwargs)(func)
1905+
all_modal_data: List[Tuple[tuple, dict]] = func.__modal_data__
1906+
for args, kwargs in all_modal_data:
1907+
func = client.modal(*args, **kwargs)(func)
19051908

1906-
modal = kwargs.get("modal") or args[0]
1907-
_modal_id: str = modal.custom_id if isinstance(modal, Modal) else modal
1908-
modal_name = f"modal_{_modal_id}"
1909+
modal = kwargs.get("modal") or args[0]
1910+
_modal_id: str = modal.custom_id if isinstance(modal, Modal) else modal
1911+
modal_name = f"modal_{_modal_id}"
19091912

1910-
listeners = self._listeners.get(modal_name, [])
1911-
listeners.append(func)
1912-
self._listeners[modal_name] = listeners
1913+
listeners = self._listeners.get(modal_name, [])
1914+
listeners.append(func)
1915+
self._listeners[modal_name] = listeners
19131916

19141917
for _, cmd in getmembers(self, predicate=lambda command: isinstance(command, Command)):
19151918
cmd: Command
@@ -1974,22 +1977,26 @@ def decorator(coro) -> Command:
19741977
@wraps(Client.event)
19751978
def extension_listener(func: Optional[Coroutine] = None, name: Optional[str] = None):
19761979
def decorator(func: Coroutine):
1977-
func.__listener_name__ = name or func.__name__
1980+
if not hasattr(func, "__listener_name__"):
1981+
func.__listener_name__ = []
1982+
func.__listener_name__.append(name or func.__name__)
19781983

19791984
return func
19801985

19811986
if func:
19821987
# allows omitting `()` on `@listener`
1983-
func.__listener_name__ = name or func.__name__
1984-
return func
1988+
return decorator(func)
19851989

19861990
return decorator
19871991

19881992

19891993
@wraps(Client.component)
19901994
def extension_component(*args, **kwargs):
19911995
def decorator(func):
1992-
func.__component_data__ = (args, kwargs)
1996+
if not hasattr(func, "__component_data__"):
1997+
func.__component_data__ = []
1998+
func.__component_data__.append((args, kwargs))
1999+
19932000
return func
19942001

19952002
return decorator
@@ -1998,21 +2005,22 @@ def decorator(func):
19982005
@wraps(Client.autocomplete)
19992006
def extension_autocomplete(*args, **kwargs):
20002007
def decorator(func):
2001-
try:
2002-
if getattr(func, "__autocomplete_data__"):
2003-
func.__autocomplete_data__.append((args, kwargs))
2004-
except AttributeError:
2005-
func.__autocomplete_data__ = [(args, kwargs)]
2006-
finally:
2007-
return func
2008+
if not hasattr(func, "__autocomplete_data__"):
2009+
func.__autocomplete_data__ = []
2010+
func.__autocomplete_data__.append((args, kwargs))
2011+
2012+
return func
20082013

20092014
return decorator
20102015

20112016

20122017
@wraps(Client.modal)
20132018
def extension_modal(*args, **kwargs):
20142019
def decorator(func):
2015-
func.__modal_data__ = (args, kwargs)
2020+
if not hasattr(func, "__modal_data__"):
2021+
func.__modal_data__ = []
2022+
func.__modal_data__.append((args, kwargs))
2023+
20162024
return func
20172025

20182026
return decorator

0 commit comments

Comments
 (0)