Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 165 additions & 167 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@

T = TypeVar("T")
BE = TypeVar("BE", bound=BaseException)
MU = TypeVar("MU", bound="MaybeUnlock")
Response = Coroutine[Any, Any, T]

API_VERSION: int = 10
Expand All @@ -106,61 +105,92 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str


class Route:
API_BASE_URL: str = "https://discord.com/api/v{API_VERSION}"

def __init__(self, method: str, path: str, **parameters: Any) -> None:
self.path: str = path
self.method: str = method
url = self.base + self.path
if parameters:
url = url.format_map(
{
k: _uriquote(v) if isinstance(v, str) else v
for k, v in parameters.items()
}
)
self.url: str = url
def __init__(
self,
method: str,
path: str,
guild_id: str | None = None,
channel_id: str | None = None,
webhook_id: str | None = None,
webhook_token: str | None = None,
**parameters: str | int,
):
self.method = method
self.path = path

# major parameters:
self.channel_id: Snowflake | None = parameters.get("channel_id")
self.guild_id: Snowflake | None = parameters.get("guild_id")
self.webhook_id: Snowflake | None = parameters.get("webhook_id")
self.webhook_token: str | None = parameters.get("webhook_token")
# major parameters
self.guild_id = guild_id
self.channel_id = channel_id
self.webhook_id = webhook_id
self.webhook_token = webhook_token

@property
def base(self) -> str:
return self.API_BASE_URL.format(API_VERSION=API_VERSION)
self.parameters = parameters

@property
def bucket(self) -> str:
# the bucket is just method + path w/ major parameters
return f"{self.channel_id}:{self.guild_id}:{self.path}"
def merge(self, url: str):
return url + self.path.format(
guild_id=self.guild_id,
channel_id=self.channel_id,
webhook_id=self.webhook_id,
webhook_token=self.webhook_token,
**self.parameters,
)

def __eq__(self, route: 'Route') -> bool:
return (
route.channel_id == self.channel_id
or route.guild_id == self.guild_id
or route.webhook_id == self.webhook_id
or route.webhook_token == self.webhook_token
Comment on lines +141 to +143
Copy link

Copilot AI May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equality implementation in Route uses a logical OR to compare individual attributes, which may incorrectly consider two distinct routes as equal if any one attribute matches. Consider using a logical AND for a stricter comparison if full equality is desired.

Suggested change
or route.guild_id == self.guild_id
or route.webhook_id == self.webhook_id
or route.webhook_token == self.webhook_token
and route.guild_id == self.guild_id
and route.webhook_id == self.webhook_id
and route.webhook_token == self.webhook_token

Copilot uses AI. Check for mistakes.
) and route.method == self.method

class MaybeUnlock:
def __init__(self, lock: asyncio.Lock) -> None:
self.lock: asyncio.Lock = lock
self._unlock: bool = True

def __enter__(self: MU) -> MU:
return self

def defer(self) -> None:
self._unlock = False
class Executor:
def __init__(self, route: Route) -> None:
self.route = route
self.is_global: bool | None = None
self._request_queue: asyncio.Queue[asyncio.Event] | None = None
self.rate_limited: bool = False

