Skip to content

Commit 056575d

Browse files
committed
refactor: port over v3 rate limit code
1 parent f3243bf commit 056575d

File tree

1 file changed

+176
-155
lines changed

1 file changed

+176
-155
lines changed

discord/http.py

Lines changed: 176 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787

8888
T = TypeVar("T")
8989
BE = TypeVar("BE", bound=BaseException)
90-
MU = TypeVar("MU", bound="MaybeUnlock")
9190
Response = Coroutine[Any, Any, T]
9291

9392
API_VERSION: int = 10
@@ -106,61 +105,90 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str
106105

107106

108107
class Route:
109-
API_BASE_URL: str = "https://discord.com/api/v{API_VERSION}"
110-
111-
def __init__(self, method: str, path: str, **parameters: Any) -> None:
112-
self.path: str = path
113-
self.method: str = method
114-
url = self.base + self.path
115-
if parameters:
116-
url = url.format_map(
117-
{
118-
k: _uriquote(v) if isinstance(v, str) else v
119-
for k, v in parameters.items()
120-
}
121-
)
122-
self.url: str = url
108+
def __init__(
109+
self,
110+
path: str,
111+
guild_id: str | None = None,
112+
channel_id: str | None = None,
113+
webhook_id: str | None = None,
114+
webhook_token: str | None = None,
115+
**parameters: str | int,
116+
):
117+
self.path = path
123118

124-
# major parameters:
125-
self.channel_id: Snowflake | None = parameters.get("channel_id")
126-
self.guild_id: Snowflake | None = parameters.get("guild_id")
127-
self.webhook_id: Snowflake | None = parameters.get("webhook_id")
128-
self.webhook_token: str | None = parameters.get("webhook_token")
119+
# major parameters
120+
self.guild_id = guild_id
121+
self.channel_id = channel_id
122+
self.webhook_id = webhook_id
123+
self.webhook_token = webhook_token
129124

130-
@property
131-
def base(self) -> str:
132-
return self.API_BASE_URL.format(API_VERSION=API_VERSION)
125+
self.parameters = parameters
133126

134-
@property
135-
def bucket(self) -> str:
136-
# the bucket is just method + path w/ major parameters
137-
return f"{self.channel_id}:{self.guild_id}:{self.path}"
127+
def merge(self, url: str):
128+
return url + self.path.format(
129+
guild_id=self.guild_id,
130+
channel_id=self.channel_id,
131+
webhook_id=self.webhook_id,
132+
webhook_token=self.webhook_token,
133+
**self.parameters,
134+
)
138135

136+
def __eq__(self, route: 'Route') -> bool:
137+
return (
138+
route.channel_id == self.channel_id
139+
or route.guild_id == self.guild_id
140+
or route.webhook_id == self.webhook_id
141+
or route.webhook_token == self.webhook_token
142+
)
139143

140-
class MaybeUnlock:
141-
def __init__(self, lock: asyncio.Lock) -> None:
142-
self.lock: asyncio.Lock = lock
143-
self._unlock: bool = True
144144

145-
def __enter__(self: MU) -> MU:
146-
return self
147145

148-
def defer(self) -> None:
149-
self._unlock = False
146+
class Executor:
147+
def __init__(self, route: Route) -> None:
148+
self.route = route
149+
self.is_global: bool | None = None
150+
self._request_queue: asyncio.Queue[asyncio.Event] | None = None
151+
self.rate_limited: bool = False
150152

