Skip to content

Commit 0c82aac

Browse files
authored
Merge pull request #34 from MasuRii/fix/antigravity-credential-stuck-unavailable
fix(google-oauth): prevent credentials from becoming permanently stuck
2 parents 73a2395 + b7aa5d6 commit 0c82aac

File tree

7 files changed

+1080
-540
lines changed

7 files changed

+1080
-540
lines changed

src/proxy_app/main.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -570,23 +570,11 @@ async def process_credential(provider: str, path: str, provider_instance):
570570
)
571571

572572
# Log loaded credentials summary (compact, always visible for deployment verification)
573-
_api_summary = (
574-
", ".join([f"{p}:{len(c)}" for p, c in api_keys.items()])
575-
if api_keys
576-
else "none"
577-
)
578-
_oauth_summary = (
579-
", ".join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()])
580-
if oauth_credentials
581-
else "none"
582-
)
583-
_total_summary = ", ".join(
584-
[f"{p}:{len(c)}" for p, c in client.all_credentials.items()]
585-
)
586-
print(
587-
f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})"
588-
)
589-
client.background_refresher.start() # Start the background task
573+
#_api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none"
574+
#_oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none"
575+
#_total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()])
576+
#print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})")
577+
client.background_refresher.start() # Start the background task
590578
app.state.rotating_client = client
591579

592580
# Warn if no provider credentials are configured

