From 056575d4f7202c61bf8add1fe398c04bfbdf31b8 Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Sat, 10 May 2025 05:27:35 +0800 Subject: [PATCH 1/7] refactor: port over v3 rate limit code --- discord/http.py | 331 +++++++++++++++++++++++++----------------------- 1 file changed, 176 insertions(+), 155 deletions(-) diff --git a/discord/http.py b/discord/http.py index 2db704b268..05e5212f1c 100644 --- a/discord/http.py +++ b/discord/http.py @@ -87,7 +87,6 @@ T = TypeVar("T") BE = TypeVar("BE", bound=BaseException) - MU = TypeVar("MU", bound="MaybeUnlock") Response = Coroutine[Any, Any, T] API_VERSION: int = 10 @@ -106,61 +105,90 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str class Route: - API_BASE_URL: str = "https://discord.com/api/v{API_VERSION}" - - def __init__(self, method: str, path: str, **parameters: Any) -> None: - self.path: str = path - self.method: str = method - url = self.base + self.path - if parameters: - url = url.format_map( - { - k: _uriquote(v) if isinstance(v, str) else v - for k, v in parameters.items() - } - ) - self.url: str = url + def __init__( + self, + path: str, + guild_id: str | None = None, + channel_id: str | None = None, + webhook_id: str | None = None, + webhook_token: str | None = None, + **parameters: str | int, + ): + self.path = path - # major parameters: - self.channel_id: Snowflake | None = parameters.get("channel_id") - self.guild_id: Snowflake | None = parameters.get("guild_id") - self.webhook_id: Snowflake | None = parameters.get("webhook_id") - self.webhook_token: str | None = parameters.get("webhook_token") + # major parameters + self.guild_id = guild_id + self.channel_id = channel_id + self.webhook_id = webhook_id + self.webhook_token = webhook_token - @property - def base(self) -> str: - return self.API_BASE_URL.format(API_VERSION=API_VERSION) + self.parameters = parameters - @property - def bucket(self) -> str: - # the bucket is just method + path w/ major parameters - return f"{self.channel_id}:{self.guild_id}:{self.path}" + def merge(self, url: str): + return url + self.path.format( + guild_id=self.guild_id, + channel_id=self.channel_id, + webhook_id=self.webhook_id, + webhook_token=self.webhook_token, + **self.parameters, + ) + def __eq__(self, route: 'Route') -> bool: + return ( + route.channel_id == self.channel_id + or route.guild_id == self.guild_id + or route.webhook_id == self.webhook_id + or route.webhook_token == self.webhook_token + ) -class MaybeUnlock: - def __init__(self, lock: asyncio.Lock) -> None: - self.lock: asyncio.Lock = lock - self._unlock: bool = True - def __enter__(self: MU) -> MU: - return self - def defer(self) -> None: - self._unlock = False +class Executor: + def __init__(self, route: Route) -> None: + self.route = route + self.is_global: bool | None = None + self._request_queue: asyncio.Queue[asyncio.Event] | None = None + self.rate_limited: bool = False - def __exit__( - self, - exc_type: type[BE] | None, - exc: BE | None, - traceback: TracebackType | None, + async def executed( + self, reset_after: int | float, limit: int, is_global: bool ) -> None: - if self._unlock: - self.lock.release() + self.rate_limited = True + self.is_global = is_global + self._reset_after = reset_after + self._request_queue = asyncio.Queue() + + await asyncio.sleep(reset_after) + + self.is_global = False + + # NOTE: This could break if someone did a second global rate limit somehow + requests_passed: int = 0 + for _ in range(self._request_queue.qsize() - 1): + if requests_passed == limit: + requests_passed = 0 + if not is_global: + await asyncio.sleep(reset_after) + else: + await asyncio.sleep(5) + requests_passed += 1 + e = await self._request_queue.get() + e.set() -# For some reason, the Discord voice websocket expects this header to be -# completely lowercase while aiohttp respects spec and does it as case-insensitive -aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore + async def wait(self) -> None: + if not self.rate_limited: + return + + event = asyncio.Event() + + if self._request_queue: + self._request_queue.put_nowait(event) + else: + raise ValueError( + 'Request queue does not exist, rate limit may have been solved.' + ) + await event.wait() class HTTPClient: @@ -180,14 +208,12 @@ def __init__( ) self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login - self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() - self._global_over: asyncio.Event = asyncio.Event() - self._global_over.set() self.token: str | None = None self.bot_token: bool = False self.proxy: str | None = proxy self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth self.use_clock: bool = not unsync_clock + self._executors = [] user_agent = ( "DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" @@ -230,12 +256,6 @@ async def request( method = route.method url = route.url - lock = self._locks.get(bucket) - if lock is None: - lock = asyncio.Lock() - if bucket is not None: - self._locks[bucket] = lock - # header creation headers: dict[str, str] = { "User-Agent": self.user_agent, @@ -272,117 +292,118 @@ async def request( response: aiohttp.ClientResponse | None = None data: dict[str, Any] | str | None = None - await lock.acquire() - with MaybeUnlock(lock) as maybe_lock: - for tries in range(5): - if files: - for f in files: - f.reset(seek=tries) - - if form: - form_data = aiohttp.FormData(quote_fields=False) - for params in form: - form_data.add_field(**params) - kwargs["data"] = form_data - - try: - async with self.__session.request( - method, url, **kwargs - ) as response: - _log.debug( - "%s %s with %s has returned %s", - method, - url, - kwargs.get("data"), - response.status, + + for executor in self._executors: + if executor.is_global or executor.route == route: + _log.debug(f'Pausing request to {route}: Found rate limit executor') + await executor.wait() + + for tries in range(5): + if files: + for f in files: + f.reset(seek=tries) + + if form: + form_data = aiohttp.FormData(quote_fields=False) + for params in form: + form_data.add_field(**params) + kwargs["data"] = form_data + + try: + async with self.__session.request( + method, url, **kwargs + ) as response: + _log.debug( + "%s %s with %s has returned %s", + method, + url, + kwargs.get("data"), + response.status, + ) + + # even errors have text involved in them so this is safe to call + data = await json_or_text(response) + + # check if we have rate limit header information + remaining = response.headers.get("X-Ratelimit-Remaining") + if remaining == "0" and response.status != 429: + _log.debug(f'Request to {route} failed: Request returned rate limit') + executor = Executor(route=route) + + self._executors.append(executor) + await executor.executed( + reset_after=data['retry_after'], + is_global=response.headers.get('X-RateLimit-Scope') == 'global', + limit=int(response.headers.get('X-RateLimit-Limit', 10)), ) + self._executors.remove(executor) + continue - # even errors have text involved in them so this is safe to call - data = await json_or_text(response) + # the request was successful so just return the text/json + if 300 > response.status >= 200: + _log.debug("%s %s has received %s", method, url, data) + return data - # check if we have rate limit header information - remaining = response.headers.get("X-Ratelimit-Remaining") - if remaining == "0" and response.status != 429: - # we've depleted our current bucket - delta = utils._parse_ratelimit_header( - response, use_clock=self.use_clock - ) - _log.debug( + # we are being rate limited + if response.status == 429: + if not response.headers.get("Via") or isinstance(data, str): + # Banned by Cloudflare more than likely. + raise HTTPException(response, data) + + fmt = ( + "We are being rate limited. Retrying in %.2f seconds." + ' Handled under the bucket "%s"' + ) + + # sleep a bit + retry_after: float = data["retry_after"] + _log.warning(fmt, retry_after, bucket) + + # check if it's a global rate limit + is_global = data.get("global", False) + if is_global: + _log.warning( ( - "A rate limit bucket has been exhausted (bucket:" - " %s, retry: %s)." + "Global rate limit has been hit. Retrying in" + " %.2f seconds." ), - bucket, - delta, - ) - maybe_lock.defer() - self.loop.call_later(delta, lock.release) - - # the request was successful so just return the text/json - if 300 > response.status >= 200: - _log.debug("%s %s has received %s", method, url, data) - return data - - # we are being rate limited - if response.status == 429: - if not response.headers.get("Via") or isinstance(data, str): - # Banned by Cloudflare more than likely. - raise HTTPException(response, data) - - fmt = ( - "We are being rate limited. Retrying in %.2f seconds." - ' Handled under the bucket "%s"' + retry_after, ) + self._global_over.clear() - # sleep a bit - retry_after: float = data["retry_after"] - _log.warning(fmt, retry_after, bucket) - - # check if it's a global rate limit - is_global = data.get("global", False) - if is_global: - _log.warning( - ( - "Global rate limit has been hit. Retrying in" - " %.2f seconds." - ), - retry_after, - ) - self._global_over.clear() - - await asyncio.sleep(retry_after) - _log.debug("Done sleeping for the rate limit. Retrying...") - - # release the global lock now that the - # global rate limit has passed - if is_global: - self._global_over.set() - _log.debug("Global rate limit is now over.") - - continue - - # we've received a 500, 502, 503, or 504, unconditional retry - if response.status in {500, 502, 503, 504}: - await asyncio.sleep(1 + tries * 2) - continue - - # the usual error cases - if response.status == 403: - raise Forbidden(response, data) - elif response.status == 404: - raise NotFound(response, data) - elif response.status >= 500: - raise DiscordServerError(response, data) - else: - raise HTTPException(response, data) + await asyncio.sleep(retry_after) + _log.debug("Done sleeping for the rate limit. Retrying...") + + # release the global lock now that the + # global rate limit has passed + if is_global: + self._global_over.set() + _log.debug("Global rate limit is now over.") - # This is handling exceptions from the request - except OSError as e: - # Connection reset by peer - if tries < 4 and e.errno in (54, 10054): + continue + + # we've received a 500, 502, 503, or 504, unconditional retry + if response.status in {500, 502, 503, 504}: await asyncio.sleep(1 + tries * 2) continue - raise + + # the usual error cases + if response.status == 403: + raise Forbidden(response, data) + elif response.status == 404: + raise NotFound(response, data) + elif response.status >= 500: + raise DiscordServerError(response, data) + else: + raise HTTPException(response, data) + + # This is handling exceptions from the request + except OSError as e: + # Connection reset by peer + if tries < 4 and e.errno in (54, 10054): + await asyncio.sleep(1 + tries * 2) + continue + raise if response is not None: # We've run out of retries, raise. From 5b32363153638c9675fc80841eb4ac998fa425ef Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Sat, 10 May 2025 05:34:33 +0800 Subject: [PATCH 2/7] fix: route stuff --- discord/http.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/discord/http.py b/discord/http.py index 05e5212f1c..0389d88a90 100644 --- a/discord/http.py +++ b/discord/http.py @@ -107,6 +107,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str class Route: def __init__( self, + method: str, path: str, guild_id: str | None = None, channel_id: str | None = None, @@ -114,6 +115,7 @@ def __init__( webhook_token: str | None = None, **parameters: str | int, ): + self.method = method self.path = path # major parameters @@ -202,7 +204,9 @@ def __init__( proxy_auth: aiohttp.BasicAuth | None = None, loop: asyncio.AbstractEventLoop | None = None, unsync_clock: bool = True, + discord_api_url: str = "https://discord.com/api/v10" ) -> None: + self.api_url = discord_api_url self.loop: asyncio.AbstractEventLoop = ( asyncio.get_event_loop() if loop is None else loop ) @@ -252,9 +256,9 @@ async def request( form: Iterable[dict[str, Any]] | None = None, **kwargs: Any, ) -> Any: - bucket = route.bucket + bucket = route.merge(self.api_url) method = route.method - url = route.url + url = route.path # header creation headers: dict[str, str] = { From 63720c59799c3d0620fcae341101d7e6b122f28f Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Sat, 10 May 2025 05:37:24 +0800 Subject: [PATCH 3/7] fix: AHHH --- discord/http.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/discord/http.py b/discord/http.py index 0389d88a90..53ce516670 100644 --- a/discord/http.py +++ b/discord/http.py @@ -290,10 +290,6 @@ async def request( if self.proxy_auth is not None: kwargs["proxy_auth"] = self.proxy_auth - if not self._global_over.is_set(): - # wait until the global lock is complete - await self._global_over.wait() - response: aiohttp.ClientResponse | None = None data: dict[str, Any] | str | None = None From c60ae2e45c06242e3cfd569eb2415684897677c7 Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Sat, 10 May 2025 12:52:37 +0800 Subject: [PATCH 4/7] fix: append api url to path --- discord/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/http.py b/discord/http.py index 53ce516670..39751e067d 100644 --- a/discord/http.py +++ b/discord/http.py @@ -258,7 +258,7 @@ async def request( ) -> Any: bucket = route.merge(self.api_url) method = route.method - url = route.path + url = self.api_url + route.path # header creation headers: dict[str, str] = { From 7edd751139e9eceeb533b60c6d76857ffc1b41e3 Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Sat, 10 May 2025 13:44:16 +0800 Subject: [PATCH 5/7] fix: final batch --- discord/http.py | 2 +- discord/webhook/async_.py | 5 +++-- discord/webhook/sync.py | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/discord/http.py b/discord/http.py index 39751e067d..e0a424415f 100644 --- a/discord/http.py +++ b/discord/http.py @@ -258,7 +258,7 @@ async def request( ) -> Any: bucket = route.merge(self.api_url) method = route.method - url = self.api_url + route.path + url = bucket # header creation headers: dict[str, str] = { diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 1661b1bb67..f1097deee3 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -104,8 +104,9 @@ async def __aexit__(self, type, value, traceback): class AsyncWebhookAdapter: - def __init__(self): + def __init__(self, *, discord_api_url: str = "https://discord.com/api/v10"): self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.api_url = discord_api_url async def request( self, @@ -144,7 +145,7 @@ async def request( response: aiohttp.ClientResponse | None = None data: dict[str, Any] | str | None = None method = route.method - url = route.url + url = route.merge(self.api_url) webhook_id = route.webhook_id async with AsyncDeferredLock(lock) as lock: diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index d2d3213d71..ea63b99554 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -96,8 +96,9 @@ def __exit__(self, type, value, traceback): class WebhookAdapter: - def __init__(self): + def __init__(self, *, discord_api_url: str = "https://discord.com/api/v10"): self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.api_url = discord_api_url def request( self, @@ -135,7 +136,7 @@ def request( data: dict[str, Any] | str | None = None file_data: dict[str, Any] | None = None method = route.method - url = route.url + url = route.merge(self.api_url) webhook_id = route.webhook_id with DeferredLock(lock) as lock: From 323978fc4f97271076aaeb18346d97dc73e285ee Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Wed, 14 May 2025 00:37:12 +0800 Subject: [PATCH 6/7] fix: oops put that in the wrong place --- discord/http.py | 45 +++++++++++---------------------------------- 1 file changed, 11 insertions(+), 34 deletions(-) diff --git a/discord/http.py b/discord/http.py index a222f0cdf3..f6794ffcc9 100644 --- a/discord/http.py +++ b/discord/http.py @@ -217,7 +217,7 @@ def __init__( self.proxy: str | None = proxy self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth self.use_clock: bool = not unsync_clock - self._executors = [] + self._executors: list[Executor] = [] user_agent = ( "DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" @@ -332,7 +332,8 @@ async def request( self._executors.append(executor) await executor.executed( - reset_after=data['retry_after'], + # NOTE: 5 is just a placeholder since this should always be present + reset_after=float(response.headers.get('X-RateLimit-Reset-After', "5")), is_global=response.headers.get('X-RateLimit-Scope') == 'global', limit=int(response.headers.get('X-RateLimit-Limit', 10)), ) @@ -346,40 +347,16 @@ async def request( # we are being rate limited if response.status == 429: - if not response.headers.get("Via") or isinstance(data, str): - # Banned by Cloudflare more than likely. - raise HTTPException(response, data) + _log.debug(f'Request to {route} failed: Request returned rate limit') + executor = Executor(route=route) - fmt = ( - "We are being rate limited. Retrying in %.2f seconds." - ' Handled under the bucket "%s"' + self._executors.append(executor) + await executor.executed( + reset_after=data['retry_after'], + is_global=response.headers.get('X-RateLimit-Scope') == 'global', + limit=int(response.headers.get('X-RateLimit-Limit', 10)), ) - - # sleep a bit - retry_after: float = data["retry_after"] - _log.warning(fmt, retry_after, bucket) - - # check if it's a global rate limit - is_global = data.get("global", False) - if is_global: - _log.warning( - ( - "Global rate limit has been hit. Retrying in" - " %.2f seconds." - ), - retry_after, - ) - self._global_over.clear() - - await asyncio.sleep(retry_after) - _log.debug("Done sleeping for the rate limit. Retrying...") - - # release the global lock now that the - # global rate limit has passed - if is_global: - self._global_over.set() - _log.debug("Global rate limit is now over.") - + self._executors.remove(executor) continue # we've received a 500, 502, 503, or 504, unconditional retry From 5527c5510539373e45fe0d44fe8c2701140756ca Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Wed, 14 May 2025 05:06:23 +0800 Subject: [PATCH 7/7] chore: take method into equality check --- discord/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/http.py b/discord/http.py index f6794ffcc9..e8af71fe5b 100644 --- a/discord/http.py +++ b/discord/http.py @@ -141,7 +141,7 @@ def __eq__(self, route: 'Route') -> bool: or route.guild_id == self.guild_id or route.webhook_id == self.webhook_id or route.webhook_token == self.webhook_token - ) + ) and route.method == self.method