forked from Pycord-Development/pycord
-
Notifications
You must be signed in to change notification settings - Fork 2
refactor: port over v3 rate limit code #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
056575d
refactor: port over v3 rate limit code
VincentRPS 5b32363
fix: route stuff
VincentRPS 63720c5
fix: AHHH
VincentRPS c60ae2e
fix: append api url to path
VincentRPS 7edd751
fix: final batch
VincentRPS 1368c1f
Merge branch 'master' into cool-api
plun1331 323978f
fix: oops put that in the wrong place
VincentRPS 5527c55
chore: take method into equality check
VincentRPS File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,7 +87,6 @@ | |
|
|
||
| T = TypeVar("T") | ||
| BE = TypeVar("BE", bound=BaseException) | ||
| MU = TypeVar("MU", bound="MaybeUnlock") | ||
| Response = Coroutine[Any, Any, T] | ||
|
|
||
| API_VERSION: int = 10 | ||
|
|
@@ -106,61 +105,92 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str | |
|
|
||
|
|
||
| class Route: | ||
| API_BASE_URL: str = "https://discord.com/api/v{API_VERSION}" | ||
|
|
||
| def __init__(self, method: str, path: str, **parameters: Any) -> None: | ||
| self.path: str = path | ||
| self.method: str = method | ||
| url = self.base + self.path | ||
| if parameters: | ||
| url = url.format_map( | ||
| { | ||
| k: _uriquote(v) if isinstance(v, str) else v | ||
| for k, v in parameters.items() | ||
| } | ||
| ) | ||
| self.url: str = url | ||
| def __init__( | ||
| self, | ||
| method: str, | ||
| path: str, | ||
| guild_id: str | None = None, | ||
| channel_id: str | None = None, | ||
| webhook_id: str | None = None, | ||
| webhook_token: str | None = None, | ||
| **parameters: str | int, | ||
| ): | ||
| self.method = method | ||
| self.path = path | ||
|
|
||
| # major parameters: | ||
| self.channel_id: Snowflake | None = parameters.get("channel_id") | ||
| self.guild_id: Snowflake | None = parameters.get("guild_id") | ||
| self.webhook_id: Snowflake | None = parameters.get("webhook_id") | ||
| self.webhook_token: str | None = parameters.get("webhook_token") | ||
| # major parameters | ||
| self.guild_id = guild_id | ||
| self.channel_id = channel_id | ||
| self.webhook_id = webhook_id | ||
| self.webhook_token = webhook_token | ||
|
|
||
| @property | ||
| def base(self) -> str: | ||
| return self.API_BASE_URL.format(API_VERSION=API_VERSION) | ||
| self.parameters = parameters | ||
|
|
||
| @property | ||
| def bucket(self) -> str: | ||
| # the bucket is just method + path w/ major parameters | ||
| return f"{self.channel_id}:{self.guild_id}:{self.path}" | ||
| def merge(self, url: str): | ||
| return url + self.path.format( | ||
| guild_id=self.guild_id, | ||
| channel_id=self.channel_id, | ||
| webhook_id=self.webhook_id, | ||
| webhook_token=self.webhook_token, | ||
| **self.parameters, | ||
| ) | ||
|
|
||
| def __eq__(self, route: 'Route') -> bool: | ||
| return ( | ||
| route.channel_id == self.channel_id | ||
| or route.guild_id == self.guild_id | ||
| or route.webhook_id == self.webhook_id | ||
| or route.webhook_token == self.webhook_token | ||
| ) and route.method == self.method | ||
|
|
||
| class MaybeUnlock: | ||
| def __init__(self, lock: asyncio.Lock) -> None: | ||
| self.lock: asyncio.Lock = lock | ||
| self._unlock: bool = True | ||
|
|
||
| def __enter__(self: MU) -> MU: | ||
| return self | ||
|
|
||
| def defer(self) -> None: | ||
| self._unlock = False | ||
| class Executor: | ||
| def __init__(self, route: Route) -> None: | ||
| self.route = route | ||
| self.is_global: bool | None = None | ||
| self._request_queue: asyncio.Queue[asyncio.Event] | None = None | ||
| self.rate_limited: bool = False | ||
|
|
||
| def __exit__( | ||
| self, | ||
| exc_type: type[BE] | None, | ||
| exc: BE | None, | ||
| traceback: TracebackType | None, | ||
| async def executed( | ||
| self, reset_after: int | float, limit: int, is_global: bool | ||
| ) -> None: | ||
| if self._unlock: | ||
| self.lock.release() | ||
| self.rate_limited = True | ||
| self.is_global = is_global | ||
| self._reset_after = reset_after | ||
| self._request_queue = asyncio.Queue() | ||
|
|
||
| await asyncio.sleep(reset_after) | ||
|
|
||
| self.is_global = False | ||
|
|
||
| # For some reason, the Discord voice websocket expects this header to be | ||
| # completely lowercase while aiohttp respects spec and does it as case-insensitive | ||
| aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore | ||
| # NOTE: This could break if someone did a second global rate limit somehow | ||
| requests_passed: int = 0 | ||
| for _ in range(self._request_queue.qsize() - 1): | ||
| if requests_passed == limit: | ||
| requests_passed = 0 | ||
| if not is_global: | ||
| await asyncio.sleep(reset_after) | ||
| else: | ||
| await asyncio.sleep(5) | ||
|
|
||
| requests_passed += 1 | ||
| e = await self._request_queue.get() | ||
| e.set() | ||
|
|
||
| async def wait(self) -> None: | ||
| if not self.rate_limited: | ||
| return | ||
|
|
||
| event = asyncio.Event() | ||
|
|
||
| if self._request_queue: | ||
| self._request_queue.put_nowait(event) | ||
| else: | ||
| raise ValueError( | ||
| 'Request queue does not exist, rate limit may have been solved.' | ||
| ) | ||
| await event.wait() | ||
|
|
||
|
|
||
| class HTTPClient: | ||
|
|
@@ -174,20 +204,20 @@ def __init__( | |
| proxy_auth: aiohttp.BasicAuth | None = None, | ||
| loop: asyncio.AbstractEventLoop | None = None, | ||
| unsync_clock: bool = True, | ||
| discord_api_url: str = "https://discord.com/api/v10" | ||
| ) -> None: | ||
| self.api_url = discord_api_url | ||
| self.loop: asyncio.AbstractEventLoop = ( | ||
| asyncio.get_event_loop() if loop is None else loop | ||
| ) | ||
| self.connector = connector | ||
| self.__session: aiohttp.ClientSession | utils.Undefined = MISSING # filled in static_login | ||
| self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() | ||
| self._global_over: asyncio.Event = asyncio.Event() | ||
| self._global_over.set() | ||
| self.token: str | None = None | ||
| self.bot_token: bool = False | ||
| self.proxy: str | None = proxy | ||
| self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth | ||
| self.use_clock: bool = not unsync_clock | ||
| self._executors: list[Executor] = [] | ||
|
|
||
| user_agent = ( | ||
| "DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be cool if we could make it so a user can change this in some way or another |
||
|
|
@@ -226,15 +256,9 @@ async def request( | |
| form: Iterable[dict[str, Any]] | None = None, | ||
| **kwargs: Any, | ||
| ) -> Any: | ||
| bucket = route.bucket | ||
| bucket = route.merge(self.api_url) | ||
| method = route.method | ||
| url = route.url | ||
|
|
||
| lock = self._locks.get(bucket) | ||
| if lock is None: | ||
| lock = asyncio.Lock() | ||
| if bucket is not None: | ||
| self._locks[bucket] = lock | ||
| url = bucket | ||
|
|
||
| # header creation | ||
| headers: dict[str, str] = { | ||
|
|
@@ -266,123 +290,97 @@ async def request( | |
| if self.proxy_auth is not None: | ||
| kwargs["proxy_auth"] = self.proxy_auth | ||
|
|
||
| if not self._global_over.is_set(): | ||
| # wait until the global lock is complete | ||
| await self._global_over.wait() | ||
|
|
||
| response: aiohttp.ClientResponse | None = None | ||
| data: dict[str, Any] | str | None = None | ||
| await lock.acquire() | ||
| with MaybeUnlock(lock) as maybe_lock: | ||
| for tries in range(5): | ||
| if files: | ||
| for f in files: | ||
| f.reset(seek=tries) | ||
|
|
||
| if form: | ||
| form_data = aiohttp.FormData(quote_fields=False) | ||
| for params in form: | ||
| form_data.add_field(**params) | ||
| kwargs["data"] = form_data | ||
|
|
||
| try: | ||
| async with self.__session.request( | ||
| method, url, **kwargs | ||
| ) as response: | ||
| _log.debug( | ||
| "%s %s with %s has returned %s", | ||
| method, | ||
| url, | ||
| kwargs.get("data"), | ||
| response.status, | ||
|
|
||
| for executor in self._executors: | ||
| if executor.is_global or executor.route == route: | ||
| _log.debug(f'Pausing request to {route}: Found rate limit executor') | ||
| await executor.wait() | ||
|
|
||
| for tries in range(5): | ||
| if files: | ||
| for f in files: | ||
| f.reset(seek=tries) | ||
|
|
||
| if form: | ||
| form_data = aiohttp.FormData(quote_fields=False) | ||
| for params in form: | ||
| form_data.add_field(**params) | ||
| kwargs["data"] = form_data | ||
|
|
||
| try: | ||
| async with self.__session.request( | ||
| method, url, **kwargs | ||
| ) as response: | ||
| _log.debug( | ||
| "%s %s with %s has returned %s", | ||
| method, | ||
| url, | ||
| kwargs.get("data"), | ||
| response.status, | ||
| ) | ||
|
|
||
| # even errors have text involved in them so this is safe to call | ||
| data = await json_or_text(response) | ||
|
|
||
| # check if we have rate limit header information | ||
| remaining = response.headers.get("X-Ratelimit-Remaining") | ||
| if remaining == "0" and response.status != 429: | ||
| _log.debug(f'Request to {route} failed: Request returned rate limit') | ||
| executor = Executor(route=route) | ||
|
|
||
| self._executors.append(executor) | ||
| await executor.executed( | ||
| # NOTE: 5 is just a placeholder since this should always be present | ||
| reset_after=float(response.headers.get('X-RateLimit-Reset-After', "5")), | ||
| is_global=response.headers.get('X-RateLimit-Scope') == 'global', | ||
| limit=int(response.headers.get('X-RateLimit-Limit', 10)), | ||
| ) | ||
| self._executors.remove(executor) | ||
| continue | ||
|
|
||
| # even errors have text involved in them so this is safe to call | ||
| data = await json_or_text(response) | ||
|
|
||
| # check if we have rate limit header information | ||
| remaining = response.headers.get("X-Ratelimit-Remaining") | ||
| if remaining == "0" and response.status != 429: | ||
| # we've depleted our current bucket | ||
| delta = utils._parse_ratelimit_header( | ||
| response, use_clock=self.use_clock | ||
| ) | ||
| _log.debug( | ||
| ( | ||
| "A rate limit bucket has been exhausted (bucket:" | ||
| " %s, retry: %s)." | ||
| ), | ||
| bucket, | ||
| delta, | ||
| ) | ||
| maybe_lock.defer() | ||
| self.loop.call_later(delta, lock.release) | ||
|
|
||
| # the request was successful so just return the text/json | ||
| if 300 > response.status >= 200: | ||
| _log.debug("%s %s has received %s", method, url, data) | ||
| return data | ||
|
|
||
| # we are being rate limited | ||
| if response.status == 429: | ||
| if not response.headers.get("Via") or isinstance(data, str): | ||
| # Banned by Cloudflare more than likely. | ||
| raise HTTPException(response, data) | ||
|
|
||
| fmt = ( | ||
| "We are being rate limited. Retrying in %.2f seconds." | ||
| ' Handled under the bucket "%s"' | ||
| ) | ||
|
|
||
| # sleep a bit | ||
| retry_after: float = data["retry_after"] | ||
| _log.warning(fmt, retry_after, bucket) | ||
|
|
||
| # check if it's a global rate limit | ||
| is_global = data.get("global", False) | ||
| if is_global: | ||
| _log.warning( | ||
| ( | ||
| "Global rate limit has been hit. Retrying in" | ||
| " %.2f seconds." | ||
| ), | ||
| retry_after, | ||
| ) | ||
| self._global_over.clear() | ||
|
|
||
| await asyncio.sleep(retry_after) | ||
| _log.debug("Done sleeping for the rate limit. Retrying...") | ||
|
|
||
| # release the global lock now that the | ||
| # global rate limit has passed | ||
| if is_global: | ||
| self._global_over.set() | ||
| _log.debug("Global rate limit is now over.") | ||
|
|
||
| continue | ||
|
|
||
| # we've received a 500, 502, 503, or 504, unconditional retry | ||
| if response.status in {500, 502, 503, 504}: | ||
| await asyncio.sleep(1 + tries * 2) | ||
| continue | ||
|
|
||
| # the usual error cases | ||
| if response.status == 403: | ||
| raise Forbidden(response, data) | ||
| elif response.status == 404: | ||
| raise NotFound(response, data) | ||
| elif response.status >= 500: | ||
| raise DiscordServerError(response, data) | ||
| else: | ||
| raise HTTPException(response, data) | ||
|
|
||
| # This is handling exceptions from the request | ||
| except OSError as e: | ||
| # Connection reset by peer | ||
| if tries < 4 and e.errno in (54, 10054): | ||
| # the request was successful so just return the text/json | ||
| if 300 > response.status >= 200: | ||
| _log.debug("%s %s has received %s", method, url, data) | ||
| return data | ||
|
|
||
| # we are being rate limited | ||
| if response.status == 429: | ||
| _log.debug(f'Request to {route} failed: Request returned rate limit') | ||
| executor = Executor(route=route) | ||
|
|
||
| self._executors.append(executor) | ||
| await executor.executed( | ||
| reset_after=data['retry_after'], | ||
| is_global=response.headers.get('X-RateLimit-Scope') == 'global', | ||
| limit=int(response.headers.get('X-RateLimit-Limit', 10)), | ||
| ) | ||
| self._executors.remove(executor) | ||
| continue | ||
|
|
||
| # we've received a 500, 502, 503, or 504, unconditional retry | ||
| if response.status in {500, 502, 503, 504}: | ||
| await asyncio.sleep(1 + tries * 2) | ||
| continue | ||
| raise | ||
|
|
||
| # the usual error cases | ||
| if response.status == 403: | ||
| raise Forbidden(response, data) | ||
| elif response.status == 404: | ||
| raise NotFound(response, data) | ||
| elif response.status >= 500: | ||
| raise DiscordServerError(response, data) | ||
| else: | ||
| raise HTTPException(response, data) | ||
|
|
||
| # This is handling exceptions from the request | ||
| except OSError as e: | ||
| # Connection reset by peer | ||
| if tries < 4 and e.errno in (54, 10054): | ||
| await asyncio.sleep(1 + tries * 2) | ||
| continue | ||
| raise | ||
|
|
||
| if response is not None: | ||
| # We've run out of retries, raise. | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The equality implementation in Route uses a logical OR to compare individual attributes, which may incorrectly consider two distinct routes as equal if any one attribute matches. Consider using a logical AND for a stricter comparison if full equality is desired.