diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..b7b6e892 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,44 @@ +# Git +.git +.gitignore + +# Python +__pycache__ +*.py[cod] +*$py.class +*.so +.Python +.env +.venv +env/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Build +*.egg-info/ +dist/ +build/ +.eggs/ + +# Logs (will be mounted as volume) +logs/ + +# OAuth credentials (will be mounted as volume) +oauth_creds/ + +# Documentation +*.md +!README.md + +# GitHub +.github/ + +# Misc +.DS_Store +*.log diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..752ec8d8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,46 @@ +# Build stage +FROM python:3.11-slim as builder + +WORKDIR /app + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for better caching +COPY requirements.txt . + +# Copy the local rotator_library for editable install +COPY src/rotator_library ./src/rotator_library + +# Install dependencies +RUN pip install --no-cache-dir --user -r requirements.txt + +# Production stage +FROM python:3.11-slim + +WORKDIR /app + +# Copy installed packages from builder +COPY --from=builder /root/.local /root/.local + +# Make sure scripts in .local are usable +ENV PATH=/root/.local/bin:$PATH + +# Copy application code +COPY src/ ./src/ + +# Create directories for logs and oauth credentials +RUN mkdir -p logs oauth_creds +EXPOSE 8000 +# Expose the default port +EXPOSE 8000 + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONPATH=/app/src + +# Default command - runs proxy with the correct PYTHONPATH +CMD ["python", "src/proxy_app/main.py", "--port", "8317"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..3fabec7d --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,31 @@ +services: + llm-proxy: + build: + context: . + dockerfile: Dockerfile + container_name: llm-api-proxy + restart: unless-stopped + ports: + - "8317:8317" + volumes: + # Mount .env files for configuration + - ./.env:/app/.env:ro + # Mount oauth_creds directory for OAuth credentials persistence + - ./oauth_creds:/app/oauth_creds + # Mount logs directory for persistent logging + - ./logs:/app/logs + # Mount key_usage.json for usage statistics persistence + - ./key_usage.json:/app/key_usage.json + # Optionally mount additional .env files (e.g., combined credential files) + # - ./antigravity_all_combined.env:/app/antigravity_all_combined.env:ro + environment: + # Skip OAuth interactive initialization in container (non-interactive) + - SKIP_OAUTH_INIT_CHECK=true + # Ensure Python output is not buffered + - PYTHONUNBUFFERED=1 + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8317/')"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 258a69f3..e38d1e48 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -10,10 +10,18 @@ # --- Argument Parsing (BEFORE heavy imports) --- parser = argparse.ArgumentParser(description="API Key Proxy Server") -parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to.") +parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind the server to." +) parser.add_argument("--port", type=int, default=8000, help="Port to run the server on.") -parser.add_argument("--enable-request-logging", action="store_true", help="Enable request logging.") -parser.add_argument("--add-credential", action="store_true", help="Launch the interactive tool to add a new OAuth credential.") +parser.add_argument( + "--enable-request-logging", action="store_true", help="Enable request logging." +) +parser.add_argument( + "--add-credential", + action="store_true", + help="Launch the interactive tool to add a new OAuth credential.", +) args, _ = parser.parse_known_args() # Add the 'src' directory to the Python path @@ -23,6 +31,7 @@ if len(sys.argv) == 1: # TUI MODE - Load ONLY what's needed for the launcher (fast path!) from proxy_app.launcher_tui import run_launcher_tui + run_launcher_tui() # Launcher modifies sys.argv and returns, or exits if user chose Exit # If we get here, user chose "Run Proxy" and sys.argv is modified @@ -32,6 +41,7 @@ # Check if credential tool mode (also doesn't need heavy proxy imports) if args.add_credential: from rotator_library.credential_tool import run_credential_tool + run_credential_tool() sys.exit(0) @@ -74,6 +84,7 @@ # Phase 2: Load Rich for loading spinner (lightweight) from rich.console import Console + _console = Console() # Phase 3: Heavy dependencies with granular loading messages @@ -82,7 +93,8 @@ from contextlib import asynccontextmanager from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware - from fastapi.responses import StreamingResponse + from fastapi.responses import StreamingResponse, JSONResponse + import uuid from fastapi.security import APIKeyHeader print(" → Loading core dependencies...") @@ -92,7 +104,7 @@ import json from typing import AsyncGenerator, Any, List, Optional, Union from pydantic import BaseModel, Field - + # --- Early Log Level Configuration --- logging.getLogger("LiteLLM").setLevel(logging.WARNING) @@ -100,7 +112,7 @@ with _console.status("[dim]Loading LiteLLM library...", spinner="dots"): import litellm -# Phase 4: Application imports with granular loading messages +# Phase 4: Application imports with granular loading messages print(" → Initializing proxy core...") with _console.status("[dim]Initializing proxy core...", spinner="dots"): from rotator_library import RotatingClient @@ -115,12 +127,15 @@ # Provider lazy loading happens during import, so time it here _provider_start = time.time() with _console.status("[dim]Discovering provider plugins...", spinner="dots"): - from rotator_library import PROVIDER_PLUGINS # This triggers lazy load via __getattr__ + from rotator_library import ( + PROVIDER_PLUGINS, + ) # This triggers lazy load via __getattr__ _provider_time = time.time() - _provider_start # Get count after import (without timing to avoid double-counting) _plugin_count = len(PROVIDER_PLUGINS) + # --- Pydantic Models --- class EmbeddingRequest(BaseModel): model: str @@ -129,15 +144,19 @@ class EmbeddingRequest(BaseModel): dimensions: Optional[int] = None user: Optional[str] = None + class ModelCard(BaseModel): """Basic model card for minimal response.""" + id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "Mirro-Proxy" + class ModelCapabilities(BaseModel): """Model capability flags.""" + tool_choice: bool = False function_calling: bool = False reasoning: bool = False @@ -146,8 +165,10 @@ class ModelCapabilities(BaseModel): prompt_caching: bool = False assistant_prefill: bool = False + class EnrichedModelCard(BaseModel): """Extended model card with pricing and capabilities.""" + id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) @@ -169,28 +190,150 @@ class EnrichedModelCard(BaseModel): # Debug info (optional) _sources: Optional[List[str]] = None _match_type: Optional[str] = None - + class Config: extra = "allow" # Allow extra fields from the service + class ModelList(BaseModel): """List of models response.""" + object: str = "list" data: List[ModelCard] + class EnrichedModelList(BaseModel): """List of enriched models with pricing and capabilities.""" + object: str = "list" data: List[EnrichedModelCard] + +# --- Anthropic API Models --- +class AnthropicTextBlock(BaseModel): + """Anthropic text content block.""" + + type: str = "text" + text: str + + +class AnthropicImageSource(BaseModel): + """Anthropic image source for base64 images.""" + + type: str = "base64" + media_type: str + data: str + + +class AnthropicImageBlock(BaseModel): + """Anthropic image content block.""" + + type: str = "image" + source: AnthropicImageSource + + +class AnthropicToolUseBlock(BaseModel): + """Anthropic tool use content block.""" + + type: str = "tool_use" + id: str + name: str + input: dict + + +class AnthropicToolResultBlock(BaseModel): + """Anthropic tool result content block.""" + + type: str = "tool_result" + tool_use_id: str + content: Union[str, List[Any]] + is_error: Optional[bool] = None + + +class AnthropicMessage(BaseModel): + """Anthropic message format.""" + + role: str + content: Union[ + str, + List[ + Union[ + AnthropicTextBlock, + AnthropicImageBlock, + AnthropicToolUseBlock, + AnthropicToolResultBlock, + dict, + ] + ], + ] + + +class AnthropicTool(BaseModel): + """Anthropic tool definition.""" + + name: str + description: Optional[str] = None + input_schema: dict + + +class AnthropicThinkingConfig(BaseModel): + """Anthropic thinking configuration.""" + + type: str # "enabled" or "disabled" + budget_tokens: Optional[int] = None + + +class AnthropicMessagesRequest(BaseModel): + """Anthropic Messages API request format.""" + + model: str + messages: List[AnthropicMessage] + max_tokens: int + system: Optional[Union[str, List[dict]]] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + stop_sequences: Optional[List[str]] = None + stream: Optional[bool] = False + tools: Optional[List[AnthropicTool]] = None + tool_choice: Optional[dict] = None + metadata: Optional[dict] = None + thinking: Optional[AnthropicThinkingConfig] = None + + +class AnthropicUsage(BaseModel): + """Anthropic usage statistics.""" + + input_tokens: int + output_tokens: int + cache_creation_input_tokens: Optional[int] = None + cache_read_input_tokens: Optional[int] = None + + +class AnthropicMessagesResponse(BaseModel): + """Anthropic Messages API response format.""" + + id: str + type: str = "message" + role: str = "assistant" + content: List[Union[AnthropicTextBlock, AnthropicToolUseBlock, dict]] + model: str + stop_reason: Optional[str] = None + stop_sequence: Optional[str] = None + usage: AnthropicUsage + + # Calculate total loading time _elapsed = time.time() - _start_time -print(f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)") +print( + f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)" +) # Clear screen and reprint header for clean startup view # This pushes loading messages up (still in scroll history) but shows a clean final screen import os as _os_module -_os_module.system('cls' if _os_module.name == 'nt' else 'clear') + +_os_module.system("cls" if _os_module.name == "nt" else "clear") # Reprint header print("━" * 70) @@ -198,7 +341,9 @@ class EnrichedModelList(BaseModel): print(f"Proxy API Key: {key_display}") print(f"GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy") print("━" * 70) -print(f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)") +print( + f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)" +) # Note: Debug logging will be added after logging configuration below @@ -211,52 +356,64 @@ class EnrichedModelList(BaseModel): console_handler = colorlog.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) formatter = colorlog.ColoredFormatter( - '%(log_color)s%(message)s', + "%(log_color)s%(message)s", log_colors={ - 'DEBUG': 'cyan', - 'INFO': 'green', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'red,bg_white', - } + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red,bg_white", + }, ) console_handler.setFormatter(formatter) # Configure a file handler for INFO-level logs and higher info_file_handler = logging.FileHandler(LOG_DIR / "proxy.log", encoding="utf-8") info_file_handler.setLevel(logging.INFO) -info_file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +info_file_handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) # Configure a dedicated file handler for all DEBUG-level logs debug_file_handler = logging.FileHandler(LOG_DIR / "proxy_debug.log", encoding="utf-8") debug_file_handler.setLevel(logging.DEBUG) -debug_file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +debug_file_handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) + # Create a filter to ensure the debug handler ONLY gets DEBUG messages from the rotator_library class RotatorDebugFilter(logging.Filter): def filter(self, record): - return record.levelno == logging.DEBUG and record.name.startswith('rotator_library') + return record.levelno == logging.DEBUG and record.name.startswith( + "rotator_library" + ) + + debug_file_handler.addFilter(RotatorDebugFilter()) # Configure a console handler with color console_handler = colorlog.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) formatter = colorlog.ColoredFormatter( - '%(log_color)s%(message)s', + "%(log_color)s%(message)s", log_colors={ - 'DEBUG': 'cyan', - 'INFO': 'green', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'red,bg_white', - } + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red,bg_white", + }, ) console_handler.setFormatter(formatter) + # Add a filter to prevent any LiteLLM logs from cluttering the console class NoLiteLLMLogFilter(logging.Filter): def filter(self, record): - return not record.name.startswith('LiteLLM') + return not record.name.startswith("LiteLLM") + + console_handler.addFilter(NoLiteLLMLogFilter()) # Get the root logger and set it to DEBUG to capture all messages @@ -306,18 +463,26 @@ def filter(self, record): for key, value in os.environ.items(): if key.startswith("IGNORE_MODELS_"): provider = key.replace("IGNORE_MODELS_", "").lower() - models_to_ignore = [model.strip() for model in value.split(',') if model.strip()] + models_to_ignore = [ + model.strip() for model in value.split(",") if model.strip() + ] ignore_models[provider] = models_to_ignore - logging.debug(f"Loaded ignore list for provider '{provider}': {models_to_ignore}") + logging.debug( + f"Loaded ignore list for provider '{provider}': {models_to_ignore}" + ) # Load model whitelist from environment variables whitelist_models = {} for key, value in os.environ.items(): if key.startswith("WHITELIST_MODELS_"): provider = key.replace("WHITELIST_MODELS_", "").lower() - models_to_whitelist = [model.strip() for model in value.split(',') if model.strip()] + models_to_whitelist = [ + model.strip() for model in value.split(",") if model.strip() + ] whitelist_models[provider] = models_to_whitelist - logging.debug(f"Loaded whitelist for provider '{provider}': {models_to_whitelist}") + logging.debug( + f"Loaded whitelist for provider '{provider}': {models_to_whitelist}" + ) # Load max concurrent requests per key from environment variables max_concurrent_requests_per_key = {} @@ -327,12 +492,19 @@ def filter(self, record): try: max_concurrent = int(value) if max_concurrent < 1: - logging.warning(f"Invalid max_concurrent value for provider '{provider}': {value}. Must be >= 1. Using default (1).") + logging.warning( + f"Invalid max_concurrent value for provider '{provider}': {value}. Must be >= 1. Using default (1)." + ) max_concurrent = 1 max_concurrent_requests_per_key[provider] = max_concurrent - logging.debug(f"Loaded max concurrent requests for provider '{provider}': {max_concurrent}") + logging.debug( + f"Loaded max concurrent requests for provider '{provider}': {max_concurrent}" + ) except ValueError: - logging.warning(f"Invalid max_concurrent value for provider '{provider}': {value}. Using default (1).") + logging.warning( + f"Invalid max_concurrent value for provider '{provider}': {value}. Using default (1)." + ) + # --- Lifespan Management --- @asynccontextmanager @@ -349,11 +521,11 @@ async def lifespan(app: FastAPI): if not skip_oauth_init and oauth_credentials: logging.info("Starting OAuth credential validation and deduplication...") processed_emails = {} # email -> {provider: path} - credentials_to_initialize = {} # provider -> [paths] + credentials_to_initialize = {} # provider -> [paths] final_oauth_credentials = {} # --- Pass 1: Pre-initialization Scan & Deduplication --- - #logging.info("Pass 1: Scanning for existing metadata to find duplicates...") + # logging.info("Pass 1: Scanning for existing metadata to find duplicates...") for provider, paths in oauth_credentials.items(): if provider not in credentials_to_initialize: credentials_to_initialize[provider] = [] @@ -362,9 +534,9 @@ async def lifespan(app: FastAPI): if path.startswith("env://"): credentials_to_initialize[provider].append(path) continue - + try: - with open(path, 'r') as f: + with open(path, "r") as f: data = json.load(f) metadata = data.get("_proxy_metadata", {}) email = metadata.get("email") @@ -372,28 +544,32 @@ async def lifespan(app: FastAPI): if email: if email not in processed_emails: processed_emails[email] = {} - + if provider in processed_emails[email]: original_path = processed_emails[email][provider] - logging.warning(f"Duplicate for '{email}' on '{provider}' found in pre-scan: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping.") + logging.warning( + f"Duplicate for '{email}' on '{provider}' found in pre-scan: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping." + ) continue else: processed_emails[email][provider] = path - + credentials_to_initialize[provider].append(path) except (FileNotFoundError, json.JSONDecodeError) as e: - logging.warning(f"Could not pre-read metadata from '{path}': {e}. Will process during initialization.") + logging.warning( + f"Could not pre-read metadata from '{path}': {e}. Will process during initialization." + ) credentials_to_initialize[provider].append(path) - + # --- Pass 2: Parallel Initialization of Filtered Credentials --- - #logging.info("Pass 2: Initializing unique credentials and performing final check...") + # logging.info("Pass 2: Initializing unique credentials and performing final check...") async def process_credential(provider: str, path: str, provider_instance): """Process a single credential: initialize and fetch user info.""" try: await provider_instance.initialize_token(path) - if not hasattr(provider_instance, 'get_user_info'): + if not hasattr(provider_instance, "get_user_info"): return (provider, path, None, None) user_info = await provider_instance.get_user_info(path) @@ -401,7 +577,9 @@ async def process_credential(provider: str, path: str, provider_instance): return (provider, path, email, None) except Exception as e: - logging.error(f"Failed to process OAuth token for {provider} at '{path}': {e}") + logging.error( + f"Failed to process OAuth token for {provider} at '{path}': {e}" + ) return (provider, path, None, e) # Collect all tasks for parallel execution @@ -413,9 +591,9 @@ async def process_credential(provider: str, path: str, provider_instance): provider_plugin_class = PROVIDER_PLUGINS.get(provider) if not provider_plugin_class: continue - + provider_instance = provider_plugin_class() - + for path in paths: tasks.append(process_credential(provider, path, provider_instance)) @@ -430,7 +608,7 @@ async def process_credential(provider: str, path: str, provider_instance): continue provider, path, email, error = result - + # Skip if there was an error if error: continue @@ -444,7 +622,9 @@ async def process_credential(provider: str, path: str, provider_instance): # Handle empty email if not email: - logging.warning(f"Could not retrieve email for '{path}'. Treating as unique.") + logging.warning( + f"Could not retrieve email for '{path}'. Treating as unique." + ) if provider not in final_oauth_credentials: final_oauth_credentials[provider] = [] final_oauth_credentials[provider].append(path) @@ -453,10 +633,15 @@ async def process_credential(provider: str, path: str, provider_instance): # Deduplication check if email not in processed_emails: processed_emails[email] = {} - - if provider in processed_emails[email] and processed_emails[email][provider] != path: + + if ( + provider in processed_emails[email] + and processed_emails[email][provider] != path + ): original_path = processed_emails[email][provider] - logging.warning(f"Duplicate for '{email}' on '{provider}' found post-init: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping.") + logging.warning( + f"Duplicate for '{email}' on '{provider}' found post-init: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping." + ) continue else: processed_emails[email][provider] = path @@ -467,7 +652,7 @@ async def process_credential(provider: str, path: str, provider_instance): # Update metadata (skip for env-based credentials - they don't have files) if not path.startswith("env://"): try: - with open(path, 'r+') as f: + with open(path, "r+") as f: data = json.load(f) metadata = data.get("_proxy_metadata", {}) metadata["email"] = email @@ -490,33 +675,47 @@ async def process_credential(provider: str, path: str, provider_instance): # The client now uses the root logger configuration client = RotatingClient( api_keys=api_keys, - oauth_credentials=oauth_credentials, # Pass OAuth config + oauth_credentials=oauth_credentials, # Pass OAuth config configure_logging=True, litellm_provider_params=litellm_provider_params, ignore_models=ignore_models, whitelist_models=whitelist_models, enable_request_logging=ENABLE_REQUEST_LOGGING, - max_concurrent_requests_per_key=max_concurrent_requests_per_key + max_concurrent_requests_per_key=max_concurrent_requests_per_key, ) - + # Log loaded credentials summary (compact, always visible for deployment verification) - _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none" - _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none" - _total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()]) - print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})") - client.background_refresher.start() # Start the background task + _api_summary = ( + ", ".join([f"{p}:{len(c)}" for p, c in api_keys.items()]) + if api_keys + else "none" + ) + _oauth_summary = ( + ", ".join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) + if oauth_credentials + else "none" + ) + _total_summary = ", ".join( + [f"{p}:{len(c)}" for p, c in client.all_credentials.items()] + ) + print( + f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})" + ) + client.background_refresher.start() # Start the background task app.state.rotating_client = client - + # Warn if no provider credentials are configured if not client.all_credentials: logging.warning("=" * 70) logging.warning("⚠️ NO PROVIDER CREDENTIALS CONFIGURED") logging.warning("The proxy is running but cannot serve any LLM requests.") - logging.warning("Launch the credential tool to add API keys or OAuth credentials.") + logging.warning( + "Launch the credential tool to add API keys or OAuth credentials." + ) logging.warning(" • Executable: Run with --add-credential flag") logging.warning(" • Source: python src/proxy_app/main.py --add-credential") logging.warning("=" * 70) - + os.environ["LITELLM_LOG"] = "ERROR" litellm.set_verbose = False litellm.drop_params = True @@ -527,29 +726,30 @@ async def process_credential(provider: str, path: str, provider_instance): else: app.state.embedding_batcher = None logging.info("RotatingClient initialized (EmbeddingBatcher disabled).") - + # Start model info service in background (fetches pricing/capabilities data) # This runs asynchronously and doesn't block proxy startup model_info_service = await init_model_info_service() app.state.model_info_service = model_info_service logging.info("Model info service started (fetching pricing data in background).") - + yield - - await client.background_refresher.stop() # Stop the background task on shutdown + + await client.background_refresher.stop() # Stop the background task on shutdown if app.state.embedding_batcher: await app.state.embedding_batcher.stop() await client.close() - + # Stop model info service - if hasattr(app.state, 'model_info_service') and app.state.model_info_service: + if hasattr(app.state, "model_info_service") and app.state.model_info_service: await app.state.model_info_service.stop() - + if app.state.embedding_batcher: logging.info("RotatingClient and EmbeddingBatcher closed.") else: logging.info("RotatingClient closed.") + # --- FastAPI App Setup --- app = FastAPI(lifespan=lifespan) @@ -563,25 +763,501 @@ async def process_credential(provider: str, path: str, provider_instance): ) api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + def get_rotating_client(request: Request) -> RotatingClient: """Dependency to get the rotating client instance from the app state.""" return request.app.state.rotating_client + def get_embedding_batcher(request: Request) -> EmbeddingBatcher: """Dependency to get the embedding batcher instance from the app state.""" return request.app.state.embedding_batcher + async def verify_api_key(auth: str = Depends(api_key_header)): """Dependency to verify the proxy API key.""" if not auth or auth != f"Bearer {PROXY_API_KEY}": raise HTTPException(status_code=401, detail="Invalid or missing API Key") return auth + +# --- Anthropic API Key Header --- +anthropic_api_key_header = APIKeyHeader(name="x-api-key", auto_error=False) + + +async def verify_anthropic_api_key( + x_api_key: str = Depends(anthropic_api_key_header), + auth: str = Depends(api_key_header), +): + """ + Dependency to verify API key for Anthropic endpoints. + Accepts either x-api-key header (Anthropic style) or Authorization Bearer (OpenAI style). + """ + # Check x-api-key first (Anthropic style) + if x_api_key and x_api_key == PROXY_API_KEY: + return x_api_key + # Fall back to Bearer token (OpenAI style) + if auth and auth == f"Bearer {PROXY_API_KEY}": + return auth + raise HTTPException(status_code=401, detail="Invalid or missing API Key") + + +# --- Anthropic <-> OpenAI Format Translation --- +def anthropic_to_openai_messages( + anthropic_messages: List[dict], system: Optional[Union[str, List[dict]]] = None +) -> List[dict]: + """ + Convert Anthropic message format to OpenAI format. + + Key differences: + - Anthropic: system is a separate field, content can be string or list of blocks + - OpenAI: system is a message with role="system", content is usually string + """ + openai_messages = [] + + # Handle system message + if system: + if isinstance(system, str): + openai_messages.append({"role": "system", "content": system}) + elif isinstance(system, list): + # System can be list of text blocks in Anthropic format + system_text = " ".join( + block.get("text", "") + for block in system + if isinstance(block, dict) and block.get("type") == "text" + ) + if system_text: + openai_messages.append({"role": "system", "content": system_text}) + + for msg in anthropic_messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if isinstance(content, str): + openai_messages.append({"role": role, "content": content}) + elif isinstance(content, list): + # Handle content blocks + openai_content = [] + tool_calls = [] + + for block in content: + if isinstance(block, dict): + block_type = block.get("type", "text") + + if block_type == "text": + openai_content.append( + {"type": "text", "text": block.get("text", "")} + ) + elif block_type == "image": + # Convert Anthropic image format to OpenAI + source = block.get("source", {}) + if source.get("type") == "base64": + openai_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:{source.get('media_type', 'image/png')};base64,{source.get('data', '')}" + }, + } + ) + elif source.get("type") == "url": + openai_content.append( + { + "type": "image_url", + "image_url": {"url": source.get("url", "")}, + } + ) + elif block_type == "tool_use": + # Anthropic tool_use -> OpenAI tool_calls + tool_calls.append( + { + "id": block.get("id", ""), + "type": "function", + "function": { + "name": block.get("name", ""), + "arguments": json.dumps(block.get("input", {})), + }, + } + ) + elif block_type == "tool_result": + # Tool results become separate messages in OpenAI format + tool_content = block.get("content", "") + if isinstance(tool_content, list): + tool_content = " ".join( + b.get("text", "") + for b in tool_content + if isinstance(b, dict) and b.get("type") == "text" + ) + openai_messages.append( + { + "role": "tool", + "tool_call_id": block.get("tool_use_id", ""), + "content": str(tool_content), + } + ) + continue # Don't add to current message + + # Build the message + if tool_calls: + # Assistant message with tool calls + msg_dict = {"role": role} + if openai_content: + # If there's text content alongside tool calls + text_parts = [ + c.get("text", "") + for c in openai_content + if c.get("type") == "text" + ] + msg_dict["content"] = " ".join(text_parts) if text_parts else None + else: + msg_dict["content"] = None + msg_dict["tool_calls"] = tool_calls + openai_messages.append(msg_dict) + elif openai_content: + # Check if it's just text or mixed content + if len(openai_content) == 1 and openai_content[0].get("type") == "text": + openai_messages.append( + {"role": role, "content": openai_content[0].get("text", "")} + ) + else: + openai_messages.append({"role": role, "content": openai_content}) + + return openai_messages + + +def anthropic_to_openai_tools( + anthropic_tools: Optional[List[dict]], +) -> Optional[List[dict]]: + """Convert Anthropic tool definitions to OpenAI format.""" + if not anthropic_tools: + return None + + openai_tools = [] + for tool in anthropic_tools: + openai_tools.append( + { + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("input_schema", {}), + }, + } + ) + return openai_tools + + +def anthropic_to_openai_tool_choice( + anthropic_tool_choice: Optional[dict], +) -> Optional[Union[str, dict]]: + """Convert Anthropic tool_choice to OpenAI format.""" + if not anthropic_tool_choice: + return None + + choice_type = anthropic_tool_choice.get("type", "auto") + + if choice_type == "auto": + return "auto" + elif choice_type == "any": + return "required" + elif choice_type == "tool": + return { + "type": "function", + "function": {"name": anthropic_tool_choice.get("name", "")}, + } + elif choice_type == "none": + return "none" + + return "auto" + + +def openai_to_anthropic_response(openai_response: dict, original_model: str) -> dict: + """ + Convert OpenAI chat completion response to Anthropic Messages format. + """ + choice = openai_response.get("choices", [{}])[0] + message = choice.get("message", {}) + usage = openai_response.get("usage", {}) + + # Build content blocks + content_blocks = [] + + # Add thinking content block if reasoning_content is present + reasoning_content = message.get("reasoning_content") + if reasoning_content: + content_blocks.append({ + "type": "thinking", + "thinking": reasoning_content, + "signature": "", # Signature is typically empty for proxied responses + }) + + # Add text content if present + text_content = message.get("content") + if text_content: + content_blocks.append({"type": "text", "text": text_content}) + + # Add tool use blocks if present + tool_calls = message.get("tool_calls") or [] + for tc in tool_calls: + func = tc.get("function", {}) + try: + input_data = json.loads(func.get("arguments", "{}")) + except json.JSONDecodeError: + input_data = {} + + content_blocks.append( + { + "type": "tool_use", + "id": tc.get("id", f"toolu_{int(time.time())}"), + "name": func.get("name", ""), + "input": input_data, + } + ) + + # Map finish_reason to stop_reason + finish_reason = choice.get("finish_reason", "end_turn") + stop_reason_map = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + "content_filter": "end_turn", + "function_call": "tool_use", + } + stop_reason = stop_reason_map.get(finish_reason, "end_turn") + + # Build usage + anthropic_usage = { + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + } + + # Add cache tokens if present + if usage.get("prompt_tokens_details"): + details = usage["prompt_tokens_details"] + if details.get("cached_tokens"): + anthropic_usage["cache_read_input_tokens"] = details["cached_tokens"] + + return { + "id": openai_response.get("id", f"msg_{int(time.time())}"), + "type": "message", + "role": "assistant", + "content": content_blocks, + "model": original_model, + "stop_reason": stop_reason, + "stop_sequence": None, + "usage": anthropic_usage, + } + + +async def anthropic_streaming_wrapper( + request: Request, + openai_stream: AsyncGenerator[str, None], + original_model: str, + request_id: str, +) -> AsyncGenerator[str, None]: + """ + Convert OpenAI streaming format to Anthropic streaming format. + + Anthropic SSE events: + - message_start: Initial message metadata + - content_block_start: Start of a content block + - content_block_delta: Content chunk + - content_block_stop: End of a content block + - message_delta: Final message metadata (stop_reason, usage) + - message_stop: End of message + """ + message_started = False + content_block_started = False + thinking_block_started = False + current_block_index = 0 + accumulated_text = "" + accumulated_thinking = "" + tool_calls_by_index = {} # Track tool calls by their index + input_tokens = 0 + output_tokens = 0 + + try: + async for chunk_str in openai_stream: + if await request.is_disconnected(): + break + + if not chunk_str.strip() or not chunk_str.startswith("data:"): + continue + + data_content = chunk_str[len("data:") :].strip() + if data_content == "[DONE]": + # Close any open content blocks (thinking, text, or tool_use) + if thinking_block_started or content_block_started or tool_calls_by_index: + yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + + # Determine stop_reason based on whether we had tool calls + stop_reason = "tool_use" if tool_calls_by_index else "end_turn" + + # Send message_delta with final info + yield f'event: message_delta\ndata: {{"type": "message_delta", "delta": {{"stop_reason": "{stop_reason}", "stop_sequence": null}}, "usage": {{"output_tokens": {output_tokens}}}}}\n\n' + + # Send message_stop + yield 'event: message_stop\ndata: {"type": "message_stop"}\n\n' + break + + try: + chunk = json.loads(data_content) + except json.JSONDecodeError: + continue + + # Extract usage if present + if "usage" in chunk and chunk["usage"]: + input_tokens = chunk["usage"].get("prompt_tokens", input_tokens) + output_tokens = chunk["usage"].get("completion_tokens", output_tokens) + + # Send message_start on first chunk + if not message_started: + message_start = { + "type": "message_start", + "message": { + "id": request_id, + "type": "message", + "role": "assistant", + "content": [], + "model": original_model, + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": input_tokens, "output_tokens": 0}, + }, + } + yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + message_started = True + + choices = chunk.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + finish_reason = choices[0].get("finish_reason") + + # Handle reasoning/thinking content (from OpenAI-style reasoning_content) + reasoning_content = delta.get("reasoning_content") + if reasoning_content: + if not thinking_block_started: + # Start a thinking content block + block_start = { + "type": "content_block_start", + "index": current_block_index, + "content_block": {"type": "thinking", "thinking": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + thinking_block_started = True + + # Send thinking delta + block_delta = { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "thinking_delta", "thinking": reasoning_content}, + } + yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + accumulated_thinking += reasoning_content + + # Handle text content + content = delta.get("content") + if content: + # If we were in a thinking block, close it first + if thinking_block_started and not content_block_started: + yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + current_block_index += 1 + thinking_block_started = False + + if not content_block_started: + # Start a text content block + block_start = { + "type": "content_block_start", + "index": current_block_index, + "content_block": {"type": "text", "text": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + content_block_started = True + + # Send content delta + block_delta = { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "text_delta", "text": content}, + } + yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + accumulated_text += content + + # Handle tool calls + tool_calls = delta.get("tool_calls", []) + for tc in tool_calls: + tc_index = tc.get("index", 0) + + if tc_index not in tool_calls_by_index: + # Close previous thinking block if open + if thinking_block_started: + yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + current_block_index += 1 + thinking_block_started = False + + # Close previous text block if open + if content_block_started: + yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' + current_block_index += 1 + content_block_started = False + + # Start new tool use block + tool_calls_by_index[tc_index] = { + "id": tc.get("id", f"toolu_{tc_index}"), + "name": tc.get("function", {}).get("name", ""), + "arguments": "", + } + + block_start = { + "type": "content_block_start", + "index": current_block_index, + "content_block": { + "type": "tool_use", + "id": tool_calls_by_index[tc_index]["id"], + "name": tool_calls_by_index[tc_index]["name"], + "input": {}, + }, + } + yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" + + # Accumulate arguments + func = tc.get("function", {}) + if func.get("name"): + tool_calls_by_index[tc_index]["name"] = func["name"] + if func.get("arguments"): + tool_calls_by_index[tc_index]["arguments"] += func["arguments"] + + # Send partial JSON delta + block_delta = { + "type": "content_block_delta", + "index": current_block_index, + "delta": { + "type": "input_json_delta", + "partial_json": func["arguments"], + }, + } + yield f"event: content_block_delta\ndata: {json.dumps(block_delta)}\n\n" + + # Note: We intentionally ignore finish_reason here. + # Block closing is handled when we receive [DONE] to avoid + # premature closes with providers that send finish_reason on each chunk. + + except Exception as e: + logging.error(f"Error in Anthropic streaming wrapper: {e}") + error_event = { + "type": "error", + "error": {"type": "api_error", "message": str(e)}, + } + yield f"event: error\ndata: {json.dumps(error_event)}\n\n" + + async def streaming_response_wrapper( request: Request, request_data: dict, response_stream: AsyncGenerator[str, None], - logger: Optional[DetailedLogger] = None + logger: Optional[DetailedLogger] = None, ) -> AsyncGenerator[str, None]: """ Wraps a streaming response to log the full response after completion @@ -589,7 +1265,7 @@ async def streaming_response_wrapper( """ response_chunks = [] full_response = {} - + try: async for chunk_str in response_stream: if await request.is_disconnected(): @@ -597,7 +1273,7 @@ async def streaming_response_wrapper( break yield chunk_str if chunk_str.strip() and chunk_str.startswith("data:"): - content = chunk_str[len("data:"):].strip() + content = chunk_str[len("data:") :].strip() if content != "[DONE]": try: chunk_data = json.loads(content) @@ -613,15 +1289,17 @@ async def streaming_response_wrapper( "error": { "message": f"An unexpected error occurred during the stream: {str(e)}", "type": "proxy_internal_error", - "code": 500 + "code": 500, } } yield f"data: {json.dumps(error_payload)}\n\n" yield "data: [DONE]\n\n" # Also log this as a failed request if logger: - logger.log_final_response(status_code=500, headers=None, body={"error": str(e)}) - return # Stop further processing + logger.log_final_response( + status_code=500, headers=None, body={"error": str(e)} + ) + return # Stop further processing finally: if response_chunks: # --- Aggregation Logic --- @@ -645,36 +1323,56 @@ async def streaming_response_wrapper( final_message["content"] = "" if value: final_message["content"] += value - + elif key == "tool_calls": for tc_chunk in value: index = tc_chunk["index"] if index not in aggregated_tool_calls: - aggregated_tool_calls[index] = {"type": "function", "function": {"name": "", "arguments": ""}} + aggregated_tool_calls[index] = { + "type": "function", + "function": {"name": "", "arguments": ""}, + } # Ensure 'function' key exists for this index before accessing its sub-keys if "function" not in aggregated_tool_calls[index]: - aggregated_tool_calls[index]["function"] = {"name": "", "arguments": ""} + aggregated_tool_calls[index]["function"] = { + "name": "", + "arguments": "", + } if tc_chunk.get("id"): aggregated_tool_calls[index]["id"] = tc_chunk["id"] if "function" in tc_chunk: if "name" in tc_chunk["function"]: if tc_chunk["function"]["name"] is not None: - aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"] + aggregated_tool_calls[index]["function"][ + "name" + ] += tc_chunk["function"]["name"] if "arguments" in tc_chunk["function"]: - if tc_chunk["function"]["arguments"] is not None: - aggregated_tool_calls[index]["function"]["arguments"] += tc_chunk["function"]["arguments"] - + if ( + tc_chunk["function"]["arguments"] + is not None + ): + aggregated_tool_calls[index]["function"][ + "arguments" + ] += tc_chunk["function"]["arguments"] + elif key == "function_call": if "function_call" not in final_message: - final_message["function_call"] = {"name": "", "arguments": ""} + final_message["function_call"] = { + "name": "", + "arguments": "", + } if "name" in value: if value["name"] is not None: - final_message["function_call"]["name"] += value["name"] + final_message["function_call"]["name"] += value[ + "name" + ] if "arguments" in value: if value["arguments"] is not None: - final_message["function_call"]["arguments"] += value["arguments"] - - else: # Generic key handling for other data like 'reasoning' + final_message["function_call"]["arguments"] += ( + value["arguments"] + ) + + else: # Generic key handling for other data like 'reasoning' # FIX: Role should always replace, never concatenate if key == "role": final_message[key] = value @@ -707,7 +1405,7 @@ async def streaming_response_wrapper( final_choice = { "index": 0, "message": final_message, - "finish_reason": finish_reason + "finish_reason": finish_reason, } full_response = { @@ -716,21 +1414,22 @@ async def streaming_response_wrapper( "created": first_chunk.get("created"), "model": first_chunk.get("model"), "choices": [final_choice], - "usage": usage_data + "usage": usage_data, } if logger: logger.log_final_response( status_code=200, headers=None, # Headers are not available at this stage - body=full_response + body=full_response, ) + @app.post("/v1/chat/completions") async def chat_completions( request: Request, client: RotatingClient = Depends(get_rotating_client), - _ = Depends(verify_api_key) + _=Depends(verify_api_key), ): """ OpenAI-compatible endpoint powered by the RotatingClient. @@ -749,16 +1448,24 @@ async def chat_completions( # instead of actual schemas, which can cause tool hallucination # Modes: "remove" = delete temperature key, "set" = change to 1.0, "false" = disabled override_temp_zero = os.getenv("OVERRIDE_TEMPERATURE_ZERO", "false").lower() - - if override_temp_zero in ("remove", "set", "true", "1", "yes") and "temperature" in request_data and request_data["temperature"] == 0: + + if ( + override_temp_zero in ("remove", "set", "true", "1", "yes") + and "temperature" in request_data + and request_data["temperature"] == 0 + ): if override_temp_zero == "remove": # Remove temperature key entirely del request_data["temperature"] - logging.debug("OVERRIDE_TEMPERATURE_ZERO=remove: Removed temperature=0 from request") + logging.debug( + "OVERRIDE_TEMPERATURE_ZERO=remove: Removed temperature=0 from request" + ) else: # Set to 1.0 (for "set", "true", "1", "yes") request_data["temperature"] = 1.0 - logging.debug("OVERRIDE_TEMPERATURE_ZERO=set: Converting temperature=0 to temperature=1.0") + logging.debug( + "OVERRIDE_TEMPERATURE_ZERO=set: Converting temperature=0 to temperature=1.0" + ) # If logging is enabled, perform all logging operations using the parsed data. if logger: @@ -766,9 +1473,17 @@ async def chat_completions( # Extract and log specific reasoning parameters for monitoring. model = request_data.get("model") - generation_cfg = request_data.get("generationConfig", {}) or request_data.get("generation_config", {}) or {} - reasoning_effort = request_data.get("reasoning_effort") or generation_cfg.get("reasoning_effort") - custom_reasoning_budget = request_data.get("custom_reasoning_budget") or generation_cfg.get("custom_reasoning_budget", False) + generation_cfg = ( + request_data.get("generationConfig", {}) + or request_data.get("generation_config", {}) + or {} + ) + reasoning_effort = request_data.get("reasoning_effort") or generation_cfg.get( + "reasoning_effort" + ) + custom_reasoning_budget = request_data.get( + "custom_reasoning_budget" + ) or generation_cfg.get("custom_reasoning_budget", False) logging.getLogger("rotator_library").debug( f"Handling reasoning parameters: model={model}, reasoning_effort={reasoning_effort}, custom_reasoning_budget={custom_reasoning_budget}" @@ -779,31 +1494,41 @@ async def chat_completions( url=str(request.url), headers=dict(request.headers), client_info=(request.client.host, request.client.port), - request_data=request_data + request_data=request_data, ) is_streaming = request_data.get("stream", False) if is_streaming: response_generator = client.acompletion(request=request, **request_data) return StreamingResponse( - streaming_response_wrapper(request, request_data, response_generator, logger), - media_type="text/event-stream" + streaming_response_wrapper( + request, request_data, response_generator, logger + ), + media_type="text/event-stream", ) else: response = await client.acompletion(request=request, **request_data) if logger: # Assuming response has status_code and headers attributes # This might need adjustment based on the actual response object - response_headers = response.headers if hasattr(response, 'headers') else None - status_code = response.status_code if hasattr(response, 'status_code') else 200 + response_headers = ( + response.headers if hasattr(response, "headers") else None + ) + status_code = ( + response.status_code if hasattr(response, "status_code") else 200 + ) logger.log_final_response( status_code=status_code, headers=response_headers, - body=response.model_dump() + body=response.model_dump(), ) return response - except (litellm.InvalidRequestError, ValueError, litellm.ContextWindowExceededError) as e: + except ( + litellm.InvalidRequestError, + ValueError, + litellm.ContextWindowExceededError, + ) as e: raise HTTPException(status_code=400, detail=f"Invalid Request: {str(e)}") except litellm.AuthenticationError as e: raise HTTPException(status_code=401, detail=f"Authentication Error: {str(e)}") @@ -824,16 +1549,197 @@ async def chat_completions( except json.JSONDecodeError: request_data = {"error": "Could not parse request body"} if logger: - logger.log_final_response(status_code=500, headers=None, body={"error": str(e)}) + logger.log_final_response( + status_code=500, headers=None, body={"error": str(e)} + ) raise HTTPException(status_code=500, detail=str(e)) + +# --- Anthropic Messages API Endpoint --- +@app.post("/v1/messages") +async def anthropic_messages( + request: Request, + body: AnthropicMessagesRequest, + client: RotatingClient = Depends(get_rotating_client), + _=Depends(verify_anthropic_api_key), +): + """ + Anthropic-compatible Messages API endpoint. + + Accepts requests in Anthropic's format and returns responses in Anthropic's format. + Internally translates to OpenAI format for processing via LiteLLM. + + This endpoint is compatible with Claude Code and other Anthropic API clients. + """ + request_id = f"msg_{uuid.uuid4().hex[:24]}" + original_model = body.model + + # Initialize logger if enabled + logger = DetailedLogger() if ENABLE_REQUEST_LOGGING else None + + try: + # Convert Anthropic request to OpenAI format + anthropic_request = body.model_dump(exclude_none=True) + + openai_messages = anthropic_to_openai_messages( + anthropic_request.get("messages", []), anthropic_request.get("system") + ) + + openai_tools = anthropic_to_openai_tools(anthropic_request.get("tools")) + openai_tool_choice = anthropic_to_openai_tool_choice( + anthropic_request.get("tool_choice") + ) + + # Build OpenAI-compatible request + openai_request = { + "model": body.model, + "messages": openai_messages, + "max_tokens": body.max_tokens, + "stream": body.stream or False, + } + + if body.temperature is not None: + openai_request["temperature"] = body.temperature + if body.top_p is not None: + openai_request["top_p"] = body.top_p + if body.stop_sequences: + openai_request["stop"] = body.stop_sequences + if openai_tools: + openai_request["tools"] = openai_tools + if openai_tool_choice: + openai_request["tool_choice"] = openai_tool_choice + + # Handle Anthropic thinking config -> reasoning_effort translation + if body.thinking: + if body.thinking.type == "enabled": + # Map budget_tokens to reasoning_effort level + # Default to "medium" if enabled but budget not specified + budget = body.thinking.budget_tokens or 10000 + if budget >= 32000: + openai_request["reasoning_effort"] = "high" + openai_request["custom_reasoning_budget"] = True + elif budget >= 10000: + openai_request["reasoning_effort"] = "high" + elif budget >= 5000: + openai_request["reasoning_effort"] = "medium" + else: + openai_request["reasoning_effort"] = "low" + elif body.thinking.type == "disabled": + openai_request["reasoning_effort"] = "disable" + elif "opus" in body.model.lower(): + # Force high thinking for Opus models when no thinking config is provided + # Opus 4.5 always uses the -thinking variant, so we want maximum thinking budget + # Without this, the backend defaults to thinkingBudget: -1 (auto) instead of high + openai_request["reasoning_effort"] = "high" + openai_request["custom_reasoning_budget"] = True + + log_request_to_console( + url=str(request.url), + headers=dict(request.headers), + client_info=( + request.client.host if request.client else "unknown", + request.client.port if request.client else 0, + ), + request_data=openai_request, + ) + + if body.stream: + # Streaming response - acompletion returns a generator for streaming + response_generator = client.acompletion(request=request, **openai_request) + + return StreamingResponse( + anthropic_streaming_wrapper( + request, response_generator, original_model, request_id + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + else: + # Non-streaming response + response = await client.acompletion(request=request, **openai_request) + + # Convert OpenAI response to Anthropic format + openai_response = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) + ) + anthropic_response = openai_to_anthropic_response( + openai_response, original_model + ) + + # Override the ID with our request ID + anthropic_response["id"] = request_id + + if logger: + logger.log_final_response( + status_code=200, + headers=None, + body=anthropic_response, + ) + + return JSONResponse(content=anthropic_response) + + except ( + litellm.InvalidRequestError, + ValueError, + litellm.ContextWindowExceededError, + ) as e: + error_response = { + "type": "error", + "error": {"type": "invalid_request_error", "message": str(e)}, + } + raise HTTPException(status_code=400, detail=error_response) + except litellm.AuthenticationError as e: + error_response = { + "type": "error", + "error": {"type": "authentication_error", "message": str(e)}, + } + raise HTTPException(status_code=401, detail=error_response) + except litellm.RateLimitError as e: + error_response = { + "type": "error", + "error": {"type": "rate_limit_error", "message": str(e)}, + } + raise HTTPException(status_code=429, detail=error_response) + except (litellm.ServiceUnavailableError, litellm.APIConnectionError) as e: + error_response = { + "type": "error", + "error": {"type": "api_error", "message": str(e)}, + } + raise HTTPException(status_code=503, detail=error_response) + except litellm.Timeout as e: + error_response = { + "type": "error", + "error": {"type": "api_error", "message": f"Request timed out: {str(e)}"}, + } + raise HTTPException(status_code=504, detail=error_response) + except Exception as e: + logging.error(f"Anthropic messages endpoint error: {e}") + if logger: + logger.log_final_response( + status_code=500, + headers=None, + body={"error": str(e)}, + ) + error_response = { + "type": "error", + "error": {"type": "api_error", "message": str(e)}, + } + raise HTTPException(status_code=500, detail=error_response) + + @app.post("/v1/embeddings") async def embeddings( request: Request, body: EmbeddingRequest, client: RotatingClient = Depends(get_rotating_client), batcher: Optional[EmbeddingBatcher] = Depends(get_embedding_batcher), - _ = Depends(verify_api_key) + _=Depends(verify_api_key), ): """ OpenAI-compatible endpoint for creating embeddings. @@ -847,7 +1753,7 @@ async def embeddings( url=str(request.url), headers=dict(request.headers), client_info=(request.client.host, request.client.port), - request_data=request_data + request_data=request_data, ) if USE_EMBEDDING_BATCHER and batcher: # --- Server-Side Batching Logic --- @@ -861,7 +1767,7 @@ async def embeddings( individual_request = request_data.copy() individual_request["input"] = single_input tasks.append(batcher.add_request(individual_request)) - + results = await asyncio.gather(*tasks) all_data = [] @@ -877,16 +1783,19 @@ async def embeddings( "object": "list", "model": results[0]["model"], "data": all_data, - "usage": { "prompt_tokens": total_prompt_tokens, "total_tokens": total_tokens }, + "usage": { + "prompt_tokens": total_prompt_tokens, + "total_tokens": total_tokens, + }, } response = litellm.EmbeddingResponse(**final_response_data) - + else: # --- Direct Pass-Through Logic --- request_data = body.model_dump(exclude_none=True) if isinstance(request_data.get("input"), str): request_data["input"] = [request_data["input"]] - + response = await client.aembedding(request=request, **request_data) return response @@ -894,7 +1803,11 @@ async def embeddings( except HTTPException as e: # Re-raise HTTPException to ensure it's not caught by the generic Exception handler raise e - except (litellm.InvalidRequestError, ValueError, litellm.ContextWindowExceededError) as e: + except ( + litellm.InvalidRequestError, + ValueError, + litellm.ContextWindowExceededError, + ) as e: raise HTTPException(status_code=400, detail=f"Invalid Request: {str(e)}") except litellm.AuthenticationError as e: raise HTTPException(status_code=401, detail=f"Authentication Error: {str(e)}") @@ -910,10 +1823,12 @@ async def embeddings( logging.error(f"Embedding request failed: {e}") raise HTTPException(status_code=500, detail=str(e)) + @app.get("/") def read_root(): return {"Status": "API Key Proxy is running"} + @app.get("/v1/models") async def list_models( request: Request, @@ -923,22 +1838,30 @@ async def list_models( ): """ Returns a list of available models in the OpenAI-compatible format. - + Query Parameters: enriched: If True (default), returns detailed model info with pricing and capabilities. If False, returns minimal OpenAI-compatible response. """ model_ids = await client.get_all_available_models(grouped=False) - - if enriched and hasattr(request.app.state, 'model_info_service'): + + if enriched and hasattr(request.app.state, "model_info_service"): model_info_service = request.app.state.model_info_service if model_info_service.is_ready: # Return enriched model data enriched_data = model_info_service.enrich_model_list(model_ids) return {"object": "list", "data": enriched_data} - + # Fallback to basic model cards - model_cards = [{"id": model_id, "object": "model", "created": int(time.time()), "owned_by": "Mirro-Proxy"} for model_id in model_ids] + model_cards = [ + { + "id": model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "Mirro-Proxy", + } + for model_id in model_ids + ] return {"object": "list", "data": model_cards} @@ -950,17 +1873,17 @@ async def get_model( ): """ Returns detailed information about a specific model. - + Path Parameters: model_id: The model ID (e.g., "anthropic/claude-3-opus", "openrouter/openai/gpt-4") """ - if hasattr(request.app.state, 'model_info_service'): + if hasattr(request.app.state, "model_info_service"): model_info_service = request.app.state.model_info_service if model_info_service.is_ready: info = model_info_service.get_model_info(model_id) if info: return info.to_dict() - + # Return basic info if service not ready or model not found return { "id": model_id, @@ -978,7 +1901,7 @@ async def model_info_stats( """ Returns statistics about the model info service (for monitoring/debugging). """ - if hasattr(request.app.state, 'model_info_service'): + if hasattr(request.app.state, "model_info_service"): return request.app.state.model_info_service.get_stats() return {"error": "Model info service not initialized"} @@ -990,11 +1913,12 @@ async def list_providers(_=Depends(verify_api_key)): """ return list(PROVIDER_PLUGINS.keys()) + @app.post("/v1/token-count") async def token_count( - request: Request, + request: Request, client: RotatingClient = Depends(get_rotating_client), - _=Depends(verify_api_key) + _=Depends(verify_api_key), ): """ Calculates the token count for a given list of messages and a model. @@ -1005,7 +1929,9 @@ async def token_count( messages = data.get("messages") if not model or not messages: - raise HTTPException(status_code=400, detail="'model' and 'messages' are required.") + raise HTTPException( + status_code=400, detail="'model' and 'messages' are required." + ) count = client.token_count(**data) return {"token_count": count} @@ -1016,13 +1942,10 @@ async def token_count( @app.post("/v1/cost-estimate") -async def cost_estimate( - request: Request, - _=Depends(verify_api_key) -): +async def cost_estimate(request: Request, _=Depends(verify_api_key)): """ Estimates the cost for a request based on token counts and model pricing. - + Request body: { "model": "anthropic/claude-3-opus", @@ -1031,7 +1954,7 @@ async def cost_estimate( "cache_read_tokens": 0, # optional "cache_creation_tokens": 0 # optional } - + Returns: { "model": "anthropic/claude-3-opus", @@ -1051,25 +1974,28 @@ async def cost_estimate( completion_tokens = data.get("completion_tokens", 0) cache_read_tokens = data.get("cache_read_tokens", 0) cache_creation_tokens = data.get("cache_creation_tokens", 0) - + if not model: raise HTTPException(status_code=400, detail="'model' is required.") - + result = { "model": model, "cost": None, "currency": "USD", "pricing": {}, - "source": None + "source": None, } - + # Try model info service first - if hasattr(request.app.state, 'model_info_service'): + if hasattr(request.app.state, "model_info_service"): model_info_service = request.app.state.model_info_service if model_info_service.is_ready: cost = model_info_service.calculate_cost( - model, prompt_tokens, completion_tokens, - cache_read_tokens, cache_creation_tokens + model, + prompt_tokens, + completion_tokens, + cache_read_tokens, + cache_creation_tokens, ) if cost is not None: cost_info = model_info_service.get_cost_info(model) @@ -1077,31 +2003,32 @@ async def cost_estimate( result["pricing"] = cost_info or {} result["source"] = "model_info_service" return result - + # Fallback to litellm try: import litellm + # Create a mock response for cost calculation model_info = litellm.get_model_info(model) input_cost = model_info.get("input_cost_per_token", 0) output_cost = model_info.get("output_cost_per_token", 0) - + if input_cost or output_cost: cost = (prompt_tokens * input_cost) + (completion_tokens * output_cost) result["cost"] = cost result["pricing"] = { "input_cost_per_token": input_cost, - "output_cost_per_token": output_cost + "output_cost_per_token": output_cost, } result["source"] = "litellm_fallback" return result except Exception: pass - + result["source"] = "unknown" result["error"] = "Pricing data not available for this model" return result - + except HTTPException: raise except Exception as e: @@ -1112,17 +2039,18 @@ async def cost_estimate( if __name__ == "__main__": # Define ENV_FILE for onboarding checks ENV_FILE = Path.cwd() / ".env" - + # Check if launcher TUI should be shown (no arguments provided) if len(sys.argv) == 1: # No arguments - show launcher TUI (lazy import) from proxy_app.launcher_tui import run_launcher_tui + run_launcher_tui() # Launcher modifies sys.argv and returns, or exits if user chose Exit # If we get here, user chose "Run Proxy" and sys.argv is modified # Re-parse arguments with modified sys.argv args = parser.parse_args() - + def needs_onboarding() -> bool: """ Check if the proxy needs onboarding (first-time setup). @@ -1132,40 +2060,49 @@ def needs_onboarding() -> bool: # PROXY_API_KEY is optional (will show warning if not set) if not ENV_FILE.is_file(): return True - + return False def show_onboarding_message(): """Display clear explanatory message for why onboarding is needed.""" - os.system('cls' if os.name == 'nt' else 'clear') # Clear terminal for clean presentation - console.print(Panel.fit( - "[bold cyan]🚀 LLM API Key Proxy - First Time Setup[/bold cyan]", - border_style="cyan" - )) + os.system( + "cls" if os.name == "nt" else "clear" + ) # Clear terminal for clean presentation + console.print( + Panel.fit( + "[bold cyan]🚀 LLM API Key Proxy - First Time Setup[/bold cyan]", + border_style="cyan", + ) + ) console.print("[bold yellow]⚠️ Configuration Required[/bold yellow]\n") - + console.print("The proxy needs initial configuration:") console.print(" [red]❌ No .env file found[/red]") - + console.print("\n[bold]Why this matters:[/bold]") console.print(" • The .env file stores your credentials and settings") console.print(" • PROXY_API_KEY protects your proxy from unauthorized access") console.print(" • Provider API keys enable LLM access") - + console.print("\n[bold]What happens next:[/bold]") console.print(" 1. We'll create a .env file with PROXY_API_KEY") console.print(" 2. You can add LLM provider credentials (API keys or OAuth)") console.print(" 3. The proxy will then start normally") - - console.print("\n[bold yellow]⚠️ Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default.") + + console.print( + "\n[bold yellow]⚠️ Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default." + ) console.print(" You can remove it later if you want an unsecured proxy.\n") - - console.input("[bold green]Press Enter to launch the credential setup tool...[/bold green]") + + console.input( + "[bold green]Press Enter to launch the credential setup tool...[/bold green]" + ) # Check if user explicitly wants to add credentials if args.add_credential: # Import and call ensure_env_defaults to create .env and PROXY_API_KEY if needed from rotator_library.credential_tool import ensure_env_defaults + ensure_env_defaults() # Reload environment variables after ensure_env_defaults creates/updates .env load_dotenv(override=True) @@ -1176,36 +2113,41 @@ def show_onboarding_message(): # Import console from rich for better messaging from rich.console import Console from rich.panel import Panel + console = Console() - + # Show clear explanatory message show_onboarding_message() - + # Launch credential tool automatically from rotator_library.credential_tool import ensure_env_defaults + ensure_env_defaults() load_dotenv(override=True) run_credential_tool() - + # After credential tool exits, reload and re-check load_dotenv(override=True) # Re-read PROXY_API_KEY from environment PROXY_API_KEY = os.getenv("PROXY_API_KEY") - + # Verify onboarding is complete if needs_onboarding(): console.print("\n[bold red]❌ Configuration incomplete.[/bold red]") - console.print("The proxy still cannot start. Please ensure PROXY_API_KEY is set in .env\n") + console.print( + "The proxy still cannot start. Please ensure PROXY_API_KEY is set in .env\n" + ) sys.exit(1) else: console.print("\n[bold green]✅ Configuration complete![/bold green]") console.print("\nStarting proxy server...\n") - + # Validate PROXY_API_KEY before starting the server if not PROXY_API_KEY: - raise ValueError("PROXY_API_KEY environment variable not set. Please run with --add-credential to set up your environment.") - - import uvicorn - uvicorn.run(app, host=args.host, port=args.port) + raise ValueError( + "PROXY_API_KEY environment variable not set. Please run with --add-credential to set up your environment." + ) + import uvicorn + uvicorn.run(app, host=args.host, port=args.port) diff --git a/src/rotator_library/background_refresher.py b/src/rotator_library/background_refresher.py index 4c1fc26f..a6830fa8 100644 --- a/src/rotator_library/background_refresher.py +++ b/src/rotator_library/background_refresher.py @@ -8,28 +8,35 @@ if TYPE_CHECKING: from .client import RotatingClient -lib_logger = logging.getLogger('rotator_library') +lib_logger = logging.getLogger("rotator_library") + class BackgroundRefresher: """ A background task that periodically checks and refreshes OAuth tokens to ensure they remain valid. """ - def __init__(self, client: 'RotatingClient'): + + def __init__(self, client: "RotatingClient"): try: interval_str = os.getenv("OAUTH_REFRESH_INTERVAL", "600") self._interval = int(interval_str) except ValueError: - lib_logger.warning(f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 600s.") + lib_logger.warning( + f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 600s." + ) self._interval = 600 self._client = client self._task: Optional[asyncio.Task] = None + self._initialized = False def start(self): """Starts the background refresh task.""" if self._task is None: self._task = asyncio.create_task(self._run()) - lib_logger.info(f"Background token refresher started. Check interval: {self._interval} seconds.") + lib_logger.info( + f"Background token refresher started. Check interval: {self._interval} seconds." + ) # [NEW] Log if custom interval is set async def stop(self): @@ -42,23 +49,107 @@ async def stop(self): pass lib_logger.info("Background token refresher stopped.") + async def _initialize_credentials(self): + """ + Initialize all providers by loading credentials and persisted tier data. + Called once before the main refresh loop starts. + """ + if self._initialized: + return + + api_summary = {} # provider -> count + oauth_summary = {} # provider -> {"count": N, "tiers": {tier: count}} + + all_credentials = self._client.all_credentials + oauth_providers = self._client.oauth_providers + + for provider, credentials in all_credentials.items(): + if not credentials: + continue + + provider_plugin = self._client._get_provider_instance(provider) + + # Call initialize_credentials if provider supports it + if provider_plugin and hasattr(provider_plugin, "initialize_credentials"): + try: + await provider_plugin.initialize_credentials(credentials) + except Exception as e: + lib_logger.error( + f"Error initializing credentials for provider '{provider}': {e}" + ) + + # Build summary based on provider type + if provider in oauth_providers: + tier_breakdown = {} + if provider_plugin and hasattr( + provider_plugin, "get_credential_tier_name" + ): + for cred in credentials: + tier = provider_plugin.get_credential_tier_name(cred) + if tier: + tier_breakdown[tier] = tier_breakdown.get(tier, 0) + 1 + oauth_summary[provider] = { + "count": len(credentials), + "tiers": tier_breakdown, + } + else: + api_summary[provider] = len(credentials) + + # Log 3-line summary + total_providers = len(api_summary) + len(oauth_summary) + total_credentials = sum(api_summary.values()) + sum( + d["count"] for d in oauth_summary.values() + ) + + if total_providers > 0: + lib_logger.info( + f"Providers initialized: {total_providers} providers, {total_credentials} credentials" + ) + + # API providers line + if api_summary: + api_parts = [f"{p}:{c}" for p, c in sorted(api_summary.items())] + lib_logger.info(f" API: {', '.join(api_parts)}") + + # OAuth providers line with tier breakdown + if oauth_summary: + oauth_parts = [] + for provider, data in sorted(oauth_summary.items()): + if data["tiers"]: + tier_str = ", ".join( + f"{t}:{c}" for t, c in sorted(data["tiers"].items()) + ) + oauth_parts.append(f"{provider}:{data['count']} ({tier_str})") + else: + oauth_parts.append(f"{provider}:{data['count']}") + lib_logger.info(f" OAuth: {', '.join(oauth_parts)}") + + self._initialized = True + async def _run(self): """The main loop for the background task.""" + # Initialize credentials (load persisted tiers) before starting the refresh loop + await self._initialize_credentials() + while True: try: - #lib_logger.info("Running proactive token refresh check...") + # lib_logger.info("Running proactive token refresh check...") oauth_configs = self._client.get_oauth_credentials() for provider, paths in oauth_configs.items(): - provider_plugin = self._client._get_provider_instance(f"{provider}_oauth") - if provider_plugin and hasattr(provider_plugin, 'proactively_refresh'): + provider_plugin = self._client._get_provider_instance(provider) + if provider_plugin and hasattr( + provider_plugin, "proactively_refresh" + ): for path in paths: try: await provider_plugin.proactively_refresh(path) except Exception as e: - lib_logger.error(f"Error during proactive refresh for '{path}': {e}") + lib_logger.error( + f"Error during proactive refresh for '{path}': {e}" + ) await asyncio.sleep(self._interval) except asyncio.CancelledError: break except Exception as e: - lib_logger.error(f"Unexpected error in background refresher loop: {e}") \ No newline at end of file + lib_logger.error(f"Unexpected error in background refresher loop: {e}") diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py index cf1bb1cf..befa39ed 100644 --- a/src/rotator_library/client.py +++ b/src/rotator_library/client.py @@ -447,12 +447,23 @@ def _get_provider_instance(self, provider_name: str): Args: provider_name: The name of the provider to get an instance for. + For OAuth providers, this may include "_oauth" suffix + (e.g., "antigravity_oauth"), but credentials are stored + under the base name (e.g., "antigravity"). Returns: Provider instance if credentials exist, None otherwise. """ + # For OAuth providers, credentials are stored under base name (without _oauth suffix) + # e.g., "antigravity_oauth" plugin → credentials under "antigravity" + credential_key = provider_name + if provider_name.endswith("_oauth"): + base_name = provider_name[:-6] # Remove "_oauth" + if base_name in self.oauth_providers: + credential_key = base_name + # Only initialize providers for which we have credentials - if provider_name not in self.all_credentials: + if credential_key not in self.all_credentials: lib_logger.debug( f"Skipping provider '{provider_name}' initialization: no credentials configured" ) @@ -824,13 +835,20 @@ async def _execute_with_retry( f"Request will likely fail." ) - # Build priority map for usage_manager + # Build priority map and tier names map for usage_manager + credential_tier_names = None if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): credential_priorities = {} + credential_tier_names = {} for cred in credentials_for_provider: priority = provider_plugin.get_credential_priority(cred) if priority is not None: credential_priorities[cred] = priority + # Also get tier name for logging + if hasattr(provider_plugin, "get_credential_tier_name"): + tier_name = provider_plugin.get_credential_tier_name(cred) + if tier_name: + credential_tier_names[cred] = tier_name if credential_priorities: lib_logger.debug( @@ -883,6 +901,7 @@ async def _execute_with_retry( deadline=deadline, max_concurrent=max_concurrent, credential_priorities=credential_priorities, + credential_tier_names=credential_tier_names, ) key_acquired = True tried_creds.add(current_cred) @@ -1371,13 +1390,20 @@ async def _streaming_acompletion_with_retry( f"Request will likely fail." ) - # Build priority map for usage_manager + # Build priority map and tier names map for usage_manager + credential_tier_names = None if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): credential_priorities = {} + credential_tier_names = {} for cred in credentials_for_provider: priority = provider_plugin.get_credential_priority(cred) if priority is not None: credential_priorities[cred] = priority + # Also get tier name for logging + if hasattr(provider_plugin, "get_credential_tier_name"): + tier_name = provider_plugin.get_credential_tier_name(cred) + if tier_name: + credential_tier_names[cred] = tier_name if credential_priorities: lib_logger.debug( @@ -1433,6 +1459,7 @@ async def _streaming_acompletion_with_retry( deadline=deadline, max_concurrent=max_concurrent, credential_priorities=credential_priorities, + credential_tier_names=credential_tier_names, ) key_acquired = True tried_creds.add(current_cred) diff --git a/src/rotator_library/model_definitions.py b/src/rotator_library/model_definitions.py index 12219bcd..cb2aabf6 100644 --- a/src/rotator_library/model_definitions.py +++ b/src/rotator_library/model_definitions.py @@ -24,10 +24,23 @@ class ModelDefinitions: - IFLOW_MODELS='{"glm-4.6": {}}' - dict format, uses "glm-4.6" as both name and ID - IFLOW_MODELS='{"custom-name": {"id": "actual-id"}}' - dict format with custom ID - IFLOW_MODELS='{"model": {"id": "id", "options": {"temperature": 0.7}}}' - with options + + This class is a singleton - instantiated once and shared across all providers. """ + _instance: Optional["ModelDefinitions"] = None + _initialized: bool = False + + def __new__(cls, config_path: Optional[str] = None): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + def __init__(self, config_path: Optional[str] = None): - """Initialize model definitions loader.""" + """Initialize model definitions loader (only runs once due to singleton).""" + if ModelDefinitions._initialized: + return + ModelDefinitions._initialized = True self.config_path = config_path self.definitions = {} self._load_definitions() @@ -49,7 +62,11 @@ def _load_definitions(self): # Handle array format: ["model-1", "model-2", "model-3"] elif isinstance(models_json, list): # Convert array to dict format with empty definitions - models_dict = {model_name: {} for model_name in models_json if isinstance(model_name, str)} + models_dict = { + model_name: {} + for model_name in models_json + if isinstance(model_name, str) + } self.definitions[provider_name] = models_dict lib_logger.info( f"Loaded {len(models_dict)} models for provider: {provider_name} (array format)" diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index b8226a8a..7ed85f4b 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -595,6 +595,11 @@ def get_credential_priority(self, credential: str) -> Optional[int]: Priority level (1-10) or None if tier not yet discovered """ tier = self.project_tier_cache.get(credential) + + # Lazy load from file if not in cache + if not tier: + tier = self._load_tier_from_file(credential) + if not tier: return None # Not yet discovered @@ -609,6 +614,60 @@ def get_credential_priority(self, credential: str) -> Optional[int]: # Legacy and unknown get even lower return 10 + def _load_tier_from_file(self, credential_path: str) -> Optional[str]: + """ + Load tier from credential file's _proxy_metadata and cache it. + + This is used as a fallback when the tier isn't in the memory cache, + typically on first access before initialize_credentials() has run. + + Args: + credential_path: Path to the credential file + + Returns: + Tier string if found, None otherwise + """ + # Skip env:// paths (environment-based credentials) + if self._parse_env_credential_path(credential_path) is not None: + return None + + try: + with open(credential_path, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + tier = metadata.get("tier") + project_id = metadata.get("project_id") + + if tier: + self.project_tier_cache[credential_path] = tier + lib_logger.debug( + f"Lazy-loaded tier '{tier}' for credential: {Path(credential_path).name}" + ) + + if project_id and credential_path not in self.project_id_cache: + self.project_id_cache[credential_path] = project_id + + return tier + except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: + lib_logger.debug(f"Could not lazy-load tier from {credential_path}: {e}") + return None + + def get_credential_tier_name(self, credential: str) -> Optional[str]: + """ + Returns the human-readable tier name for a credential. + + Args: + credential: The credential path + + Returns: + Tier name string (e.g., "free-tier") or None if unknown + """ + tier = self.project_tier_cache.get(credential) + if not tier: + tier = self._load_tier_from_file(credential) + return tier + def get_model_tier_requirement(self, model: str) -> Optional[int]: """ Returns the minimum priority tier required for a model. @@ -622,6 +681,72 @@ def get_model_tier_requirement(self, model: str) -> Optional[int]: """ return None + async def initialize_credentials(self, credential_paths: List[str]) -> None: + """ + Load persisted tier information from credential files at startup. + + This ensures all credential priorities are known before any API calls, + preventing unknown credentials from getting priority 999. + """ + await self._load_persisted_tiers(credential_paths) + + async def _load_persisted_tiers( + self, credential_paths: List[str] + ) -> Dict[str, str]: + """ + Load persisted tier information from credential files into memory cache. + + Args: + credential_paths: List of credential file paths + + Returns: + Dict mapping credential path to tier name for logging purposes + """ + loaded = {} + for path in credential_paths: + # Skip env:// paths (environment-based credentials) + if self._parse_env_credential_path(path) is not None: + continue + + # Skip if already in cache + if path in self.project_tier_cache: + continue + + try: + with open(path, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + tier = metadata.get("tier") + project_id = metadata.get("project_id") + + if tier: + self.project_tier_cache[path] = tier + loaded[path] = tier + lib_logger.debug( + f"Loaded persisted tier '{tier}' for credential: {Path(path).name}" + ) + + if project_id: + self.project_id_cache[path] = project_id + + except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: + lib_logger.debug(f"Could not load persisted tier from {path}: {e}") + + if loaded: + # Log summary at debug level + tier_counts: Dict[str, int] = {} + for tier in loaded.values(): + tier_counts[tier] = tier_counts.get(tier, 0) + 1 + lib_logger.debug( + f"Antigravity: Loaded {len(loaded)} credential tiers from disk: " + + ", ".join( + f"{tier}={count}" for tier, count in sorted(tier_counts.items()) + ) + ) + + return loaded + # ========================================================================= # MODEL UTILITIES # ========================================================================= diff --git a/src/rotator_library/providers/gemini_cli_provider.py b/src/rotator_library/providers/gemini_cli_provider.py index 259fb831..e4109ef9 100644 --- a/src/rotator_library/providers/gemini_cli_provider.py +++ b/src/rotator_library/providers/gemini_cli_provider.py @@ -19,13 +19,15 @@ import uuid from datetime import datetime -lib_logger = logging.getLogger('rotator_library') +lib_logger = logging.getLogger("rotator_library") LOGS_DIR = Path(__file__).resolve().parent.parent.parent.parent / "logs" GEMINI_CLI_LOGS_DIR = LOGS_DIR / "gemini_cli_logs" + class _GeminiCliFileLogger: """A simple file logger for a single Gemini CLI transaction.""" + def __init__(self, model_name: str, enabled: bool = True): self.enabled = enabled if not self.enabled: @@ -34,8 +36,10 @@ def __init__(self, model_name: str, enabled: bool = True): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") request_id = str(uuid.uuid4()) # Sanitize model name for directory - safe_model_name = model_name.replace('/', '_').replace(':', '_') - self.log_dir = GEMINI_CLI_LOGS_DIR / f"{timestamp}_{safe_model_name}_{request_id}" + safe_model_name = model_name.replace("/", "_").replace(":", "_") + self.log_dir = ( + GEMINI_CLI_LOGS_DIR / f"{timestamp}_{safe_model_name}_{request_id}" + ) try: self.log_dir.mkdir(parents=True, exist_ok=True) except Exception as e: @@ -44,25 +48,32 @@ def __init__(self, model_name: str, enabled: bool = True): def log_request(self, payload: Dict[str, Any]): """Logs the request payload sent to Gemini.""" - if not self.enabled: return + if not self.enabled: + return try: - with open(self.log_dir / "request_payload.json", "w", encoding="utf-8") as f: + with open( + self.log_dir / "request_payload.json", "w", encoding="utf-8" + ) as f: json.dump(payload, f, indent=2, ensure_ascii=False) except Exception as e: lib_logger.error(f"_GeminiCliFileLogger: Failed to write request: {e}") def log_response_chunk(self, chunk: str): """Logs a raw chunk from the Gemini response stream.""" - if not self.enabled: return + if not self.enabled: + return try: with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f: f.write(chunk + "\n") except Exception as e: - lib_logger.error(f"_GeminiCliFileLogger: Failed to write response chunk: {e}") + lib_logger.error( + f"_GeminiCliFileLogger: Failed to write response chunk: {e}" + ) def log_error(self, error_message: str): """Logs an error message.""" - if not self.enabled: return + if not self.enabled: + return try: with open(self.log_dir / "error.log", "a", encoding="utf-8") as f: f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n") @@ -71,12 +82,16 @@ def log_error(self, error_message: str): def log_final_response(self, response_data: Dict[str, Any]): """Logs the final, reassembled response.""" - if not self.enabled: return + if not self.enabled: + return try: with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f: json.dump(response_data, f, indent=2, ensure_ascii=False) except Exception as e: - lib_logger.error(f"_GeminiCliFileLogger: Failed to write final response: {e}") + lib_logger.error( + f"_GeminiCliFileLogger: Failed to write final response: {e}" + ) + CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal" @@ -84,11 +99,13 @@ def log_final_response(self, response_data: Dict[str, Any]): "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite", - "gemini-3-pro-preview" + "gemini-3-pro-preview", ] # Cache directory for Gemini CLI -CACHE_DIR = Path(__file__).resolve().parent.parent.parent.parent / "cache" / "gemini_cli" +CACHE_DIR = ( + Path(__file__).resolve().parent.parent.parent.parent / "cache" / "gemini_cli" +) GEMINI3_SIGNATURE_CACHE_FILE = CACHE_DIR / "gemini3_signatures.json" # Gemini 3 tool fix system instruction (prevents hallucination) @@ -172,36 +189,49 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface): def __init__(self): super().__init__() self.model_definitions = ModelDefinitions() - self.project_id_cache: Dict[str, str] = {} # Cache project ID per credential path - self.project_tier_cache: Dict[str, str] = {} # Cache project tier per credential path - + self.project_id_cache: Dict[ + str, str + ] = {} # Cache project ID per credential path + self.project_tier_cache: Dict[ + str, str + ] = {} # Cache project tier per credential path + # Gemini 3 configuration from environment memory_ttl = _env_int("GEMINI_CLI_SIGNATURE_CACHE_TTL", 3600) disk_ttl = _env_int("GEMINI_CLI_SIGNATURE_DISK_TTL", 86400) - + # Initialize signature cache for Gemini 3 thoughtSignatures self._signature_cache = ProviderCache( - GEMINI3_SIGNATURE_CACHE_FILE, memory_ttl, disk_ttl, - env_prefix="GEMINI_CLI_SIGNATURE" + GEMINI3_SIGNATURE_CACHE_FILE, + memory_ttl, + disk_ttl, + env_prefix="GEMINI_CLI_SIGNATURE", ) - + # Gemini 3 feature flags - self._preserve_signatures_in_client = _env_bool("GEMINI_CLI_PRESERVE_THOUGHT_SIGNATURES", True) - self._enable_signature_cache = _env_bool("GEMINI_CLI_ENABLE_SIGNATURE_CACHE", True) + self._preserve_signatures_in_client = _env_bool( + "GEMINI_CLI_PRESERVE_THOUGHT_SIGNATURES", True + ) + self._enable_signature_cache = _env_bool( + "GEMINI_CLI_ENABLE_SIGNATURE_CACHE", True + ) self._enable_gemini3_tool_fix = _env_bool("GEMINI_CLI_GEMINI3_TOOL_FIX", True) - self._gemini3_enforce_strict_schema = _env_bool("GEMINI_CLI_GEMINI3_STRICT_SCHEMA", True) - + self._gemini3_enforce_strict_schema = _env_bool( + "GEMINI_CLI_GEMINI3_STRICT_SCHEMA", True + ) + # Gemini 3 tool fix configuration - self._gemini3_tool_prefix = os.getenv("GEMINI_CLI_GEMINI3_TOOL_PREFIX", "gemini3_") + self._gemini3_tool_prefix = os.getenv( + "GEMINI_CLI_GEMINI3_TOOL_PREFIX", "gemini3_" + ) self._gemini3_description_prompt = os.getenv( "GEMINI_CLI_GEMINI3_DESCRIPTION_PROMPT", - "\n\n⚠️ STRICT PARAMETERS (use EXACTLY as shown): {params}. Do NOT use parameters from your training data - use ONLY these parameter names." + "\n\n⚠️ STRICT PARAMETERS (use EXACTLY as shown): {params}. Do NOT use parameters from your training data - use ONLY these parameter names.", ) self._gemini3_system_instruction = os.getenv( - "GEMINI_CLI_GEMINI3_SYSTEM_INSTRUCTION", - DEFAULT_GEMINI3_SYSTEM_INSTRUCTION + "GEMINI_CLI_GEMINI3_SYSTEM_INSTRUCTION", DEFAULT_GEMINI3_SYSTEM_INSTRUCTION ) - + lib_logger.debug( f"GeminiCli config: signatures_in_client={self._preserve_signatures_in_client}, " f"cache={self._enable_signature_cache}, gemini3_fix={self._enable_gemini3_tool_fix}, " @@ -211,75 +241,200 @@ def __init__(self): # ========================================================================= # CREDENTIAL PRIORITIZATION # ========================================================================= - + def get_credential_priority(self, credential: str) -> Optional[int]: """ Returns priority based on Gemini tier. Paid tiers: priority 1 (highest) Free/Legacy tiers: priority 2 Unknown: priority 10 (lowest) - + Args: credential: The credential path - + Returns: Priority level (1-10) or None if tier not yet discovered """ tier = self.project_tier_cache.get(credential) + + # Lazy load from file if not in cache + if not tier: + tier = self._load_tier_from_file(credential) + if not tier: return None # Not yet discovered - + # Paid tiers get highest priority - if tier not in ['free-tier', 'legacy-tier', 'unknown']: + if tier not in ["free-tier", "legacy-tier", "unknown"]: return 1 - + # Free tier gets lower priority - if tier == 'free-tier': + if tier == "free-tier": return 2 - + # Legacy and unknown get even lower return 10 - + + def _load_tier_from_file(self, credential_path: str) -> Optional[str]: + """ + Load tier from credential file's _proxy_metadata and cache it. + + This is used as a fallback when the tier isn't in the memory cache, + typically on first access before initialize_credentials() has run. + + Args: + credential_path: Path to the credential file + + Returns: + Tier string if found, None otherwise + """ + # Skip env:// paths (environment-based credentials) + if self._parse_env_credential_path(credential_path) is not None: + return None + + try: + with open(credential_path, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + tier = metadata.get("tier") + project_id = metadata.get("project_id") + + if tier: + self.project_tier_cache[credential_path] = tier + lib_logger.debug( + f"Lazy-loaded tier '{tier}' for credential: {Path(credential_path).name}" + ) + + if project_id and credential_path not in self.project_id_cache: + self.project_id_cache[credential_path] = project_id + + return tier + except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: + lib_logger.debug(f"Could not lazy-load tier from {credential_path}: {e}") + return None + + def get_credential_tier_name(self, credential: str) -> Optional[str]: + """ + Returns the human-readable tier name for a credential. + + Args: + credential: The credential path + + Returns: + Tier name string (e.g., "free-tier") or None if unknown + """ + tier = self.project_tier_cache.get(credential) + if not tier: + tier = self._load_tier_from_file(credential) + return tier + def get_model_tier_requirement(self, model: str) -> Optional[int]: """ Returns the minimum priority tier required for a model. Gemini 3 requires paid tier (priority 1). - + Args: model: The model name (with or without provider prefix) - + Returns: Minimum required priority level or None if no restrictions """ - model_name = model.split('/')[-1].replace(':thinking', '') - + model_name = model.split("/")[-1].replace(":thinking", "") + # Gemini 3 requires paid tier if model_name.startswith("gemini-3-"): return 1 # Only priority 1 (paid) credentials - + return None # All other models have no restrictions + async def initialize_credentials(self, credential_paths: List[str]) -> None: + """ + Load persisted tier information from credential files at startup. + This ensures all credential priorities are known before any API calls, + preventing unknown credentials from getting priority 999. + """ + await self._load_persisted_tiers(credential_paths) + + async def _load_persisted_tiers( + self, credential_paths: List[str] + ) -> Dict[str, str]: + """ + Load persisted tier information from credential files into memory cache. + + Args: + credential_paths: List of credential file paths + + Returns: + Dict mapping credential path to tier name for logging purposes + """ + loaded = {} + for path in credential_paths: + # Skip env:// paths (environment-based credentials) + if self._parse_env_credential_path(path) is not None: + continue + + # Skip if already in cache + if path in self.project_tier_cache: + continue + + try: + with open(path, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + tier = metadata.get("tier") + project_id = metadata.get("project_id") + + if tier: + self.project_tier_cache[path] = tier + loaded[path] = tier + lib_logger.debug( + f"Loaded persisted tier '{tier}' for credential: {Path(path).name}" + ) + + if project_id: + self.project_id_cache[path] = project_id + + except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: + lib_logger.debug(f"Could not load persisted tier from {path}: {e}") + + if loaded: + # Log summary at debug level + tier_counts: Dict[str, int] = {} + for tier in loaded.values(): + tier_counts[tier] = tier_counts.get(tier, 0) + 1 + lib_logger.debug( + f"GeminiCli: Loaded {len(loaded)} credential tiers from disk: " + + ", ".join( + f"{tier}={count}" for tier, count in sorted(tier_counts.items()) + ) + ) + + return loaded # ========================================================================= # MODEL UTILITIES # ========================================================================= - + def _is_gemini_3(self, model: str) -> bool: """Check if model is Gemini 3 (requires special handling).""" - model_name = model.split('/')[-1].replace(':thinking', '') + model_name = model.split("/")[-1].replace(":thinking", "") return model_name.startswith("gemini-3-") - + def _strip_gemini3_prefix(self, name: str) -> str: """Strip the Gemini 3 namespace prefix from a tool name.""" if name and name.startswith(self._gemini3_tool_prefix): - return name[len(self._gemini3_tool_prefix):] + return name[len(self._gemini3_tool_prefix) :] return name - async def _discover_project_id(self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]) -> str: + async def _discover_project_id( + self, credential_path: str, access_token: str, litellm_params: Dict[str, Any] + ) -> str: """ Discovers the Google Cloud Project ID, with caching and onboarding for new accounts. - + This follows the official Gemini CLI discovery flow: 1. Check in-memory cache 2. Check configured project_id override (litellm_params or env var) @@ -293,7 +448,9 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li - PAID tier: pass cloudaicompanionProject=configured_project_id 6. Fallback to GCP Resource Manager project listing """ - lib_logger.debug(f"Starting project discovery for credential: {credential_path}") + lib_logger.debug( + f"Starting project discovery for credential: {credential_path}" + ) # Check in-memory cache first if credential_path in self.project_id_cache: @@ -305,7 +462,9 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li # This is REQUIRED for paid tier users per the official CLI behavior configured_project_id = litellm_params.get("project_id") if configured_project_id: - lib_logger.debug(f"Found configured project_id override: {configured_project_id}") + lib_logger.debug( + f"Found configured project_id override: {configured_project_id}" + ) # Load credentials from file to check for persisted project_id and tier # Skip for env:// paths (environment-based credentials don't persist to files) @@ -313,35 +472,44 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li if credential_index is None: # Only try to load from file if it's not an env:// path try: - with open(credential_path, 'r') as f: + with open(credential_path, "r") as f: creds = json.load(f) - + metadata = creds.get("_proxy_metadata", {}) persisted_project_id = metadata.get("project_id") persisted_tier = metadata.get("tier") - + if persisted_project_id: - lib_logger.info(f"Loaded persisted project ID from credential file: {persisted_project_id}") + lib_logger.info( + f"Loaded persisted project ID from credential file: {persisted_project_id}" + ) self.project_id_cache[credential_path] = persisted_project_id - + # Also load tier if available if persisted_tier: self.project_tier_cache[credential_path] = persisted_tier lib_logger.debug(f"Loaded persisted tier: {persisted_tier}") - + return persisted_project_id except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: lib_logger.debug(f"Could not load persisted project ID from file: {e}") - lib_logger.debug("No cached or configured project ID found, initiating discovery...") - headers = {'Authorization': f'Bearer {access_token}', 'Content-Type': 'application/json'} + lib_logger.debug( + "No cached or configured project ID found, initiating discovery..." + ) + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } discovered_project_id = None discovered_tier = None async with httpx.AsyncClient() as client: # 1. Try discovery endpoint with loadCodeAssist - lib_logger.debug("Attempting project discovery via Code Assist loadCodeAssist endpoint...") + lib_logger.debug( + "Attempting project discovery via Code Assist loadCodeAssist endpoint..." + ) try: # Build metadata - include duetProject only if we have a configured project core_client_metadata = { @@ -351,53 +519,65 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li } if configured_project_id: core_client_metadata["duetProject"] = configured_project_id - + # Build load request - pass configured_project_id if available, otherwise None load_request = { "cloudaicompanionProject": configured_project_id, # Can be None "metadata": core_client_metadata, } - - lib_logger.debug(f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}") - response = await client.post(f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist", headers=headers, json=load_request, timeout=20) + + lib_logger.debug( + f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}" + ) + response = await client.post( + f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist", + headers=headers, + json=load_request, + timeout=20, + ) response.raise_for_status() data = response.json() # Log full response for debugging - lib_logger.debug(f"loadCodeAssist full response keys: {list(data.keys())}") + lib_logger.debug( + f"loadCodeAssist full response keys: {list(data.keys())}" + ) # Extract and log ALL tier information for debugging - allowed_tiers = data.get('allowedTiers', []) - current_tier = data.get('currentTier') - + allowed_tiers = data.get("allowedTiers", []) + current_tier = data.get("currentTier") + lib_logger.debug(f"=== Tier Information ===") lib_logger.debug(f"currentTier: {current_tier}") lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}") for i, tier in enumerate(allowed_tiers): - tier_id = tier.get('id', 'unknown') - is_default = tier.get('isDefault', False) - user_defined = tier.get('userDefinedCloudaicompanionProject', False) - lib_logger.debug(f" Tier {i+1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}") + tier_id = tier.get("id", "unknown") + is_default = tier.get("isDefault", False) + user_defined = tier.get("userDefinedCloudaicompanionProject", False) + lib_logger.debug( + f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}" + ) lib_logger.debug(f"========================") # Determine the current tier ID current_tier_id = None if current_tier: - current_tier_id = current_tier.get('id') + current_tier_id = current_tier.get("id") lib_logger.debug(f"User has currentTier: {current_tier_id}") # Check if user is already known to server (has currentTier) if current_tier_id: # User is already onboarded - check for project from server - server_project = data.get('cloudaicompanionProject') - + server_project = data.get("cloudaicompanionProject") + # Check if this tier requires user-defined project (paid tiers) requires_user_project = any( - t.get('id') == current_tier_id and t.get('userDefinedCloudaicompanionProject', False) + t.get("id") == current_tier_id + and t.get("userDefinedCloudaicompanionProject", False) for t in allowed_tiers ) - is_free_tier = current_tier_id == 'free-tier' - + is_free_tier = current_tier_id == "free-tier" + if server_project: # Server returned a project - use it (server wins) # This is the normal case for FREE tier users @@ -407,11 +587,15 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li # No server project but we have configured one - use it # This is the PAID TIER case where server doesn't return a project project_id = configured_project_id - lib_logger.debug(f"No server project, using configured: {project_id}") + lib_logger.debug( + f"No server project, using configured: {project_id}" + ) elif is_free_tier: # Free tier user without server project - this shouldn't happen normally # but let's not fail, just proceed to onboarding - lib_logger.debug("Free tier user with currentTier but no project - will try onboarding") + lib_logger.debug( + "Free tier user with currentTier but no project - will try onboarding" + ) project_id = None elif requires_user_project: # Paid tier requires a project ID to be set @@ -421,7 +605,9 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li ) else: # Unknown tier without project - proceed carefully - lib_logger.warning(f"Tier '{current_tier_id}' has no project and none configured - will try onboarding") + lib_logger.warning( + f"Tier '{current_tier_id}' has no project and none configured - will try onboarding" + ) project_id = None if project_id: @@ -430,54 +616,70 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li discovered_tier = current_tier_id # Log appropriately based on tier - is_paid = current_tier_id and current_tier_id not in ['free-tier', 'legacy-tier', 'unknown'] + is_paid = current_tier_id and current_tier_id not in [ + "free-tier", + "legacy-tier", + "unknown", + ] if is_paid: - lib_logger.info(f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}") + lib_logger.info( + f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}" + ) else: - lib_logger.info(f"Discovered Gemini project ID via loadCodeAssist: {project_id}") + lib_logger.info( + f"Discovered Gemini project ID via loadCodeAssist: {project_id}" + ) self.project_id_cache[credential_path] = project_id discovered_project_id = project_id - + # Persist to credential file - await self._persist_project_metadata(credential_path, project_id, discovered_tier) - + await self._persist_project_metadata( + credential_path, project_id, discovered_tier + ) + return project_id - + # 2. User needs onboarding - no currentTier - lib_logger.info("No existing Gemini session found (no currentTier), attempting to onboard user...") - + lib_logger.info( + "No existing Gemini session found (no currentTier), attempting to onboard user..." + ) + # Determine which tier to onboard with onboard_tier = None for tier in allowed_tiers: - if tier.get('isDefault'): + if tier.get("isDefault"): onboard_tier = tier break - + # Fallback to LEGACY tier if no default (requires user project) if not onboard_tier and allowed_tiers: # Look for legacy-tier as fallback for tier in allowed_tiers: - if tier.get('id') == 'legacy-tier': + if tier.get("id") == "legacy-tier": onboard_tier = tier break # If still no tier, use first available if not onboard_tier: onboard_tier = allowed_tiers[0] - + if not onboard_tier: raise ValueError("No onboarding tiers available from server") - - tier_id = onboard_tier.get('id', 'free-tier') - requires_user_project = onboard_tier.get('userDefinedCloudaicompanionProject', False) - - lib_logger.debug(f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}") - + + tier_id = onboard_tier.get("id", "free-tier") + requires_user_project = onboard_tier.get( + "userDefinedCloudaicompanionProject", False + ) + + lib_logger.debug( + f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}" + ) + # Build onboard request based on tier type (following official CLI logic) # FREE tier: cloudaicompanionProject = None (server-managed) # PAID tier: cloudaicompanionProject = configured_project_id (user must provide) - is_free_tier = tier_id == 'free-tier' - + is_free_tier = tier_id == "free-tier" + if is_free_tier: # Free tier uses server-managed project onboard_request = { @@ -485,7 +687,9 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li "cloudaicompanionProject": None, # Server will create/manage "metadata": core_client_metadata, } - lib_logger.debug("Free tier onboarding: using server-managed project") + lib_logger.debug( + "Free tier onboarding: using server-managed project" + ) else: # Paid/legacy tier requires user-provided project if not configured_project_id and requires_user_project: @@ -499,51 +703,85 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li "metadata": { **core_client_metadata, "duetProject": configured_project_id, - } if configured_project_id else core_client_metadata, + } + if configured_project_id + else core_client_metadata, } - lib_logger.debug(f"Paid tier onboarding: using project {configured_project_id}") + lib_logger.debug( + f"Paid tier onboarding: using project {configured_project_id}" + ) lib_logger.debug("Initiating onboardUser request...") - lro_response = await client.post(f"{CODE_ASSIST_ENDPOINT}:onboardUser", headers=headers, json=onboard_request, timeout=30) + lro_response = await client.post( + f"{CODE_ASSIST_ENDPOINT}:onboardUser", + headers=headers, + json=onboard_request, + timeout=30, + ) lro_response.raise_for_status() lro_data = lro_response.json() - lib_logger.debug(f"Initial onboarding response: done={lro_data.get('done')}") + lib_logger.debug( + f"Initial onboarding response: done={lro_data.get('done')}" + ) for i in range(150): # Poll for up to 5 minutes (150 × 2s) - if lro_data.get('done'): - lib_logger.debug(f"Onboarding completed after {i} polling attempts") + if lro_data.get("done"): + lib_logger.debug( + f"Onboarding completed after {i} polling attempts" + ) break await asyncio.sleep(2) if (i + 1) % 15 == 0: # Log every 30 seconds - lib_logger.info(f"Still waiting for onboarding completion... ({(i+1)*2}s elapsed)") - lib_logger.debug(f"Polling onboarding status... (Attempt {i+1}/150)") - lro_response = await client.post(f"{CODE_ASSIST_ENDPOINT}:onboardUser", headers=headers, json=onboard_request, timeout=30) + lib_logger.info( + f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)" + ) + lib_logger.debug( + f"Polling onboarding status... (Attempt {i + 1}/150)" + ) + lro_response = await client.post( + f"{CODE_ASSIST_ENDPOINT}:onboardUser", + headers=headers, + json=onboard_request, + timeout=30, + ) lro_response.raise_for_status() lro_data = lro_response.json() - if not lro_data.get('done'): + if not lro_data.get("done"): lib_logger.error("Onboarding process timed out after 5 minutes") - raise ValueError("Onboarding process timed out after 5 minutes. Please try again or contact support.") + raise ValueError( + "Onboarding process timed out after 5 minutes. Please try again or contact support." + ) # Extract project ID from LRO response # Note: onboardUser returns response.cloudaicompanionProject as an object with .id - lro_response_data = lro_data.get('response', {}) - lro_project_obj = lro_response_data.get('cloudaicompanionProject', {}) - project_id = lro_project_obj.get('id') if isinstance(lro_project_obj, dict) else None - + lro_response_data = lro_data.get("response", {}) + lro_project_obj = lro_response_data.get("cloudaicompanionProject", {}) + project_id = ( + lro_project_obj.get("id") + if isinstance(lro_project_obj, dict) + else None + ) + # Fallback to configured project if LRO didn't return one if not project_id and configured_project_id: project_id = configured_project_id - lib_logger.debug(f"LRO didn't return project, using configured: {project_id}") - + lib_logger.debug( + f"LRO didn't return project, using configured: {project_id}" + ) + if not project_id: - lib_logger.error("Onboarding completed but no project ID in response and none configured") + lib_logger.error( + "Onboarding completed but no project ID in response and none configured" + ) raise ValueError( "Onboarding completed, but no project ID was returned. " "For paid tiers, set GEMINI_CLI_PROJECT_ID environment variable." ) - lib_logger.debug(f"Successfully extracted project ID from onboarding response: {project_id}") + lib_logger.debug( + f"Successfully extracted project ID from onboarding response: {project_id}" + ) # Cache tier info self.project_tier_cache[credential_path] = tier_id @@ -551,18 +789,24 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li lib_logger.debug(f"Cached tier information: {tier_id}") # Log concise message for paid projects - is_paid = tier_id and tier_id not in ['free-tier', 'legacy-tier'] + is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"] if is_paid: - lib_logger.info(f"Using Gemini paid tier '{tier_id}' with project: {project_id}") + lib_logger.info( + f"Using Gemini paid tier '{tier_id}' with project: {project_id}" + ) else: - lib_logger.info(f"Successfully onboarded user and discovered project ID: {project_id}") + lib_logger.info( + f"Successfully onboarded user and discovered project ID: {project_id}" + ) self.project_id_cache[credential_path] = project_id discovered_project_id = project_id - + # Persist to credential file - await self._persist_project_metadata(credential_path, project_id, discovered_tier) - + await self._persist_project_metadata( + credential_path, project_id, discovered_tier + ) + return project_id except httpx.HTTPStatusError as e: @@ -572,50 +816,86 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li except Exception: pass if e.response.status_code == 403: - lib_logger.error(f"Gemini Code Assist API access denied (403). Response: {error_body}") - lib_logger.error("Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions") + lib_logger.error( + f"Gemini Code Assist API access denied (403). Response: {error_body}" + ) + lib_logger.error( + "Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions" + ) elif e.response.status_code == 404: - lib_logger.warning(f"Gemini Code Assist endpoint not found (404). Falling back to project listing.") + lib_logger.warning( + f"Gemini Code Assist endpoint not found (404). Falling back to project listing." + ) elif e.response.status_code == 412: # Precondition Failed - often means wrong project for free tier onboarding - lib_logger.error(f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier.") + lib_logger.error( + f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier." + ) else: - lib_logger.warning(f"Gemini onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing.") + lib_logger.warning( + f"Gemini onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing." + ) except httpx.RequestError as e: - lib_logger.warning(f"Gemini onboarding/discovery network error: {e}. Falling back to project listing.") + lib_logger.warning( + f"Gemini onboarding/discovery network error: {e}. Falling back to project listing." + ) # 3. Fallback to listing all available GCP projects (last resort) - lib_logger.debug("Attempting to discover project via GCP Resource Manager API...") + lib_logger.debug( + "Attempting to discover project via GCP Resource Manager API..." + ) try: async with httpx.AsyncClient() as client: - lib_logger.debug("Querying Cloud Resource Manager for available projects...") - response = await client.get("https://cloudresourcemanager.googleapis.com/v1/projects", headers=headers, timeout=20) + lib_logger.debug( + "Querying Cloud Resource Manager for available projects..." + ) + response = await client.get( + "https://cloudresourcemanager.googleapis.com/v1/projects", + headers=headers, + timeout=20, + ) response.raise_for_status() - projects = response.json().get('projects', []) + projects = response.json().get("projects", []) lib_logger.debug(f"Found {len(projects)} total projects") - active_projects = [p for p in projects if p.get('lifecycleState') == 'ACTIVE'] + active_projects = [ + p for p in projects if p.get("lifecycleState") == "ACTIVE" + ] lib_logger.debug(f"Found {len(active_projects)} active projects") if not projects: - lib_logger.error("No GCP projects found for this account. Please create a project in Google Cloud Console.") + lib_logger.error( + "No GCP projects found for this account. Please create a project in Google Cloud Console." + ) elif not active_projects: - lib_logger.error("No active GCP projects found. Please activate a project in Google Cloud Console.") + lib_logger.error( + "No active GCP projects found. Please activate a project in Google Cloud Console." + ) else: - project_id = active_projects[0]['projectId'] - lib_logger.info(f"Discovered Gemini project ID from active projects list: {project_id}") - lib_logger.debug(f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)") + project_id = active_projects[0]["projectId"] + lib_logger.info( + f"Discovered Gemini project ID from active projects list: {project_id}" + ) + lib_logger.debug( + f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)" + ) self.project_id_cache[credential_path] = project_id discovered_project_id = project_id - + # [NEW] Persist to credential file (no tier info from resource manager) - await self._persist_project_metadata(credential_path, project_id, None) - + await self._persist_project_metadata( + credential_path, project_id, None + ) + return project_id except httpx.HTTPStatusError as e: if e.response.status_code == 403: - lib_logger.error("Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission.") + lib_logger.error( + "Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission." + ) else: - lib_logger.error(f"Failed to list GCP projects with status {e.response.status_code}: {e}") + lib_logger.error( + f"Failed to list GCP projects with status {e.response.status_code}: {e}" + ) except httpx.RequestError as e: lib_logger.error(f"Network error while listing GCP projects: {e}") @@ -626,20 +906,24 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li " 3. Account lacks necessary permissions\n" "To manually specify a project, set GEMINI_CLI_PROJECT_ID in your .env file." ) - - async def _persist_project_metadata(self, credential_path: str, project_id: str, tier: Optional[str]): + + async def _persist_project_metadata( + self, credential_path: str, project_id: str, tier: Optional[str] + ): """Persists project ID and tier to the credential file for faster future startups.""" # Skip persistence for env:// paths (environment-based credentials) credential_index = self._parse_env_credential_path(credential_path) if credential_index is not None: - lib_logger.debug(f"Skipping project metadata persistence for env:// credential path: {credential_path}") + lib_logger.debug( + f"Skipping project metadata persistence for env:// credential path: {credential_path}" + ) return - + try: # Load current credentials - with open(credential_path, 'r') as f: + with open(credential_path, "r") as f: creds = json.load(f) - + # Update metadata if "_proxy_metadata" not in creds: creds["_proxy_metadata"] = {} @@ -647,33 +931,36 @@ async def _persist_project_metadata(self, credential_path: str, project_id: str, creds["_proxy_metadata"]["project_id"] = project_id if tier: creds["_proxy_metadata"]["tier"] = tier - + # Save back using the existing save method (handles atomic writes and permissions) await self._save_credentials(credential_path, creds) - - lib_logger.debug(f"Persisted project_id and tier to credential file: {credential_path}") + + lib_logger.debug( + f"Persisted project_id and tier to credential file: {credential_path}" + ) except Exception as e: - lib_logger.warning(f"Failed to persist project metadata to credential file: {e}") + lib_logger.warning( + f"Failed to persist project metadata to credential file: {e}" + ) # Non-fatal - just means slower startup next time - def _check_mixed_tier_warning(self): """Check if mixed free/paid tier credentials are loaded and emit warning.""" if not self.project_tier_cache: return # No tiers loaded yet - + tiers = set(self.project_tier_cache.values()) if len(tiers) <= 1: return # All same tier or only one credential - + # Define paid vs free tiers - free_tiers = {'free-tier', 'legacy-tier', 'unknown'} + free_tiers = {"free-tier", "legacy-tier", "unknown"} paid_tiers = tiers - free_tiers - + # Check if we have both free and paid has_free = bool(tiers & free_tiers) has_paid = bool(paid_tiers) - + if has_free and has_paid: lib_logger.warning( f"Mixed Gemini tier credentials detected! You have both free-tier and paid-tier " @@ -688,12 +975,12 @@ def _cli_preview_fallback_order(self, model: str) -> List[str]: """ Returns a list of model names to try in order for rate limit fallback. First model in list is the original model, subsequent models are fallback options. - + Since all fallbacks have been deprecated, this now only returns the base model. The fallback logic will check if there are actual fallbacks available. """ # Remove provider prefix if present - model_name = model.split('/')[-1].replace(':thinking', '') + model_name = model.split("/")[-1].replace(":thinking", "") # Define fallback chains for models with preview versions # All fallbacks have been deprecated, so only base models are returned @@ -706,10 +993,12 @@ def _cli_preview_fallback_order(self, model: str) -> List[str]: # Return fallback chain if available, otherwise just return the original model return fallback_chains.get(model_name, [model_name]) - def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") -> Tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: + def _transform_messages( + self, messages: List[Dict[str, Any]], model: str = "" + ) -> Tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: """ Transform OpenAI messages to Gemini CLI format. - + Handles: - System instruction extraction - Multi-part content (text, images) @@ -720,14 +1009,14 @@ def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") - system_instruction = None gemini_contents = [] is_gemini_3 = self._is_gemini_3(model) - + # Separate system prompt from other messages - if messages and messages[0].get('role') == 'system': - system_prompt_content = messages.pop(0).get('content', '') + if messages and messages[0].get("role") == "system": + system_prompt_content = messages.pop(0).get("content", "") if system_prompt_content: system_instruction = { "role": "user", - "parts": [{"text": system_prompt_content}] + "parts": [{"text": system_prompt_content}], } tool_call_id_to_name = {} @@ -735,18 +1024,22 @@ def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") - if msg.get("role") == "assistant" and msg.get("tool_calls"): for tool_call in msg["tool_calls"]: if tool_call.get("type") == "function": - tool_call_id_to_name[tool_call["id"]] = tool_call["function"]["name"] + tool_call_id_to_name[tool_call["id"]] = tool_call["function"][ + "name" + ] # Process messages and consolidate consecutive tool responses # Per Gemini docs: parallel function responses must be in a single user message, # not interleaved as separate messages pending_tool_parts = [] # Accumulate tool responses - + for msg in messages: role = msg.get("role") content = msg.get("content") parts = [] - gemini_role = "model" if role == "assistant" else "user" # tool -> user in Gemini + gemini_role = ( + "model" if role == "assistant" else "user" + ) # tool -> user in Gemini # If we have pending tool parts and hit a non-tool message, flush them first if pending_tool_parts and role != "tool": @@ -773,16 +1066,22 @@ def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") - # Parse: ... header, data = image_url.split(",", 1) mime_type = header.split(":")[1].split(";")[0] - parts.append({ - "inlineData": { - "mimeType": mime_type, - "data": data + parts.append( + { + "inlineData": { + "mimeType": mime_type, + "data": data, + } } - }) + ) except Exception as e: - lib_logger.warning(f"Failed to parse image data URL: {e}") + lib_logger.warning( + f"Failed to parse image data URL: {e}" + ) else: - lib_logger.warning(f"Non-data-URL images not supported: {image_url[:50]}...") + lib_logger.warning( + f"Non-data-URL images not supported: {image_url[:50]}..." + ) elif role == "assistant": if isinstance(content, str): @@ -794,25 +1093,27 @@ def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") - for tool_call in msg["tool_calls"]: if tool_call.get("type") == "function": try: - args_dict = json.loads(tool_call["function"]["arguments"]) + args_dict = json.loads( + tool_call["function"]["arguments"] + ) except (json.JSONDecodeError, TypeError): args_dict = {} - + tool_id = tool_call.get("id", "") func_name = tool_call["function"]["name"] - + # Add prefix for Gemini 3 if is_gemini_3 and self._enable_gemini3_tool_fix: func_name = f"{self._gemini3_tool_prefix}{func_name}" - + func_part = { "functionCall": { "name": func_name, "args": args_dict, - "id": tool_id + "id": tool_id, } } - + # Add thoughtSignature for Gemini 3 # Per Gemini docs: Only the FIRST parallel function call gets a signature. # Subsequent parallel calls should NOT have a thoughtSignature field. @@ -820,17 +1121,21 @@ def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") - sig = tool_call.get("thought_signature") if not sig and tool_id and self._enable_signature_cache: sig = self._signature_cache.retrieve(tool_id) - + if sig: func_part["thoughtSignature"] = sig elif first_func_in_msg: # Only add bypass to the first function call if no sig available - func_part["thoughtSignature"] = "skip_thought_signature_validator" - lib_logger.warning(f"Missing thoughtSignature for first func call {tool_id}, using bypass") + func_part["thoughtSignature"] = ( + "skip_thought_signature_validator" + ) + lib_logger.warning( + f"Missing thoughtSignature for first func call {tool_id}, using bypass" + ) # Subsequent parallel calls: no signature field at all - + first_func_in_msg = False - + parts.append(func_part) elif role == "tool": @@ -840,17 +1145,19 @@ def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") - # Add prefix for Gemini 3 if is_gemini_3 and self._enable_gemini3_tool_fix: function_name = f"{self._gemini3_tool_prefix}{function_name}" - + # Wrap the tool response in a 'result' object response_content = {"result": content} # Accumulate tool responses - they'll be combined into one user message - pending_tool_parts.append({ - "functionResponse": { - "name": function_name, - "response": response_content, - "id": tool_call_id + pending_tool_parts.append( + { + "functionResponse": { + "name": function_name, + "response": response_content, + "id": tool_call_id, + } } - }) + ) # Don't add parts here - tool responses are handled via pending_tool_parts continue @@ -861,15 +1168,17 @@ def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") - if pending_tool_parts: gemini_contents.append({"role": "user", "parts": pending_tool_parts}) - if not gemini_contents or gemini_contents[0]['role'] != 'user': + if not gemini_contents or gemini_contents[0]["role"] != "user": gemini_contents.insert(0, {"role": "user", "parts": [{"text": ""}]}) return system_instruction, gemini_contents - def _handle_reasoning_parameters(self, payload: Dict[str, Any], model: str) -> Optional[Dict[str, Any]]: + def _handle_reasoning_parameters( + self, payload: Dict[str, Any], model: str + ) -> Optional[Dict[str, Any]]: """ Map reasoning_effort to thinking configuration. - + - Gemini 2.5: thinkingBudget (integer tokens) - Gemini 3: thinkingLevel (string: "low"/"high") """ @@ -887,13 +1196,13 @@ def _handle_reasoning_parameters(self, payload: Dict[str, Any], model: str) -> O payload.pop("reasoning_effort", None) payload.pop("custom_reasoning_budget", None) return None - + # Gemini 3: String-based thinkingLevel if is_gemini_3: # Clean up the original payload payload.pop("reasoning_effort", None) payload.pop("custom_reasoning_budget", None) - + if reasoning_effort == "low": return {"thinkingLevel": "low", "include_thoughts": True} return {"thinkingLevel": "high", "include_thoughts": True} @@ -918,122 +1227,137 @@ def _handle_reasoning_parameters(self, payload: Dict[str, Any], model: str) -> O budget = budgets.get(reasoning_effort, -1) if reasoning_effort == "disable": budget = 0 - + if not custom_reasoning_budget: budget = budget // 4 # Clean up the original payload payload.pop("reasoning_effort", None) payload.pop("custom_reasoning_budget", None) - + return {"thinkingBudget": budget, "include_thoughts": True} - def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str, accumulator: Optional[Dict[str, Any]] = None): + def _convert_chunk_to_openai( + self, + chunk: Dict[str, Any], + model_id: str, + accumulator: Optional[Dict[str, Any]] = None, + ): """ Convert Gemini response chunk to OpenAI streaming format. - + Args: chunk: Gemini API response chunk model_id: Model name accumulator: Optional dict to accumulate data for post-processing (signatures, etc.) """ - response_data = chunk.get('response', chunk) - candidates = response_data.get('candidates', []) + response_data = chunk.get("response", chunk) + candidates = response_data.get("candidates", []) if not candidates: return candidate = candidates[0] - parts = candidate.get('content', {}).get('parts', []) + parts = candidate.get("content", {}).get("parts", []) is_gemini_3 = self._is_gemini_3(model_id) for part in parts: delta = {} - - has_func = 'functionCall' in part - has_text = 'text' in part - has_sig = bool(part.get('thoughtSignature')) - is_thought = part.get('thought') is True or (isinstance(part.get('thought'), str) and str(part.get('thought')).lower() == 'true') - + + has_func = "functionCall" in part + has_text = "text" in part + has_sig = bool(part.get("thoughtSignature")) + is_thought = part.get("thought") is True or ( + isinstance(part.get("thought"), str) + and str(part.get("thought")).lower() == "true" + ) + # Skip standalone signature parts (no function, no meaningful text) - if has_sig and not has_func and (not has_text or not part.get('text')): + if has_sig and not has_func and (not has_text or not part.get("text")): continue if has_func: - function_call = part['functionCall'] - function_name = function_call.get('name', 'unknown') - + function_call = part["functionCall"] + function_name = function_call.get("name", "unknown") + # Strip Gemini 3 prefix from tool name if is_gemini_3 and self._enable_gemini3_tool_fix: function_name = self._strip_gemini3_prefix(function_name) - + # Use provided ID or generate unique one with nanosecond precision - tool_call_id = function_call.get('id') or f"call_{function_name}_{int(time.time() * 1_000_000_000)}" - + tool_call_id = ( + function_call.get("id") + or f"call_{function_name}_{int(time.time() * 1_000_000_000)}" + ) + # Get current tool index from accumulator (default 0) and increment - current_tool_idx = accumulator.get('tool_idx', 0) if accumulator else 0 - + current_tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0 + tool_call = { "index": current_tool_idx, "id": tool_call_id, "type": "function", "function": { "name": function_name, - "arguments": json.dumps(function_call.get('args', {})) - } + "arguments": json.dumps(function_call.get("args", {})), + }, } - + # Handle thoughtSignature for Gemini 3 # Store signature for each tool call (needed for parallel tool calls) if is_gemini_3 and has_sig: - sig = part['thoughtSignature'] - + sig = part["thoughtSignature"] + if self._enable_signature_cache: self._signature_cache.store(tool_call_id, sig) lib_logger.debug(f"Stored signature for {tool_call_id}") - + if self._preserve_signatures_in_client: tool_call["thought_signature"] = sig - - delta['tool_calls'] = [tool_call] + + delta["tool_calls"] = [tool_call] # Mark that we've sent tool calls and increment tool_idx if accumulator is not None: - accumulator['has_tool_calls'] = True - accumulator['tool_idx'] = current_tool_idx + 1 - + accumulator["has_tool_calls"] = True + accumulator["tool_idx"] = current_tool_idx + 1 + elif has_text: # Use an explicit check for the 'thought' flag, as its type can be inconsistent if is_thought: - delta['reasoning_content'] = part['text'] + delta["reasoning_content"] = part["text"] else: - delta['content'] = part['text'] - + delta["content"] = part["text"] + if not delta: continue # Mark that we have tool calls for accumulator tracking # finish_reason determination is handled by the client - + # Mark stream complete if we have usageMetadata - is_final_chunk = 'usageMetadata' in response_data + is_final_chunk = "usageMetadata" in response_data if is_final_chunk and accumulator is not None: - accumulator['is_complete'] = True + accumulator["is_complete"] = True # Build choice - don't include finish_reason, let client handle it choice = {"index": 0, "delta": delta} - + openai_chunk = { - "choices": [choice], "model": model_id, "object": "chat.completion.chunk", - "id": chunk.get("responseId", f"chatcmpl-geminicli-{time.time()}"), "created": int(time.time()) + "choices": [choice], + "model": model_id, + "object": "chat.completion.chunk", + "id": chunk.get("responseId", f"chatcmpl-geminicli-{time.time()}"), + "created": int(time.time()), } - if 'usageMetadata' in response_data: - usage = response_data['usageMetadata'] + if "usageMetadata" in response_data: + usage = response_data["usageMetadata"] prompt_tokens = usage.get("promptTokenCount", 0) thoughts_tokens = usage.get("thoughtsTokenCount", 0) candidate_tokens = usage.get("candidatesTokenCount", 0) openai_chunk["usage"] = { - "prompt_tokens": prompt_tokens + thoughts_tokens, # Include thoughts in prompt tokens + "prompt_tokens": prompt_tokens + + thoughts_tokens, # Include thoughts in prompt tokens "completion_tokens": candidate_tokens, "total_tokens": usage.get("totalTokenCount", 0), } @@ -1042,14 +1366,18 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str, accumul if thoughts_tokens > 0: if "completion_tokens_details" not in openai_chunk["usage"]: openai_chunk["usage"]["completion_tokens_details"] = {} - openai_chunk["usage"]["completion_tokens_details"]["reasoning_tokens"] = thoughts_tokens - + openai_chunk["usage"]["completion_tokens_details"][ + "reasoning_tokens" + ] = thoughts_tokens + yield openai_chunk - def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse: + def _stream_to_completion_response( + self, chunks: List[litellm.ModelResponse] + ) -> litellm.ModelResponse: """ Manually reassembles streaming chunks into a complete response. - + Key improvements: - Determines finish_reason based on accumulated state - Priority: tool_calls > chunk's finish_reason (length, content_filter, etc.) > stop @@ -1069,7 +1397,7 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> # Process each chunk to aggregate content for chunk in chunks: - if not hasattr(chunk, 'choices') or not chunk.choices: + if not hasattr(chunk, "choices") or not chunk.choices: continue choice = chunk.choices[0] @@ -1092,25 +1420,48 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> for tc_chunk in delta["tool_calls"]: index = tc_chunk.get("index", 0) if index not in aggregated_tool_calls: - aggregated_tool_calls[index] = {"type": "function", "function": {"name": "", "arguments": ""}} + aggregated_tool_calls[index] = { + "type": "function", + "function": {"name": "", "arguments": ""}, + } if "id" in tc_chunk: aggregated_tool_calls[index]["id"] = tc_chunk["id"] if "type" in tc_chunk: aggregated_tool_calls[index]["type"] = tc_chunk["type"] if "function" in tc_chunk: - if "name" in tc_chunk["function"] and tc_chunk["function"]["name"] is not None: - aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"] - if "arguments" in tc_chunk["function"] and tc_chunk["function"]["arguments"] is not None: - aggregated_tool_calls[index]["function"]["arguments"] += tc_chunk["function"]["arguments"] + if ( + "name" in tc_chunk["function"] + and tc_chunk["function"]["name"] is not None + ): + aggregated_tool_calls[index]["function"]["name"] += ( + tc_chunk["function"]["name"] + ) + if ( + "arguments" in tc_chunk["function"] + and tc_chunk["function"]["arguments"] is not None + ): + aggregated_tool_calls[index]["function"]["arguments"] += ( + tc_chunk["function"]["arguments"] + ) # Aggregate function calls (legacy format) if "function_call" in delta and delta["function_call"] is not None: if "function_call" not in final_message: final_message["function_call"] = {"name": "", "arguments": ""} - if "name" in delta["function_call"] and delta["function_call"]["name"] is not None: - final_message["function_call"]["name"] += delta["function_call"]["name"] - if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None: - final_message["function_call"]["arguments"] += delta["function_call"]["arguments"] + if ( + "name" in delta["function_call"] + and delta["function_call"]["name"] is not None + ): + final_message["function_call"]["name"] += delta["function_call"][ + "name" + ] + if ( + "arguments" in delta["function_call"] + and delta["function_call"]["arguments"] is not None + ): + final_message["function_call"]["arguments"] += delta[ + "function_call" + ]["arguments"] # Track finish_reason from chunks (respects length, content_filter, etc.) if choice.get("finish_reason"): @@ -1118,7 +1469,7 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> # Handle usage data from the last chunk that has it for chunk in reversed(chunks): - if hasattr(chunk, 'usage') and chunk.usage: + if hasattr(chunk, "usage") and chunk.usage: usage_data = chunk.usage break @@ -1139,12 +1490,12 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> finish_reason = chunk_finish_reason else: finish_reason = "stop" - + # Construct the final response final_choice = { "index": 0, "message": final_message, - "finish_reason": finish_reason + "finish_reason": finish_reason, } # Create the final ModelResponse @@ -1154,7 +1505,7 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> "created": first_chunk.created, "model": first_chunk.model, "choices": [final_choice], - "usage": usage_data + "usage": usage_data, } return litellm.ModelResponse(**final_response_data) @@ -1169,63 +1520,72 @@ def _gemini_cli_transform_schema(self, schema: Dict[str, Any]) -> Dict[str, Any] return schema # Handle nullable types - if 'type' in schema and isinstance(schema['type'], list): - types = schema['type'] - if 'null' in types: - schema['nullable'] = True - remaining_types = [t for t in types if t != 'null'] + if "type" in schema and isinstance(schema["type"], list): + types = schema["type"] + if "null" in types: + schema["nullable"] = True + remaining_types = [t for t in types if t != "null"] if len(remaining_types) == 1: - schema['type'] = remaining_types[0] + schema["type"] = remaining_types[0] elif len(remaining_types) > 1: - schema['type'] = remaining_types # Let's see if Gemini supports this + schema["type"] = ( + remaining_types # Let's see if Gemini supports this + ) else: - del schema['type'] + del schema["type"] # Recurse into properties - if 'properties' in schema and isinstance(schema['properties'], dict): - for prop_schema in schema['properties'].values(): + if "properties" in schema and isinstance(schema["properties"], dict): + for prop_schema in schema["properties"].values(): self._gemini_cli_transform_schema(prop_schema) # Recurse into items (for arrays) - if 'items' in schema and isinstance(schema['items'], dict): - self._gemini_cli_transform_schema(schema['items']) + if "items" in schema and isinstance(schema["items"], dict): + self._gemini_cli_transform_schema(schema["items"]) # Clean up unsupported properties schema.pop("strict", None) schema.pop("additionalProperties", None) - + return schema def _enforce_strict_schema(self, schema: Any) -> Any: """ Enforce strict JSON schema for Gemini 3 to prevent hallucinated parameters. - + Adds 'additionalProperties: false' recursively to all object schemas, which tells the model it CANNOT add properties not in the schema. """ if not isinstance(schema, dict): return schema - + result = {} for key, value in schema.items(): if isinstance(value, dict): result[key] = self._enforce_strict_schema(value) elif isinstance(value, list): - result[key] = [self._enforce_strict_schema(item) if isinstance(item, dict) else item for item in value] + result[key] = [ + self._enforce_strict_schema(item) + if isinstance(item, dict) + else item + for item in value + ] else: result[key] = value - + # Add additionalProperties: false to object schemas if result.get("type") == "object" and "properties" in result: result["additionalProperties"] = False - + return result - def _transform_tool_schemas(self, tools: List[Dict[str, Any]], model: str = "") -> List[Dict[str, Any]]: + def _transform_tool_schemas( + self, tools: List[Dict[str, Any]], model: str = "" + ) -> List[Dict[str, Any]]: """ Transforms a list of OpenAI-style tool schemas into the format required by the Gemini CLI API. This uses a custom schema transformer instead of litellm's generic one. - + For Gemini 3 models, also applies: - Namespace prefix to tool names - Parameter signature injection into descriptions @@ -1233,22 +1593,27 @@ def _transform_tool_schemas(self, tools: List[Dict[str, Any]], model: str = "") """ transformed_declarations = [] is_gemini_3 = self._is_gemini_3(model) - + for tool in tools: if tool.get("type") == "function" and "function" in tool: new_function = json.loads(json.dumps(tool["function"])) - + # The Gemini CLI API does not support the 'strict' property. new_function.pop("strict", None) # Gemini CLI expects 'parametersJsonSchema' instead of 'parameters' if "parameters" in new_function: - schema = self._gemini_cli_transform_schema(new_function["parameters"]) + schema = self._gemini_cli_transform_schema( + new_function["parameters"] + ) new_function["parametersJsonSchema"] = schema del new_function["parameters"] elif "parametersJsonSchema" not in new_function: # Set default empty schema if neither exists - new_function["parametersJsonSchema"] = {"type": "object", "properties": {}} + new_function["parametersJsonSchema"] = { + "type": "object", + "properties": {}, + } # Gemini 3 specific transformations if is_gemini_3 and self._enable_gemini3_tool_fix: @@ -1256,64 +1621,73 @@ def _transform_tool_schemas(self, tools: List[Dict[str, Any]], model: str = "") name = new_function.get("name", "") if name: new_function["name"] = f"{self._gemini3_tool_prefix}{name}" - + # Enforce strict schema (additionalProperties: false) - if self._gemini3_enforce_strict_schema and "parametersJsonSchema" in new_function: - new_function["parametersJsonSchema"] = self._enforce_strict_schema(new_function["parametersJsonSchema"]) - + if ( + self._gemini3_enforce_strict_schema + and "parametersJsonSchema" in new_function + ): + new_function["parametersJsonSchema"] = ( + self._enforce_strict_schema( + new_function["parametersJsonSchema"] + ) + ) + # Inject parameter signature into description new_function = self._inject_signature_into_description(new_function) transformed_declarations.append(new_function) - + return transformed_declarations - def _inject_signature_into_description(self, func_decl: Dict[str, Any]) -> Dict[str, Any]: + def _inject_signature_into_description( + self, func_decl: Dict[str, Any] + ) -> Dict[str, Any]: """Inject parameter signatures into tool description for Gemini 3.""" schema = func_decl.get("parametersJsonSchema", {}) if not schema: return func_decl - + required = schema.get("required", []) properties = schema.get("properties", {}) - + if not properties: return func_decl - + param_list = [] for prop_name, prop_data in properties.items(): if not isinstance(prop_data, dict): continue - + type_hint = self._format_type_hint(prop_data) is_required = prop_name in required param_list.append( f"{prop_name} ({type_hint}{', REQUIRED' if is_required else ''})" ) - + if param_list: sig_str = self._gemini3_description_prompt.replace( "{params}", ", ".join(param_list) ) func_decl["description"] = func_decl.get("description", "") + sig_str - + return func_decl def _format_type_hint(self, prop_data: Dict[str, Any], depth: int = 0) -> str: """Format a detailed type hint for a property schema.""" type_hint = prop_data.get("type", "unknown") - + # Handle enum values - show allowed options if "enum" in prop_data: enum_vals = prop_data["enum"] if len(enum_vals) <= 5: return f"string ENUM[{', '.join(repr(v) for v in enum_vals)}]" return f"string ENUM[{len(enum_vals)} options]" - + # Handle const values if "const" in prop_data: return f"string CONST={repr(prop_data['const'])}" - + if type_hint == "array": items = prop_data.get("items", {}) if isinstance(items, dict): @@ -1336,7 +1710,7 @@ def _format_type_hint(self, prop_data: Dict[str, Any], depth: int = 0) -> str: return "ARRAY_OF_OBJECTS" return f"ARRAY_OF_{item_type.upper()}" return "ARRAY" - + if type_hint == "object": nested_props = prop_data.get("properties", {}) nested_req = prop_data.get("required", []) @@ -1348,31 +1722,39 @@ def _format_type_hint(self, prop_data: Dict[str, Any], depth: int = 0) -> str: req = " REQUIRED" if n in nested_req else "" nested_list.append(f"{n}: {t}{req}") return f"object{{{', '.join(nested_list)}}}" - + return type_hint - def _inject_gemini3_system_instruction(self, request_payload: Dict[str, Any]) -> None: + def _inject_gemini3_system_instruction( + self, request_payload: Dict[str, Any] + ) -> None: """Inject Gemini 3 tool fix system instruction if tools are present.""" if not request_payload.get("request", {}).get("tools"): return - + existing_system = request_payload.get("request", {}).get("systemInstruction") - + if existing_system: # Prepend to existing system instruction existing_parts = existing_system.get("parts", []) if existing_parts and existing_parts[0].get("text"): - existing_parts[0]["text"] = self._gemini3_system_instruction + "\n\n" + existing_parts[0]["text"] + existing_parts[0]["text"] = ( + self._gemini3_system_instruction + + "\n\n" + + existing_parts[0]["text"] + ) else: existing_parts.insert(0, {"text": self._gemini3_system_instruction}) else: # Create new system instruction request_payload["request"]["systemInstruction"] = { "role": "user", - "parts": [{"text": self._gemini3_system_instruction}] + "parts": [{"text": self._gemini3_system_instruction}], } - def _translate_tool_choice(self, tool_choice: Union[str, Dict[str, Any]], model: str = "") -> Optional[Dict[str, Any]]: + def _translate_tool_choice( + self, tool_choice: Union[str, Dict[str, Any]], model: str = "" + ) -> Optional[Dict[str, Any]]: """ Translates OpenAI's `tool_choice` to Gemini's `toolConfig`. Handles Gemini 3 namespace prefixes for specific tool selection. @@ -1397,18 +1779,20 @@ def _translate_tool_choice(self, tool_choice: Union[str, Dict[str, Any]], model: # Add Gemini 3 prefix if needed if is_gemini_3 and self._enable_gemini3_tool_fix: function_name = f"{self._gemini3_tool_prefix}{function_name}" - - mode = "ANY" # Force a call, but only to this function + + mode = "ANY" # Force a call, but only to this function config["functionCallingConfig"] = { "mode": mode, - "allowedFunctionNames": [function_name] + "allowedFunctionNames": [function_name], } return config config["functionCallingConfig"] = {"mode": mode} return config - async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: + async def acompletion( + self, client: httpx.AsyncClient, **kwargs + ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: model = kwargs["model"] credential_path = kwargs.pop("credential_identifier") enable_request_logging = kwargs.pop("enable_request_logging", False) @@ -1423,28 +1807,37 @@ async def do_call(attempt_model: str, is_fallback: bool = False): # Discover project ID only if not already cached project_id = self.project_id_cache.get(credential_path) if not project_id: - access_token = auth_header['Authorization'].split(' ')[1] - project_id = await self._discover_project_id(credential_path, access_token, kwargs.get("litellm_params", {})) + access_token = auth_header["Authorization"].split(" ")[1] + project_id = await self._discover_project_id( + credential_path, access_token, kwargs.get("litellm_params", {}) + ) # Log paid tier usage visibly on each request credential_tier = self.project_tier_cache.get(credential_path) - if credential_tier and credential_tier not in ['free-tier', 'legacy-tier', 'unknown']: - lib_logger.info(f"[PAID TIER] Using Gemini '{credential_tier}' subscription for this request") + if credential_tier and credential_tier not in [ + "free-tier", + "legacy-tier", + "unknown", + ]: + lib_logger.info( + f"[PAID TIER] Using Gemini '{credential_tier}' subscription for this request" + ) # Handle :thinking suffix - model_name = attempt_model.split('/')[-1].replace(':thinking', '') + model_name = attempt_model.split("/")[-1].replace(":thinking", "") # [NEW] Create a dedicated file logger for this request file_logger = _GeminiCliFileLogger( - model_name=model_name, - enabled=enable_request_logging + model_name=model_name, enabled=enable_request_logging ) - + is_gemini_3 = self._is_gemini_3(model_name) gen_config = { - "maxOutputTokens": kwargs.get("max_tokens", 64000), # Increased default - "temperature": kwargs.get("temperature", 1), # Default to 1 if not provided + "maxOutputTokens": kwargs.get("max_tokens", 64000), # Increased default + "temperature": kwargs.get( + "temperature", 1 + ), # Default to 1 if not provided } if "top_k" in kwargs: gen_config["topK"] = kwargs["top_k"] @@ -1456,7 +1849,9 @@ async def do_call(attempt_model: str, is_fallback: bool = False): if thinking_config: gen_config["thinkingConfig"] = thinking_config - system_instruction, contents = self._transform_messages(kwargs.get("messages", []), model_name) + system_instruction, contents = self._transform_messages( + kwargs.get("messages", []), model_name + ) request_payload = { "model": model_name, "project": project_id, @@ -1470,16 +1865,22 @@ async def do_call(attempt_model: str, is_fallback: bool = False): request_payload["request"]["systemInstruction"] = system_instruction if "tools" in kwargs and kwargs["tools"]: - function_declarations = self._transform_tool_schemas(kwargs["tools"], model_name) + function_declarations = self._transform_tool_schemas( + kwargs["tools"], model_name + ) if function_declarations: - request_payload["request"]["tools"] = [{"functionDeclarations": function_declarations}] + request_payload["request"]["tools"] = [ + {"functionDeclarations": function_declarations} + ] # [NEW] Handle tool_choice translation if "tool_choice" in kwargs and kwargs["tool_choice"]: - tool_config = self._translate_tool_choice(kwargs["tool_choice"], model_name) + tool_config = self._translate_tool_choice( + kwargs["tool_choice"], model_name + ) if tool_config: request_payload["request"]["toolConfig"] = tool_config - + # Inject Gemini 3 system instruction if using tools if is_gemini_3 and self._enable_gemini3_tool_fix: self._inject_gemini3_system_instruction(request_payload) @@ -1491,52 +1892,77 @@ async def do_call(attempt_model: str, is_fallback: bool = False): {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + { + "category": "HARM_CATEGORY_CIVIC_INTEGRITY", + "threshold": "BLOCK_NONE", + }, ] # Log the final payload for debugging and to the dedicated file - #lib_logger.debug(f"Gemini CLI Request Payload: {json.dumps(request_payload, indent=2)}") + # lib_logger.debug(f"Gemini CLI Request Payload: {json.dumps(request_payload, indent=2)}") file_logger.log_request(request_payload) - + url = f"{CODE_ASSIST_ENDPOINT}:streamGenerateContent" async def stream_handler(): # Track state across chunks for tool indexing - accumulator = {"has_tool_calls": False, "tool_idx": 0, "is_complete": False} - + accumulator = { + "has_tool_calls": False, + "tool_idx": 0, + "is_complete": False, + } + final_headers = auth_header.copy() - final_headers.update({ - "User-Agent": "google-api-nodejs-client/9.15.1", - "X-Goog-Api-Client": "gl-node/22.17.0", - "Client-Metadata": "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI", - "Accept": "application/json", - }) + final_headers.update( + { + "User-Agent": "google-api-nodejs-client/9.15.1", + "X-Goog-Api-Client": "gl-node/22.17.0", + "Client-Metadata": "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI", + "Accept": "application/json", + } + ) try: - async with client.stream("POST", url, headers=final_headers, json=request_payload, params={"alt": "sse"}, timeout=600) as response: + async with client.stream( + "POST", + url, + headers=final_headers, + json=request_payload, + params={"alt": "sse"}, + timeout=600, + ) as response: # Read and log error body before raise_for_status for better debugging if response.status_code >= 400: try: error_body = await response.aread() - lib_logger.error(f"Gemini CLI API error {response.status_code}: {error_body.decode()}") - file_logger.log_error(f"API error {response.status_code}: {error_body.decode()}") + lib_logger.error( + f"Gemini CLI API error {response.status_code}: {error_body.decode()}" + ) + file_logger.log_error( + f"API error {response.status_code}: {error_body.decode()}" + ) except Exception: pass - + # This will raise an HTTPStatusError for 4xx/5xx responses response.raise_for_status() async for line in response.aiter_lines(): file_logger.log_response_chunk(line) - if line.startswith('data: '): + if line.startswith("data: "): data_str = line[6:] - if data_str == "[DONE]": break + if data_str == "[DONE]": + break try: chunk = json.loads(data_str) - for openai_chunk in self._convert_chunk_to_openai(chunk, model, accumulator): + for openai_chunk in self._convert_chunk_to_openai( + chunk, model, accumulator + ): yield litellm.ModelResponse(**openai_chunk) except json.JSONDecodeError: - lib_logger.warning(f"Could not decode JSON from Gemini CLI: {line}") - + lib_logger.warning( + f"Could not decode JSON from Gemini CLI: {line}" + ) + # Emit final chunk if stream ended without usageMetadata # Client will determine the correct finish_reason if not accumulator.get("is_complete"): @@ -1545,9 +1971,15 @@ async def stream_handler(): "object": "chat.completion.chunk", "created": int(time.time()), "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": None}], + "choices": [ + {"index": 0, "delta": {}, "finish_reason": None} + ], # Include minimal usage to signal this is the final chunk - "usage": {"prompt_tokens": 0, "completion_tokens": 1, "total_tokens": 1} + "usage": { + "prompt_tokens": 0, + "completion_tokens": 1, + "total_tokens": 1, + }, } yield litellm.ModelResponse(**final_chunk) @@ -1558,27 +1990,35 @@ async def stream_handler(): error_body = e.response.text except Exception: pass - + # Only log to file logger (for detailed logging) if error_body: - file_logger.log_error(f"HTTPStatusError {e.response.status_code}: {error_body}") + file_logger.log_error( + f"HTTPStatusError {e.response.status_code}: {error_body}" + ) else: - file_logger.log_error(f"HTTPStatusError {e.response.status_code}: {str(e)}") - + file_logger.log_error( + f"HTTPStatusError {e.response.status_code}: {str(e)}" + ) + if e.response.status_code == 429: # Extract retry-after time from the error body retry_after = extract_retry_after_from_body(error_body) - retry_info = f" (retry after {retry_after}s)" if retry_after else "" + retry_info = ( + f" (retry after {retry_after}s)" if retry_after else "" + ) error_msg = f"Gemini CLI rate limit exceeded{retry_info}" if error_body: error_msg = f"{error_msg} | {error_body}" # Only log at debug level - rotation happens silently - lib_logger.debug(f"Gemini CLI 429 rate limit: retry_after={retry_after}s") + lib_logger.debug( + f"Gemini CLI 429 rate limit: retry_after={retry_after}s" + ) raise RateLimitError( message=error_msg, llm_provider="gemini_cli", model=model, - response=e.response + response=e.response, ) # Re-raise other status errors to be handled by the main acompletion logic raise e @@ -1595,29 +2035,41 @@ async def logging_stream_wrapper(): yield chunk finally: if openai_chunks: - final_response = self._stream_to_completion_response(openai_chunks) + final_response = self._stream_to_completion_response( + openai_chunks + ) file_logger.log_final_response(final_response.dict()) return logging_stream_wrapper() # Check if there are actual fallback models available # If fallback_models is empty or contains only the base model (no actual fallbacks), skip fallback logic - has_fallbacks = len(fallback_models) > 1 and any(model != fallback_models[0] for model in fallback_models[1:]) - + has_fallbacks = len(fallback_models) > 1 and any( + model != fallback_models[0] for model in fallback_models[1:] + ) + lib_logger.debug(f"Fallback models available: {fallback_models}") if not has_fallbacks: - lib_logger.debug("No actual fallback models available, proceeding with single model attempt") - + lib_logger.debug( + "No actual fallback models available, proceeding with single model attempt" + ) + last_error = None for idx, attempt_model in enumerate(fallback_models): is_fallback = idx > 0 if is_fallback: # Silent rotation - only log at debug level - lib_logger.debug(f"Rate limited on previous model, trying fallback: {attempt_model}") + lib_logger.debug( + f"Rate limited on previous model, trying fallback: {attempt_model}" + ) elif has_fallbacks: - lib_logger.debug(f"Attempting primary model: {attempt_model} (with {len(fallback_models)-1} fallback(s) available)") + lib_logger.debug( + f"Attempting primary model: {attempt_model} (with {len(fallback_models) - 1} fallback(s) available)" + ) else: - lib_logger.debug(f"Attempting model: {attempt_model} (no fallbacks available)") + lib_logger.debug( + f"Attempting model: {attempt_model} (no fallbacks available)" + ) try: response_gen = await do_call(attempt_model, is_fallback) @@ -1633,10 +2085,14 @@ async def logging_stream_wrapper(): last_error = e # If this is not the last model in the fallback chain, continue to next model if idx + 1 < len(fallback_models): - lib_logger.debug(f"Rate limit hit on {attempt_model}, trying next fallback...") + lib_logger.debug( + f"Rate limit hit on {attempt_model}, trying next fallback..." + ) continue # If this was the last fallback option, log error and raise - lib_logger.warning(f"Rate limit exhausted on all fallback models (tried {len(fallback_models)} models)") + lib_logger.warning( + f"Rate limit exhausted on all fallback models (tried {len(fallback_models)} models)" + ) raise # Should not reach here, but raise last error if we do @@ -1651,7 +2107,7 @@ async def count_tokens( model: str, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, - litellm_params: Optional[Dict[str, Any]] = None + litellm_params: Optional[Dict[str, Any]] = None, ) -> Dict[str, int]: """ Counts tokens for the given prompt using the Gemini CLI :countTokens endpoint. @@ -1673,11 +2129,13 @@ async def count_tokens( # Discover project ID project_id = self.project_id_cache.get(credential_path) if not project_id: - access_token = auth_header['Authorization'].split(' ')[1] - project_id = await self._discover_project_id(credential_path, access_token, litellm_params or {}) + access_token = auth_header["Authorization"].split(" ")[1] + project_id = await self._discover_project_id( + credential_path, access_token, litellm_params or {} + ) # Handle :thinking suffix - model_name = model.split('/')[-1].replace(':thinking', '') + model_name = model.split("/")[-1].replace(":thinking", "") # Transform messages to Gemini format system_instruction, contents = self._transform_messages(messages) @@ -1695,35 +2153,41 @@ async def count_tokens( if tools: function_declarations = self._transform_tool_schemas(tools) if function_declarations: - request_payload["request"]["tools"] = [{"functionDeclarations": function_declarations}] + request_payload["request"]["tools"] = [ + {"functionDeclarations": function_declarations} + ] # Make the request url = f"{CODE_ASSIST_ENDPOINT}:countTokens" headers = auth_header.copy() - headers.update({ - "User-Agent": "google-api-nodejs-client/9.15.1", - "X-Goog-Api-Client": "gl-node/22.17.0", - "Client-Metadata": "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI", - "Accept": "application/json", - }) + headers.update( + { + "User-Agent": "google-api-nodejs-client/9.15.1", + "X-Goog-Api-Client": "gl-node/22.17.0", + "Client-Metadata": "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI", + "Accept": "application/json", + } + ) try: - response = await client.post(url, headers=headers, json=request_payload, timeout=30) + response = await client.post( + url, headers=headers, json=request_payload, timeout=30 + ) response.raise_for_status() data = response.json() # Extract token counts from response - total_tokens = data.get('totalTokens', 0) + total_tokens = data.get("totalTokens", 0) return { - 'prompt_tokens': total_tokens, - 'total_tokens': total_tokens, + "prompt_tokens": total_tokens, + "total_tokens": total_tokens, } except httpx.HTTPStatusError as e: lib_logger.error(f"Failed to count tokens: {e}") # Return 0 on error rather than raising - return {'prompt_tokens': 0, 'total_tokens': 0} + return {"prompt_tokens": 0, "total_tokens": 0} # Use the shared GeminiAuthBase for auth logic async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]: @@ -1738,9 +2202,11 @@ async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[s """ # Check for mixed tier credentials and warn if detected self._check_mixed_tier_warning() - + models = [] - env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates + env_var_ids = ( + set() + ) # Track IDs from env vars to prevent hardcoded/dynamic duplicates def extract_model_id(item) -> str: """Extract model ID from various formats (dict, string with/without provider prefix).""" @@ -1770,7 +2236,9 @@ def extract_model_id(item) -> str: # Track the ID to prevent hardcoded/dynamic duplicates if model_id: env_var_ids.add(model_id) - lib_logger.info(f"Loaded {len(static_models)} static models for gemini_cli from environment variables") + lib_logger.info( + f"Loaded {len(static_models)} static models for gemini_cli from environment variables" + ) # Source 2: Add hardcoded models (only if ID not already in env vars) for model_id in HARDCODED_MODELS: @@ -1782,7 +2250,7 @@ def extract_model_id(item) -> str: try: # Get access token for API calls auth_header = await self.get_auth_header(credential) - access_token = auth_header['Authorization'].split(' ')[1] + access_token = auth_header["Authorization"].split(" ")[1] # Try Vertex AI models endpoint # Note: Gemini may not support a simple /models endpoint like OpenAI @@ -1790,8 +2258,7 @@ def extract_model_id(item) -> str: models_url = f"https://generativelanguage.googleapis.com/v1beta/models" response = await client.get( - models_url, - headers={"Authorization": f"Bearer {access_token}"} + models_url, headers={"Authorization": f"Bearer {access_token}"} ) response.raise_for_status() @@ -1803,17 +2270,23 @@ def extract_model_id(item) -> str: for model in model_list: model_id = extract_model_id(model) # Only include Gemini models that aren't already in env vars - if model_id and model_id not in env_var_ids and model_id.startswith("gemini"): + if ( + model_id + and model_id not in env_var_ids + and model_id.startswith("gemini") + ): models.append(f"gemini_cli/{model_id}") env_var_ids.add(model_id) dynamic_count += 1 if dynamic_count > 0: - lib_logger.debug(f"Discovered {dynamic_count} additional models for gemini_cli from API") + lib_logger.debug( + f"Discovered {dynamic_count} additional models for gemini_cli from API" + ) except Exception as e: # Silently ignore dynamic discovery errors lib_logger.debug(f"Dynamic model discovery failed for gemini_cli: {e}") pass - return models \ No newline at end of file + return models diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index 8a20a64c..996f3a7e 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -3,13 +3,15 @@ import httpx import litellm + class ProviderInterface(ABC): """ An interface for API provider-specific functionality, including model discovery and custom API call handling for non-standard providers. """ + skip_cost_calculation: bool = False - + @abstractmethod async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: """ @@ -32,28 +34,38 @@ def has_custom_logic(self) -> bool: """ return False - async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: + async def acompletion( + self, client: httpx.AsyncClient, **kwargs + ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: """ Handles the entire completion call for non-standard providers. """ - raise NotImplementedError(f"{self.__class__.__name__} does not implement custom acompletion.") + raise NotImplementedError( + f"{self.__class__.__name__} does not implement custom acompletion." + ) - async def aembedding(self, client: httpx.AsyncClient, **kwargs) -> litellm.EmbeddingResponse: + async def aembedding( + self, client: httpx.AsyncClient, **kwargs + ) -> litellm.EmbeddingResponse: """Handles the entire embedding call for non-standard providers.""" - raise NotImplementedError(f"{self.__class__.__name__} does not implement custom aembedding.") - - def convert_safety_settings(self, settings: Dict[str, str]) -> Optional[List[Dict[str, Any]]]: + raise NotImplementedError( + f"{self.__class__.__name__} does not implement custom aembedding." + ) + + def convert_safety_settings( + self, settings: Dict[str, str] + ) -> Optional[List[Dict[str, Any]]]: """ Converts a generic safety settings dictionary to the provider-specific format. - + Args: settings: A dictionary with generic harm categories and thresholds. - + Returns: A list of provider-specific safety setting objects or None. """ return None - + # [NEW] Add new methods for OAuth providers async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]: """ @@ -67,23 +79,23 @@ async def proactively_refresh(self, credential_path: str): Proactively refreshes a token if it's nearing expiry. """ pass - + # [NEW] Credential Prioritization System def get_credential_priority(self, credential: str) -> Optional[int]: """ Returns the priority level for a credential. Lower numbers = higher priority (1 is highest). Returns None if provider doesn't use priorities. - + This allows providers to auto-detect credential tiers (e.g., paid vs free) and ensure higher-tier credentials are always tried first. - + Args: credential: The credential identifier (API key or path) - + Returns: Priority level (1-10) or None if no priority system - + Example: For Gemini CLI: - Paid tier credentials: priority 1 (highest) @@ -91,24 +103,53 @@ def get_credential_priority(self, credential: str) -> Optional[int]: - Unknown tier: priority 10 (lowest) """ return None - + def get_model_tier_requirement(self, model: str) -> Optional[int]: """ Returns the minimum priority tier required for a model. If a model requires priority 1, only credentials with priority <= 1 can use it. - + This allows providers to restrict certain models to specific credential tiers. For example, Gemini 3 models require paid-tier credentials. - + Args: model: The model name (with or without provider prefix) - + Returns: Minimum required priority level or None if no restrictions - + Example: For Gemini CLI: - gemini-3-*: requires priority 1 (paid tier only) - gemini-2.5-*: no restriction (None) """ - return None \ No newline at end of file + return None + + async def initialize_credentials(self, credential_paths: List[str]) -> None: + """ + Called at startup to initialize provider with all available credentials. + + Providers can override this to load cached tier data, discover priorities, + or perform any other initialization needed before the first API request. + + This is called once during startup by the BackgroundRefresher before + the main refresh loop begins. + + Args: + credential_paths: List of credential file paths for this provider + """ + pass + + def get_credential_tier_name(self, credential: str) -> Optional[str]: + """ + Returns the human-readable tier name for a credential. + + This is used for logging purposes to show which plan tier a credential belongs to. + + Args: + credential: The credential identifier (API key or path) + + Returns: + Tier name string (e.g., "free-tier", "paid-tier") or None if unknown + """ + return None diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index c72d9769..577bf4aa 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -22,24 +22,24 @@ class UsageManager: """ Manages usage statistics and cooldowns for API keys with asyncio-safe locking, asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. - + The credential rotation strategy can be configured via the `rotation_tolerance` parameter: - + - **tolerance = 0.0**: Deterministic least-used selection. The credential with the lowest usage count is always selected. This provides predictable, perfectly balanced load distribution but may be vulnerable to fingerprinting. - + - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected randomly with weights biased toward less-used ones. Credentials within 2 uses of the maximum can still be selected with reasonable probability. This provides security through unpredictability while maintaining good load balance. - + - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant selection probability. Useful for stress testing or maximum unpredictability, but may result in less balanced load distribution. - + The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` - + This ensures lower-usage credentials are preferred while tolerance controls how much randomness is introduced into the selection process. """ @@ -52,7 +52,7 @@ def __init__( ): """ Initialize the UsageManager. - + Args: file_path: Path to the usage data JSON file daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) @@ -139,7 +139,9 @@ async def _reset_daily_stats_if_needed(self): last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc ): - lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}") + lib_logger.debug( + f"Performing daily reset for key {mask_credential(key)}" + ) needs_saving = True # Reset cooldowns @@ -194,24 +196,20 @@ def _initialize_key_states(self, keys: List[str]): "models_in_use": {}, # Dict[model_name, concurrent_count] } - def _select_weighted_random( - self, - candidates: List[tuple], - tolerance: float - ) -> str: + def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str: """ Selects a credential using weighted random selection based on usage counts. - + Args: candidates: List of (credential_id, usage_count) tuples tolerance: Tolerance value for weight calculation - + Returns: Selected credential ID - + Formula: weight = (max_usage - credential_usage) + tolerance + 1 - + This formula ensures: - Lower usage = higher weight = higher selection probability - Tolerance adds variability: higher tolerance means more randomness @@ -219,63 +217,66 @@ def _select_weighted_random( """ if not candidates: raise ValueError("Cannot select from empty candidate list") - + if len(candidates) == 1: return candidates[0][0] - + # Extract usage counts usage_counts = [usage for _, usage in candidates] max_usage = max(usage_counts) - + # Calculate weights using the formula: (max - current) + tolerance + 1 weights = [] for credential, usage in candidates: weight = (max_usage - usage) + tolerance + 1 weights.append(weight) - + # Log weight distribution for debugging if lib_logger.isEnabledFor(logging.DEBUG): total_weight = sum(weights) weight_info = ", ".join( - f"{mask_credential(cred)}: w={w:.1f} ({w/total_weight*100:.1f}%)" + f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)" for (cred, _), w in zip(candidates, weights) ) - #lib_logger.debug(f"Weighted selection candidates: {weight_info}") - + # lib_logger.debug(f"Weighted selection candidates: {weight_info}") + # Random selection with weights selected_credential = random.choices( - [cred for cred, _ in candidates], - weights=weights, - k=1 + [cred for cred, _ in candidates], weights=weights, k=1 )[0] - + return selected_credential async def acquire_key( - self, available_keys: List[str], model: str, deadline: float, + self, + available_keys: List[str], + model: str, + deadline: float, max_concurrent: int = 1, - credential_priorities: Optional[Dict[str, int]] = None + credential_priorities: Optional[Dict[str, int]] = None, + credential_tier_names: Optional[Dict[str, str]] = None, ) -> str: """ Acquires the best available key using a tiered, model-aware locking strategy, respecting a global deadline and credential priorities. - + Priority Logic: - Groups credentials by priority level (1=highest, 2=lower, etc.) - Always tries highest priority (lowest number) first - Within same priority, sorts by usage count (load balancing) - Only moves to next priority if all higher-priority keys exhausted/busy - + Args: available_keys: List of credential identifiers to choose from model: Model name being requested deadline: Timestamp after which to stop trying max_concurrent: Maximum concurrent requests allowed per credential credential_priorities: Optional dict mapping credentials to priority levels (1=highest) - + credential_tier_names: Optional dict mapping credentials to tier names (for logging) + Returns: Selected credential identifier - + Raises: NoAvailableKeysError: If no key could be acquired within the deadline """ @@ -294,16 +295,16 @@ async def acquire_key( async with self._data_lock: for key in available_keys: key_data = self._usage_data.get(key, {}) - + # Skip keys on cooldown if (key_data.get("key_cooldown_until") or 0) > now or ( key_data.get("model_cooldowns", {}).get(model) or 0 ) > now: continue - + # Get priority for this key (default to 999 if not specified) priority = credential_priorities.get(key, 999) - + # Get usage count for load balancing within priority groups usage_count = ( key_data.get("daily", {}) @@ -311,58 +312,75 @@ async def acquire_key( .get(model, {}) .get("success_count", 0) ) - + # Group by priority if priority not in priority_groups: priority_groups[priority] = [] priority_groups[priority].append((key, usage_count)) - + # Try priority groups in order (1, 2, 3, ...) sorted_priorities = sorted(priority_groups.keys()) - + for priority_level in sorted_priorities: keys_in_priority = priority_groups[priority_level] - + # Within each priority group, use existing tier1/tier2 logic tier1_keys, tier2_keys = [], [] for key, usage_count in keys_in_priority: key_state = self.key_states[key] - + # Tier 1: Completely idle keys (preferred) if not key_state["models_in_use"]: tier1_keys.append((key, usage_count)) # Tier 2: Keys that can accept more concurrent requests elif key_state["models_in_use"].get(model, 0) < max_concurrent: tier2_keys.append((key, usage_count)) - + # Apply weighted random selection or deterministic sorting - selection_method = "weighted-random" if self.rotation_tolerance > 0 else "least-used" - + selection_method = ( + "weighted-random" + if self.rotation_tolerance > 0 + else "least-used" + ) + if self.rotation_tolerance > 0: # Weighted random selection within each tier if tier1_keys: - selected_key = self._select_weighted_random(tier1_keys, self.rotation_tolerance) - tier1_keys = [(k, u) for k, u in tier1_keys if k == selected_key] + selected_key = self._select_weighted_random( + tier1_keys, self.rotation_tolerance + ) + tier1_keys = [ + (k, u) for k, u in tier1_keys if k == selected_key + ] if tier2_keys: - selected_key = self._select_weighted_random(tier2_keys, self.rotation_tolerance) - tier2_keys = [(k, u) for k, u in tier2_keys if k == selected_key] + selected_key = self._select_weighted_random( + tier2_keys, self.rotation_tolerance + ) + tier2_keys = [ + (k, u) for k, u in tier2_keys if k == selected_key + ] else: # Deterministic: sort by usage within each tier tier1_keys.sort(key=lambda x: x[1]) tier2_keys.sort(key=lambda x: x[1]) - + # Try to acquire from Tier 1 first for key, usage in tier1_keys: state = self.key_states[key] async with state["lock"]: if not state["models_in_use"]: state["models_in_use"][model] = 1 + tier_name = ( + credential_tier_names.get(key, "unknown") + if credential_tier_names + else "unknown" + ) lib_logger.info( - f"Acquired Priority-{priority_level} Tier-1 key {mask_credential(key)} for model {model} " - f"(selection: {selection_method}, usage: {usage})" + f"Acquired key {mask_credential(key)} for model {model} " + f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, usage: {usage})" ) return key - + # Then try Tier 2 for key, usage in tier2_keys: state = self.key_states[key] @@ -370,35 +388,40 @@ async def acquire_key( current_count = state["models_in_use"].get(model, 0) if current_count < max_concurrent: state["models_in_use"][model] = current_count + 1 + tier_name = ( + credential_tier_names.get(key, "unknown") + if credential_tier_names + else "unknown" + ) lib_logger.info( - f"Acquired Priority-{priority_level} Tier-2 key {mask_credential(key)} for model {model} " - f"(selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})" + f"Acquired key {mask_credential(key)} for model {model} " + f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})" ) return key - + # If we get here, all priority groups were exhausted but keys might become available # Collect all keys across all priorities for waiting all_potential_keys = [] for keys_list in priority_groups.values(): all_potential_keys.extend(keys_list) - + if not all_potential_keys: lib_logger.warning( "No keys are eligible (all on cooldown or filtered out). Waiting before re-evaluating." ) await asyncio.sleep(1) continue - + # Wait for the highest priority key with lowest usage best_priority = min(priority_groups.keys()) best_priority_keys = priority_groups[best_priority] best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] wait_condition = self.key_states[best_wait_key]["condition"] - + lib_logger.info( f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." ) - + else: # Original logic when no priorities specified tier1_keys, tier2_keys = [], [] @@ -430,16 +453,26 @@ async def acquire_key( tier2_keys.append((key, usage_count)) # Apply weighted random selection or deterministic sorting - selection_method = "weighted-random" if self.rotation_tolerance > 0 else "least-used" - + selection_method = ( + "weighted-random" if self.rotation_tolerance > 0 else "least-used" + ) + if self.rotation_tolerance > 0: # Weighted random selection within each tier if tier1_keys: - selected_key = self._select_weighted_random(tier1_keys, self.rotation_tolerance) - tier1_keys = [(k, u) for k, u in tier1_keys if k == selected_key] + selected_key = self._select_weighted_random( + tier1_keys, self.rotation_tolerance + ) + tier1_keys = [ + (k, u) for k, u in tier1_keys if k == selected_key + ] if tier2_keys: - selected_key = self._select_weighted_random(tier2_keys, self.rotation_tolerance) - tier2_keys = [(k, u) for k, u in tier2_keys if k == selected_key] + selected_key = self._select_weighted_random( + tier2_keys, self.rotation_tolerance + ) + tier2_keys = [ + (k, u) for k, u in tier2_keys if k == selected_key + ] else: # Deterministic: sort by usage within each tier tier1_keys.sort(key=lambda x: x[1]) @@ -451,9 +484,15 @@ async def acquire_key( async with state["lock"]: if not state["models_in_use"]: state["models_in_use"][model] = 1 + tier_name = ( + credential_tier_names.get(key) + if credential_tier_names + else None + ) + tier_info = f"tier: {tier_name}, " if tier_name else "" lib_logger.info( - f"Acquired Tier 1 key {mask_credential(key)} for model {model} " - f"(selection: {selection_method}, usage: {usage})" + f"Acquired key {mask_credential(key)} for model {model} " + f"({tier_info}selection: {selection_method}, usage: {usage})" ) return key @@ -464,9 +503,15 @@ async def acquire_key( current_count = state["models_in_use"].get(model, 0) if current_count < max_concurrent: state["models_in_use"][model] = current_count + 1 + tier_name = ( + credential_tier_names.get(key) + if credential_tier_names + else None + ) + tier_info = f"tier: {tier_name}, " if tier_name else "" lib_logger.info( - f"Acquired Tier 2 key {mask_credential(key)} for model {model} " - f"(selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})" + f"Acquired key {mask_credential(key)} for model {model} " + f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})" ) return key @@ -506,8 +551,6 @@ async def acquire_key( f"Could not acquire a key for model {model} within the global time budget." ) - - async def release_key(self, key: str, model: str): """Releases a key's lock for a specific model and notifies waiting tasks.""" if key not in self.key_states: @@ -640,8 +683,11 @@ async def record_success( await self._save_usage() async def record_failure( - self, key: str, model: str, classified_error: ClassifiedError, - increment_consecutive_failures: bool = True + self, + key: str, + model: str, + classified_error: ClassifiedError, + increment_consecutive_failures: bool = True, ): """Records a failure and applies cooldowns based on an escalating backoff strategy. @@ -705,7 +751,9 @@ async def record_failure( # If cooldown wasn't set by specific error type, use escalating backoff if cooldown_seconds is None: backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120} - cooldown_seconds = backoff_tiers.get(count, 7200) # Default to 2 hours for "spent" keys + cooldown_seconds = backoff_tiers.get( + count, 7200 + ) # Default to 2 hours for "spent" keys lib_logger.warning( f"Failure #{count} for key {mask_credential(key)} with model {model}. " f"Error type: {classified_error.error_type}"