def __exit__(
self,
exc_type: type[BE] | None,
exc: BE | None,
traceback: TracebackType | None,
async def executed(
self, reset_after: int | float, limit: int, is_global: bool
) -> None:
if self._unlock:
self.lock.release()
self.rate_limited = True
self.is_global = is_global
self._reset_after = reset_after
self._request_queue = asyncio.Queue()

await asyncio.sleep(reset_after)

self.is_global = False

# For some reason, the Discord voice websocket expects this header to be
# completely lowercase while aiohttp respects spec and does it as case-insensitive
aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore
# NOTE: This could break if someone did a second global rate limit somehow
requests_passed: int = 0
for _ in range(self._request_queue.qsize() - 1):
if requests_passed == limit:
requests_passed = 0
if not is_global:
await asyncio.sleep(reset_after)
else:
await asyncio.sleep(5)

requests_passed += 1
e = await self._request_queue.get()
e.set()

async def wait(self) -> None:
if not self.rate_limited:
return

event = asyncio.Event()

if self._request_queue:
self._request_queue.put_nowait(event)
else:
raise ValueError(
'Request queue does not exist, rate limit may have been solved.'
)
await event.wait()


class HTTPClient:
Expand All @@ -174,20 +204,20 @@ def __init__(
proxy_auth: aiohttp.BasicAuth | None = None,
loop: asyncio.AbstractEventLoop | None = None,
unsync_clock: bool = True,
discord_api_url: str = "https://discord.com/api/v10"
) -> None:
self.api_url = discord_api_url
self.loop: asyncio.AbstractEventLoop = (
asyncio.get_event_loop() if loop is None else loop
)
self.connector = connector
self.__session: aiohttp.ClientSession | utils.Undefined = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self._global_over: asyncio.Event = asyncio.Event()
self._global_over.set()
self.token: str | None = None
self.bot_token: bool = False
self.proxy: str | None = proxy
self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth
self.use_clock: bool = not unsync_clock
self._executors: list[Executor] = []

user_agent = (
"DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cool if we could make it so a user can change this in some way or another

Expand Down Expand Up @@ -226,15 +256,9 @@ async def request(
form: Iterable[dict[str, Any]] | None = None,
**kwargs: Any,
) -> Any:
bucket = route.bucket
bucket = route.merge(self.api_url)
method = route.method
url = route.url

lock = self._locks.get(bucket)
if lock is None:
lock = asyncio.Lock()
if bucket is not None:
self._locks[bucket] = lock
url = bucket

# header creation
headers: dict[str, str] = {
Expand Down Expand Up @@ -266,123 +290,97 @@ async def request(
if self.proxy_auth is not None:
kwargs["proxy_auth"] = self.proxy_auth

if not self._global_over.is_set():
# wait until the global lock is complete
await self._global_over.wait()

response: aiohttp.ClientResponse | None = None
data: dict[str, Any] | str | None = None
await lock.acquire()
with MaybeUnlock(lock) as maybe_lock:
for tries in range(5):
if files:
for f in files:
f.reset(seek=tries)

if form:
form_data = aiohttp.FormData(quote_fields=False)
for params in form:
form_data.add_field(**params)
kwargs["data"] = form_data

try:
async with self.__session.request(
method, url, **kwargs
) as response:
_log.debug(
"%s %s with %s has returned %s",
method,
url,
kwargs.get("data"),
response.status,

for executor in self._executors:
if executor.is_global or executor.route == route:
_log.debug(f'Pausing request to {route}: Found rate limit executor')
await executor.wait()

for tries in range(5):
if files:
for f in files:
f.reset(seek=tries)

if form:
form_data = aiohttp.FormData(quote_fields=False)
for params in form:
form_data.add_field(**params)
kwargs["data"] = form_data

try:
async with self.__session.request(
method, url, **kwargs
) as response:
_log.debug(
"%s %s with %s has returned %s",
method,
url,
kwargs.get("data"),
response.status,
)

# even errors have text involved in them so this is safe to call
data = await json_or_text(response)

# check if we have rate limit header information
remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == "0" and response.status != 429:
_log.debug(f'Request to {route} failed: Request returned rate limit')
executor = Executor(route=route)

self._executors.append(executor)
await executor.executed(
# NOTE: 5 is just a placeholder since this should always be present
reset_after=float(response.headers.get('X-RateLimit-Reset-After', "5")),
is_global=response.headers.get('X-RateLimit-Scope') == 'global',
limit=int(response.headers.get('X-RateLimit-Limit', 10)),
)
self._executors.remove(executor)
continue

# even errors have text involved in them so this is safe to call
data = await json_or_text(response)

# check if we have rate limit header information
remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == "0" and response.status != 429:
# we've depleted our current bucket
delta = utils._parse_ratelimit_header(
response, use_clock=self.use_clock
)
_log.debug(
(
"A rate limit bucket has been exhausted (bucket:"
" %s, retry: %s)."
),
bucket,
delta,
)
maybe_lock.defer()
self.loop.call_later(delta, lock.release)

# the request was successful so just return the text/json
if 300 > response.status >= 200:
_log.debug("%s %s has received %s", method, url, data)
return data

# we are being rate limited
if response.status == 429:
if not response.headers.get("Via") or isinstance(data, str):
# Banned by Cloudflare more than likely.
raise HTTPException(response, data)

fmt = (
"We are being rate limited. Retrying in %.2f seconds."
' Handled under the bucket "%s"'
)

# sleep a bit
retry_after: float = data["retry_after"]
_log.warning(fmt, retry_after, bucket)

# check if it's a global rate limit
is_global = data.get("global", False)
if is_global:
_log.warning(
(
"Global rate limit has been hit. Retrying in"
" %.2f seconds."
),
retry_after,
)
self._global_over.clear()

await asyncio.sleep(retry_after)
_log.debug("Done sleeping for the rate limit. Retrying...")

# release the global lock now that the
# global rate limit has passed
if is_global:
self._global_over.set()
_log.debug("Global rate limit is now over.")

continue

# we've received a 500, 502, 503, or 504, unconditional retry
if response.status in {500, 502, 503, 504}:
await asyncio.sleep(1 + tries * 2)
continue

# the usual error cases
if response.status == 403:
raise Forbidden(response, data)
elif response.status == 404:
raise NotFound(response, data)
elif response.status >= 500:
raise DiscordServerError(response, data)
else:
raise HTTPException(response, data)

# This is handling exceptions from the request
except OSError as e:
# Connection reset by peer
if tries < 4 and e.errno in (54, 10054):
# the request was successful so just return the text/json
if 300 > response.status >= 200:
_log.debug("%s %s has received %s", method, url, data)
return data

# we are being rate limited
if response.status == 429:
_log.debug(f'Request to {route} failed: Request returned rate limit')
executor = Executor(route=route)

self._executors.append(executor)
await executor.executed(
reset_after=data['retry_after'],
is_global=response.headers.get('X-RateLimit-Scope') == 'global',
limit=int(response.headers.get('X-RateLimit-Limit', 10)),
)
self._executors.remove(executor)
continue

# we've received a 500, 502, 503, or 504, unconditional retry
if response.status in {500, 502, 503, 504}:
await asyncio.sleep(1 + tries * 2)
continue
raise

# the usual error cases
if response.status == 403:
raise Forbidden(response, data)
elif response.status == 404:
raise NotFound(response, data)
elif response.status >= 500:
raise DiscordServerError(response, data)
else:
raise HTTPException(response, data)

# This is handling exceptions from the request
except OSError as e:
# Connection reset by peer
if tries < 4 and e.errno in (54, 10054):
await asyncio.sleep(1 + tries * 2)
continue
raise

if response is not None:
# We've run out of retries, raise.
Expand Down
Loading