src/rotator_library/background_refresher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class BackgroundRefresher:
1818
"""
1919

2020
def __init__(self, client: "RotatingClient"):
21+
self._client = client
22+
self._task: Optional[asyncio.Task] = None
23+
self._initialized = False
2124
try:
2225
interval_str = os.getenv("OAUTH_REFRESH_INTERVAL", "600")
2326
self._interval = int(interval_str)
@@ -26,9 +29,6 @@ def __init__(self, client: "RotatingClient"):
2629
f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 600s."
2730
)
2831
self._interval = 600
29-
self._client = client
30-
self._task: Optional[asyncio.Task] = None
31-
self._initialized = False
3232

3333
def start(self):
3434
"""Starts the background refresh task."""

src/rotator_library/providers/google_oauth_base.py

Lines changed: 281 additions & 182 deletions
Large diffs are not rendered by default.

src/rotator_library/providers/iflow_auth_base.py

Lines changed: 243 additions & 138 deletions
Large diffs are not rendered by default.

src/rotator_library/providers/qwen_auth_base.py

Lines changed: 310 additions & 199 deletions
Large diffs are not rendered by default.
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# src/rotator_library/utils/__init__.py
22

33
from .headless_detection import is_headless_environment
4+
from .reauth_coordinator import get_reauth_coordinator, ReauthCoordinator
45

5-
__all__ = ['is_headless_environment']
6+
__all__ = ["is_headless_environment", "get_reauth_coordinator", "ReauthCoordinator"]
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# src/rotator_library/utils/reauth_coordinator.py
2+
3+
"""
4+
Global Re-authentication Coordinator
5+
6+
Ensures only ONE interactive OAuth flow runs at a time across ALL providers.
7+
This prevents port conflicts and user confusion when multiple credentials
8+
need re-authentication simultaneously.
9+
10+
When a credential needs interactive re-auth (expired refresh token, revoked, etc.),
11+
it queues a request here. The coordinator ensures only one re-auth happens at a time,
12+
regardless of which provider the credential belongs to.
13+
"""
14+
15+
import asyncio
16+
import logging
17+
import time
18+
from typing import Callable, Optional, Dict, Any, Awaitable
19+
from pathlib import Path
20+
21+
lib_logger = logging.getLogger("rotator_library")
22+
23+
24+
class ReauthCoordinator:
25+
"""
26+
Singleton coordinator for global re-authentication serialization.
27+
28+
When a credential needs interactive re-auth (expired refresh token, revoked, etc.),
29+
it queues a request here. The coordinator ensures only one re-auth happens at a time.
30+
31+
This is critical because:
32+
1. Different providers may use the same callback ports
33+
2. User can only complete one OAuth flow at a time
34+
3. Prevents race conditions in credential state management
35+
"""
36+
37+
_instance: Optional["ReauthCoordinator"] = None
38+
_initialized: bool = False # Class-level declaration for Pylint
39+
40+
def __new__(cls):
41+
# Singleton pattern - only one coordinator exists
42+
if cls._instance is None:
43+
cls._instance = super().__new__(cls)
44+
cls._instance._initialized = False
45+
return cls._instance
46+
47+
def __init__(self):
48+
if self._initialized:
49+
return
50+
51+
# Global semaphore - only 1 re-auth at a time
52+
self._reauth_semaphore: asyncio.Semaphore = asyncio.Semaphore(1)
53+
54+
# Tracking for observability
55+
self._pending_reauths: Dict[str, float] = {} # credential -> queue_time
56+
self._current_reauth: Optional[str] = None
57+
self._current_provider: Optional[str] = None
58+
self._reauth_start_time: Optional[float] = None
59+
60+
# Lock for tracking dict modifications
61+
self._tracking_lock: asyncio.Lock = asyncio.Lock()
62+
63+
# Statistics
64+
self._total_reauths: int = 0
65+
self._successful_reauths: int = 0
66+
self._failed_reauths: int = 0
67+
self._timeout_reauths: int = 0
68+
69+
self._initialized = True
70+
lib_logger.info("Global ReauthCoordinator initialized")
71+
72+
def _get_display_name(self, credential_path: str) -> str:
73+
"""Get a display-friendly name for a credential path."""
74+
if credential_path.startswith("env://"):
75+
return credential_path
76+
return Path(credential_path).name
77+
78+
async def execute_reauth(
79+
self,
80+
credential_path: str,
81+
provider_name: str,
82+
reauth_func: Callable[[], Awaitable[Dict[str, Any]]],
83+
timeout: float = 300.0, # 5 minutes default timeout
84+
) -> Dict[str, Any]:
85+
"""
86+
Execute a re-authentication function with global serialization.
87+
88+
Only one re-auth can run at a time across all providers.
89+
Other requests wait in queue.
90+
91+
Args:
92+
credential_path: Path/identifier of the credential needing re-auth
93+
provider_name: Name of the provider (for logging)
94+
reauth_func: Async function that performs the actual re-auth
95+
timeout: Maximum time to wait for re-auth to complete
96+
97+
Returns:
98+
The result from reauth_func (new credentials dict)
99+
100+
Raises:
101+
TimeoutError: If re-auth doesn't complete within timeout
102+
Exception: Any exception from reauth_func is re-raised
103+
"""
104+
display_name = self._get_display_name(credential_path)
105+
106+
# Track that this credential is waiting
107+
async with self._tracking_lock:
108+
self._pending_reauths[credential_path] = time.time()
109+
pending_count = len(self._pending_reauths)
110+
111+
# Log queue status
112+
if self._current_reauth:
113+
current_display = self._get_display_name(self._current_reauth)
114+
lib_logger.info(
115+
f"[ReauthCoordinator] Credential '{display_name}' ({provider_name}) queued for re-auth. "
116+
f"Position in queue: {pending_count}. "
117+
f"Currently processing: '{current_display}' ({self._current_provider})"
118+
)
119+
else:
120+
lib_logger.info(
121+
f"[ReauthCoordinator] Credential '{display_name}' ({provider_name}) requesting re-auth."
122+
)
123+
124+
try:
125+
# Acquire global semaphore - blocks until our turn
126+
async with self._reauth_semaphore:
127+
# Calculate how long we waited in queue
128+
async with self._tracking_lock:
129+
queue_time = self._pending_reauths.pop(credential_path, time.time())
130+
wait_duration = time.time() - queue_time
131+
self._current_reauth = credential_path
132+
self._current_provider = provider_name
133+
self._reauth_start_time = time.time()
134+
self._total_reauths += 1
135+
136+
if wait_duration > 1.0:
137+
lib_logger.info(
138+
f"[ReauthCoordinator] Starting re-auth for '{display_name}' ({provider_name}) "
139+
f"after waiting {wait_duration:.1f}s in queue"
140+
)
141+
else:
142+
lib_logger.info(
143+
f"[ReauthCoordinator] Starting re-auth for '{display_name}' ({provider_name})"
144+
)
145+
146+
try:
147+
# Execute the actual re-auth with timeout
148+
result = await asyncio.wait_for(reauth_func(), timeout=timeout)
149+
150+
async with self._tracking_lock:
151+
self._successful_reauths += 1
152+
duration = time.time() - self._reauth_start_time
153+
154+
lib_logger.info(
155+
f"[ReauthCoordinator] Re-auth SUCCESS for '{display_name}' ({provider_name}) "
156+
f"in {duration:.1f}s"
157+
)
158+
return result
159+
160+
except asyncio.TimeoutError:
161+
async with self._tracking_lock:
162+
self._failed_reauths += 1
163+
self._timeout_reauths += 1
164+
lib_logger.error(
165+
f"[ReauthCoordinator] Re-auth TIMEOUT for '{display_name}' ({provider_name}) "
166+
f"after {timeout}s. User did not complete OAuth flow in time."
167+
)
168+
raise TimeoutError(
169+
f"Re-authentication timed out after {timeout}s. "
170+
f"Please try again and complete the OAuth flow within the time limit."
171+
)
172+
173+
except Exception as e:
174+
async with self._tracking_lock:
175+
self._failed_reauths += 1
176+
lib_logger.error(
177+
f"[ReauthCoordinator] Re-auth FAILED for '{display_name}' ({provider_name}): {e}"
178+
)
179+
raise
180+
181+
finally:
182+
async with self._tracking_lock:
183+
self._current_reauth = None
184+
self._current_provider = None
185+
self._reauth_start_time = None
186+
187+
# Log if there are still pending reauths
188+
if self._pending_reauths:
189+
lib_logger.info(
190+
f"[ReauthCoordinator] {len(self._pending_reauths)} credential(s) "
191+
f"still waiting for re-auth"
192+
)
193+
194+
finally:
195+
# Ensure we're removed from pending even if something goes wrong
196+
async with self._tracking_lock:
197+
self._pending_reauths.pop(credential_path, None)
198+
199+
def is_reauth_in_progress(self) -> bool:
200+
"""Check if a re-auth is currently in progress."""
201+
return self._current_reauth is not None
202+
203+
def get_pending_count(self) -> int:
204+
"""Get number of credentials waiting for re-auth."""
205+
return len(self._pending_reauths)
206+
207+
def get_status(self) -> Dict[str, Any]:
208+
"""Get current coordinator status for debugging/monitoring."""
209+
return {
210+
"current_reauth": self._current_reauth,
211+
"current_provider": self._current_provider,
212+
"reauth_in_progress": self._current_reauth is not None,
213+
"reauth_duration": (time.time() - self._reauth_start_time)
214+
if self._reauth_start_time
215+
else None,
216+
"pending_count": len(self._pending_reauths),
217+
"pending_credentials": list(self._pending_reauths.keys()),
218+
"stats": {
219+
"total": self._total_reauths,
220+
"successful": self._successful_reauths,
221+
"failed": self._failed_reauths,
222+
"timeouts": self._timeout_reauths,
223+
},
224+
}
225+
226+
227+
# Global singleton instance
228+
_coordinator: Optional[ReauthCoordinator] = None
229+
230+
231+
def get_reauth_coordinator() -> ReauthCoordinator:
232+
"""Get the global ReauthCoordinator instance."""
233+
global _coordinator
234+
if _coordinator is None:
235+
_coordinator = ReauthCoordinator()
236+
return _coordinator

0 commit comments

Comments
 (0)