8787
8888 T = TypeVar ("T" )
8989 BE = TypeVar ("BE" , bound = BaseException )
90- MU = TypeVar ("MU" , bound = "MaybeUnlock" )
9190 Response = Coroutine [Any , Any , T ]
9291
9392API_VERSION : int = 10
@@ -106,61 +105,90 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str
106105
107106
108107class Route :
109- API_BASE_URL : str = "https://discord.com/api/v{API_VERSION}"
110-
111- def __init__ (self , method : str , path : str , ** parameters : Any ) -> None :
112- self .path : str = path
113- self .method : str = method
114- url = self .base + self .path
115- if parameters :
116- url = url .format_map (
117- {
118- k : _uriquote (v ) if isinstance (v , str ) else v
119- for k , v in parameters .items ()
120- }
121- )
122- self .url : str = url
108+ def __init__ (
109+ self ,
110+ path : str ,
111+ guild_id : str | None = None ,
112+ channel_id : str | None = None ,
113+ webhook_id : str | None = None ,
114+ webhook_token : str | None = None ,
115+ ** parameters : str | int ,
116+ ):
117+ self .path = path
123118
124- # major parameters:
125- self .channel_id : Snowflake | None = parameters . get ( "channel_id" )
126- self .guild_id : Snowflake | None = parameters . get ( "guild_id" )
127- self .webhook_id : Snowflake | None = parameters . get ( " webhook_id" )
128- self .webhook_token : str | None = parameters . get ( " webhook_token" )
119+ # major parameters
120+ self .guild_id = guild_id
121+ self .channel_id = channel_id
122+ self .webhook_id = webhook_id
123+ self .webhook_token = webhook_token
129124
130- @property
131- def base (self ) -> str :
132- return self .API_BASE_URL .format (API_VERSION = API_VERSION )
125+ self .parameters = parameters
133126
134- @property
135- def bucket (self ) -> str :
136- # the bucket is just method + path w/ major parameters
137- return f"{ self .channel_id } :{ self .guild_id } :{ self .path } "
127+ def merge (self , url : str ):
128+ return url + self .path .format (
129+ guild_id = self .guild_id ,
130+ channel_id = self .channel_id ,
131+ webhook_id = self .webhook_id ,
132+ webhook_token = self .webhook_token ,
133+ ** self .parameters ,
134+ )
138135
136+ def __eq__ (self , route : 'Route' ) -> bool :
137+ return (
138+ route .channel_id == self .channel_id
139+ or route .guild_id == self .guild_id
140+ or route .webhook_id == self .webhook_id
141+ or route .webhook_token == self .webhook_token
142+ )
139143
140- class MaybeUnlock :
141- def __init__ (self , lock : asyncio .Lock ) -> None :
142- self .lock : asyncio .Lock = lock
143- self ._unlock : bool = True
144144
145- def __enter__ (self : MU ) -> MU :
146- return self
147145
148- def defer (self ) -> None :
149- self ._unlock = False
146+ class Executor :
147+ def __init__ (self , route : Route ) -> None :
148+ self .route = route
149+ self .is_global : bool | None = None
150+ self ._request_queue : asyncio .Queue [asyncio .Event ] | None = None
151+ self .rate_limited : bool = False
150152
151- def __exit__ (
152- self ,
153- exc_type : type [BE ] | None ,
154- exc : BE | None ,
155- traceback : TracebackType | None ,
153+ async def executed (
154+ self , reset_after : int | float , limit : int , is_global : bool
156155 ) -> None :
157- if self ._unlock :
158- self .lock .release ()
156+ self .rate_limited = True
157+ self .is_global = is_global
158+ self ._reset_after = reset_after
159+ self ._request_queue = asyncio .Queue ()
160+
161+ await asyncio .sleep (reset_after )
162+
163+ self .is_global = False
164+
165+ # NOTE: This could break if someone did a second global rate limit somehow
166+ requests_passed : int = 0
167+ for _ in range (self ._request_queue .qsize () - 1 ):
168+ if requests_passed == limit :
169+ requests_passed = 0
170+ if not is_global :
171+ await asyncio .sleep (reset_after )
172+ else :
173+ await asyncio .sleep (5 )
159174
175+ requests_passed += 1
176+ e = await self ._request_queue .get ()
177+ e .set ()
160178
161- # For some reason, the Discord voice websocket expects this header to be
162- # completely lowercase while aiohttp respects spec and does it as case-insensitive
163- aiohttp .hdrs .WEBSOCKET = "websocket" # type: ignore
179+ async def wait (self ) -> None :
180+ if not self .rate_limited :
181+ return
182+
183+ event = asyncio .Event ()
184+
185+ if self ._request_queue :
186+ self ._request_queue .put_nowait (event )
187+ else :
188+ raise ValueError (
189+ 'Request queue does not exist, rate limit may have been solved.'
190+ )
191+ await event .wait ()
164192
165193
166194class HTTPClient :
@@ -180,14 +208,12 @@ def __init__(
180208 )
181209 self .connector = connector
182210 self .__session : aiohttp .ClientSession = MISSING # filled in static_login
183- self ._locks : weakref .WeakValueDictionary = weakref .WeakValueDictionary ()
184- self ._global_over : asyncio .Event = asyncio .Event ()
185- self ._global_over .set ()
186211 self .token : str | None = None
187212 self .bot_token : bool = False
188213 self .proxy : str | None = proxy
189214 self .proxy_auth : aiohttp .BasicAuth | None = proxy_auth
190215 self .use_clock : bool = not unsync_clock
216+ self ._executors = []
191217
192218 user_agent = (
193219 "DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
@@ -230,12 +256,6 @@ async def request(
230256 method = route .method
231257 url = route .url
232258
233- lock = self ._locks .get (bucket )
234- if lock is None :
235- lock = asyncio .Lock ()
236- if bucket is not None :
237- self ._locks [bucket ] = lock
238-
239259 # header creation
240260 headers : dict [str , str ] = {
241261 "User-Agent" : self .user_agent ,
@@ -272,117 +292,118 @@ async def request(
272292
273293 response : aiohttp .ClientResponse | None = None
274294 data : dict [str , Any ] | str | None = None
275- await lock .acquire ()
276- with MaybeUnlock (lock ) as maybe_lock :
277- for tries in range (5 ):
278- if files :
279- for f in files :
280- f .reset (seek = tries )
281-
282- if form :
283- form_data = aiohttp .FormData (quote_fields = False )
284- for params in form :
285- form_data .add_field (** params )
286- kwargs ["data" ] = form_data
287-
288- try :
289- async with self .__session .request (
290- method , url , ** kwargs
291- ) as response :
292- _log .debug (
293- "%s %s with %s has returned %s" ,
294- method ,
295- url ,
296- kwargs .get ("data" ),
297- response .status ,
295+
296+ for executor in self ._executors :
297+ if executor .is_global or executor .route == route :
298+ _log .debug (f'Pausing request to { route } : Found rate limit executor' )
299+ await executor .wait ()
300+
301+ for tries in range (5 ):
302+ if files :
303+ for f in files :
304+ f .reset (seek = tries )
305+
306+ if form :
307+ form_data = aiohttp .FormData (quote_fields = False )
308+ for params in form :
309+ form_data .add_field (** params )
310+ kwargs ["data" ] = form_data
311+
312+ try :
313+ async with self .__session .request (
314+ method , url , ** kwargs
315+ ) as response :
316+ _log .debug (
317+ "%s %s with %s has returned %s" ,
318+ method ,
319+ url ,
320+ kwargs .get ("data" ),
321+ response .status ,
322+ )
323+
324+ # even errors have text involved in them so this is safe to call
325+ data = await json_or_text (response )
326+
327+ # check if we have rate limit header information
328+ remaining = response .headers .get ("X-Ratelimit-Remaining" )
329+ if remaining == "0" and response .status != 429 :
330+ _log .debug (f'Request to { route } failed: Request returned rate limit' )
331+ executor = Executor (route = route )
332+
333+ self ._executors .append (executor )
334+ await executor .executed (
335+ reset_after = data ['retry_after' ],
336+ is_global = response .headers .get ('X-RateLimit-Scope' ) == 'global' ,
337+ limit = int (response .headers .get ('X-RateLimit-Limit' , 10 )),
298338 )
339+ self ._executors .remove (executor )
340+ continue
299341
300- # even errors have text involved in them so this is safe to call
301- data = await json_or_text (response )
342+ # the request was successful so just return the text/json
343+ if 300 > response .status >= 200 :
344+ _log .debug ("%s %s has received %s" , method , url , data )
345+ return data
302346
303- # check if we have rate limit header information
304- remaining = response .headers .get ("X-Ratelimit-Remaining" )
305- if remaining == "0" and response .status != 429 :
306- # we've depleted our current bucket
307- delta = utils ._parse_ratelimit_header (
308- response , use_clock = self .use_clock
309- )
310- _log .debug (
347+ # we are being rate limited
348+ if response .status == 429 :
349+ if not response .headers .get ("Via" ) or isinstance (data , str ):
350+ # Banned by Cloudflare more than likely.
351+ raise HTTPException (response , data )
352+
353+ fmt = (
354+ "We are being rate limited. Retrying in %.2f seconds."
355+ ' Handled under the bucket "%s"'
356+ )
357+
358+ # sleep a bit
359+ retry_after : float = data ["retry_after" ]
360+ _log .warning (fmt , retry_after , bucket )
361+
362+ # check if it's a global rate limit
363+ is_global = data .get ("global" , False )
364+ if is_global :
365+ _log .warning (
311366 (
312- "A rate limit bucket has been exhausted (bucket: "
313- " %s, retry: %s) ."
367+ "Global rate limit has been hit. Retrying in "
368+ " %.2f seconds ."
314369 ),
315- bucket ,
316- delta ,
317- )
318- maybe_lock .defer ()
319- self .loop .call_later (delta , lock .release )
320-
321- # the request was successful so just return the text/json
322- if 300 > response .status >= 200 :
323- _log .debug ("%s %s has received %s" , method , url , data )
324- return data
325-
326- # we are being rate limited
327- if response .status == 429 :
328- if not response .headers .get ("Via" ) or isinstance (data , str ):
329- # Banned by Cloudflare more than likely.
330- raise HTTPException (response , data )
331-
332- fmt = (
333- "We are being rate limited. Retrying in %.2f seconds."
334- ' Handled under the bucket "%s"'
370+ retry_after ,
335371 )
372+ self ._global_over .clear ()
336373
337- # sleep a bit
338- retry_after : float = data ["retry_after" ]
339- _log .warning (fmt , retry_after , bucket )
340-
341- # check if it's a global rate limit
342- is_global = data .get ("global" , False )
343- if is_global :
344- _log .warning (
345- (
346- "Global rate limit has been hit. Retrying in"
347- " %.2f seconds."
348- ),
349- retry_after ,
350- )
351- self ._global_over .clear ()
352-
353- await asyncio .sleep (retry_after )
354- _log .debug ("Done sleeping for the rate limit. Retrying..." )
355-
356- # release the global lock now that the
357- # global rate limit has passed
358- if is_global :
359- self ._global_over .set ()
360- _log .debug ("Global rate limit is now over." )
361-
362- continue
363-
364- # we've received a 500, 502, 503, or 504, unconditional retry
365- if response .status in {500 , 502 , 503 , 504 }:
366- await asyncio .sleep (1 + tries * 2 )
367- continue
368-
369- # the usual error cases
370- if response .status == 403 :
371- raise Forbidden (response , data )
372- elif response .status == 404 :
373- raise NotFound (response , data )
374- elif response .status >= 500 :
375- raise DiscordServerError (response , data )
376- else :
377- raise HTTPException (response , data )
374+ await asyncio .sleep (retry_after )
375+ _log .debug ("Done sleeping for the rate limit. Retrying..." )
376+
377+ # release the global lock now that the
378+ # global rate limit has passed
379+ if is_global :
380+ self ._global_over .set ()
381+ _log .debug ("Global rate limit is now over." )
378382
379- # This is handling exceptions from the request
380- except OSError as e :
381- # Connection reset by peer
382- if tries < 4 and e . errno in ( 54 , 10054 ) :
383+ continue
384+
385+ # we've received a 500, 502, 503, or 504, unconditional retry
386+ if response . status in { 500 , 502 , 503 , 504 } :
383387 await asyncio .sleep (1 + tries * 2 )
384388 continue
385- raise
389+
390+ # the usual error cases
391+ if response .status == 403 :
392+ raise Forbidden (response , data )
393+ elif response .status == 404 :
394+ raise NotFound (response , data )
395+ elif response .status >= 500 :
396+ raise DiscordServerError (response , data )
397+ else :
398+ raise HTTPException (response , data )
399+
400+ # This is handling exceptions from the request
401+ except OSError as e :
402+ # Connection reset by peer
403+ if tries < 4 and e .errno in (54 , 10054 ):
404+ await asyncio .sleep (1 + tries * 2 )
405+ continue
406+ raise
386407
387408 if response is not None :
388409 # We've run out of retries, raise.
0 commit comments