151-
def __exit__(
152-
self,
153-
exc_type: type[BE] | None,
154-
exc: BE | None,
155-
traceback: TracebackType | None,
153+
async def executed(
154+
self, reset_after: int | float, limit: int, is_global: bool
156155
) -> None:
157-
if self._unlock:
158-
self.lock.release()
156+
self.rate_limited = True
157+
self.is_global = is_global
158+
self._reset_after = reset_after
159+
self._request_queue = asyncio.Queue()
160+
161+
await asyncio.sleep(reset_after)
162+
163+
self.is_global = False
164+
165+
# NOTE: This could break if someone did a second global rate limit somehow
166+
requests_passed: int = 0
167+
for _ in range(self._request_queue.qsize() - 1):
168+
if requests_passed == limit:
169+
requests_passed = 0
170+
if not is_global:
171+
await asyncio.sleep(reset_after)
172+
else:
173+
await asyncio.sleep(5)
159174

175+
requests_passed += 1
176+
e = await self._request_queue.get()
177+
e.set()
160178

161-
# For some reason, the Discord voice websocket expects this header to be
162-
# completely lowercase while aiohttp respects spec and does it as case-insensitive
163-
aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore
179+
async def wait(self) -> None:
180+
if not self.rate_limited:
181+
return
182+
183+
event = asyncio.Event()
184+
185+
if self._request_queue:
186+
self._request_queue.put_nowait(event)
187+
else:
188+
raise ValueError(
189+
'Request queue does not exist, rate limit may have been solved.'
190+
)
191+
await event.wait()
164192

165193

166194
class HTTPClient:
@@ -180,14 +208,12 @@ def __init__(
180208
)
181209
self.connector = connector
182210
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
183-
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
184-
self._global_over: asyncio.Event = asyncio.Event()
185-
self._global_over.set()
186211
self.token: str | None = None
187212
self.bot_token: bool = False
188213
self.proxy: str | None = proxy
189214
self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth
190215
self.use_clock: bool = not unsync_clock
216+
self._executors = []
191217

192218
user_agent = (
193219
"DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
@@ -230,12 +256,6 @@ async def request(
230256
method = route.method
231257
url = route.url
232258

233-
lock = self._locks.get(bucket)
234-
if lock is None:
235-
lock = asyncio.Lock()
236-
if bucket is not None:
237-
self._locks[bucket] = lock
238-
239259
# header creation
240260
headers: dict[str, str] = {
241261
"User-Agent": self.user_agent,
@@ -272,117 +292,118 @@ async def request(
272292

273293
response: aiohttp.ClientResponse | None = None
274294
data: dict[str, Any] | str | None = None
275-
await lock.acquire()
276-
with MaybeUnlock(lock) as maybe_lock:
277-
for tries in range(5):
278-
if files:
279-
for f in files:
280-
f.reset(seek=tries)
281-
282-
if form:
283-
form_data = aiohttp.FormData(quote_fields=False)
284-
for params in form:
285-
form_data.add_field(**params)
286-
kwargs["data"] = form_data
287-
288-
try:
289-
async with self.__session.request(
290-
method, url, **kwargs
291-
) as response:
292-
_log.debug(
293-
"%s %s with %s has returned %s",
294-
method,
295-
url,
296-
kwargs.get("data"),
297-
response.status,
295+
296+
for executor in self._executors:
297+
if executor.is_global or executor.route == route:
298+
_log.debug(f'Pausing request to {route}: Found rate limit executor')
299+
await executor.wait()
300+
301+
for tries in range(5):
302+
if files:
303+
for f in files:
304+
f.reset(seek=tries)
305+
306+
if form:
307+
form_data = aiohttp.FormData(quote_fields=False)
308+
for params in form:
309+
form_data.add_field(**params)
310+
kwargs["data"] = form_data
311+
312+
try:
313+
async with self.__session.request(
314+
method, url, **kwargs
315+
) as response:
316+
_log.debug(
317+
"%s %s with %s has returned %s",
318+
method,
319+
url,
320+
kwargs.get("data"),
321+
response.status,
322+
)
323+
324+
# even errors have text involved in them so this is safe to call
325+
data = await json_or_text(response)
326+
327+
# check if we have rate limit header information
328+
remaining = response.headers.get("X-Ratelimit-Remaining")
329+
if remaining == "0" and response.status != 429:
330+
_log.debug(f'Request to {route} failed: Request returned rate limit')
331+
executor = Executor(route=route)
332+
333+
self._executors.append(executor)
334+
await executor.executed(
335+
reset_after=data['retry_after'],
336+
is_global=response.headers.get('X-RateLimit-Scope') == 'global',
337+
limit=int(response.headers.get('X-RateLimit-Limit', 10)),
298338
)
339+
self._executors.remove(executor)
340+
continue
299341

300-
# even errors have text involved in them so this is safe to call
301-
data = await json_or_text(response)
342+
# the request was successful so just return the text/json
343+
if 300 > response.status >= 200:
344+
_log.debug("%s %s has received %s", method, url, data)
345+
return data
302346

303-
# check if we have rate limit header information
304-
remaining = response.headers.get("X-Ratelimit-Remaining")
305-
if remaining == "0" and response.status != 429:
306-
# we've depleted our current bucket
307-
delta = utils._parse_ratelimit_header(
308-
response, use_clock=self.use_clock
309-
)
310-
_log.debug(
347+
# we are being rate limited
348+
if response.status == 429:
349+
if not response.headers.get("Via") or isinstance(data, str):
350+
# Banned by Cloudflare more than likely.
351+
raise HTTPException(response, data)
352+
353+
fmt = (
354+
"We are being rate limited. Retrying in %.2f seconds."
355+
' Handled under the bucket "%s"'
356+
)
357+
358+
# sleep a bit
359+
retry_after: float = data["retry_after"]
360+
_log.warning(fmt, retry_after, bucket)
361+
362+
# check if it's a global rate limit
363+
is_global = data.get("global", False)
364+
if is_global:
365+
_log.warning(
311366
(
312-
"A rate limit bucket has been exhausted (bucket:"
313-
" %s, retry: %s)."
367+
"Global rate limit has been hit. Retrying in"
368+
" %.2f seconds."
314369
),
315-
bucket,
316-
delta,
317-
)
318-
maybe_lock.defer()
319-
self.loop.call_later(delta, lock.release)
320-
321-
# the request was successful so just return the text/json
322-
if 300 > response.status >= 200:
323-
_log.debug("%s %s has received %s", method, url, data)
324-
return data
325-
326-
# we are being rate limited
327-
if response.status == 429:
328-
if not response.headers.get("Via") or isinstance(data, str):
329-
# Banned by Cloudflare more than likely.
330-
raise HTTPException(response, data)
331-
332-
fmt = (
333-
"We are being rate limited. Retrying in %.2f seconds."
334-
' Handled under the bucket "%s"'
370+
retry_after,
335371
)
372+
self._global_over.clear()
336373

337-
# sleep a bit
338-
retry_after: float = data["retry_after"]
339-
_log.warning(fmt, retry_after, bucket)
340-
341-
# check if it's a global rate limit
342-
is_global = data.get("global", False)
343-
if is_global:
344-
_log.warning(
345-
(
346-
"Global rate limit has been hit. Retrying in"
347-
" %.2f seconds."
348-
),
349-
retry_after,
350-
)
351-
self._global_over.clear()
352-
353-
await asyncio.sleep(retry_after)
354-
_log.debug("Done sleeping for the rate limit. Retrying...")
355-
356-
# release the global lock now that the
357-
# global rate limit has passed
358-
if is_global:
359-
self._global_over.set()
360-
_log.debug("Global rate limit is now over.")
361-
362-
continue
363-
364-
# we've received a 500, 502, 503, or 504, unconditional retry
365-
if response.status in {500, 502, 503, 504}:
366-
await asyncio.sleep(1 + tries * 2)
367-
continue
368-
369-
# the usual error cases
370-
if response.status == 403:
371-
raise Forbidden(response, data)
372-
elif response.status == 404:
373-
raise NotFound(response, data)
374-
elif response.status >= 500:
375-
raise DiscordServerError(response, data)
376-
else:
377-
raise HTTPException(response, data)
374+
await asyncio.sleep(retry_after)
375+
_log.debug("Done sleeping for the rate limit. Retrying...")
376+
377+
# release the global lock now that the
378+
# global rate limit has passed
379+
if is_global:
380+
self._global_over.set()
381+
_log.debug("Global rate limit is now over.")
378382

379-
# This is handling exceptions from the request
380-
except OSError as e:
381-
# Connection reset by peer
382-
if tries < 4 and e.errno in (54, 10054):
383+
continue
384+
385+
# we've received a 500, 502, 503, or 504, unconditional retry
386+
if response.status in {500, 502, 503, 504}:
383387
await asyncio.sleep(1 + tries * 2)
384388
continue
385-
raise
389+
390+
# the usual error cases
391+
if response.status == 403:
392+
raise Forbidden(response, data)
393+
elif response.status == 404:
394+
raise NotFound(response, data)
395+
elif response.status >= 500:
396+
raise DiscordServerError(response, data)
397+
else:
398+
raise HTTPException(response, data)
399+
400+
# This is handling exceptions from the request
401+
except OSError as e:
402+
# Connection reset by peer
403+
if tries < 4 and e.errno in (54, 10054):
404+
await asyncio.sleep(1 + tries * 2)
405+
continue
406+
raise
386407

387408
if response is not None:
388409
# We've run out of retries, raise.

0 commit comments

Comments
 (0)