diff --git a/discord/http.py b/discord/http.py index 45528ffe9b..e8af71fe5b 100644 --- a/discord/http.py +++ b/discord/http.py @@ -87,7 +87,6 @@ T = TypeVar("T") BE = TypeVar("BE", bound=BaseException) - MU = TypeVar("MU", bound="MaybeUnlock") Response = Coroutine[Any, Any, T] API_VERSION: int = 10 @@ -106,61 +105,92 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str class Route: - API_BASE_URL: str = "https://discord.com/api/v{API_VERSION}" - - def __init__(self, method: str, path: str, **parameters: Any) -> None: - self.path: str = path - self.method: str = method - url = self.base + self.path - if parameters: - url = url.format_map( - { - k: _uriquote(v) if isinstance(v, str) else v - for k, v in parameters.items() - } - ) - self.url: str = url + def __init__( + self, + method: str, + path: str, + guild_id: str | None = None, + channel_id: str | None = None, + webhook_id: str | None = None, + webhook_token: str | None = None, + **parameters: str | int, + ): + self.method = method + self.path = path - # major parameters: - self.channel_id: Snowflake | None = parameters.get("channel_id") - self.guild_id: Snowflake | None = parameters.get("guild_id") - self.webhook_id: Snowflake | None = parameters.get("webhook_id") - self.webhook_token: str | None = parameters.get("webhook_token") + # major parameters + self.guild_id = guild_id + self.channel_id = channel_id + self.webhook_id = webhook_id + self.webhook_token = webhook_token - @property - def base(self) -> str: - return self.API_BASE_URL.format(API_VERSION=API_VERSION) + self.parameters = parameters - @property - def bucket(self) -> str: - # the bucket is just method + path w/ major parameters - return f"{self.channel_id}:{self.guild_id}:{self.path}" + def merge(self, url: str): + return url + self.path.format( + guild_id=self.guild_id, + channel_id=self.channel_id, + webhook_id=self.webhook_id, + webhook_token=self.webhook_token, + **self.parameters, + ) + def __eq__(self, route: 'Route') -> bool: + return ( + route.channel_id == self.channel_id + or route.guild_id == self.guild_id + or route.webhook_id == self.webhook_id + or route.webhook_token == self.webhook_token + ) and route.method == self.method -class MaybeUnlock: - def __init__(self, lock: asyncio.Lock) -> None: - self.lock: asyncio.Lock = lock - self._unlock: bool = True - def __enter__(self: MU) -> MU: - return self - def defer(self) -> None: - self._unlock = False +class Executor: + def __init__(self, route: Route) -> None: + self.route = route + self.is_global: bool | None = None + self._request_queue: asyncio.Queue[asyncio.Event] | None = None + self.rate_limited: bool = False - def __exit__( - self, - exc_type: type[BE] | None, - exc: BE | None, - traceback: TracebackType | None, + async def executed( + self, reset_after: int | float, limit: int, is_global: bool ) -> None: - if self._unlock: - self.lock.release() + self.rate_limited = True + self.is_global = is_global + self._reset_after = reset_after + self._request_queue = asyncio.Queue() + + await asyncio.sleep(reset_after) + self.is_global = False -# For some reason, the Discord voice websocket expects this header to be -# completely lowercase while aiohttp respects spec and does it as case-insensitive -aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore + # NOTE: This could break if someone did a second global rate limit somehow + requests_passed: int = 0 + for _ in range(self._request_queue.qsize() - 1): + if requests_passed == limit: + requests_passed = 0 + if not is_global: + await asyncio.sleep(reset_after) + else: + await asyncio.sleep(5) + + requests_passed += 1 + e = await self._request_queue.get() + e.set() + + async def wait(self) -> None: + if not self.rate_limited: + return + + event = asyncio.Event() + + if self._request_queue: + self._request_queue.put_nowait(event) + else: + raise ValueError( + 'Request queue does not exist, rate limit may have been solved.' + ) + await event.wait() class HTTPClient: @@ -174,20 +204,20 @@ def __init__( proxy_auth: aiohttp.BasicAuth | None = None, loop: asyncio.AbstractEventLoop | None = None, unsync_clock: bool = True, + discord_api_url: str = "https://discord.com/api/v10" ) -> None: + self.api_url = discord_api_url self.loop: asyncio.AbstractEventLoop = ( asyncio.get_event_loop() if loop is None else loop ) self.connector = connector self.__session: aiohttp.ClientSession | utils.Undefined = MISSING # filled in static_login - self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() - self._global_over: asyncio.Event = asyncio.Event() - self._global_over.set() self.token: str | None = None self.bot_token: bool = False self.proxy: str | None = proxy self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth self.use_clock: bool = not unsync_clock + self._executors: list[Executor] = [] user_agent = ( "DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" @@ -226,15 +256,9 @@ async def request( form: Iterable[dict[str, Any]] | None = None, **kwargs: Any, ) -> Any: - bucket = route.bucket + bucket = route.merge(self.api_url) method = route.method - url = route.url - - lock = self._locks.get(bucket) - if lock is None: - lock = asyncio.Lock() - if bucket is not None: - self._locks[bucket] = lock + url = bucket # header creation headers: dict[str, str] = { @@ -266,123 +290,97 @@ async def request( if self.proxy_auth is not None: kwargs["proxy_auth"] = self.proxy_auth - if not self._global_over.is_set(): - # wait until the global lock is complete - await self._global_over.wait() - response: aiohttp.ClientResponse | None = None data: dict[str, Any] | str | None = None - await lock.acquire() - with MaybeUnlock(lock) as maybe_lock: - for tries in range(5): - if files: - for f in files: - f.reset(seek=tries) - - if form: - form_data = aiohttp.FormData(quote_fields=False) - for params in form: - form_data.add_field(**params) - kwargs["data"] = form_data - - try: - async with self.__session.request( - method, url, **kwargs - ) as response: - _log.debug( - "%s %s with %s has returned %s", - method, - url, - kwargs.get("data"), - response.status, + + for executor in self._executors: + if executor.is_global or executor.route == route: + _log.debug(f'Pausing request to {route}: Found rate limit executor') + await executor.wait() + + for tries in range(5): + if files: + for f in files: + f.reset(seek=tries) + + if form: + form_data = aiohttp.FormData(quote_fields=False) + for params in form: + form_data.add_field(**params) + kwargs["data"] = form_data + + try: + async with self.__session.request( + method, url, **kwargs + ) as response: + _log.debug( + "%s %s with %s has returned %s", + method, + url, + kwargs.get("data"), + response.status, + ) + + # even errors have text involved in them so this is safe to call + data = await json_or_text(response) + + # check if we have rate limit header information + remaining = response.headers.get("X-Ratelimit-Remaining") + if remaining == "0" and response.status != 429: + _log.debug(f'Request to {route} failed: Request returned rate limit') + executor = Executor(route=route) + + self._executors.append(executor) + await executor.executed( + # NOTE: 5 is just a placeholder since this should always be present + reset_after=float(response.headers.get('X-RateLimit-Reset-After', "5")), + is_global=response.headers.get('X-RateLimit-Scope') == 'global', + limit=int(response.headers.get('X-RateLimit-Limit', 10)), ) + self._executors.remove(executor) + continue - # even errors have text involved in them so this is safe to call - data = await json_or_text(response) - - # check if we have rate limit header information - remaining = response.headers.get("X-Ratelimit-Remaining") - if remaining == "0" and response.status != 429: - # we've depleted our current bucket - delta = utils._parse_ratelimit_header( - response, use_clock=self.use_clock - ) - _log.debug( - ( - "A rate limit bucket has been exhausted (bucket:" - " %s, retry: %s)." - ), - bucket, - delta, - ) - maybe_lock.defer() - self.loop.call_later(delta, lock.release) - - # the request was successful so just return the text/json - if 300 > response.status >= 200: - _log.debug("%s %s has received %s", method, url, data) - return data - - # we are being rate limited - if response.status == 429: - if not response.headers.get("Via") or isinstance(data, str): - # Banned by Cloudflare more than likely. - raise HTTPException(response, data) - - fmt = ( - "We are being rate limited. Retrying in %.2f seconds." - ' Handled under the bucket "%s"' - ) - - # sleep a bit - retry_after: float = data["retry_after"] - _log.warning(fmt, retry_after, bucket) - - # check if it's a global rate limit - is_global = data.get("global", False) - if is_global: - _log.warning( - ( - "Global rate limit has been hit. Retrying in" - " %.2f seconds." - ), - retry_after, - ) - self._global_over.clear() - - await asyncio.sleep(retry_after) - _log.debug("Done sleeping for the rate limit. Retrying...") - - # release the global lock now that the - # global rate limit has passed - if is_global: - self._global_over.set() - _log.debug("Global rate limit is now over.") - - continue - - # we've received a 500, 502, 503, or 504, unconditional retry - if response.status in {500, 502, 503, 504}: - await asyncio.sleep(1 + tries * 2) - continue - - # the usual error cases - if response.status == 403: - raise Forbidden(response, data) - elif response.status == 404: - raise NotFound(response, data) - elif response.status >= 500: - raise DiscordServerError(response, data) - else: - raise HTTPException(response, data) - - # This is handling exceptions from the request - except OSError as e: - # Connection reset by peer - if tries < 4 and e.errno in (54, 10054): + # the request was successful so just return the text/json + if 300 > response.status >= 200: + _log.debug("%s %s has received %s", method, url, data) + return data + + # we are being rate limited + if response.status == 429: + _log.debug(f'Request to {route} failed: Request returned rate limit') + executor = Executor(route=route) + + self._executors.append(executor) + await executor.executed( + reset_after=data['retry_after'], + is_global=response.headers.get('X-RateLimit-Scope') == 'global', + limit=int(response.headers.get('X-RateLimit-Limit', 10)), + ) + self._executors.remove(executor) + continue + + # we've received a 500, 502, 503, or 504, unconditional retry + if response.status in {500, 502, 503, 504}: await asyncio.sleep(1 + tries * 2) continue - raise + + # the usual error cases + if response.status == 403: + raise Forbidden(response, data) + elif response.status == 404: + raise NotFound(response, data) + elif response.status >= 500: + raise DiscordServerError(response, data) + else: + raise HTTPException(response, data) + + # This is handling exceptions from the request + except OSError as e: + # Connection reset by peer + if tries < 4 and e.errno in (54, 10054): + await asyncio.sleep(1 + tries * 2) + continue + raise if response is not None: # We've run out of retries, raise. diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 61ac6391e3..d13303e735 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -104,8 +104,9 @@ async def __aexit__(self, type, value, traceback): class AsyncWebhookAdapter: - def __init__(self): + def __init__(self, *, discord_api_url: str = "https://discord.com/api/v10"): self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.api_url = discord_api_url async def request( self, @@ -144,7 +145,7 @@ async def request( response: aiohttp.ClientResponse | None = None data: dict[str, Any] | str | None = None method = route.method - url = route.url + url = route.merge(self.api_url) webhook_id = route.webhook_id async with AsyncDeferredLock(lock) as lock: diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index dcb26482ed..33638d7b19 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -96,8 +96,9 @@ def __exit__(self, type, value, traceback): class WebhookAdapter: - def __init__(self): + def __init__(self, *, discord_api_url: str = "https://discord.com/api/v10"): self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.api_url = discord_api_url def request( self, @@ -135,7 +136,7 @@ def request( data: dict[str, Any] | str | None = None file_data: dict[str, Any] | None = None method = route.method - url = route.url + url = route.merge(self.api_url) webhook_id = route.webhook_id with DeferredLock(lock) as lock: