diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 5d43b61..2ce1038 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -10,6 +10,7 @@ The project is a monorepo containing two primary components: * **Batch Manager**: Optimizes high-volume embedding requests. * **Detailed Logger**: Provides per-request file logging for debugging. * **OpenAI-Compatible Endpoints**: `/v1/chat/completions`, `/v1/embeddings`, etc. + * **Model Filter GUI**: Visual interface for configuring model ignore/whitelist rules per provider (see Section 6). 2. **The Resilience Library (`rotator_library`)**: This is the core engine that provides high availability. It is consumed by the proxy app to manage a pool of API keys, handle errors gracefully, and ensure requests are completed successfully even when individual keys or provider endpoints face issues. This architecture cleanly separates the API interface from the resilience logic, making the library a portable and powerful tool for any application needing robust API key management. @@ -1145,3 +1146,83 @@ stats = cache.get_stats() # Includes: {"disk_available": True, "disk_errors": 0, ...} ``` +--- + +## 6. Model Filter GUI + +The Model Filter GUI (`model_filter_gui.py`) provides a visual interface for configuring model ignore and whitelist rules per provider. It replaces the need to manually edit `IGNORE_MODELS_*` and `WHITELIST_MODELS_*` environment variables. + +### 6.1. Overview + +**Purpose**: Visually manage which models are exposed via the `/v1/models` endpoint for each provider. + +**Launch**: +```bash +python -c "from src.proxy_app.model_filter_gui import run_model_filter_gui; run_model_filter_gui()" +``` + +Or via the launcher TUI if integrated. + +### 6.2. Features + +#### Core Functionality + +- **Provider Selection**: Dropdown to switch between available providers with automatic model fetching +- **Ignore Rules**: Pattern-based rules (supports wildcards like `*-preview`, `gpt-4*`) to exclude models +- **Whitelist Rules**: Pattern-based rules to explicitly include models, overriding ignore rules +- **Real-time Preview**: Typing in rule input fields highlights affected models before committing +- **Rule-Model Linking**: Click a model to highlight the affecting rule; click a rule to highlight all affected models +- **Persistence**: Rules saved to `.env` file in standard `IGNORE_MODELS_` and `WHITELIST_MODELS_` format + +#### Dual-Pane Model View + +The interface displays two synchronized lists: + +| Left Pane | Right Pane | +|-----------|------------| +| All fetched models (plain text) | Same models with color-coded status | +| Shows total count | Shows available/ignored count | +| Scrolls in sync with right pane | Color indicates affecting rule | + +**Color Coding**: +- **Green**: Model is available (no rule affects it, or whitelisted) +- **Red/Orange tones**: Model is ignored (color matches the specific ignore rule) +- **Blue/Teal tones**: Model is explicitly whitelisted (color matches the whitelist rule) + +#### Rule Management + +- **Comma-separated input**: Add multiple rules at once (e.g., `*-preview, *-beta, gpt-3.5*`) +- **Wildcard support**: `*` matches any characters (e.g., `gemini-*-preview`) +- **Affected count**: Each rule shows how many models it affects +- **Tooltips**: Hover over a rule to see the list of affected models +- **Instant delete**: Click the × button to remove a rule immediately + +### 6.3. Keyboard Shortcuts + +| Shortcut | Action | +|----------|--------| +| `Ctrl+S` | Save changes to `.env` | +| `Ctrl+R` | Refresh models from provider | +| `Ctrl+F` | Focus search field | +| `F1` | Show help dialog | +| `Escape` | Clear search / Clear highlights | + +### 6.4. Context Menu + +Right-click on any model to access: + +- **Add to Ignore List**: Creates an ignore rule for the exact model name +- **Add to Whitelist**: Creates a whitelist rule for the exact model name +- **View Affecting Rule**: Highlights the rule that affects this model +- **Copy Model Name**: Copies the full model ID to clipboard + +### 6.5. Integration with Proxy + +The GUI modifies the same environment variables that the `RotatingClient` reads: + +1. **GUI saves rules** → Updates `.env` file +2. **Proxy reads on startup** → Loads `IGNORE_MODELS_*` and `WHITELIST_MODELS_*` +3. **Proxy applies rules** → `get_available_models()` filters based on rules + +**Note**: The proxy must be restarted to pick up rule changes made via the GUI (or use the Launcher TUI's reload functionality if available). + diff --git a/requirements.txt b/requirements.txt index edb2bce..64f6aca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,6 @@ aiohttp colorlog rich + +# GUI for model filter configuration +customtkinter diff --git a/src/proxy_app/ai_assistant/DESIGN.md b/src/proxy_app/ai_assistant/DESIGN.md new file mode 100644 index 0000000..6ead5bb --- /dev/null +++ b/src/proxy_app/ai_assistant/DESIGN.md @@ -0,0 +1,1254 @@ +# AI Assistant System - Design Document + +**Version**: 1.1 +**Status**: Draft +**Target Window**: Model Filter Configuration GUI (prototype) + +--- + +## 1. Overview + +### 1.1 Purpose + +A reusable AI assistant system that can be integrated into any GUI tool window. The assistant has full context of the window's state, can execute actions via tools, maintains checkpoints for undo capability, and supports streaming responses with thinking visibility. + +### 1.2 Core Principles + +| Principle | Description | +|-----------|-------------| +| **Reusability** | Window-agnostic core with window-specific adapters | +| **Full Context** | Assistant always has complete visibility into window state | +| **Agentic** | Multi-tool execution, self-correction, error handling | +| **Non-destructive** | Checkpoint system prevents data loss | +| **Responsive** | Streaming responses with thinking display | + +### 1.3 Implementation Strategy + +**Prototype Phase**: Pop-out window only, no embedded panel. The embedded compact mode will be added after the pop-out is tested and refined. + +--- + +## 2. System Architecture + +### 2.1 Component Hierarchy + +``` ++-------------------------------------------------------------------+ +| GUI Window | +| (e.g., ModelFilterGUI) | +| +---------------------------------------------------------------+| +| | WindowContextAdapter || +| | - Implements window-specific context extraction || +| | - Registers window-specific tools || +| | - Provides window-specific system prompt || +| +---------------------------------------------------------------+| ++-------------------------------------------------------------------+ + | + v ++-------------------------------------------------------------------+ +| AIAssistantCore | +| +--------------+ +--------------+ +--------------------------+ | +| | ChatSession | | ToolExecutor | | CheckpointManager | | +| | - History | | - Registry | | - Snapshots + Deltas | | +| | - Context | | - Validation | | - Hybrid storage | | +| | - Streaming | | - Execution | | - Temp file backup | | +| +--------------+ +--------------+ +--------------------------+ | ++-------------------------------------------------------------------+ + | + v ++-------------------------------------------------------------------+ +| LLMBridge | +| - Wraps RotatingClient | +| - Thread/async coordination | +| - Streaming chunk processing | +| - Model selection | ++-------------------------------------------------------------------+ + | + v ++-------------------------------------------------------------------+ +| AIChatWindow (UI) | +| +---------------------------------------------------------------+| +| | Popped-Out Mode || +| | - Full message display (dynamically sized) || +| | - Expanded input || +| | - Model selector (grouped by provider) || +| | - Thinking sections (collapsible, auto-collapse) || +| | - Checkpoint dropdown || +| | - All features enabled || +| +---------------------------------------------------------------+| ++-------------------------------------------------------------------+ +``` + +### 2.2 Data Flow + +``` +User Input + | + v +[Queue if busy] ------------------------------------------+ + | | + v | +WindowContextAdapter.get_full_context() | + | | + v | +Diff against last_known_context | + | | + v | +Build messages array: | + - Base system prompt | + - Window-specific system prompt | + - Context injection | + - Conversation history | + - User message | + | | + v | +LLMBridge.stream_completion() | + | | + +---> [Thinking chunks] ---> Display (collapsible) | + | | + +---> [Content chunks] ---> Display streaming | + | | + +---> [Tool calls] ---+ | + v | + Parse tool calls | + | | + v | + Has write tools? | + | | | + YES NO | + | | | + v | | + Create checkpoint | | + (if not exists | | + for this response)| | + | | | + v v | + Execute tools sequentially | + | | + v | + Collect results (success/failure) | + | | + v | + Feed results back to LLM | + | | + +-----> [Continue if more tool calls] ----+ + | + v + [Response complete] +``` + +--- + +## 3. Core Components + +### 3.1 WindowContextAdapter (Abstract Base Class) + +**Purpose**: Interface that each window must implement to connect to the assistant. + +**Required Methods**: + +| Method | Return Type | Description | +|--------|-------------|-------------| +| `get_full_context()` | `Dict[str, Any]` | Complete structured state of the window | +| `get_window_system_prompt()` | `str` | Window-specific instructions for the AI | +| `get_tools()` | `List[ToolDefinition]` | Available tools for this window | +| `apply_state(state: Dict)` | `None` | Restore window to a given state (for checkpoints) | +| `get_state_hash()` | `str` | Quick hash for change detection | + +**Example Context Structure for ModelFilterGUI**: + +```python +{ + "window_type": "model_filter_gui", + "current_provider": "openai", + "models": { + "total_count": 45, + "available_count": 38, + "items": [ + { + "id": "openai/gpt-4o", + "display_name": "gpt-4o", + "status": "normal", # "normal" | "ignored" | "whitelisted" + "affecting_rule": null + }, + { + "id": "openai/gpt-4-turbo", + "display_name": "gpt-4-turbo", + "status": "ignored", + "affecting_rule": { + "pattern": "gpt-4-turbo*", + "type": "ignore" + } + } + # ... all models + ] + }, + "rules": { + "ignore": [ + { + "pattern": "gpt-4-turbo*", + "affected_count": 3, + "affected_models": ["gpt-4-turbo", "gpt-4-turbo-preview", "..."] + }, + { + "pattern": "*-preview", + "affected_count": 5, + "affected_models": ["..."] + } + ], + "whitelist": [ + { + "pattern": "gpt-4o", + "affected_count": 1, + "affected_models": ["gpt-4o"] + } + ] + }, + "ui_state": { + "search_query": "", + "has_unsaved_changes": true, + "highlighted_rule": null, + "highlighted_models": [] + }, + "available_providers": ["openai", "gemini", "anthropic"], + "changes_since_last_message": [ + { + "type": "rule_added", + "rule_type": "ignore", + "pattern": "o1*", + "timestamp": "..." + }, + { + "type": "provider_changed", + "from": "gemini", + "to": "openai", + "timestamp": "..." + } + ] +} +``` + +### 3.2 Tool Definition System + +**Tool Decorator Syntax**: + +```python +@assistant_tool( + name="add_ignore_rule", + description="Add a pattern to the ignore list. Models matching this pattern will be blocked.", + parameters={ + "pattern": { + "type": "string", + "description": "The pattern to ignore. Supports * wildcard." + } + }, + required=["pattern"], + is_write=True # Triggers checkpoint creation +) +def tool_add_ignore_rule(self, pattern: str) -> ToolResult: + """Add an ignore rule.""" + success = self._add_ignore_pattern(pattern) + if success: + return ToolResult( + success=True, + message=f"Added ignore rule: {pattern}", + data={ + "pattern": pattern, + "affected_models": self._get_affected_models(pattern) + } + ) + else: + return ToolResult( + success=False, + message=f"Pattern '{pattern}' is already covered by existing rule", + data={"existing_rules": self._get_covering_rules(pattern)} + ) +``` + +**ToolResult Structure**: + +```python +@dataclass +class ToolResult: + success: bool + message: str # Human-readable description + data: Optional[Dict[str, Any]] = None # Structured data for AI + error_code: Optional[str] = None # Machine-readable error type +``` + +**Tool Categories for ModelFilterGUI**: + +| Category | Tool Name | Write? | Description | +|----------|-----------|--------|-------------| +| **Rules** | `add_ignore_rule` | Yes | Add pattern to ignore list | +| | `remove_ignore_rule` | Yes | Remove pattern from ignore list | +| | `add_whitelist_rule` | Yes | Add pattern to whitelist | +| | `remove_whitelist_rule` | Yes | Remove pattern from whitelist | +| | `clear_all_ignore_rules` | Yes | Clear all ignore rules | +| | `clear_all_whitelist_rules` | Yes | Clear all whitelist rules | +| | `import_rules` | Yes | Bulk import rules | +| **Query** | `get_models_matching_pattern` | No | Preview pattern matches | +| | `get_model_details` | No | Get details for specific model | +| | `explain_model_status` | No | Explain why a model has its status | +| **Provider** | `switch_provider` | No | Change active provider | +| | `refresh_models` | No | Refetch models from provider | +| **State** | `save_changes` | Yes | Save to .env file | +| | `discard_changes` | Yes | Revert to saved state | + +### 3.3 CheckpointManager + +**Checkpoint Strategy**: Hybrid snapshot/delta approach. + +- Full snapshot stored every 10 checkpoints +- Deltas stored between snapshots +- All checkpoints persisted to temp file for crash recovery +- On rollback: Load nearest full snapshot, apply deltas forward to target + +``` +[Full #0] -> D1 -> D2 -> ... -> D9 -> [Full #10] -> D11 -> ... -> D19 -> [Full #20] +``` + +**Checkpoint Structure**: + +```python +@dataclass +class Checkpoint: + id: str # UUID + timestamp: datetime + description: str # Auto-generated from tools + tool_calls: List[ToolCallSummary] # What tools were called + message_index: int # Conversation position at checkpoint time + + # One of these will be populated: + full_state: Optional[Dict[str, Any]] # Full snapshot (every Nth) + delta: Optional[Dict[str, Any]] # Changes from previous + + is_full_snapshot: bool +``` + +**Delta Format**: + +```python +{ + "added": { + "rules.ignore": [{"pattern": "gpt-4*", "...": "..."}] + }, + "removed": { + "rules.whitelist": [{"pattern": "claude*", "...": "..."}] + }, + "modified": { + "ui_state.search_query": {"old": "", "new": "gpt"} + } +} +``` + +**Checkpoint Creation Logic**: + +1. Before executing the first `is_write=True` tool in a response +2. Check if a checkpoint already exists for this response +3. If not, create one and proceed +4. If yes (shouldn't happen), log warning and proceed anyway + +**Rollback Algorithm**: + +```python +def rollback_to(checkpoint_id: str): + target_index = find_checkpoint_index(checkpoint_id) + + # Find nearest full snapshot at or before target + snapshot_index = find_nearest_snapshot_before(target_index) + + # Load full snapshot + state = load_full_snapshot(snapshot_index) + + # Apply deltas from snapshot to target + for i in range(snapshot_index + 1, target_index + 1): + state = apply_delta(state, checkpoints[i].delta) + + # Apply state to window (atomic operation) + window_context.apply_state(state) + + # Rollback conversation history to that point + truncate_conversation_to_checkpoint(checkpoint_id) + + # Truncate checkpoint list (remove all after target) + checkpoints = checkpoints[:target_index + 1] + + # Mark current position + current_checkpoint_position = target_index +``` + +**Apply State Mechanics**: + +The `WindowContextAdapter.apply_state()` method restores window state atomically: + +1. Update internal data structures (e.g., `filter_engine.ignore_rules`) +2. Update UI input fields (e.g., search query) +3. Call existing refresh methods (e.g., `_on_rules_changed()`, `_update_model_display()`) + +This reuses existing UI update logic rather than duplicating it. The operation is atomic: +if any step fails, the entire restore is aborted and the window remains in its pre-restore state. + +**Conversation Rollback**: + +When rolling back to a checkpoint, conversation history is also rolled back: +- All messages after the checkpoint are removed +- The conversation state matches what it was when the checkpoint was created +- This ensures AI context and window state are always synchronized + +### 3.4 LLMBridge + +**Purpose**: Bridge between async `RotatingClient` and sync GUI thread. + +**Key Responsibilities**: + +- Manage `RotatingClient` lifecycle +- Handle thread/async coordination using `threading.Thread` + `asyncio.run()` +- Process streaming chunks and route to appropriate handlers +- Parse tool calls from responses (OpenAI-compatible native JSON tools) +- Manage model list fetching (same list as `/v1/models` endpoint) + +**Streaming Callback Interface**: + +```python +callbacks = { + "on_thinking_chunk": Callable[[str], None], + "on_content_chunk": Callable[[str], None], + "on_tool_calls": Callable[[List[ToolCall]], None], + "on_error": Callable[[str], None], + "on_complete": Callable[[], None], +} +``` + +**Thread Coordination Pattern**: + +```python +def stream_completion(messages, tools, model, callbacks): + def run_in_thread(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + async def stream(): + response = await client.acompletion( + model=model, + messages=messages, + tools=tools, + stream=True + ) + + async for chunk in response: + parsed = parse_chunk(chunk) + + if parsed.reasoning_content: + schedule_on_gui_thread( + callbacks["on_thinking_chunk"], + parsed.reasoning_content + ) + + if parsed.content: + schedule_on_gui_thread( + callbacks["on_content_chunk"], + parsed.content + ) + + if parsed.tool_calls: + schedule_on_gui_thread( + callbacks["on_tool_calls"], + parsed.tool_calls + ) + + schedule_on_gui_thread(callbacks["on_complete"]) + + loop.run_until_complete(stream()) + except Exception as e: + schedule_on_gui_thread(callbacks["on_error"], str(e)) + finally: + loop.close() + + thread = threading.Thread(target=run_in_thread, daemon=True) + thread.start() +``` + +### 3.5 ChatSession + +**Purpose**: Manages conversation state and message history. + +**State**: + +```python +@dataclass +class ChatSession: + session_id: str + model: str + messages: List[Message] + pending_message: Optional[str] # Queued user message + is_streaming: bool + current_checkpoint_position: int + last_known_context_hash: str + + # Retry tracking + consecutive_invalid_tool_calls: int + max_tool_retries: int = 4 +``` + +**Message Types**: + +```python +@dataclass +class Message: + role: str # "user" | "assistant" | "tool" + content: Optional[str] + reasoning_content: Optional[str] # Thinking (from reasoning_content field) + tool_calls: Optional[List[ToolCall]] + tool_call_id: Optional[str] # For tool response messages + timestamp: datetime + +@dataclass +class ToolCall: + id: str + name: str + arguments: Dict[str, Any] + result: Optional[ToolResult] # Populated after execution +``` + +**Model Selection**: + +- Per-window, per-session persistence +- Model list fetched from `RotatingClient.get_all_available_models()` when window opens +- Model list is refreshed each time the AI chat window is opened +- Same filtered list that proxy serves at `/v1/models` endpoint +- If selected model becomes unavailable: show error, user must select another + +**Message Queue**: + +- If user sends message while streaming: queue it +- Process queued message after current response completes (not the entire agentic chain) +- Only one message can be queued at a time (new message replaces queued) + +**New Session Button**: + +When clicked: +1. Clear all conversation history +2. Clear all checkpoints +3. Model selection is preserved +4. Window state is NOT reset (current rules, models, etc. remain) +5. Fresh context snapshot is taken for the new session + +--- + +## 4. UI Components + +### 4.1 Design Principles + +**Dynamic Resizing**: + +- All panels and boxes use weight-based grid layout +- Only fixed sizes for: buttons, input field heights, padding +- Window is freely resizable by user +- Reference: Current `ModelFilterGUI` implementation using `grid_rowconfigure(weight=N)` + +**Consistency**: + +- Use same color constants as `model_filter_gui.py` +- Same font family and size scale +- Same border styles and corner radii + +### 4.2 AIChatWindow (Pop-Out) Layout + +``` ++------------------------------------------------------------------------+ +| AI Assistant - Model Filter Configuration [-] [_] [X] | ++------------------------------------------------------------------------+ +| +------------------------------------------+ +------------------------+ | +| | Model: [openai/gpt-4o v] | | Checkpoints [v] | | +| +------------------------------------------+ +------------------------+ | ++------------------------------------------------------------------------+ +| | +| +------------------------------------------------------------------+ | +| | Message Display | | +| | (scrollable, weight=3) | | +| | ----------------------------------------------------------------- | | +| | | | +| | v Thinking (collapsed - click to expand) | | +| | ---------------------------------------------------------------- | | +| | AI: I'll help you configure the model filters. I can see you | | +| | have 45 models from OpenAI, with 7 currently ignored. | | +| | | | +| | ---------------------------------------------------------------- | | +| | | | +| | You: Block all preview and experimental models | | +| | | | +| | ---------------------------------------------------------------- | | +| | | | +| | > Thinking (expanded) | | +| | +-------------------------------------------------------------+ | | +| | | I need to identify patterns that match preview and | | | +| | | experimental models. Looking at the model list, I see: | | | +| | | - gpt-4-turbo-preview | | | +| | | - gpt-4o-preview | | | +| | | ... | | | +| | +-------------------------------------------------------------+ | | +| | | | +| | AI: I'll add two patterns to block these: | | +| | - `*-preview` - blocks 5 preview models | | +| | - `*-experimental` - blocks 2 experimental models | | +| | | | +| | [checkmark] Tool: add_ignore_rule(pattern="*-preview") | | +| | Result: Added. 5 models now blocked. | | +| | | | +| | [checkmark] Tool: add_ignore_rule(pattern="*-experimental") | | +| | Result: Added. 2 models now blocked. | | +| | | | +| +------------------------------------------------------------------+ | +| | ++------------------------------------------------------------------------+ +| +------------------------------------------------------------------+ | +| | | | +| | Type your message here... | | +| | | | +| | (scrollable input, 3+ lines visible) | | +| | | | +| +------------------------------------------------------------------+ | +| [New Session] [Send ->] | ++------------------------------------------------------------------------+ +``` + +**Grid Layout Specification**: + +```python +# Window grid configuration +window.grid_columnconfigure(0, weight=1) + +# Row 0: Header (model selector + checkpoints) - fixed height +window.grid_rowconfigure(0, weight=0) + +# Row 1: Message display - weight=3 (takes most space) +window.grid_rowconfigure(1, weight=3, minsize=200) + +# Row 2: Input area - weight=1 (grows but less than messages) +window.grid_rowconfigure(2, weight=1, minsize=80) + +# Row 3: Buttons - fixed height +window.grid_rowconfigure(3, weight=0) +``` + +### 4.3 Component Details + +#### 4.3.1 Model Selector + +- Grouped dropdown by provider +- Format: `provider/model-name` +- Groups: openai, gemini, anthropic, etc. +- Persists selection for session + +``` ++--------------------------------+ +| openai/gpt-4o [v] | ++--------------------------------+ +| -- openai -- | +| gpt-4o | +| gpt-4o-mini | +| gpt-4-turbo | +| -- gemini -- | +| gemini-2.0-flash | +| gemini-1.5-pro | +| -- anthropic -- | +| claude-3-5-sonnet | ++--------------------------------+ +``` + +#### 4.3.2 Checkpoint Dropdown + +Clicking opens a popup/dropdown list: + +``` ++----------------------------------------------------------+ +| Checkpoints [X] | ++----------------------------------------------------------+ +| (*) Current State | +| --------------------------------------------------------| +| ( ) 14:32:15 - add_ignore_rule("*-preview") | +| -> Added 5 models to ignore | +| --------------------------------------------------------| +| ( ) 14:31:42 - add_ignore_rule("gpt-4*") | +| -> Added 3 models to ignore | +| --------------------------------------------------------| +| ( ) 14:30:00 - Session Start | +| -> Initial state | ++----------------------------------------------------------+ +| [Rollback to Selected] [Cancel] | ++----------------------------------------------------------+ +``` + +#### 4.3.3 Message Display + +Canvas-based virtual list for performance (reference: `VirtualModelList` in `model_filter_gui.py`). + +**Message Styling**: + +| Element | Style | +|---------|-------| +| User message | Right-aligned, accent background | +| AI message | Left-aligned, secondary background | +| Thinking block | Muted color, smaller font, collapsible | +| Tool execution | Monospace, subtle background, icon prefix | +| Tool success | Green checkmark prefix | +| Tool failure | Red X prefix | +| Timestamp | Muted, small, right-aligned | + +**Thinking Behavior**: + +- Starts expanded while streaming +- Auto-collapses when first chunk with `content` but no `reasoning_content` arrives +- Click to expand/collapse manually at any time +- Styled: muted text color, slightly smaller font, distinct background + +#### 4.3.4 Input Area + +- Multi-line text input (CTkTextbox) +- Minimum 3 lines visible +- Scrollable for longer input +- Keyboard shortcuts: + - `Ctrl+Enter`: Send message + - `Escape`: Clear input / cancel if streaming + +#### 4.3.5 Error Display + +Errors appear inline where response would be, replacing on success: + +``` ++----------------------------------------------------------+ +| [!] Connection Error | +| Failed to reach model. Check your network connection. | +| [Retry] [Cancel] | ++----------------------------------------------------------+ +``` + +- Styled with warning colors +- Replaced by response if retry succeeds +- Not added to message history + +--- + +## 5. Error Handling + +### 5.1 Invalid Tool Calls + +**Retry Logic**: + +1. If AI generates invalid tool call (bad parameters, unknown tool) +2. Silently feed error back to AI: "Tool call failed: [error]. Please correct and retry." +3. After 2nd retry failure: show subtle "Retrying..." indicator in UI +4. After 4 total failures: show error to user + +**Error Message to AI**: + +```json +{ + "role": "tool", + "tool_call_id": "call_xyz", + "content": { + "success": false, + "error": "Invalid parameter: 'patern' is not a valid parameter. Did you mean 'pattern'?", + "hint": "Please review the tool schema and retry." + } +} +``` + +### 5.2 Partial Tool Execution + +If AI calls 3 tools and 2nd fails: + +1. Tool 1 result: applied, success fed back +2. Tool 2 result: NOT applied, error fed back +3. Tool 3: still executed (errors are per-tool, not chain-breaking) + +AI receives all results and can decide how to proceed. + +### 5.3 Model Unavailability + +If selected model becomes unavailable (credential issue, rate limit): + +1. Show error in UI (not as chat message) +2. User must select different model from dropdown +3. No auto-fallback + +### 5.4 LLM Connection Errors + +1. Display inline error where response would appear +2. Provide Retry and Cancel buttons +3. If retry succeeds, error replaced by response +4. Error is NOT added to conversation history + +### 5.5 Tool Execution Timeout + +All tools have a default timeout. If a tool execution exceeds the timeout: + +1. Tool returns failure result with timeout error +2. Error is fed back to AI like any other tool failure +3. AI can decide to retry or inform user + +### 5.6 Streaming Cancellation + +If user presses Escape or clicks Cancel during streaming: + +1. Streaming is immediately stopped +2. Partial response is discarded (not added to conversation history) +3. Any tool calls that were pending are NOT executed +4. UI returns to ready state for new input + +### 5.7 Context Window Limits + +If conversation history approaches token limits: + +1. Show warning to user: "Conversation is getting long. Consider starting a new session." +2. Do NOT automatically truncate or summarize +3. User can click "New Session" to start fresh + +Given typical context windows of 120k-250k+ tokens, this should be rare. + +### 5.8 Concurrency and Window Locking + +When the assistant's turn begins (request is sent): + +1. The main window (e.g., ModelFilterGUI) is locked for user interaction +2. User cannot click buttons or modify fields during agent execution +3. Lock is released when agent turn completes (after all tool calls finish) + +This prevents race conditions between manual user actions and AI tool execution. + +--- + +## 6. System Prompts + +### 6.1 Base Assistant Prompt (All Windows) + +``` +You are an AI assistant embedded in a GUI application. Your role is to help users +accomplish tasks within this window by understanding their intent and executing +actions using the available tools. + +## Core Behaviors + +1. **Full Context Awareness**: You have complete visibility into the window's state. + Use this information to provide accurate, contextual help. + +2. **Tool Execution**: When the user requests an action, use the appropriate tools + to execute it. You may call multiple tools in sequence to accomplish complex tasks. + +3. **Verbose Feedback**: After executing tools, clearly explain what was done, + what changed, and any important consequences. Both you and the user will see + the tool results. + +4. **Error Handling**: If a tool fails, explain why and suggest alternatives. + If you receive an error about an invalid tool call, carefully re-examine the + tool schema and try again with corrected parameters. + +5. **Proactive Assistance**: If you notice potential issues or improvements, + mention them to the user. + +## Tool Execution Guidelines + +- Always confirm understanding before making destructive changes +- For bulk operations, summarize what will happen before executing +- If uncertain about user intent, ask for clarification +- Report all tool results, including partial successes +- You may call multiple tools in a single response when appropriate + +## Context Updates + +You will receive updates about changes to the window state in the +`changes_since_last_message` field. Use this to stay aware of what +the user may have done manually between messages. +``` + +### 6.2 Window-Specific Prompt (ModelFilterGUI) + +``` +## Model Filter Configuration Assistant + +You are helping the user configure model filtering rules for an LLM proxy server. + +### Domain Knowledge + +- **Ignore Rules**: Patterns that block models from being available through the proxy +- **Whitelist Rules**: Patterns that ensure models are always available (override ignore rules) +- **Pattern Syntax**: + - Exact match: `gpt-4` matches only "gpt-4" + - Wildcard: `gpt-4*` matches "gpt-4", "gpt-4-turbo", "gpt-4-vision", etc. + - Match anywhere: `*preview*` matches any model containing "preview" + +### Rule Priority + +Whitelist > Ignore > Default (available) + +A model that matches both an ignore rule and a whitelist rule will be AVAILABLE +(whitelist wins). + +### Common Tasks + +1. "Block all preview models" -> Use pattern `*-preview` or `*preview*` +2. "Only allow GPT-4o" -> Ignore `*`, whitelist `gpt-4o` +3. "What models are blocked?" -> Query the ignore rules and their affected models + +### Important Notes + +- Changes are not saved until the user explicitly saves (or you use save_changes tool) +- The `has_unsaved_changes` field in context tells you if there are pending changes +- Always inform the user if there are unsaved changes that might be lost + +### Current Context + +[Context will be injected dynamically from WindowContextAdapter.get_full_context()] +``` + +--- + +## 7. Message Flow Examples + +### 7.1 Simple Tool Execution + +**User**: "Add gpt-4* to the ignore list" + +**Processing**: +1. Parse intent: add ignore rule +2. Check checkpoint: first write tool, create checkpoint +3. Execute: `add_ignore_rule(pattern="gpt-4*")` +4. Tool returns: `{success: true, message: "Added...", data: {affected_models: [...]}}` +5. Generate response incorporating result + +**Response**: +"I've added `gpt-4*` to the ignore list. This blocks 4 models: gpt-4, gpt-4-turbo, gpt-4-vision, and gpt-4-turbo-preview." + +### 7.2 Multi-Tool with Partial Failure + +**User**: "Block all Claude and GPT-3 models" + +**Tool Calls**: +1. `add_ignore_rule(pattern="claude*")` -> Success, applied +2. `add_ignore_rule(pattern="gpt-3*")` -> Fails: "Pattern covered by existing 'gpt-*'" + +**Response**: +"I've added `claude*` to the ignore list, blocking 3 Claude models. + +However, I couldn't add `gpt-3*` because it's already covered by your existing rule `gpt-*`. All GPT-3 models are already blocked by that broader pattern." + +### 7.3 Invalid Tool Call with Self-Correction + +**User**: "Show me what matches 'preview'" + +**Attempt 1** (user doesn't see this): +`get_models_matching_pattern()` - missing required parameter +-> Error fed to AI: "Missing required parameter: pattern" + +**Attempt 2** (user doesn't see this): +`get_models_matching_pattern(pattern="preview")` +-> Error: "No matches. Hint: use wildcards for partial matching" + +**Attempt 3**: +`get_models_matching_pattern(pattern="*preview*")` +-> Success + +**Response** (user only sees final result): +"Here are 5 models containing 'preview': gpt-4-turbo-preview, gpt-4o-preview, ..." + +### 7.4 Agentic Multi-Step Task + +**User**: "Set up the filter so only the latest models from each major provider are available" + +**AI Thinking** (visible, collapsible): +"I need to: +1. First understand what models are available +2. Identify 'latest' models per provider +3. Create an ignore-all rule +4. Whitelist the latest models + +Let me check the current model list..." + +**Tool Calls** (sequence): +1. `get_model_details()` - Query current state +2. `clear_all_ignore_rules()` - Start fresh +3. `add_ignore_rule(pattern="*")` - Block everything +4. `add_whitelist_rule(pattern="gpt-4o")` - Latest OpenAI +5. `add_whitelist_rule(pattern="gemini-2.0-flash")` - Latest Gemini +6. `add_whitelist_rule(pattern="claude-3-5-sonnet")` - Latest Anthropic + +**Response**: +"I've configured the filter to only allow the latest models: + +**Whitelisted (available):** +- openai/gpt-4o +- gemini/gemini-2.0-flash +- anthropic/claude-3-5-sonnet + +**Blocked:** All other 42 models + +Note: These changes are not saved yet. Use the Save button or ask me to save when you're satisfied with the configuration." + +--- + +## 8. Implementation Phases + +### Phase 1: Core Infrastructure + +**Directory**: `src/proxy_app/ai_assistant/` + +| File | Contents | +|------|----------| +| `__init__.py` | Public exports | +| `core.py` | `AIAssistantCore`, `ChatSession`, `Message` classes | +| `tools.py` | `@assistant_tool` decorator, `ToolDefinition`, `ToolResult`, `ToolExecutor` | +| `checkpoint.py` | `CheckpointManager`, `Checkpoint`, delta/snapshot logic | +| `bridge.py` | `LLMBridge`, threading/async coordination | +| `context.py` | `WindowContextAdapter` ABC, context diffing utilities | +| `prompts.py` | Base system prompt constant | + +### Phase 2: UI Components + +**Directory**: `src/proxy_app/ai_assistant/ui/` + +| File | Contents | +|------|----------| +| `__init__.py` | Public exports | +| `chat_window.py` | `AIChatWindow` - main pop-out window | +| `message_view.py` | Canvas-based message display widget | +| `thinking.py` | Collapsible thinking section widget | +| `checkpoint_ui.py` | Checkpoint dropdown/popup widget | +| `model_selector.py` | Grouped model dropdown widget | +| `styles.py` | Colors, fonts, shared constants | + +### Phase 3: ModelFilterGUI Integration + +**Directory**: `src/proxy_app/ai_assistant/adapters/` + +| File | Contents | +|------|----------| +| `__init__.py` | Public exports | +| `model_filter.py` | `ModelFilterWindowContext` implementation, all tools | + +**Modifications to existing files**: + +| File | Changes | +|------|---------| +| `model_filter_gui.py` | Add button to open AI assistant, wire up context adapter | + +### Phase 4: Polish & Edge Cases + +| Task | Description | +|------|-------------| +| Checkpoint persistence | Save/load checkpoints to temp file | +| Model list caching | Efficient model list refresh | +| Error handling | Retry logic, user-facing errors | +| Message queue | Queue messages during streaming | +| Silent retry | 4-attempt retry for invalid tools | + +--- + +## 9. File Structure + +``` +src/proxy_app/ ++-- ai_assistant/ +| +-- __init__.py +| +-- core.py # AIAssistantCore, ChatSession +| +-- tools.py # Tool decorator and executor +| +-- checkpoint.py # CheckpointManager +| +-- bridge.py # LLMBridge (RotatingClient wrapper) +| +-- context.py # WindowContextAdapter ABC +| +-- prompts.py # Base system prompt +| +-- ui/ +| | +-- __init__.py +| | +-- chat_window.py # AIChatWindow (pop-out) +| | +-- message_view.py # Message display canvas +| | +-- thinking.py # Collapsible thinking widget +| | +-- checkpoint_ui.py # Checkpoint dropdown/popup +| | +-- model_selector.py # Grouped model dropdown +| | +-- styles.py # UI constants +| +-- adapters/ +| | +-- __init__.py +| | +-- model_filter.py # ModelFilterWindowContext +| +-- DESIGN.md # This document ++-- model_filter_gui.py # Modified to include AI assistant button ++-- ... +``` + +--- + +## 10. Keyboard Shortcuts + +| Shortcut | Context | Action | +|----------|---------|--------| +| `Ctrl+Enter` | Input focused | Send message | +| `Escape` | Input focused | Clear input | +| `Escape` | Streaming | Cancel generation, discard partial response | + +--- + +## 11. Future Considerations (Out of Scope for v1) + +These items are noted for future planning but not implemented in v1: + +1. **Embedded Compact Mode**: After pop-out is stable, add compact panel for embedding in windows +2. **Conversation Persistence**: Save/load conversation history across sessions +3. **Conversation Export**: Export chat as markdown/text +4. **Custom Model Aliases**: User-defined shortcuts like "smart" -> "openai/gpt-4o" +5. **Multiple Sessions**: Support multiple concurrent assistant windows +6. **Voice Input**: Speech-to-text for input +7. **Image Support**: For multimodal models, support image context + +--- + +## 12. Dependencies + +**Required**: +- `customtkinter` - Already used by ModelFilterGUI +- `threading` - Standard library +- `asyncio` - Standard library +- `json` - Standard library +- `hashlib` - For context hashing +- `tempfile` - For checkpoint persistence +- `uuid` - For checkpoint IDs +- `dataclasses` - For data structures +- `abc` - For WindowContextAdapter +- `functools` - For decorator implementation +- `logging` - Standard library, for error/warning logging + +**From existing codebase**: +- `rotator_library.client.RotatingClient` - LLM communication +- UI constants from `model_filter_gui.py` - Colors, fonts, etc. + +--- + +## 13. Logging + +The AI assistant system logs errors, warnings, and important events to file. + +**Log Levels**: + +| Level | What is logged | +|-------|----------------| +| ERROR | Tool execution failures, LLM connection errors, checkpoint restore failures | +| WARNING | Invalid tool call retries, context size approaching limits, timeout occurrences | +| INFO | Session start/end, checkpoint creation, model changes | +| DEBUG | Full request/response payloads (disabled by default) | + +**Log Location**: Uses existing application logging infrastructure. + +**What is NOT logged**: +- Full conversation history (privacy) +- User input content (unless DEBUG level) +- Sensitive context data + +Request-response logs from `RotatingClient` already capture LLM interaction details, so the assistant layer focuses on assistant-specific events. + +--- + +## Appendix A: Context Diff Format + +When tracking changes between LLM calls, the diff format is: + +```python +{ + "changes_since_last_message": [ + { + "type": "rule_added", + "rule_type": "ignore", # or "whitelist" + "pattern": "gpt-4*", + "timestamp": "2024-01-15T14:32:15Z" + }, + { + "type": "rule_removed", + "rule_type": "ignore", + "pattern": "old-pattern*", + "timestamp": "2024-01-15T14:32:10Z" + }, + { + "type": "provider_changed", + "from": "openai", + "to": "gemini", + "timestamp": "2024-01-15T14:31:00Z" + }, + { + "type": "models_refreshed", + "provider": "openai", + "new_count": 45, + "timestamp": "2024-01-15T14:30:00Z" + }, + { + "type": "search_changed", + "query": "gpt", + "timestamp": "2024-01-15T14:29:00Z" + }, + { + "type": "changes_saved", + "timestamp": "2024-01-15T14:28:00Z" + }, + { + "type": "changes_discarded", + "timestamp": "2024-01-15T14:27:00Z" + } + ] +} +``` + +--- + +## Appendix B: OpenAI Tool Format + +Tools are sent to the LLM in OpenAI-compatible format: + +```json +{ + "tools": [ + { + "type": "function", + "function": { + "name": "add_ignore_rule", + "description": "Add a pattern to the ignore list...", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "The pattern to ignore..." + } + }, + "required": ["pattern"] + } + } + } + ] +} +``` + +Tool calls are received as: + +```json +{ + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "add_ignore_rule", + "arguments": "{\"pattern\": \"gpt-4*\"}" + } + } + ] +} +``` + +Tool results are sent back as: + +```json +{ + "role": "tool", + "tool_call_id": "call_abc123", + "content": "{\"success\": true, \"message\": \"Added...\", \"data\": {...}}" +} +``` + +--- + +*End of Design Document* diff --git a/src/proxy_app/ai_assistant/__init__.py b/src/proxy_app/ai_assistant/__init__.py new file mode 100644 index 0000000..99a2825 --- /dev/null +++ b/src/proxy_app/ai_assistant/__init__.py @@ -0,0 +1,76 @@ +""" +AI Assistant System for GUI Windows. + +A reusable AI assistant that can be integrated into any GUI tool window. +Provides full context awareness, tool execution, checkpoints for undo, +and streaming responses with thinking visibility. + +Main Components: +- AIAssistantCore: Main orchestration class +- WindowContextAdapter: Abstract base for window integration +- LLMBridge: Async LLM communication bridge +- CheckpointManager: Undo/rollback capability +- Tool system: @assistant_tool decorator and execution + +Usage: + from proxy_app.ai_assistant import AIAssistantCore, WindowContextAdapter + + class MyWindowAdapter(WindowContextAdapter): + # Implement abstract methods + ... + + core = AIAssistantCore( + window_adapter=my_adapter, + schedule_on_gui=lambda fn: window.after(0, fn), + ) +""" + +from .assistant_logger import AssistantLogger +from .bridge import LLMBridge, StreamCallbacks +from .checkpoint import Checkpoint, CheckpointManager +from .context import ( + WindowContextAdapter, + apply_delta, + compute_context_diff, + compute_delta, +) +from .core import AIAssistantCore, ChatSession, Message +from .prompts import BASE_ASSISTANT_PROMPT, MODEL_FILTER_SYSTEM_PROMPT +from .tools import ( + ToolCall, + ToolCallSummary, + ToolDefinition, + ToolExecutor, + ToolResult, + assistant_tool, +) + +__all__ = [ + # Core + "AIAssistantCore", + "ChatSession", + "Message", + # Bridge + "LLMBridge", + "StreamCallbacks", + # Checkpoint + "CheckpointManager", + "Checkpoint", + # Context + "WindowContextAdapter", + "compute_context_diff", + "compute_delta", + "apply_delta", + # Tools + "assistant_tool", + "ToolDefinition", + "ToolResult", + "ToolCall", + "ToolCallSummary", + "ToolExecutor", + # Prompts + "BASE_ASSISTANT_PROMPT", + "MODEL_FILTER_SYSTEM_PROMPT", + # Logging + "AssistantLogger", +] diff --git a/src/proxy_app/ai_assistant/adapters/__init__.py b/src/proxy_app/ai_assistant/adapters/__init__.py new file mode 100644 index 0000000..58113d8 --- /dev/null +++ b/src/proxy_app/ai_assistant/adapters/__init__.py @@ -0,0 +1,11 @@ +""" +Window Context Adapters for the AI Assistant. + +Each adapter connects a specific GUI window to the AI assistant system. +""" + +from .model_filter import ModelFilterWindowContext + +__all__ = [ + "ModelFilterWindowContext", +] diff --git a/src/proxy_app/ai_assistant/adapters/model_filter.py b/src/proxy_app/ai_assistant/adapters/model_filter.py new file mode 100644 index 0000000..bfe6214 --- /dev/null +++ b/src/proxy_app/ai_assistant/adapters/model_filter.py @@ -0,0 +1,940 @@ +""" +ModelFilterGUI Window Context Adapter. + +Implements WindowContextAdapter for the Model Filter GUI, providing: +- Full context extraction from FilterEngine and UI state +- All tools for manipulating filter rules +- State application for checkpoint rollback +""" + +import copy +import hashlib +import json +import logging +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING + +from ..context import WindowContextAdapter +from ..prompts import MODEL_FILTER_SYSTEM_PROMPT +from ..tools import ToolDefinition, ToolResult, assistant_tool + +if TYPE_CHECKING: + from ...model_filter_gui import ModelFilterGUI + +logger = logging.getLogger(__name__) + + +class ModelFilterWindowContext(WindowContextAdapter): + """ + Window context adapter for the Model Filter GUI. + + Provides complete context and tool access for the AI assistant + to help users configure model filtering rules. + """ + + def __init__(self, gui: "ModelFilterGUI"): + """ + Initialize the adapter. + + Args: + gui: The ModelFilterGUI instance + """ + self._gui = gui + self._last_context_hash: str = "" + self._changes: List[Dict[str, Any]] = [] + self._tracking_changes: bool = False + + # ========================================================================= + # WindowContextAdapter Implementation + # ========================================================================= + + def get_full_context(self) -> Dict[str, Any]: + """Get the complete structured state of the window.""" + gui = self._gui + engine = gui.filter_engine + + # Get current models + models = gui.models or [] + + # Build model list with statuses + model_items = [] + engine.get_all_statuses(models) # Ensure cache is valid + + for model_id in models: + status = engine.get_model_status(model_id) + item = { + "id": model_id, + "display_name": status.display_name, + "status": status.status, # "normal", "ignored", "whitelisted" + } + if status.affecting_rule: + item["affecting_rule"] = { + "pattern": status.affecting_rule.pattern, + "type": status.affecting_rule.rule_type, + } + else: + item["affecting_rule"] = None + model_items.append(item) + + # Build rules lists + ignore_rules = [] + for rule in engine.ignore_rules: + ignore_rules.append( + { + "pattern": rule.pattern, + "affected_count": rule.affected_count, + "affected_models": rule.affected_models[ + :10 + ], # Limit for context size + } + ) + + whitelist_rules = [] + for rule in engine.whitelist_rules: + whitelist_rules.append( + { + "pattern": rule.pattern, + "affected_count": rule.affected_count, + "affected_models": rule.affected_models[:10], + } + ) + + # Get counts + available, total = engine.get_available_count(models) + + # Build context + context = { + "window_type": "model_filter_gui", + "current_provider": gui.current_provider, + "models": { + "total_count": total, + "available_count": available, + "items": model_items, + }, + "rules": { + "ignore": ignore_rules, + "whitelist": whitelist_rules, + }, + "ui_state": { + "search_query": gui.search_entry.get() + if hasattr(gui, "search_entry") + else "", + "has_unsaved_changes": engine.has_unsaved_changes(), + }, + "available_providers": gui.available_providers or [], + } + + # Add changes since last message if tracking + if self._changes: + context["changes_since_last_message"] = self._changes.copy() + self._changes.clear() + + return context + + def get_window_system_prompt(self) -> str: + """Get window-specific instructions for the AI.""" + return MODEL_FILTER_SYSTEM_PROMPT + + def get_tools(self) -> List[ToolDefinition]: + """Get the list of tools available for this window.""" + # Collect all methods with @assistant_tool decorator + tools = [] + for name in dir(self): + method = getattr(self, name) + if hasattr(method, "_tool_definition"): + tools.append(method._tool_definition) + return tools + + def apply_state(self, state: Dict[str, Any]) -> None: + """Restore the window to a given state (for checkpoint rollback).""" + gui = self._gui + engine = gui.filter_engine + + try: + # Clear current rules + engine.ignore_rules.clear() + engine.whitelist_rules.clear() + engine._invalidate_cache() + + # Restore ignore rules + for rule_data in state.get("rules", {}).get("ignore", []): + engine.add_ignore_rule(rule_data["pattern"]) + + # Restore whitelist rules + for rule_data in state.get("rules", {}).get("whitelist", []): + engine.add_whitelist_rule(rule_data["pattern"]) + + # Restore search query + search_query = state.get("ui_state", {}).get("search_query", "") + if hasattr(gui, "search_entry"): + gui.search_entry.delete(0, "end") + if search_query: + gui.search_entry.insert(0, search_query) + + # Trigger UI refresh + gui._on_rules_changed() + + logger.info("State restored successfully") + + except Exception as e: + logger.exception("Failed to apply state") + raise + + def get_state_hash(self) -> str: + """Get a quick hash of the current state for change detection.""" + engine = self._gui.filter_engine + + # Hash based on rules only (quick check) + ignore_patterns = [r.pattern for r in engine.ignore_rules] + whitelist_patterns = [r.pattern for r in engine.whitelist_rules] + + state_str = json.dumps( + { + "ignore": sorted(ignore_patterns), + "whitelist": sorted(whitelist_patterns), + } + ) + return hashlib.md5(state_str.encode()).hexdigest() + + def lock_window(self) -> None: + """Lock the window to prevent user interaction during AI execution.""" + gui = self._gui + try: + # Change cursor to indicate busy + gui.configure(cursor="wait") + + # Disable interactive widgets + self._set_widgets_state("disabled") + + logger.debug("Window locked") + except Exception as e: + logger.warning(f"Failed to lock window: {e}") + + def unlock_window(self) -> None: + """Unlock the window after AI execution completes.""" + gui = self._gui + try: + # Restore cursor + gui.configure(cursor="") + + # Re-enable widgets + self._set_widgets_state("normal") + + logger.debug("Window unlocked") + except Exception as e: + logger.warning(f"Failed to unlock window: {e}") + + def _set_widgets_state(self, state: str) -> None: + """Set the state of interactive widgets.""" + gui = self._gui + + # Disable/enable key widgets + widgets_to_toggle = [ + "provider_combo", + "search_entry", + "refresh_btn", + "help_btn", + ] + + for widget_name in widgets_to_toggle: + if hasattr(gui, widget_name): + widget = getattr(gui, widget_name) + try: + widget.configure(state=state) + except Exception: + pass + + # Handle rule panels + if hasattr(gui, "ignore_panel") and hasattr(gui.ignore_panel, "pattern_entry"): + try: + gui.ignore_panel.pattern_entry.configure(state=state) + except Exception: + pass + + if hasattr(gui, "whitelist_panel") and hasattr( + gui.whitelist_panel, "pattern_entry" + ): + try: + gui.whitelist_panel.pattern_entry.configure(state=state) + except Exception: + pass + + def on_ai_started(self) -> None: + """Called when the AI starts processing a request.""" + self._tracking_changes = True + self._changes.clear() + self.lock_window() + + def on_ai_completed(self) -> None: + """Called when the AI finishes processing.""" + self._tracking_changes = False + self.unlock_window() + + def _record_change(self, change_type: str, **kwargs) -> None: + """Record a change for context updates.""" + if self._tracking_changes: + change = { + "type": change_type, + "timestamp": datetime.now().isoformat(), + **kwargs, + } + self._changes.append(change) + + # ========================================================================= + # Tool Implementations - Rule Management + # ========================================================================= + + @assistant_tool( + name="add_ignore_rule", + description="Add a pattern to the ignore list. Models matching this pattern will be blocked from the proxy.", + parameters={ + "pattern": { + "type": "string", + "description": "The pattern to ignore. Supports * wildcard for prefix matching (e.g., 'gpt-4*' matches all gpt-4 models).", + } + }, + required=["pattern"], + is_write=True, + ) + def tool_add_ignore_rule(self, pattern: str) -> ToolResult: + """Add an ignore rule.""" + gui = self._gui + engine = gui.filter_engine + models = gui.models or [] + + pattern = pattern.strip() + if not pattern: + return ToolResult( + success=False, + message="Pattern cannot be empty", + error_code="invalid_pattern", + ) + + # Check if already covered + if engine.is_pattern_covered(pattern, "ignore"): + covering_rules = [ + r.pattern + for r in engine.ignore_rules + if engine.pattern_is_covered_by(pattern, r.pattern) + ] + return ToolResult( + success=False, + message=f"Pattern '{pattern}' is already covered by existing rule(s): {covering_rules}", + data={"covering_rules": covering_rules}, + error_code="pattern_covered", + ) + + # Check for patterns this would cover (smart merge) + covered = engine.get_covered_patterns(pattern, "ignore") + + # Add the rule + rule = engine.add_ignore_rule(pattern) + if rule is None: + return ToolResult( + success=False, + message=f"Pattern '{pattern}' already exists", + error_code="duplicate_pattern", + ) + + # Remove covered patterns (smart merge) + for covered_pattern in covered: + engine.remove_ignore_rule(covered_pattern) + + # Update UI + engine.update_affected_counts(models) + gui._on_rules_changed() + + # Get affected models + affected = engine.preview_pattern(pattern, "ignore", models) + + self._record_change("rule_added", rule_type="ignore", pattern=pattern) + + message = f"Added ignore rule: {pattern}. {len(affected)} model(s) now blocked." + if covered: + message += f" Removed {len(covered)} redundant rule(s): {covered}" + + return ToolResult( + success=True, + message=message, + data={ + "pattern": pattern, + "affected_count": len(affected), + "affected_models": affected[:10], + "removed_redundant": covered, + }, + ) + + @assistant_tool( + name="remove_ignore_rule", + description="Remove a pattern from the ignore list. Models previously blocked by this pattern will become available.", + parameters={ + "pattern": { + "type": "string", + "description": "The exact pattern to remove from the ignore list.", + } + }, + required=["pattern"], + is_write=True, + ) + def tool_remove_ignore_rule(self, pattern: str) -> ToolResult: + """Remove an ignore rule.""" + gui = self._gui + engine = gui.filter_engine + models = gui.models or [] + + # Get affected models before removal + affected = engine.preview_pattern(pattern, "ignore", models) + + if engine.remove_ignore_rule(pattern): + gui._on_rules_changed() + self._record_change("rule_removed", rule_type="ignore", pattern=pattern) + + return ToolResult( + success=True, + message=f"Removed ignore rule: {pattern}. {len(affected)} model(s) now available.", + data={ + "pattern": pattern, + "models_now_available": affected[:10], + }, + ) + else: + return ToolResult( + success=False, + message=f"Pattern '{pattern}' not found in ignore rules", + error_code="pattern_not_found", + ) + + @assistant_tool( + name="add_whitelist_rule", + description="Add a pattern to the whitelist. Models matching this pattern will always be available, even if they match ignore rules.", + parameters={ + "pattern": { + "type": "string", + "description": "The pattern to whitelist. Supports * wildcard for prefix matching.", + } + }, + required=["pattern"], + is_write=True, + ) + def tool_add_whitelist_rule(self, pattern: str) -> ToolResult: + """Add a whitelist rule.""" + gui = self._gui + engine = gui.filter_engine + models = gui.models or [] + + pattern = pattern.strip() + if not pattern: + return ToolResult( + success=False, + message="Pattern cannot be empty", + error_code="invalid_pattern", + ) + + # Check if already covered + if engine.is_pattern_covered(pattern, "whitelist"): + covering_rules = [ + r.pattern + for r in engine.whitelist_rules + if engine.pattern_is_covered_by(pattern, r.pattern) + ] + return ToolResult( + success=False, + message=f"Pattern '{pattern}' is already covered by existing rule(s): {covering_rules}", + data={"covering_rules": covering_rules}, + error_code="pattern_covered", + ) + + # Add the rule + rule = engine.add_whitelist_rule(pattern) + if rule is None: + return ToolResult( + success=False, + message=f"Pattern '{pattern}' already exists", + error_code="duplicate_pattern", + ) + + # Update UI + engine.update_affected_counts(models) + gui._on_rules_changed() + + affected = engine.preview_pattern(pattern, "whitelist", models) + self._record_change("rule_added", rule_type="whitelist", pattern=pattern) + + return ToolResult( + success=True, + message=f"Added whitelist rule: {pattern}. {len(affected)} model(s) are now guaranteed available.", + data={ + "pattern": pattern, + "affected_count": len(affected), + "affected_models": affected[:10], + }, + ) + + @assistant_tool( + name="remove_whitelist_rule", + description="Remove a pattern from the whitelist.", + parameters={ + "pattern": { + "type": "string", + "description": "The exact pattern to remove from the whitelist.", + } + }, + required=["pattern"], + is_write=True, + ) + def tool_remove_whitelist_rule(self, pattern: str) -> ToolResult: + """Remove a whitelist rule.""" + gui = self._gui + engine = gui.filter_engine + + if engine.remove_whitelist_rule(pattern): + gui._on_rules_changed() + self._record_change("rule_removed", rule_type="whitelist", pattern=pattern) + + return ToolResult( + success=True, + message=f"Removed whitelist rule: {pattern}", + data={"pattern": pattern}, + ) + else: + return ToolResult( + success=False, + message=f"Pattern '{pattern}' not found in whitelist rules", + error_code="pattern_not_found", + ) + + @assistant_tool( + name="clear_all_ignore_rules", + description="Remove all ignore rules. All models will become available (unless blocked by other means).", + parameters={}, + required=[], + is_write=True, + ) + def tool_clear_all_ignore_rules(self) -> ToolResult: + """Clear all ignore rules.""" + gui = self._gui + engine = gui.filter_engine + + count = len(engine.ignore_rules) + if count == 0: + return ToolResult( + success=True, + message="No ignore rules to remove", + data={"removed_count": 0}, + ) + + patterns = [r.pattern for r in engine.ignore_rules] + engine.ignore_rules.clear() + engine._invalidate_cache() + gui._on_rules_changed() + + self._record_change("rules_cleared", rule_type="ignore", count=count) + + return ToolResult( + success=True, + message=f"Removed all {count} ignore rule(s)", + data={ + "removed_count": count, + "removed_patterns": patterns, + }, + ) + + @assistant_tool( + name="clear_all_whitelist_rules", + description="Remove all whitelist rules.", + parameters={}, + required=[], + is_write=True, + ) + def tool_clear_all_whitelist_rules(self) -> ToolResult: + """Clear all whitelist rules.""" + gui = self._gui + engine = gui.filter_engine + + count = len(engine.whitelist_rules) + if count == 0: + return ToolResult( + success=True, + message="No whitelist rules to remove", + data={"removed_count": 0}, + ) + + patterns = [r.pattern for r in engine.whitelist_rules] + engine.whitelist_rules.clear() + engine._invalidate_cache() + gui._on_rules_changed() + + self._record_change("rules_cleared", rule_type="whitelist", count=count) + + return ToolResult( + success=True, + message=f"Removed all {count} whitelist rule(s)", + data={ + "removed_count": count, + "removed_patterns": patterns, + }, + ) + + # ========================================================================= + # Tool Implementations - Query + # ========================================================================= + + @assistant_tool( + name="get_models_matching_pattern", + description="Preview which models would be affected by a pattern without adding it as a rule. Useful for testing patterns before applying them.", + parameters={ + "pattern": { + "type": "string", + "description": "The pattern to test. Supports * wildcard.", + } + }, + required=["pattern"], + is_write=False, + ) + def tool_get_models_matching_pattern(self, pattern: str) -> ToolResult: + """Get models matching a pattern.""" + gui = self._gui + engine = gui.filter_engine + models = gui.models or [] + + pattern = pattern.strip() + if not pattern: + return ToolResult( + success=False, + message="Pattern cannot be empty", + error_code="invalid_pattern", + ) + + matches = engine.preview_pattern(pattern, "ignore", models) + + if not matches: + # Suggest using wildcards + hint = "" + if "*" not in pattern: + hint = " Hint: Use wildcards like '*preview*' to match models containing 'preview'." + return ToolResult( + success=True, + message=f"No models match pattern '{pattern}'.{hint}", + data={ + "pattern": pattern, + "match_count": 0, + "matches": [], + }, + ) + + return ToolResult( + success=True, + message=f"Pattern '{pattern}' matches {len(matches)} model(s)", + data={ + "pattern": pattern, + "match_count": len(matches), + "matches": matches, + }, + ) + + @assistant_tool( + name="get_model_details", + description="Get detailed information about a specific model, including its current status and any rules affecting it.", + parameters={ + "model_id": { + "type": "string", + "description": "The model ID to look up (e.g., 'gpt-4' or 'openai/gpt-4').", + } + }, + required=["model_id"], + is_write=False, + ) + def tool_get_model_details(self, model_id: str) -> ToolResult: + """Get details about a specific model.""" + gui = self._gui + engine = gui.filter_engine + models = gui.models or [] + + # Try to find the model + found_model = None + for m in models: + if m == model_id or m.endswith(f"/{model_id}"): + found_model = m + break + + if not found_model: + return ToolResult( + success=False, + message=f"Model '{model_id}' not found in current provider", + error_code="model_not_found", + ) + + status = engine.get_model_status(found_model) + + data = { + "model_id": found_model, + "display_name": status.display_name, + "status": status.status, + "color": status.color, + } + + if status.affecting_rule: + data["affecting_rule"] = { + "pattern": status.affecting_rule.pattern, + "type": status.affecting_rule.rule_type, + } + + status_text = { + "normal": "available (no rules affecting it)", + "ignored": f"blocked by ignore rule '{status.affecting_rule.pattern}'" + if status.affecting_rule + else "blocked", + "whitelisted": f"whitelisted by rule '{status.affecting_rule.pattern}'" + if status.affecting_rule + else "whitelisted", + } + + return ToolResult( + success=True, + message=f"Model '{status.display_name}' is {status_text.get(status.status, status.status)}", + data=data, + ) + + @assistant_tool( + name="explain_model_status", + description="Explain why a model has its current status (normal, ignored, or whitelisted). Shows the rule priority and which rules are affecting it.", + parameters={ + "model_id": {"type": "string", "description": "The model ID to explain."} + }, + required=["model_id"], + is_write=False, + ) + def tool_explain_model_status(self, model_id: str) -> ToolResult: + """Explain why a model has its status.""" + gui = self._gui + engine = gui.filter_engine + models = gui.models or [] + + # Find the model + found_model = None + for m in models: + if m == model_id or m.endswith(f"/{model_id}"): + found_model = m + break + + if not found_model: + return ToolResult( + success=False, + message=f"Model '{model_id}' not found", + error_code="model_not_found", + ) + + # Check all matching rules + matching_ignore = [] + matching_whitelist = [] + + for rule in engine.ignore_rules: + if engine._pattern_matches(found_model, rule.pattern): + matching_ignore.append(rule.pattern) + + for rule in engine.whitelist_rules: + if engine._pattern_matches(found_model, rule.pattern): + matching_whitelist.append(rule.pattern) + + status = engine.get_model_status(found_model) + + explanation = [] + explanation.append(f"Model: {status.display_name}") + explanation.append(f"Current status: {status.status.upper()}") + explanation.append("") + explanation.append("Rule priority: Whitelist > Ignore > Normal") + explanation.append("") + + if matching_whitelist: + explanation.append(f"Matching WHITELIST rules: {matching_whitelist}") + else: + explanation.append("No matching whitelist rules") + + if matching_ignore: + explanation.append(f"Matching IGNORE rules: {matching_ignore}") + else: + explanation.append("No matching ignore rules") + + explanation.append("") + + if status.status == "whitelisted": + explanation.append( + f"Result: Model is AVAILABLE because whitelist rule '{status.affecting_rule.pattern}' takes priority" + ) + elif status.status == "ignored": + explanation.append( + f"Result: Model is BLOCKED because ignore rule '{status.affecting_rule.pattern}' matches and no whitelist overrides it" + ) + else: + explanation.append( + "Result: Model is AVAILABLE by default (no rules affect it)" + ) + + return ToolResult( + success=True, + message="\n".join(explanation), + data={ + "model_id": found_model, + "status": status.status, + "matching_ignore_rules": matching_ignore, + "matching_whitelist_rules": matching_whitelist, + "affecting_rule": status.affecting_rule.pattern + if status.affecting_rule + else None, + }, + ) + + # ========================================================================= + # Tool Implementations - Provider + # ========================================================================= + + @assistant_tool( + name="switch_provider", + description="Switch to a different provider to view and configure its model filters.", + parameters={ + "provider": { + "type": "string", + "description": "The provider name to switch to (e.g., 'openai', 'gemini', 'anthropic').", + } + }, + required=["provider"], + is_write=False, + ) + def tool_switch_provider(self, provider: str) -> ToolResult: + """Switch to a different provider.""" + gui = self._gui + available = gui.available_providers or [] + + provider = provider.lower().strip() + + if provider not in available: + return ToolResult( + success=False, + message=f"Provider '{provider}' is not available. Available providers: {available}", + data={"available_providers": available}, + error_code="provider_not_found", + ) + + if provider == gui.current_provider: + return ToolResult( + success=True, + message=f"Already on provider '{provider}'", + data={"provider": provider}, + ) + + # Switch provider via the combo box + if hasattr(gui, "provider_combo"): + gui.provider_combo.set(provider) + gui._on_provider_changed(provider) + + self._record_change( + "provider_changed", + from_provider=gui.current_provider, + to_provider=provider, + ) + + return ToolResult( + success=True, + message=f"Switched to provider '{provider}'", + data={"provider": provider}, + ) + + @assistant_tool( + name="refresh_models", + description="Refresh the model list from the current provider. Use this if models seem outdated or missing.", + parameters={}, + required=[], + is_write=False, + ) + def tool_refresh_models(self) -> ToolResult: + """Refresh the model list.""" + gui = self._gui + + # Trigger refresh + gui._refresh_models() + + self._record_change("models_refreshed", provider=gui.current_provider) + + return ToolResult( + success=True, + message=f"Refreshing models for provider '{gui.current_provider}'... The model list will update shortly.", + data={"provider": gui.current_provider}, + ) + + # ========================================================================= + # Tool Implementations - State + # ========================================================================= + + @assistant_tool( + name="save_changes", + description="Save the current rules to the .env file. Changes will persist across restarts.", + parameters={}, + required=[], + is_write=True, + ) + def tool_save_changes(self) -> ToolResult: + """Save changes to .env file.""" + gui = self._gui + engine = gui.filter_engine + + if not engine.has_unsaved_changes(): + return ToolResult( + success=True, + message="No unsaved changes to save", + data={"saved": False}, + ) + + if engine.save_to_env(gui.current_provider): + gui._update_status() + self._record_change("changes_saved") + + return ToolResult( + success=True, + message=f"Saved rules for provider '{gui.current_provider}' to .env file", + data={ + "saved": True, + "provider": gui.current_provider, + "ignore_rules": [r.pattern for r in engine.ignore_rules], + "whitelist_rules": [r.pattern for r in engine.whitelist_rules], + }, + ) + else: + return ToolResult( + success=False, + message="Failed to save changes to .env file", + error_code="save_failed", + ) + + @assistant_tool( + name="discard_changes", + description="Discard all unsaved changes and reload rules from the .env file.", + parameters={}, + required=[], + is_write=True, + ) + def tool_discard_changes(self) -> ToolResult: + """Discard unsaved changes.""" + gui = self._gui + engine = gui.filter_engine + + if not engine.has_unsaved_changes(): + return ToolResult( + success=True, + message="No unsaved changes to discard", + data={"discarded": False}, + ) + + engine.discard_changes() + gui._on_rules_changed() + + self._record_change("changes_discarded") + + return ToolResult( + success=True, + message="Discarded unsaved changes and reloaded rules from .env", + data={ + "discarded": True, + "ignore_rules": [r.pattern for r in engine.ignore_rules], + "whitelist_rules": [r.pattern for r in engine.whitelist_rules], + }, + ) diff --git a/src/proxy_app/ai_assistant/assistant_logger.py b/src/proxy_app/ai_assistant/assistant_logger.py new file mode 100644 index 0000000..89b88b8 --- /dev/null +++ b/src/proxy_app/ai_assistant/assistant_logger.py @@ -0,0 +1,390 @@ +""" +Detailed logger for AI Assistant requests and responses. + +Logs comprehensive details of each AI Assistant transaction to help debug +tool calling issues and understand the full request/response flow. +""" + +import json +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional +import logging + +from rotator_library.utils.resilient_io import ( + safe_write_json, + safe_log_write, + safe_mkdir, +) + +LOGS_DIR = Path(__file__).resolve().parent.parent.parent.parent / "logs" +ASSISTANT_LOGS_DIR = LOGS_DIR / "assistant_logs" + +logger = logging.getLogger(__name__) + + +class AssistantLogger: + """ + Logs comprehensive details of each AI Assistant conversation turn. + + Creates a directory per conversation turn containing: + - request.json: The full messages array and tools sent to the LLM + - streaming_chunks.jsonl: Each streaming chunk received + - tool_calls.json: Parsed tool calls from the response + - tool_results.json: Results from tool execution + - final_response.json: Accumulated response data + - metadata.json: Summary of the turn + """ + + def __init__(self, session_id: str): + """ + Initialize the logger for a conversation turn. + + Args: + session_id: The chat session ID + """ + self.session_id = session_id + self.turn_id = str(uuid.uuid4())[:8] + self.start_time = time.time() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.log_dir = ASSISTANT_LOGS_DIR / session_id / f"{timestamp}_{self.turn_id}" + self._dir_available = safe_mkdir(self.log_dir, logger) + + # Accumulate data for final summary + self._chunks_count = 0 + self._content_accumulated = "" + self._reasoning_accumulated = "" + self._raw_tool_calls: Dict[int, Dict[str, Any]] = {} + self._model = "" + self._error: Optional[str] = None + + def _write_json(self, filename: str, data: Dict[str, Any]) -> None: + """Helper to write data to a JSON file in the log directory.""" + if not self._dir_available: + self._dir_available = safe_mkdir(self.log_dir, logger) + if not self._dir_available: + return + + safe_write_json( + self.log_dir / filename, + data, + logger, + atomic=False, + indent=2, + ensure_ascii=False, + ) + + def log_request( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + model: str, + reasoning_effort: Optional[str] = None, + ) -> None: + """ + Log the request being sent to the LLM. + + Args: + messages: The messages array in OpenAI format + tools: The tools array in OpenAI format + model: The model being used + reasoning_effort: Optional reasoning effort level + """ + self._model = model + + request_data = { + "turn_id": self.turn_id, + "session_id": self.session_id, + "timestamp_utc": datetime.utcnow().isoformat(), + "model": model, + "reasoning_effort": reasoning_effort, + "messages_count": len(messages), + "tools_count": len(tools) if tools else 0, + "messages": messages, + "tools": tools, + } + self._write_json("request.json", request_data) + + # Also log a summary to the main logger + logger.info( + f"[AssistantLogger:{self.turn_id}] Request: model={model}, " + f"messages={len(messages)}, tools={len(tools) if tools else 0}" + ) + + def log_chunk( + self, + chunk: Any, + parsed_content: Optional[str] = None, + parsed_reasoning: Optional[str] = None, + parsed_tool_calls: Optional[List[Dict[str, Any]]] = None, + ) -> None: + """ + Log a streaming chunk. + + Args: + chunk: The raw chunk (can be string or object) + parsed_content: Extracted content from the chunk + parsed_reasoning: Extracted reasoning content from the chunk + parsed_tool_calls: Extracted tool calls from the chunk + """ + if not self._dir_available: + return + + self._chunks_count += 1 + + # Accumulate for summary + if parsed_content: + self._content_accumulated += parsed_content + if parsed_reasoning: + self._reasoning_accumulated += parsed_reasoning + + # Accumulate tool calls + if parsed_tool_calls: + for tc in parsed_tool_calls: + index = tc.get("index", 0) + if index not in self._raw_tool_calls: + self._raw_tool_calls[index] = { + "id": "", + "name": "", + "arguments": "", + "chunks": [], + } + + # Record this chunk's contribution + self._raw_tool_calls[index]["chunks"].append(tc) + + # Accumulate + if tc.get("id"): + self._raw_tool_calls[index]["id"] = tc["id"] + func = tc.get("function", {}) + if func.get("name"): + self._raw_tool_calls[index]["name"] = func["name"] + if func.get("arguments"): + self._raw_tool_calls[index]["arguments"] += func["arguments"] + + # Convert chunk to serializable format + if hasattr(chunk, "model_dump"): + chunk_data = chunk.model_dump() + elif hasattr(chunk, "__dict__"): + chunk_data = str(chunk) + else: + chunk_data = chunk + + log_entry = { + "chunk_number": self._chunks_count, + "timestamp_utc": datetime.utcnow().isoformat(), + "raw_chunk": chunk_data, + "parsed": { + "content": parsed_content, + "reasoning": parsed_reasoning, + "tool_calls": parsed_tool_calls, + }, + } + + content = json.dumps(log_entry, ensure_ascii=False, default=str) + "\n" + safe_log_write(self.log_dir / "streaming_chunks.jsonl", content, logger) + + def log_tool_calls_parsed(self, tool_calls: List[Any]) -> None: + """ + Log the final parsed tool calls. + + Args: + tool_calls: List of ToolCall objects + """ + tool_calls_data = { + "turn_id": self.turn_id, + "timestamp_utc": datetime.utcnow().isoformat(), + "tool_calls_count": len(tool_calls), + "tool_calls": [ + { + "id": tc.id, + "name": tc.name, + "arguments": tc.arguments, + "id_empty": not tc.id, + "name_empty": not tc.name, + } + for tc in tool_calls + ], + "raw_accumulated": { + str(k): { + "id": v["id"], + "name": v["name"], + "arguments": v["arguments"], + "chunk_count": len(v.get("chunks", [])), + } + for k, v in self._raw_tool_calls.items() + }, + } + self._write_json("tool_calls.json", tool_calls_data) + + # Log summary + for tc in tool_calls: + status = [] + if not tc.id: + status.append("EMPTY_ID") + if not tc.name: + status.append("EMPTY_NAME") + status_str = f" [{', '.join(status)}]" if status else "" + logger.info( + f"[AssistantLogger:{self.turn_id}] Tool call: {tc.name or '(empty)'}" + f"({json.dumps(tc.arguments)[:100]}){status_str}" + ) + + def log_tool_execution( + self, + tool_call_id: str, + tool_name: str, + arguments: Dict[str, Any], + result_success: bool, + result_message: str, + result_data: Optional[Dict[str, Any]] = None, + error_code: Optional[str] = None, + ) -> None: + """ + Log a tool execution result. + + Args: + tool_call_id: The tool call ID + tool_name: The tool name + arguments: The arguments passed to the tool + result_success: Whether the tool succeeded + result_message: The result message + result_data: Optional result data + error_code: Optional error code if failed + """ + if not self._dir_available: + return + + log_entry = { + "timestamp_utc": datetime.utcnow().isoformat(), + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "arguments": arguments, + "result": { + "success": result_success, + "message": result_message, + "data": result_data, + "error_code": error_code, + }, + } + + content = json.dumps(log_entry, ensure_ascii=False, default=str) + "\n" + safe_log_write(self.log_dir / "tool_results.jsonl", content, logger) + + # Log summary + status = "SUCCESS" if result_success else f"FAILED ({error_code})" + logger.info( + f"[AssistantLogger:{self.turn_id}] Tool result: {tool_name} -> {status}: {result_message[:100]}" + ) + + def log_error(self, error: str) -> None: + """Log an error that occurred during the turn.""" + self._error = error + logger.error(f"[AssistantLogger:{self.turn_id}] Error: {error}") + + error_data = { + "turn_id": self.turn_id, + "timestamp_utc": datetime.utcnow().isoformat(), + "error": error, + } + self._write_json("error.json", error_data) + + def log_completion(self, finish_reason: Optional[str] = None) -> None: + """ + Log the completion of this turn and write final summary. + + Args: + finish_reason: The finish reason from the LLM (if available) + """ + end_time = time.time() + duration_ms = (end_time - self.start_time) * 1000 + + # Parse accumulated tool call arguments for summary + parsed_tool_calls = [] + for index in sorted(self._raw_tool_calls.keys()): + tc_data = self._raw_tool_calls[index] + try: + args = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {} + except json.JSONDecodeError: + args = {"_parse_error": tc_data["arguments"]} + + parsed_tool_calls.append( + { + "id": tc_data["id"], + "name": tc_data["name"], + "arguments": args, + } + ) + + final_response = { + "turn_id": self.turn_id, + "session_id": self.session_id, + "timestamp_utc": datetime.utcnow().isoformat(), + "duration_ms": round(duration_ms), + "model": self._model, + "chunks_received": self._chunks_count, + "finish_reason": finish_reason, + "content": self._content_accumulated, + "reasoning_content": self._reasoning_accumulated + if self._reasoning_accumulated + else None, + "tool_calls": parsed_tool_calls if parsed_tool_calls else None, + "error": self._error, + } + self._write_json("final_response.json", final_response) + + # Metadata summary + metadata = { + "turn_id": self.turn_id, + "session_id": self.session_id, + "timestamp_utc": datetime.utcnow().isoformat(), + "duration_ms": round(duration_ms), + "model": self._model, + "chunks_received": self._chunks_count, + "content_length": len(self._content_accumulated), + "reasoning_length": len(self._reasoning_accumulated), + "tool_calls_count": len(parsed_tool_calls), + "tool_calls_summary": [ + {"name": tc["name"], "id_present": bool(tc["id"])} + for tc in parsed_tool_calls + ], + "finish_reason": finish_reason, + "had_error": self._error is not None, + } + self._write_json("metadata.json", metadata) + + logger.info( + f"[AssistantLogger:{self.turn_id}] Completed: " + f"duration={round(duration_ms)}ms, chunks={self._chunks_count}, " + f"content_len={len(self._content_accumulated)}, " + f"tool_calls={len(parsed_tool_calls)}, " + f"finish_reason={finish_reason}" + ) + + def log_messages_sent(self, messages: List[Dict[str, Any]]) -> None: + """ + Log the messages being sent in a continuation request. + + Useful for debugging the agentic loop. + + Args: + messages: The full messages array being sent + """ + messages_data = { + "turn_id": self.turn_id, + "timestamp_utc": datetime.utcnow().isoformat(), + "context": "agentic_loop_continuation", + "messages_count": len(messages), + "messages": messages, + } + + # Write to a separate file to track continuations + content = json.dumps(messages_data, ensure_ascii=False, default=str) + "\n" + safe_log_write(self.log_dir / "continuation_requests.jsonl", content, logger) + + logger.info( + f"[AssistantLogger:{self.turn_id}] Continuation request: " + f"messages={len(messages)}" + ) diff --git a/src/proxy_app/ai_assistant/bridge.py b/src/proxy_app/ai_assistant/bridge.py new file mode 100644 index 0000000..52b5a9c --- /dev/null +++ b/src/proxy_app/ai_assistant/bridge.py @@ -0,0 +1,562 @@ +""" +LLM Bridge for the AI Assistant. + +Provides the bridge between the async RotatingClient and the synchronous GUI thread. +Handles streaming, tool call parsing, and model list fetching. +""" + +import asyncio +import json +import logging +import threading +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +from .assistant_logger import AssistantLogger +from .tools import ToolCall + +logger = logging.getLogger(__name__) + + +@dataclass +class StreamCallbacks: + """Callbacks for streaming response handling.""" + + on_thinking_chunk: Optional[Callable[[str], None]] = None + on_content_chunk: Optional[Callable[[str], None]] = None + on_tool_calls: Optional[Callable[[List[ToolCall]], None]] = None + on_error: Optional[Callable[[str], None]] = None + on_complete: Optional[Callable[[], None]] = None + + +@dataclass +class ParsedChunk: + """Parsed data from a streaming chunk.""" + + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + finish_reason: Optional[str] = None + is_done: bool = False + + +class LLMBridge: + """ + Bridge between async RotatingClient and synchronous GUI thread. + + Handles: + - RotatingClient lifecycle management + - Thread/async coordination using threading.Thread + asyncio.run() + - Streaming chunk processing + - Tool call parsing + - Model list fetching + """ + + def __init__( + self, + schedule_on_gui: Callable[[Callable], None], + ignore_models: Optional[Dict[str, List[str]]] = None, + whitelist_models: Optional[Dict[str, List[str]]] = None, + session_id: Optional[str] = None, + ): + """ + Initialize the LLM Bridge. + + Args: + schedule_on_gui: Function to schedule callbacks on the GUI thread + (typically window.after(0, callback)) + ignore_models: Model patterns to ignore (passed to RotatingClient) + whitelist_models: Model patterns to whitelist (passed to RotatingClient) + session_id: Session ID for logging (optional) + """ + self._schedule_on_gui = schedule_on_gui + self._ignore_models = ignore_models + self._whitelist_models = whitelist_models + self._session_id = session_id or str(uuid.uuid4())[:8] + self._client = None + self._current_thread: Optional[threading.Thread] = None + self._cancel_requested = False + self._models_cache: Optional[Dict[str, List[str]]] = None + self._current_logger: Optional[AssistantLogger] = None + + def _get_client(self): + """Get or create the RotatingClient instance.""" + if self._client is None: + # Import here to avoid circular imports and reduce startup time + import os + + from dotenv import load_dotenv + + from rotator_library import RotatingClient + from rotator_library.credential_manager import CredentialManager + + # Load environment variables + load_dotenv(override=True) + + # Discover API keys from environment variables (same as main.py) + api_keys = {} + for key, value in os.environ.items(): + if "_API_KEY" in key and key != "PROXY_API_KEY": + provider = key.split("_API_KEY")[0].lower() + if provider not in api_keys: + api_keys[provider] = [] + api_keys[provider].append(value) + + # Discover OAuth credentials via CredentialManager + cred_manager = CredentialManager(os.environ) + oauth_credentials = cred_manager.discover_and_prepare() + + # Discover model filtering rules from environment (same as main.py) + ignore_models = self._ignore_models or {} + whitelist_models = self._whitelist_models or {} + + # Load per-provider ignore/whitelist from env vars + for key, value in os.environ.items(): + if key.startswith("IGNORE_MODELS_") and value: + provider = key.replace("IGNORE_MODELS_", "").lower() + patterns = [p.strip() for p in value.split(",") if p.strip()] + if patterns: + ignore_models[provider] = patterns + elif key.startswith("WHITELIST_MODELS_") and value: + provider = key.replace("WHITELIST_MODELS_", "").lower() + patterns = [p.strip() for p in value.split(",") if p.strip()] + if patterns: + whitelist_models[provider] = patterns + + self._client = RotatingClient( + api_keys=api_keys, + oauth_credentials=oauth_credentials, + ignore_models=ignore_models if ignore_models else None, + whitelist_models=whitelist_models if whitelist_models else None, + configure_logging=False, # Use existing logging config + ) + return self._client + + def cancel_stream(self) -> None: + """Request cancellation of the current stream.""" + self._cancel_requested = True + logger.info("Stream cancellation requested") + + def is_streaming(self) -> bool: + """Check if a stream is currently in progress.""" + return self._current_thread is not None and self._current_thread.is_alive() + + @property + def current_logger(self) -> Optional[AssistantLogger]: + """Get the current assistant logger for this turn.""" + return self._current_logger + + def set_session_id(self, session_id: str) -> None: + """Update the session ID for logging.""" + self._session_id = session_id + + def stream_completion( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + model: str, + callbacks: StreamCallbacks, + reasoning_effort: Optional[str] = None, + ) -> None: + """ + Start a streaming completion request in a background thread. + + Args: + messages: The message history in OpenAI format + tools: Tool definitions in OpenAI format + model: The model to use (e.g., "openai/gpt-4o") + callbacks: Callbacks for handling streaming events + reasoning_effort: Optional reasoning effort level ("low", "medium", "high") + """ + self._cancel_requested = False + + # Create logger for this turn + self._current_logger = AssistantLogger(self._session_id) + self._current_logger.log_request(messages, tools, model, reasoning_effort) + + def run_in_thread(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + loop.run_until_complete( + self._stream_async( + messages, tools, model, callbacks, reasoning_effort + ) + ) + except Exception as e: + logger.exception("Error in streaming thread") + if callbacks.on_error: + self._schedule_on_gui(lambda: callbacks.on_error(str(e))) + finally: + # Clean up pending tasks + try: + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + if pending: + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + loop.run_until_complete(loop.shutdown_asyncgens()) + except Exception: + pass + finally: + loop.close() + + self._current_thread = threading.Thread(target=run_in_thread, daemon=True) + self._current_thread.start() + + async def _stream_async( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + model: str, + callbacks: StreamCallbacks, + reasoning_effort: Optional[str] = None, + ) -> None: + """Async implementation of streaming completion.""" + client = self._get_client() + assistant_logger = self._current_logger + + # Accumulate tool calls across chunks + accumulated_tool_calls: Dict[int, Dict[str, Any]] = {} + finish_reason: Optional[str] = None + + try: + # Build completion kwargs + completion_kwargs = { + "model": model, + "messages": messages, + "tools": tools if tools else None, + "stream": True, + } + + # Add reasoning_effort if specified + if reasoning_effort and reasoning_effort in ("low", "medium", "high"): + completion_kwargs["reasoning_effort"] = reasoning_effort + + logger.debug(f"Starting completion request: model={model}") + response = client.acompletion(**completion_kwargs) + + async for chunk in response: + if self._cancel_requested: + logger.info("Stream cancelled by user") + break + + parsed = self._parse_chunk(chunk) + + # Track finish reason + if parsed.finish_reason: + finish_reason = parsed.finish_reason + logger.debug(f"Got finish_reason: {finish_reason}") + + if parsed.is_done: + break + + # Log chunk with parsed data + if assistant_logger: + assistant_logger.log_chunk( + chunk, + parsed_content=parsed.content, + parsed_reasoning=parsed.reasoning_content, + parsed_tool_calls=parsed.tool_calls, + ) + + # Handle reasoning/thinking content + if parsed.reasoning_content and callbacks.on_thinking_chunk: + content = parsed.reasoning_content + self._schedule_on_gui( + lambda c=content: callbacks.on_thinking_chunk(c) + ) + + # Handle regular content + if parsed.content and callbacks.on_content_chunk: + content = parsed.content + self._schedule_on_gui( + lambda c=content: callbacks.on_content_chunk(c) + ) + + # Accumulate tool calls + if parsed.tool_calls: + for tc in parsed.tool_calls: + index = tc.get("index", 0) + if index not in accumulated_tool_calls: + accumulated_tool_calls[index] = { + "id": tc.get("id", ""), + "name": "", + "arguments": "", + } + + # Accumulate ID if provided + if tc.get("id"): + accumulated_tool_calls[index]["id"] = tc["id"] + + # Accumulate function name and arguments + func = tc.get("function", {}) + if func.get("name"): + accumulated_tool_calls[index]["name"] = func["name"] + if func.get("arguments"): + accumulated_tool_calls[index]["arguments"] += func[ + "arguments" + ] + + # Process accumulated tool calls with validation + if accumulated_tool_calls and callbacks.on_tool_calls: + tool_calls = [] + skipped_count = 0 + + for index in sorted(accumulated_tool_calls.keys()): + tc_data = accumulated_tool_calls[index] + + # Validate: skip tool calls with empty name + if not tc_data["name"]: + logger.warning( + f"Skipping tool call at index {index}: empty name. " + f"ID={tc_data['id']!r}, args={tc_data['arguments'][:100]!r}" + ) + skipped_count += 1 + continue + + # Generate ID if missing + tool_call_id = tc_data["id"] + if not tool_call_id: + tool_call_id = f"call_{uuid.uuid4().hex[:8]}" + logger.warning( + f"Generated missing ID for tool call {tc_data['name']}: {tool_call_id}" + ) + + # Parse arguments + try: + arguments = ( + json.loads(tc_data["arguments"]) + if tc_data["arguments"] + else {} + ) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse tool arguments for {tc_data['name']}: " + f"{tc_data['arguments']!r} - {e}" + ) + arguments = {} + + tool_calls.append( + ToolCall( + id=tool_call_id, + name=tc_data["name"], + arguments=arguments, + ) + ) + + if skipped_count > 0: + logger.warning(f"Skipped {skipped_count} invalid tool call(s)") + + # Log parsed tool calls + if assistant_logger and tool_calls: + assistant_logger.log_tool_calls_parsed(tool_calls) + + if tool_calls: + self._schedule_on_gui( + lambda tc=tool_calls: callbacks.on_tool_calls(tc) + ) + elif skipped_count > 0 and not tool_calls: + # All tool calls were invalid - log error + error_msg = ( + f"All {skipped_count} tool call(s) were invalid (empty names)" + ) + logger.error(error_msg) + if assistant_logger: + assistant_logger.log_error(error_msg) + + # Log completion + if assistant_logger: + assistant_logger.log_completion(finish_reason) + + # Signal completion + if callbacks.on_complete and not self._cancel_requested: + self._schedule_on_gui(callbacks.on_complete) + + except Exception as e: + logger.exception("Error during streaming") + if assistant_logger: + assistant_logger.log_error(str(e)) + assistant_logger.log_completion(finish_reason="error") + if callbacks.on_error: + error_msg = self._format_error(e) + self._schedule_on_gui(lambda msg=error_msg: callbacks.on_error(msg)) + + def _parse_chunk(self, chunk: str) -> ParsedChunk: + """ + Parse a streaming chunk. + + Args: + chunk: SSE-formatted chunk string (e.g., "data: {...}\n\n") + + Returns: + ParsedChunk with extracted data + """ + result = ParsedChunk() + + # Handle SSE format + if isinstance(chunk, str): + chunk = chunk.strip() + if chunk.startswith("data: "): + chunk = chunk[6:] + + if chunk == "[DONE]": + result.is_done = True + return result + + try: + data = json.loads(chunk) + except json.JSONDecodeError: + return result + elif hasattr(chunk, "choices"): + # It's already a parsed object (litellm response) + data = chunk + else: + return result + + # Extract from choices + if hasattr(data, "choices"): + choices = data.choices + elif isinstance(data, dict): + choices = data.get("choices", []) + else: + return result + + if not choices: + return result + + choice = choices[0] if isinstance(choices, list) else choices + + # Get delta (for streaming) or message + if hasattr(choice, "delta"): + delta = choice.delta + elif isinstance(choice, dict): + delta = choice.get("delta", choice.get("message", {})) + else: + delta = {} + + # Extract content + if hasattr(delta, "content"): + result.content = delta.content + elif isinstance(delta, dict): + result.content = delta.get("content") + + # Extract reasoning content (for models that support it) + if hasattr(delta, "reasoning_content"): + result.reasoning_content = delta.reasoning_content + elif isinstance(delta, dict): + result.reasoning_content = delta.get("reasoning_content") + + # Extract tool calls + if hasattr(delta, "tool_calls"): + tool_calls = delta.tool_calls + if tool_calls: + result.tool_calls = [ + { + "index": getattr(tc, "index", i), + "id": getattr(tc, "id", None), + "function": { + "name": getattr(tc.function, "name", None) + if hasattr(tc, "function") + else None, + "arguments": getattr(tc.function, "arguments", "") + if hasattr(tc, "function") + else "", + }, + } + for i, tc in enumerate(tool_calls) + ] + elif isinstance(delta, dict) and "tool_calls" in delta: + result.tool_calls = delta["tool_calls"] + + # Extract finish reason + if hasattr(choice, "finish_reason"): + result.finish_reason = choice.finish_reason + elif isinstance(choice, dict): + result.finish_reason = choice.get("finish_reason") + + return result + + def _format_error(self, error: Exception) -> str: + """Format an exception into a user-friendly error message.""" + error_str = str(error) + + # Check for common error types + if "rate_limit" in error_str.lower() or "429" in error_str: + return "Rate limit exceeded. Please try again in a moment." + elif "quota" in error_str.lower(): + return "API quota exceeded. Please check your API usage limits." + elif "authentication" in error_str.lower() or "401" in error_str: + return "Authentication failed. Please check your API credentials." + elif "connection" in error_str.lower() or "network" in error_str.lower(): + return "Connection error. Please check your network connection." + elif "timeout" in error_str.lower(): + return "Request timed out. Please try again." + elif "all_credentials_exhausted" in error_str.lower(): + return "All API credentials exhausted. Please add more credentials or wait for rate limits to reset." + + # Return truncated error for unknown errors + if len(error_str) > 200: + return f"Error: {error_str[:200]}..." + return f"Error: {error_str}" + + def fetch_models( + self, + on_success: Callable[[Dict[str, List[str]]], None], + on_error: Callable[[str], None], + ) -> None: + """ + Fetch available models in a background thread. + + Models are grouped by provider. + + Args: + on_success: Callback with dict of provider -> model list + on_error: Callback with error message + """ + + def run_in_thread(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + client = self._get_client() + models = loop.run_until_complete( + client.get_all_available_models(grouped=True) + ) + self._models_cache = models + self._schedule_on_gui(lambda: on_success(models)) + except Exception as e: + logger.exception("Error fetching models") + self._schedule_on_gui(lambda: on_error(str(e))) + finally: + try: + loop.run_until_complete(loop.shutdown_asyncgens()) + except Exception: + pass + loop.close() + + thread = threading.Thread(target=run_in_thread, daemon=True) + thread.start() + + def get_cached_models(self) -> Optional[Dict[str, List[str]]]: + """Get the cached model list, if available.""" + return self._models_cache + + def close(self) -> None: + """Close the client and clean up resources.""" + if self._client is not None: + # Run close in a new event loop since it's async + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._client.close()) + loop.close() + except Exception as e: + logger.error(f"Error closing client: {e}") + finally: + self._client = None diff --git a/src/proxy_app/ai_assistant/checkpoint.py b/src/proxy_app/ai_assistant/checkpoint.py new file mode 100644 index 0000000..cde8d22 --- /dev/null +++ b/src/proxy_app/ai_assistant/checkpoint.py @@ -0,0 +1,408 @@ +""" +Checkpoint management for the AI Assistant. + +Provides undo capability through a hybrid snapshot/delta storage system +with temp file persistence for crash recovery. +""" + +import copy +import json +import logging +import tempfile +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .context import apply_delta, compute_delta +from .tools import ToolCallSummary + +logger = logging.getLogger(__name__) + +# Full snapshot is stored every N checkpoints +SNAPSHOT_INTERVAL = 10 + + +@dataclass +class Checkpoint: + """Represents a point-in-time snapshot of the window state.""" + + id: str # UUID + timestamp: datetime + description: str # Auto-generated from tools + tool_calls: List[ToolCallSummary] # What tools were called + message_index: int # Conversation position at checkpoint time + + # One of these will be populated: + full_state: Optional[Dict[str, Any]] = None # Full snapshot (every Nth) + delta: Optional[Dict[str, Any]] = None # Changes from previous + + is_full_snapshot: bool = False + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "id": self.id, + "timestamp": self.timestamp.isoformat(), + "description": self.description, + "tool_calls": [tc.to_dict() for tc in self.tool_calls], + "message_index": self.message_index, + "full_state": self.full_state, + "delta": self.delta, + "is_full_snapshot": self.is_full_snapshot, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint": + """Create from dictionary.""" + return cls( + id=data["id"], + timestamp=datetime.fromisoformat(data["timestamp"]), + description=data["description"], + tool_calls=[ + ToolCallSummary( + name=tc["name"], + arguments=tc["arguments"], + success=tc["success"], + message=tc["message"], + ) + for tc in data["tool_calls"] + ], + message_index=data["message_index"], + full_state=data.get("full_state"), + delta=data.get("delta"), + is_full_snapshot=data.get("is_full_snapshot", False), + ) + + def get_display_text(self) -> str: + """Get human-readable display text for UI.""" + time_str = self.timestamp.strftime("%H:%M:%S") + if self.tool_calls: + first_call = self.tool_calls[0] + tool_text = f'{first_call.name}("{list(first_call.arguments.values())[0] if first_call.arguments else ""}")' + if len(self.tool_calls) > 1: + tool_text += f" +{len(self.tool_calls) - 1} more" + return f"{time_str} - {tool_text}" + return f"{time_str} - {self.description}" + + +class CheckpointManager: + """ + Manages checkpoints for undo capability. + + Uses a hybrid snapshot/delta approach: + - Full snapshot stored every SNAPSHOT_INTERVAL checkpoints + - Deltas stored between snapshots + - All checkpoints persisted to temp file for crash recovery + """ + + def __init__(self, session_id: str): + """ + Initialize the checkpoint manager. + + Args: + session_id: Unique identifier for this session + """ + self.session_id = session_id + self._checkpoints: List[Checkpoint] = [] + self._current_position: int = ( + -1 + ) # Index of current checkpoint (-1 = no checkpoints) + self._temp_file: Optional[Path] = None + self._checkpoint_created_for_response: bool = False + + # Initialize temp file for crash recovery + self._init_temp_file() + + def _init_temp_file(self) -> None: + """Initialize the temp file for checkpoint persistence.""" + try: + temp_dir = Path(tempfile.gettempdir()) / "ai_assistant_checkpoints" + temp_dir.mkdir(exist_ok=True) + self._temp_file = temp_dir / f"{self.session_id}.json" + logger.info(f"Checkpoint temp file: {self._temp_file}") + except Exception as e: + logger.error(f"Failed to initialize temp file: {e}") + self._temp_file = None + + def _save_to_temp_file(self) -> None: + """Save checkpoints to temp file.""" + if not self._temp_file: + return + + try: + data = { + "session_id": self.session_id, + "current_position": self._current_position, + "checkpoints": [cp.to_dict() for cp in self._checkpoints], + } + self._temp_file.write_text(json.dumps(data, indent=2, default=str)) + except Exception as e: + logger.error(f"Failed to save checkpoints to temp file: {e}") + + def load_from_temp_file(self) -> bool: + """ + Load checkpoints from temp file (for crash recovery). + + Returns: + True if checkpoints were loaded, False otherwise + """ + if not self._temp_file or not self._temp_file.exists(): + return False + + try: + data = json.loads(self._temp_file.read_text()) + if data.get("session_id") != self.session_id: + return False + + self._checkpoints = [Checkpoint.from_dict(cp) for cp in data["checkpoints"]] + self._current_position = data["current_position"] + logger.info(f"Loaded {len(self._checkpoints)} checkpoints from temp file") + return True + except Exception as e: + logger.error(f"Failed to load checkpoints from temp file: {e}") + return False + + def clear_temp_file(self) -> None: + """Delete the temp file.""" + if self._temp_file and self._temp_file.exists(): + try: + self._temp_file.unlink() + except Exception as e: + logger.error(f"Failed to delete temp file: {e}") + + def start_response(self) -> None: + """ + Called when a new LLM response begins. + + Resets the flag tracking whether a checkpoint was created for this response. + """ + self._checkpoint_created_for_response = False + + def should_create_checkpoint(self) -> bool: + """ + Check if a checkpoint should be created for this response. + + Returns: + True if no checkpoint has been created for this response yet + """ + return not self._checkpoint_created_for_response + + def create_checkpoint( + self, + state: Dict[str, Any], + tool_calls: List[ToolCallSummary], + message_index: int, + description: Optional[str] = None, + ) -> Checkpoint: + """ + Create a new checkpoint. + + Args: + state: The current window state to snapshot + tool_calls: Summary of tool calls that led to this checkpoint + message_index: Current position in conversation history + description: Optional description (auto-generated if not provided) + + Returns: + The created Checkpoint + """ + checkpoint_id = str(uuid.uuid4())[:8] + timestamp = datetime.now() + + # Generate description if not provided + if description is None: + if tool_calls: + first_call = tool_calls[0] + description = f"{first_call.name} executed" + if len(tool_calls) > 1: + description += f" (+{len(tool_calls) - 1} more)" + else: + description = "State saved" + + # Determine if this should be a full snapshot + is_full = len(self._checkpoints) % SNAPSHOT_INTERVAL == 0 + + if is_full: + # Full snapshot + checkpoint = Checkpoint( + id=checkpoint_id, + timestamp=timestamp, + description=description, + tool_calls=tool_calls, + message_index=message_index, + full_state=copy.deepcopy(state), + is_full_snapshot=True, + ) + else: + # Delta from previous state + previous_state = self._reconstruct_state(len(self._checkpoints) - 1) + delta = compute_delta(previous_state, state) if previous_state else None + + checkpoint = Checkpoint( + id=checkpoint_id, + timestamp=timestamp, + description=description, + tool_calls=tool_calls, + message_index=message_index, + delta=delta if delta else {"added": [], "removed": [], "modified": []}, + is_full_snapshot=False, + ) + + self._checkpoints.append(checkpoint) + self._current_position = len(self._checkpoints) - 1 + self._checkpoint_created_for_response = True + + # Persist to temp file + self._save_to_temp_file() + + logger.info( + f"Created checkpoint {checkpoint_id}: {description} " + f"(full={is_full}, position={self._current_position})" + ) + + return checkpoint + + def _find_nearest_snapshot_before(self, index: int) -> int: + """Find the nearest full snapshot at or before the given index.""" + for i in range(index, -1, -1): + if self._checkpoints[i].is_full_snapshot: + return i + return -1 + + def _reconstruct_state(self, target_index: int) -> Optional[Dict[str, Any]]: + """ + Reconstruct the state at a given checkpoint index. + + Args: + target_index: Index of the checkpoint to reconstruct + + Returns: + The reconstructed state, or None if reconstruction fails + """ + if target_index < 0 or target_index >= len(self._checkpoints): + return None + + # Find nearest full snapshot + snapshot_index = self._find_nearest_snapshot_before(target_index) + if snapshot_index < 0: + logger.error("No full snapshot found - cannot reconstruct state") + return None + + # Start with the full snapshot + state = copy.deepcopy(self._checkpoints[snapshot_index].full_state) + if state is None: + logger.error( + f"Checkpoint {snapshot_index} marked as snapshot but has no state" + ) + return None + + # Apply deltas from snapshot to target + for i in range(snapshot_index + 1, target_index + 1): + checkpoint = self._checkpoints[i] + if checkpoint.delta: + state = apply_delta(state, checkpoint.delta) + + return state + + def get_state_at(self, checkpoint_id: str) -> Optional[Dict[str, Any]]: + """ + Get the state at a specific checkpoint. + + Args: + checkpoint_id: ID of the checkpoint + + Returns: + The state at that checkpoint, or None if not found + """ + index = self._find_checkpoint_index(checkpoint_id) + if index < 0: + return None + return self._reconstruct_state(index) + + def _find_checkpoint_index(self, checkpoint_id: str) -> int: + """Find the index of a checkpoint by ID.""" + for i, cp in enumerate(self._checkpoints): + if cp.id == checkpoint_id: + return i + return -1 + + def rollback_to(self, checkpoint_id: str) -> Optional[Dict[str, Any]]: + """ + Rollback to a specific checkpoint. + + This truncates all checkpoints after the target and returns the state + that should be applied to the window. + + Args: + checkpoint_id: ID of the checkpoint to rollback to + + Returns: + The state to apply, or None if rollback fails + """ + target_index = self._find_checkpoint_index(checkpoint_id) + if target_index < 0: + logger.error(f"Checkpoint {checkpoint_id} not found") + return None + + # Reconstruct the state + state = self._reconstruct_state(target_index) + if state is None: + return None + + # Truncate checkpoints after target + self._checkpoints = self._checkpoints[: target_index + 1] + self._current_position = target_index + + # Persist the truncated list + self._save_to_temp_file() + + logger.info( + f"Rolled back to checkpoint {checkpoint_id} (position={target_index})" + ) + + return state + + def get_checkpoints(self) -> List[Checkpoint]: + """Get all checkpoints.""" + return self._checkpoints.copy() + + def get_current_checkpoint(self) -> Optional[Checkpoint]: + """Get the current checkpoint, if any.""" + if 0 <= self._current_position < len(self._checkpoints): + return self._checkpoints[self._current_position] + return None + + def get_message_index_at(self, checkpoint_id: str) -> Optional[int]: + """ + Get the message index at a specific checkpoint. + + Args: + checkpoint_id: ID of the checkpoint + + Returns: + The message index, or None if checkpoint not found + """ + index = self._find_checkpoint_index(checkpoint_id) + if index < 0: + return None + return self._checkpoints[index].message_index + + def clear(self) -> None: + """Clear all checkpoints (for new session).""" + self._checkpoints.clear() + self._current_position = -1 + self._checkpoint_created_for_response = False + self._save_to_temp_file() + logger.info("Cleared all checkpoints") + + @property + def checkpoint_count(self) -> int: + """Get the number of checkpoints.""" + return len(self._checkpoints) + + @property + def current_position(self) -> int: + """Get the current checkpoint position.""" + return self._current_position diff --git a/src/proxy_app/ai_assistant/context.py b/src/proxy_app/ai_assistant/context.py new file mode 100644 index 0000000..b4a5f06 --- /dev/null +++ b/src/proxy_app/ai_assistant/context.py @@ -0,0 +1,246 @@ +""" +Window context adapter abstract base class and utilities. + +Each GUI window that wants to use the AI assistant must implement +the WindowContextAdapter interface. +""" + +import hashlib +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from .tools import ToolDefinition + + +class WindowContextAdapter(ABC): + """ + Abstract base class that each window must implement to connect to the AI assistant. + + The adapter provides: + - Full context extraction for the LLM + - Window-specific system prompts + - Tool definitions for the window + - State application for checkpoint rollback + """ + + @abstractmethod + def get_full_context(self) -> Dict[str, Any]: + """ + Get the complete structured state of the window. + + This should include all relevant information the AI needs to understand + the current state and make decisions. The returned dictionary will be + serialized and included in the LLM context. + + Returns: + Dictionary containing the full window state + """ + pass + + @abstractmethod + def get_window_system_prompt(self) -> str: + """ + Get window-specific instructions for the AI. + + This prompt is appended to the base assistant prompt to provide + domain-specific knowledge and guidelines. + + Returns: + String containing the window-specific system prompt + """ + pass + + @abstractmethod + def get_tools(self) -> List[ToolDefinition]: + """ + Get the list of tools available for this window. + + Returns: + List of ToolDefinition objects + """ + pass + + @abstractmethod + def apply_state(self, state: Dict[str, Any]) -> None: + """ + Restore the window to a given state (for checkpoint rollback). + + This method should atomically restore the window state. If restoration + fails partway through, the window should remain in its pre-restore state. + + Args: + state: The state dictionary to restore (from a checkpoint) + """ + pass + + def get_state_hash(self) -> str: + """ + Get a quick hash of the current state for change detection. + + Override this for more efficient change detection if needed. + Default implementation hashes the full context. + + Returns: + String hash of the current state + """ + context = self.get_full_context() + context_str = json.dumps(context, sort_keys=True, default=str) + return hashlib.md5(context_str.encode()).hexdigest() + + def lock_window(self) -> None: + """ + Lock the window to prevent user interaction during AI execution. + + Override this to implement window locking (gray out widgets, change cursor, etc.) + Default implementation does nothing. + """ + pass + + def unlock_window(self) -> None: + """ + Unlock the window after AI execution completes. + + Override this to restore window interactivity. + Default implementation does nothing. + """ + pass + + def on_ai_started(self) -> None: + """ + Called when the AI starts processing a request. + + Override to perform any setup needed (e.g., start tracking changes). + Default implementation just locks the window. + """ + self.lock_window() + + def on_ai_completed(self) -> None: + """ + Called when the AI finishes processing (success or failure). + + Override to perform any cleanup needed (e.g., stop tracking changes). + Default implementation just unlocks the window. + """ + self.unlock_window() + + +def compute_context_diff( + old_context: Dict[str, Any], new_context: Dict[str, Any] +) -> Dict[str, Any]: + """ + Compute the differences between two context snapshots. + + Args: + old_context: The previous context snapshot + new_context: The current context snapshot + + Returns: + Dictionary with 'added', 'removed', and 'modified' keys + """ + + def _diff_recursive( + old: Any, new: Any, path: str = "" + ) -> Dict[str, List[Dict[str, Any]]]: + """Recursively diff two values.""" + result: Dict[str, List[Dict[str, Any]]] = { + "added": [], + "removed": [], + "modified": [], + } + + if type(old) != type(new): + result["modified"].append({"path": path, "old": old, "new": new}) + return result + + if isinstance(old, dict) and isinstance(new, dict): + all_keys = set(old.keys()) | set(new.keys()) + for key in all_keys: + sub_path = f"{path}.{key}" if path else key + if key not in old: + result["added"].append({"path": sub_path, "value": new[key]}) + elif key not in new: + result["removed"].append({"path": sub_path, "value": old[key]}) + else: + sub_diff = _diff_recursive(old[key], new[key], sub_path) + result["added"].extend(sub_diff["added"]) + result["removed"].extend(sub_diff["removed"]) + result["modified"].extend(sub_diff["modified"]) + + elif isinstance(old, list) and isinstance(new, list): + if old != new: + result["modified"].append({"path": path, "old": old, "new": new}) + + elif old != new: + result["modified"].append({"path": path, "old": old, "new": new}) + + return result + + return _diff_recursive(old_context, new_context) + + +def compute_delta( + old_state: Dict[str, Any], new_state: Dict[str, Any] +) -> Dict[str, Any]: + """ + Compute a delta that can be applied to old_state to produce new_state. + + Used for checkpoint storage - stores only differences between states. + + Args: + old_state: The previous state + new_state: The new state + + Returns: + Delta dictionary with 'added', 'removed', 'modified' keys + """ + return compute_context_diff(old_state, new_state) + + +def apply_delta(state: Dict[str, Any], delta: Dict[str, Any]) -> Dict[str, Any]: + """ + Apply a delta to a state to produce a new state. + + Args: + state: The base state + delta: The delta to apply + + Returns: + New state with delta applied + """ + import copy + + result = copy.deepcopy(state) + + def _set_nested(d: Dict, path: str, value: Any) -> None: + """Set a nested value by path.""" + keys = path.split(".") + for key in keys[:-1]: + if key not in d: + d[key] = {} + d = d[key] + d[keys[-1]] = value + + def _delete_nested(d: Dict, path: str) -> None: + """Delete a nested value by path.""" + keys = path.split(".") + for key in keys[:-1]: + if key not in d: + return + d = d[key] + if keys[-1] in d: + del d[keys[-1]] + + # Apply additions + for item in delta.get("added", []): + _set_nested(result, item["path"], item["value"]) + + # Apply modifications + for item in delta.get("modified", []): + _set_nested(result, item["path"], item["new"]) + + # Apply removals (do this last) + for item in delta.get("removed", []): + _delete_nested(result, item["path"]) + + return result diff --git a/src/proxy_app/ai_assistant/core.py b/src/proxy_app/ai_assistant/core.py new file mode 100644 index 0000000..2bdd93c --- /dev/null +++ b/src/proxy_app/ai_assistant/core.py @@ -0,0 +1,642 @@ +""" +AI Assistant Core orchestration logic. + +Manages conversation sessions, tool execution, context injection, +and coordinates between the UI, LLM bridge, and window adapter. +""" + +import json +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional + +from .bridge import LLMBridge, StreamCallbacks +from .checkpoint import CheckpointManager +from .context import WindowContextAdapter +from .prompts import BASE_ASSISTANT_PROMPT +from .tools import ToolCall, ToolCallSummary, ToolExecutor, ToolResult + +logger = logging.getLogger(__name__) + + +@dataclass +class Message: + """Represents a message in the conversation.""" + + role: str # "user" | "assistant" | "tool" | "system" + content: Optional[str] = None + reasoning_content: Optional[str] = None # Thinking (from reasoning_content field) + tool_calls: Optional[List[ToolCall]] = None + tool_call_id: Optional[str] = None # For tool response messages + timestamp: datetime = field(default_factory=datetime.now) + + def to_openai_format(self) -> Dict[str, Any]: + """Convert to OpenAI-compatible message format.""" + msg: Dict[str, Any] = {"role": self.role} + + # Content handling: when tool_calls are present, some providers require + # content to be present (even if null/empty) + if self.content is not None: + msg["content"] = self.content + elif self.tool_calls: + # Ensure content is present when tool_calls exist + msg["content"] = None + + if self.tool_calls: + msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.arguments), + }, + } + for tc in self.tool_calls + ] + + if self.tool_call_id is not None: + msg["tool_call_id"] = self.tool_call_id + + return msg + + +@dataclass +class ChatSession: + """Manages conversation state and message history.""" + + session_id: str + model: str + messages: List[Message] = field(default_factory=list) + pending_message: Optional[str] = None # Queued user message + is_streaming: bool = False + current_checkpoint_position: int = -1 + last_known_context_hash: str = "" + + # Retry tracking + consecutive_invalid_tool_calls: int = 0 + max_tool_retries: int = 4 + + +class AIAssistantCore: + """ + Core orchestration for the AI Assistant. + + Manages: + - ChatSession lifecycle + - Tool execution with checkpoints + - Context injection and diffing + - Message queuing + - Agentic loops (multi-turn tool execution) + """ + + def __init__( + self, + window_adapter: WindowContextAdapter, + schedule_on_gui: Callable[[Callable], None], + default_model: str = "openai/gpt-4o", + ): + """ + Initialize the AI Assistant Core. + + Args: + window_adapter: The window-specific context adapter + schedule_on_gui: Function to schedule callbacks on GUI thread + default_model: Default model to use + """ + self._adapter = window_adapter + self._schedule_on_gui = schedule_on_gui + + # Create session + session_id = str(uuid.uuid4())[:8] + self._session = ChatSession(session_id=session_id, model=default_model) + + # Initialize components + self._bridge = LLMBridge(schedule_on_gui, session_id=session_id) + self._checkpoint_manager = CheckpointManager(session_id) + self._tool_executor = ToolExecutor(window_adapter.get_tools()) + + # Callbacks for UI updates + self._ui_callbacks: Dict[str, Callable] = {} + + # Current response state + self._current_thinking: str = "" + self._current_content: str = "" + self._pending_tool_calls: List[ToolCall] = [] + + # Reasoning effort setting (None = auto/don't send) + self._reasoning_effort: Optional[str] = None + + @property + def session(self) -> ChatSession: + """Get the current chat session.""" + return self._session + + @property + def checkpoint_manager(self) -> CheckpointManager: + """Get the checkpoint manager.""" + return self._checkpoint_manager + + @property + def bridge(self) -> LLMBridge: + """Get the LLM bridge.""" + return self._bridge + + def set_ui_callbacks( + self, + on_thinking_chunk: Optional[Callable[[str], None]] = None, + on_content_chunk: Optional[Callable[[str], None]] = None, + on_tool_start: Optional[Callable[[ToolCall], None]] = None, + on_tool_result: Optional[Callable[[ToolCall, ToolResult], None]] = None, + on_message_complete: Optional[Callable[[Message], None]] = None, + on_error: Optional[ + Callable[[str, bool], None] + ] = None, # (message, is_retryable) + on_stream_complete: Optional[Callable[[], None]] = None, + ) -> None: + """ + Set UI callbacks for response handling. + + Args: + on_thinking_chunk: Called with each thinking/reasoning chunk + on_content_chunk: Called with each content chunk + on_tool_start: Called when a tool execution begins + on_tool_result: Called when a tool execution completes + on_message_complete: Called when a full message is added to history + on_error: Called on errors (message, is_retryable) + on_stream_complete: Called when streaming/processing is fully complete + """ + self._ui_callbacks = { + "on_thinking_chunk": on_thinking_chunk, + "on_content_chunk": on_content_chunk, + "on_tool_start": on_tool_start, + "on_tool_result": on_tool_result, + "on_message_complete": on_message_complete, + "on_error": on_error, + "on_stream_complete": on_stream_complete, + } + + def set_model(self, model: str) -> None: + """Set the model to use.""" + self._session.model = model + logger.info(f"Model set to: {model}") + + def set_reasoning_effort(self, effort: Optional[str]) -> None: + """ + Set the reasoning effort level. + + Args: + effort: One of "low", "medium", "high", or None (auto/don't send) + """ + if effort is not None and effort not in ("low", "medium", "high"): + logger.warning(f"Invalid reasoning effort: {effort}") + effort = None + self._reasoning_effort = effort + logger.info(f"Reasoning effort set to: {effort or 'auto'}") + + def send_message(self, content: str) -> bool: + """ + Send a user message. + + If currently streaming, the message is queued. + + Args: + content: The user message content + + Returns: + True if message was sent/queued, False if invalid + """ + content = content.strip() + if not content: + return False + + if self._session.is_streaming: + # Queue the message (replaces any existing queued message) + self._session.pending_message = content + logger.info("Message queued (currently streaming)") + return True + + # Add user message to history + user_message = Message(role="user", content=content) + self._session.messages.append(user_message) + + # Notify UI + if self._ui_callbacks.get("on_message_complete"): + self._ui_callbacks["on_message_complete"](user_message) + + # Start the response + self._start_response() + return True + + def _start_response(self) -> None: + """Start generating an LLM response.""" + self._session.is_streaming = True + self._current_thinking = "" + self._current_content = "" + self._pending_tool_calls = [] + self._session.consecutive_invalid_tool_calls = 0 + + # Signal checkpoint manager + self._checkpoint_manager.start_response() + + # Notify adapter that AI is starting + self._adapter.on_ai_started() + + # Build messages array + messages = self._build_messages() + + # Get tools in OpenAI format + tools = self._tool_executor.get_tools_openai_format() + + # Set up callbacks + callbacks = StreamCallbacks( + on_thinking_chunk=self._handle_thinking_chunk, + on_content_chunk=self._handle_content_chunk, + on_tool_calls=self._handle_tool_calls, + on_error=self._handle_error, + on_complete=self._handle_stream_complete, + ) + + # Start streaming + self._bridge.stream_completion( + messages=messages, + tools=tools, + model=self._session.model, + callbacks=callbacks, + reasoning_effort=self._reasoning_effort, + ) + + def _build_messages(self) -> List[Dict[str, Any]]: + """Build the messages array for the LLM request.""" + messages = [] + + # System prompt (base + window-specific) + system_prompt = ( + BASE_ASSISTANT_PROMPT + "\n\n" + self._adapter.get_window_system_prompt() + ) + + # Add current context + context = self._adapter.get_full_context() + context_str = json.dumps(context, indent=2, default=str) + system_prompt += f"\n\n### Current Context\n\n```json\n{context_str}\n```" + + messages.append({"role": "system", "content": system_prompt}) + + # Add conversation history + for msg in self._session.messages: + messages.append(msg.to_openai_format()) + + # Log summary of messages being sent + role_counts = {} + for msg in messages: + role = msg.get("role", "unknown") + role_counts[role] = role_counts.get(role, 0) + 1 + # Log tool_calls in assistant messages + if role == "assistant" and msg.get("tool_calls"): + tc_names = [tc["function"]["name"] for tc in msg.get("tool_calls", [])] + logger.debug(f" Assistant message has tool_calls: {tc_names}") + # Log tool responses + if role == "tool": + logger.debug( + f" Tool response: id={msg.get('tool_call_id')}, " + f"content_preview={str(msg.get('content', ''))[:100]}" + ) + + logger.debug(f"Built messages array: {role_counts}") + + return messages + + def _handle_thinking_chunk(self, chunk: str) -> None: + """Handle a thinking/reasoning chunk.""" + self._current_thinking += chunk + + if self._ui_callbacks.get("on_thinking_chunk"): + self._ui_callbacks["on_thinking_chunk"](chunk) + + def _handle_content_chunk(self, chunk: str) -> None: + """Handle a content chunk.""" + self._current_content += chunk + + if self._ui_callbacks.get("on_content_chunk"): + self._ui_callbacks["on_content_chunk"](chunk) + + def _handle_tool_calls(self, tool_calls: List[ToolCall]) -> None: + """Handle tool calls from the LLM.""" + self._pending_tool_calls = tool_calls + + def _handle_error(self, error: str) -> None: + """Handle a streaming error.""" + logger.error(f"Streaming error: {error}") + self._session.is_streaming = False + self._adapter.on_ai_completed() + + if self._ui_callbacks.get("on_error"): + # Most errors are retryable + is_retryable = "authentication" not in error.lower() + self._ui_callbacks["on_error"](error, is_retryable) + + def _handle_stream_complete(self) -> None: + """Handle stream completion.""" + logger.info( + f"Stream complete: content_len={len(self._current_content)}, " + f"thinking_len={len(self._current_thinking)}, " + f"tool_calls={len(self._pending_tool_calls)}" + ) + + # Create assistant message + assistant_message = Message( + role="assistant", + content=self._current_content if self._current_content else None, + reasoning_content=self._current_thinking + if self._current_thinking + else None, + tool_calls=self._pending_tool_calls if self._pending_tool_calls else None, + ) + self._session.messages.append(assistant_message) + + logger.debug( + f"Added assistant message to history. Total messages: {len(self._session.messages)}" + ) + + # Notify UI of message completion + if self._ui_callbacks.get("on_message_complete"): + self._ui_callbacks["on_message_complete"](assistant_message) + + # Execute tool calls if any + if self._pending_tool_calls: + logger.info( + f"Processing {len(self._pending_tool_calls)} pending tool call(s)" + ) + self._execute_tool_calls(self._pending_tool_calls) + else: + # No tool calls - response is complete + logger.info("No tool calls - finishing response") + self._finish_response() + + def _execute_tool_calls(self, tool_calls: List[ToolCall]) -> None: + """Execute pending tool calls and handle results.""" + logger.info(f"Executing {len(tool_calls)} tool call(s)") + + # Check if we need to create a checkpoint + if self._tool_executor.has_write_tools(tool_calls): + if self._checkpoint_manager.should_create_checkpoint(): + # Create checkpoint before executing write tools + state = self._adapter.get_full_context() + summaries = [ + ToolCallSummary( + name=tc.name, + arguments=tc.arguments, + success=True, # Will be updated + message="Pending", + ) + for tc in tool_calls + ] + self._checkpoint_manager.create_checkpoint( + state=state, + tool_calls=summaries, + message_index=len(self._session.messages) - 1, + ) + + # Execute each tool + all_results: List[Message] = [] + has_errors = False + assistant_logger = self._bridge.current_logger + + for tool_call in tool_calls: + logger.info( + f"Executing tool: {tool_call.name} (id={tool_call.id}) " + f"args={json.dumps(tool_call.arguments)[:200]}" + ) + + # Notify UI + if self._ui_callbacks.get("on_tool_start"): + self._ui_callbacks["on_tool_start"](tool_call) + + # Execute + result = self._tool_executor.execute(tool_call, self._adapter) + tool_call.result = result + + logger.info( + f"Tool result: {tool_call.name} -> " + f"{'SUCCESS' if result.success else 'FAILED'}: {result.message[:100]}" + ) + + # Log to assistant logger + if assistant_logger: + assistant_logger.log_tool_execution( + tool_call_id=tool_call.id, + tool_name=tool_call.name, + arguments=tool_call.arguments, + result_success=result.success, + result_message=result.message, + result_data=result.data, + error_code=result.error_code, + ) + + # Notify UI + if self._ui_callbacks.get("on_tool_result"): + self._ui_callbacks["on_tool_result"](tool_call, result) + + # Create tool response message + tool_message = Message( + role="tool", + content=result.to_json(), + tool_call_id=tool_call.id, + ) + all_results.append(tool_message) + + if not result.success: + has_errors = True + + # Add tool results to history + self._session.messages.extend(all_results) + logger.info( + f"Added {len(all_results)} tool result message(s) to history. " + f"Total messages: {len(self._session.messages)}" + ) + + # Handle retry logic for invalid tool calls + if has_errors: + self._session.consecutive_invalid_tool_calls += 1 + if ( + self._session.consecutive_invalid_tool_calls + >= self._session.max_tool_retries + ): + # Max retries exceeded - show error to user + logger.warning("Max tool retries exceeded") + if self._ui_callbacks.get("on_error"): + self._ui_callbacks["on_error"]( + "Tool execution failed after multiple retries. Please try a different approach.", + False, + ) + self._finish_response() + return + + # Show retry indicator after 2nd failure + if self._session.consecutive_invalid_tool_calls >= 2: + logger.info("Tool retry in progress (shown to user)") + # UI will show "Retrying..." based on the failed tool results + + # Continue the agentic loop - get next response from LLM + self._continue_agentic_loop() + + def _continue_agentic_loop(self) -> None: + """Continue the conversation after tool execution.""" + logger.info("Continuing agentic loop after tool execution") + + # Reset current response state + self._current_thinking = "" + self._current_content = "" + self._pending_tool_calls = [] + + # Build messages (includes tool results) + messages = self._build_messages() + tools = self._tool_executor.get_tools_openai_format() + + # Log the continuation request + assistant_logger = self._bridge.current_logger + if assistant_logger: + assistant_logger.log_messages_sent(messages) + + logger.debug( + f"Continuation request: {len(messages)} messages, " + f"last message role={messages[-1]['role'] if messages else 'N/A'}" + ) + + callbacks = StreamCallbacks( + on_thinking_chunk=self._handle_thinking_chunk, + on_content_chunk=self._handle_content_chunk, + on_tool_calls=self._handle_tool_calls, + on_error=self._handle_error, + on_complete=self._handle_stream_complete, + ) + + self._bridge.stream_completion( + messages=messages, + tools=tools, + model=self._session.model, + callbacks=callbacks, + reasoning_effort=self._reasoning_effort, + ) + + def _finish_response(self) -> None: + """Finish the response cycle.""" + self._session.is_streaming = False + self._adapter.on_ai_completed() + + # Notify UI + if self._ui_callbacks.get("on_stream_complete"): + self._ui_callbacks["on_stream_complete"]() + + # Check for queued message + if self._session.pending_message: + queued = self._session.pending_message + self._session.pending_message = None + logger.info("Processing queued message") + self.send_message(queued) + + def cancel(self) -> None: + """Cancel the current response.""" + if not self._session.is_streaming: + return + + self._bridge.cancel_stream() + self._session.is_streaming = False + self._adapter.on_ai_completed() + + # Discard partial response (don't add to history) + # The last user message stays in history + + logger.info("Response cancelled") + + def rollback_to_checkpoint(self, checkpoint_id: str) -> bool: + """ + Rollback to a specific checkpoint. + + Args: + checkpoint_id: ID of the checkpoint to rollback to + + Returns: + True if rollback succeeded + """ + if self._session.is_streaming: + logger.warning("Cannot rollback while streaming") + return False + + # Get the state and message index + state = self._checkpoint_manager.rollback_to(checkpoint_id) + if state is None: + return False + + message_index = self._checkpoint_manager.get_message_index_at(checkpoint_id) + if message_index is not None: + # Truncate conversation history + self._session.messages = self._session.messages[: message_index + 1] + + # Apply state to window + try: + self._adapter.apply_state(state) + logger.info(f"Rolled back to checkpoint {checkpoint_id}") + return True + except Exception as e: + logger.exception(f"Failed to apply state for checkpoint {checkpoint_id}") + return False + + def new_session(self) -> None: + """Start a new session (clear history and checkpoints).""" + if self._session.is_streaming: + self.cancel() + + # Clear history + self._session.messages.clear() + self._session.pending_message = None + self._session.consecutive_invalid_tool_calls = 0 + + # Clear checkpoints + self._checkpoint_manager.clear() + + # Generate new session ID + new_session_id = str(uuid.uuid4())[:8] + self._session.session_id = new_session_id + self._checkpoint_manager.session_id = new_session_id + self._bridge.set_session_id(new_session_id) + + logger.info(f"Started new session: {new_session_id}") + + def retry_last(self) -> bool: + """ + Retry the last user message. + + Returns: + True if retry was started + """ + if self._session.is_streaming: + return False + + # Find the last user message + last_user_idx = -1 + for i in range(len(self._session.messages) - 1, -1, -1): + if self._session.messages[i].role == "user": + last_user_idx = i + break + + if last_user_idx < 0: + return False + + # Remove all messages after (and including) the last user message + last_user_content = self._session.messages[last_user_idx].content + self._session.messages = self._session.messages[:last_user_idx] + + # Re-send the message + if last_user_content: + self.send_message(last_user_content) + return True + + return False + + def close(self) -> None: + """Clean up resources.""" + self._bridge.close() + self._checkpoint_manager.clear_temp_file() diff --git a/src/proxy_app/ai_assistant/prompts.py b/src/proxy_app/ai_assistant/prompts.py new file mode 100644 index 0000000..a3ac1b4 --- /dev/null +++ b/src/proxy_app/ai_assistant/prompts.py @@ -0,0 +1,75 @@ +""" +System prompts for the AI Assistant. + +Contains the base assistant prompt and window-specific prompts. +""" + +BASE_ASSISTANT_PROMPT = """You are an AI assistant embedded in a GUI application. Your role is to help users \ +accomplish tasks within this window by understanding their intent and executing \ +actions using the available tools. + +## Core Behaviors + +1. **Full Context Awareness**: You have complete visibility into the window's state. \ +Use this information to provide accurate, contextual help. + +2. **Tool Execution**: When the user requests an action, use the appropriate tools \ +to execute it. You may call multiple tools in sequence to accomplish complex tasks. + +3. **Verbose Feedback**: After executing tools, clearly explain what was done, \ +what changed, and any important consequences. Both you and the user will see \ +the tool results. + +4. **Error Handling**: If a tool fails, explain why and suggest alternatives. \ +If you receive an error about an invalid tool call, carefully re-examine the \ +tool schema and try again with corrected parameters. + +5. **Proactive Assistance**: If you notice potential issues or improvements, \ +mention them to the user. + +## Tool Execution Guidelines + +- Always confirm understanding before making destructive changes +- For bulk operations, summarize what will happen before executing +- If uncertain about user intent, ask for clarification +- Report all tool results, including partial successes +- You may call multiple tools in a single response when appropriate + +## Context Updates + +You will receive updates about changes to the window state in the \ +`changes_since_last_message` field. Use this to stay aware of what \ +the user may have done manually between messages.""" + + +MODEL_FILTER_SYSTEM_PROMPT = """## Model Filter Configuration Assistant + +You are helping the user configure model filtering rules for an LLM proxy server. + +### Domain Knowledge + +- **Ignore Rules**: Patterns that block models from being available through the proxy +- **Whitelist Rules**: Patterns that ensure models are always available (override ignore rules) +- **Pattern Syntax**: + - Exact match: `gpt-4` matches only "gpt-4" + - Wildcard: `gpt-4*` matches "gpt-4", "gpt-4-turbo", "gpt-4-vision", etc. + - Match anywhere: `*preview*` matches any model containing "preview" + +### Rule Priority + +Whitelist > Ignore > Default (available) + +A model that matches both an ignore rule and a whitelist rule will be AVAILABLE \ +(whitelist wins). + +### Common Tasks + +1. "Block all preview models" -> Use pattern `*-preview` or `*preview*` +2. "Only allow GPT-4o" -> Ignore `*`, whitelist `gpt-4o` +3. "What models are blocked?" -> Query the ignore rules and their affected models + +### Important Notes + +- Changes are not saved until the user explicitly saves (or you use save_changes tool) +- The `has_unsaved_changes` field in context tells you if there are pending changes +- Always inform the user if there are unsaved changes that might be lost""" diff --git a/src/proxy_app/ai_assistant/tools.py b/src/proxy_app/ai_assistant/tools.py new file mode 100644 index 0000000..ca7c9de --- /dev/null +++ b/src/proxy_app/ai_assistant/tools.py @@ -0,0 +1,312 @@ +""" +Tool definition system for the AI Assistant. + +Provides the @assistant_tool decorator, ToolDefinition, ToolResult, and ToolExecutor classes. +""" + +import json +import logging +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Callable, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolResult: + """Result of a tool execution.""" + + success: bool + message: str # Human-readable description + data: Optional[Dict[str, Any]] = None # Structured data for AI + error_code: Optional[str] = None # Machine-readable error type + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result = {"success": self.success, "message": self.message} + if self.data is not None: + result["data"] = self.data + if self.error_code is not None: + result["error_code"] = self.error_code + return result + + def to_json(self) -> str: + """Convert to JSON string for LLM tool response.""" + return json.dumps(self.to_dict()) + + +@dataclass +class ToolDefinition: + """Definition of an assistant tool.""" + + name: str + description: str + parameters: Dict[str, Any] # JSON Schema for parameters + required: List[str] = field(default_factory=list) + is_write: bool = False # If True, triggers checkpoint creation + handler: Optional[Callable[..., ToolResult]] = None + + def to_openai_format(self) -> Dict[str, Any]: + """Convert to OpenAI-compatible tool format.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": self.parameters, + "required": self.required, + }, + }, + } + + +@dataclass +class ToolCall: + """Represents a tool call from the LLM.""" + + id: str + name: str + arguments: Dict[str, Any] + result: Optional[ToolResult] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "id": self.id, + "name": self.name, + "arguments": self.arguments, + "result": self.result.to_dict() if self.result else None, + } + + +@dataclass +class ToolCallSummary: + """Summary of a tool call for checkpoint description.""" + + name: str + arguments: Dict[str, Any] + success: bool + message: str + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "name": self.name, + "arguments": self.arguments, + "success": self.success, + "message": self.message, + } + + +def assistant_tool( + name: str, + description: str, + parameters: Dict[str, Dict[str, Any]], + required: Optional[List[str]] = None, + is_write: bool = False, +) -> Callable: + """ + Decorator to mark a method as an assistant tool. + + Args: + name: The tool name (used in LLM function calls) + description: Human-readable description of what the tool does + parameters: JSON Schema properties dict for the tool parameters + required: List of required parameter names + is_write: If True, a checkpoint will be created before execution + + Example: + @assistant_tool( + name="add_ignore_rule", + description="Add a pattern to the ignore list.", + parameters={ + "pattern": { + "type": "string", + "description": "The pattern to ignore. Supports * wildcard." + } + }, + required=["pattern"], + is_write=True + ) + def tool_add_ignore_rule(self, pattern: str) -> ToolResult: + ... + """ + if required is None: + required = [] + + def decorator(func: Callable[..., ToolResult]) -> Callable[..., ToolResult]: + # Store tool metadata on the function + func._tool_definition = ToolDefinition( + name=name, + description=description, + parameters=parameters, + required=required, + is_write=is_write, + handler=func, + ) + + @wraps(func) + def wrapper(*args, **kwargs) -> ToolResult: + return func(*args, **kwargs) + + # Copy the tool definition to the wrapper + wrapper._tool_definition = func._tool_definition + return wrapper + + return decorator + + +class ToolExecutor: + """ + Executes tool calls with validation. + + Collects tools from a WindowContextAdapter and executes them + when called by the LLM. + """ + + def __init__(self, tools: List[ToolDefinition]): + """ + Initialize with a list of tool definitions. + + Args: + tools: List of ToolDefinition objects + """ + self._tools: Dict[str, ToolDefinition] = {tool.name: tool for tool in tools} + self._timeout: float = 30.0 # Default timeout in seconds + + @property + def tools(self) -> List[ToolDefinition]: + """Get list of all registered tools.""" + return list(self._tools.values()) + + def get_tool(self, name: str) -> Optional[ToolDefinition]: + """Get a tool by name.""" + return self._tools.get(name) + + def has_write_tools(self, tool_calls: List[ToolCall]) -> bool: + """Check if any of the tool calls are write operations.""" + for call in tool_calls: + tool = self._tools.get(call.name) + if tool and tool.is_write: + return True + return False + + def get_tools_openai_format(self) -> List[Dict[str, Any]]: + """Get all tools in OpenAI-compatible format.""" + return [tool.to_openai_format() for tool in self._tools.values()] + + def validate_tool_call(self, tool_call: ToolCall) -> Optional[str]: + """ + Validate a tool call before execution. + + Returns None if valid, or an error message if invalid. + """ + tool = self._tools.get(tool_call.name) + + if tool is None: + available = list(self._tools.keys()) + return f"Unknown tool: '{tool_call.name}'. Available tools: {available}" + + # Check required parameters + for param in tool.required: + if param not in tool_call.arguments: + return f"Missing required parameter: '{param}'" + + # Check for unknown parameters + known_params = set(tool.parameters.keys()) + provided_params = set(tool_call.arguments.keys()) + unknown = provided_params - known_params + + if unknown: + # Provide helpful suggestions for typos + suggestions = [] + for unk in unknown: + for known in known_params: + if unk.lower() == known.lower() or ( + len(unk) > 2 and unk[:-1] == known[:-1] + ): + suggestions.append(f"'{unk}' -> did you mean '{known}'?") + if suggestions: + return f"Invalid parameter(s): {unknown}. {' '.join(suggestions)}" + return f"Invalid parameter(s): {unknown}. Valid parameters: {known_params}" + + # Type validation (basic) + for param_name, param_value in tool_call.arguments.items(): + if param_name not in tool.parameters: + continue + param_schema = tool.parameters[param_name] + expected_type = param_schema.get("type") + + if expected_type == "string" and not isinstance(param_value, str): + return f"Parameter '{param_name}' must be a string, got {type(param_value).__name__}" + elif expected_type == "number" and not isinstance( + param_value, (int, float) + ): + return f"Parameter '{param_name}' must be a number, got {type(param_value).__name__}" + elif expected_type == "integer" and not isinstance(param_value, int): + return f"Parameter '{param_name}' must be an integer, got {type(param_value).__name__}" + elif expected_type == "boolean" and not isinstance(param_value, bool): + return f"Parameter '{param_name}' must be a boolean, got {type(param_value).__name__}" + elif expected_type == "array" and not isinstance(param_value, list): + return f"Parameter '{param_name}' must be an array, got {type(param_value).__name__}" + elif expected_type == "object" and not isinstance(param_value, dict): + return f"Parameter '{param_name}' must be an object, got {type(param_value).__name__}" + + return None # Valid + + def execute(self, tool_call: ToolCall, context: Any) -> ToolResult: + """ + Execute a tool call. + + Args: + tool_call: The tool call to execute + context: The WindowContextAdapter instance (passed as self to the tool method) + + Returns: + ToolResult with success/failure status + """ + # Validate first + validation_error = self.validate_tool_call(tool_call) + if validation_error: + logger.warning( + f"Tool validation failed for {tool_call.name}: {validation_error}" + ) + return ToolResult( + success=False, + message=validation_error, + error_code="invalid_parameters", + ) + + tool = self._tools[tool_call.name] + + try: + # Execute the tool handler + result = tool.handler(context, **tool_call.arguments) + + if not isinstance(result, ToolResult): + logger.error( + f"Tool {tool_call.name} returned non-ToolResult: {type(result)}" + ) + return ToolResult( + success=False, + message=f"Tool returned invalid result type: {type(result).__name__}", + error_code="internal_error", + ) + + if result.success: + logger.info(f"Tool {tool_call.name} succeeded: {result.message}") + else: + logger.warning(f"Tool {tool_call.name} failed: {result.message}") + + return result + + except Exception as e: + logger.exception(f"Tool {tool_call.name} raised exception") + return ToolResult( + success=False, + message=f"Tool execution error: {str(e)}", + error_code="execution_error", + ) diff --git a/src/proxy_app/ai_assistant/ui/__init__.py b/src/proxy_app/ai_assistant/ui/__init__.py new file mode 100644 index 0000000..099fde0 --- /dev/null +++ b/src/proxy_app/ai_assistant/ui/__init__.py @@ -0,0 +1,52 @@ +""" +UI Components for the AI Assistant. + +Provides all the visual components for the chat interface. +""" + +from .chat_window import AIChatWindow +from .checkpoint_ui import CheckpointDropdown, CheckpointItem, CheckpointPopup +from .message_view import MessageView +from .model_selector import ModelSelector +from .styles import ( + AI_MESSAGE_BG, + ERROR_BG, + MESSAGE_PADDING, + MESSAGE_SPACING, + THINKING_BG, + THINKING_TEXT, + TOOL_BG, + TOOL_FAILURE_COLOR, + TOOL_SUCCESS_COLOR, + USER_MESSAGE_BG, + apply_button_style, + get_font, + get_scrollbar_style, +) +from .thinking import ThinkingSection + +__all__ = [ + # Main window + "AIChatWindow", + # Components + "MessageView", + "ModelSelector", + "CheckpointDropdown", + "CheckpointPopup", + "CheckpointItem", + "ThinkingSection", + # Styles + "get_font", + "apply_button_style", + "get_scrollbar_style", + "USER_MESSAGE_BG", + "AI_MESSAGE_BG", + "THINKING_BG", + "THINKING_TEXT", + "TOOL_BG", + "TOOL_SUCCESS_COLOR", + "TOOL_FAILURE_COLOR", + "ERROR_BG", + "MESSAGE_SPACING", + "MESSAGE_PADDING", +] diff --git a/src/proxy_app/ai_assistant/ui/chat_window.py b/src/proxy_app/ai_assistant/ui/chat_window.py new file mode 100644 index 0000000..448ab69 --- /dev/null +++ b/src/proxy_app/ai_assistant/ui/chat_window.py @@ -0,0 +1,479 @@ +""" +AI Chat Window - Main Pop-Out Window. + +The primary UI for interacting with the AI assistant. +""" + +import customtkinter as ctk +import tkinter as tk +from typing import Any, Callable, Dict, List, Optional + +from ..checkpoint import Checkpoint +from ..core import AIAssistantCore, Message +from ..context import WindowContextAdapter +from ..tools import ToolCall, ToolResult +from .checkpoint_ui import CheckpointDropdown +from .message_view import MessageView +from .model_selector import ModelSelector +from .styles import ( + ACCENT_BLUE, + ACCENT_GREEN, + BG_HOVER, + BG_PRIMARY, + BG_SECONDARY, + BG_TERTIARY, + BORDER_COLOR, + FONT_FAMILY, + FONT_SIZE_NORMAL, + INPUT_MAX_HEIGHT, + INPUT_MIN_HEIGHT, + TEXT_MUTED, + TEXT_PRIMARY, + TEXT_SECONDARY, + get_font, +) + + +class AIChatWindow(ctk.CTkToplevel): + """ + Pop-out AI Chat Window. + + Features: + - Full message display with streaming + - Model selector (grouped by provider) + - Checkpoint dropdown for undo + - Multi-line input with keyboard shortcuts + - New Session and Send buttons + """ + + def __init__( + self, + parent, + window_adapter: WindowContextAdapter, + title: str = "AI Assistant", + default_model: str = "openai/gpt-4o", + **kwargs, + ): + """ + Initialize the chat window. + + Args: + parent: Parent window + window_adapter: The window context adapter + title: Window title + default_model: Default model to use + **kwargs: Additional toplevel arguments + """ + super().__init__(parent, **kwargs) + + self._parent = parent + self._adapter = window_adapter + self._default_model = default_model + + # Initialize core + self._core = AIAssistantCore( + window_adapter=window_adapter, + schedule_on_gui=lambda fn: self.after(0, fn), + default_model=default_model, + ) + + # Window configuration + self.title(title) + self.configure(fg_color=BG_PRIMARY) + self.geometry("600x700") + self.minsize(450, 500) + + # Track pending tool calls for UI updates + self._pending_tool_frames: Dict[str, ctk.CTkFrame] = {} + + self._create_widgets() + self._setup_callbacks() + self._setup_bindings() + + # Load models + self._load_models() + + # Focus input + self.after(100, lambda: self.input_text.focus_set()) + + def _create_widgets(self) -> None: + """Create the UI widgets.""" + # Configure grid + self.grid_columnconfigure(0, weight=1) + self.grid_rowconfigure(0, weight=0) # Header + self.grid_rowconfigure(1, weight=3, minsize=200) # Messages + self.grid_rowconfigure(2, weight=0) # Input + self.grid_rowconfigure(3, weight=0) # Buttons + + # Header + self._create_header() + + # Message display + self.message_view = MessageView(self) + self.message_view.grid(row=1, column=0, sticky="nsew", padx=8, pady=(8, 0)) + + # Input area + self._create_input_area() + + # Buttons + self._create_buttons() + + def _create_header(self) -> None: + """Create the header with model selector, reasoning effort, and checkpoints.""" + header = ctk.CTkFrame(self, fg_color="transparent") + header.grid(row=0, column=0, sticky="ew", padx=8, pady=(8, 0)) + header.grid_columnconfigure(1, weight=1) # Spacer column + + # Left side: Model selector + self.model_selector = ModelSelector( + header, + on_model_changed=self._on_model_changed, + ) + self.model_selector.grid(row=0, column=0, sticky="w") + + # Middle: Reasoning effort selector + reasoning_frame = ctk.CTkFrame(header, fg_color="transparent") + reasoning_frame.grid(row=0, column=1, sticky="w", padx=(16, 0)) + + ctk.CTkLabel( + reasoning_frame, + text="Reasoning:", + font=get_font("normal"), + text_color=TEXT_SECONDARY, + ).pack(side="left", padx=(0, 8)) + + self._reasoning_effort_var = ctk.StringVar(value="Auto") + self.reasoning_dropdown = ctk.CTkComboBox( + reasoning_frame, + values=["Auto", "Low", "Medium", "High"], + variable=self._reasoning_effort_var, + font=get_font("normal"), + dropdown_font=get_font("normal"), + fg_color=BG_TERTIARY, + border_color=BORDER_COLOR, + button_color=BG_SECONDARY, + button_hover_color=BG_HOVER, + dropdown_fg_color=BG_SECONDARY, + dropdown_hover_color=BG_HOVER, + text_color=TEXT_PRIMARY, + dropdown_text_color=TEXT_PRIMARY, + width=100, + state="readonly", + command=self._on_reasoning_changed, + ) + self.reasoning_dropdown.pack(side="left") + + # Right side: Checkpoint dropdown + self.checkpoint_dropdown = CheckpointDropdown( + header, + on_rollback=self._on_rollback, + ) + self.checkpoint_dropdown.grid(row=0, column=2, sticky="e") + + def _create_input_area(self) -> None: + """Create the input area with buttons stacked on the right.""" + input_frame = ctk.CTkFrame(self, fg_color="transparent") + input_frame.grid(row=2, column=0, sticky="ew", padx=8, pady=(8, 0)) + input_frame.grid_columnconfigure(0, weight=1) # Input takes remaining space + input_frame.grid_columnconfigure(1, weight=0) # Buttons column fixed + + # Multi-line text input (left side) + self.input_text = ctk.CTkTextbox( + input_frame, + font=get_font("normal"), + fg_color=BG_TERTIARY, + text_color=TEXT_PRIMARY, + border_width=1, + border_color=BORDER_COLOR, + corner_radius=8, + height=INPUT_MIN_HEIGHT, + wrap="word", + ) + self.input_text.grid(row=0, column=0, sticky="nsew") + + # Button stack (right side) + btn_stack = ctk.CTkFrame(input_frame, fg_color="transparent") + btn_stack.grid(row=0, column=1, sticky="ns", padx=(8, 0)) + + # New Session button (top) + self.new_session_btn = ctk.CTkButton( + btn_stack, + text="New Session", + font=get_font("small"), + fg_color=BG_SECONDARY, + hover_color=BG_HOVER, + text_color=TEXT_PRIMARY, + border_width=1, + border_color=BORDER_COLOR, + width=90, + height=26, + command=self._on_new_session, + ) + self.new_session_btn.pack(side="top", pady=(0, 4)) + + # Send button (bottom) + self.send_btn = ctk.CTkButton( + btn_stack, + text="Send →", + font=get_font("small", bold=True), + fg_color=ACCENT_BLUE, + hover_color="#3a8eef", + text_color=TEXT_PRIMARY, + width=90, + height=26, + command=self._on_send, + ) + self.send_btn.pack(side="top") + + # Placeholder text handling + self._placeholder_active = True + self._show_placeholder() + + def _show_placeholder(self) -> None: + """Show placeholder text in input.""" + self.input_text.delete("1.0", "end") + self.input_text.insert("1.0", "Type your message here... (Ctrl+Enter to send)") + self.input_text.configure(text_color=TEXT_MUTED) + self._placeholder_active = True + + def _hide_placeholder(self) -> None: + """Hide placeholder text.""" + if self._placeholder_active: + self.input_text.delete("1.0", "end") + self.input_text.configure(text_color=TEXT_PRIMARY) + self._placeholder_active = False + + def _create_buttons(self) -> None: + """Create the status bar (buttons moved to input area).""" + status_frame = ctk.CTkFrame(self, fg_color="transparent") + status_frame.grid(row=3, column=0, sticky="ew", padx=8, pady=(0, 4)) + + # Status label + self.status_label = ctk.CTkLabel( + status_frame, + text="", + font=get_font("small"), + text_color=TEXT_MUTED, + ) + self.status_label.pack(side="left") + + def _setup_callbacks(self) -> None: + """Set up AI core callbacks.""" + self._core.set_ui_callbacks( + on_thinking_chunk=self._on_thinking_chunk, + on_content_chunk=self._on_content_chunk, + on_tool_start=self._on_tool_start, + on_tool_result=self._on_tool_result, + on_message_complete=self._on_message_complete, + on_error=self._on_error, + on_stream_complete=self._on_stream_complete, + ) + + def _setup_bindings(self) -> None: + """Set up keyboard bindings.""" + # Ctrl+Enter to send + self.input_text.bind("", lambda e: self._on_send()) + + # Escape to cancel or clear + self.bind("", self._on_escape) + + # Focus handling for placeholder + self.input_text.bind("", lambda e: self._hide_placeholder()) + self.input_text.bind("", self._on_input_focus_out) + + # Window close + self.protocol("WM_DELETE_WINDOW", self._on_close) + + def _on_input_focus_out(self, event) -> None: + """Handle input focus out.""" + content = self.input_text.get("1.0", "end").strip() + if not content or content == "Type your message here... (Ctrl+Enter to send)": + self._show_placeholder() + + def _load_models(self) -> None: + """Load available models.""" + self.model_selector.set_loading() + + self._core.bridge.fetch_models( + on_success=self._on_models_loaded, + on_error=self._on_models_error, + ) + + def _on_models_loaded(self, models: Dict[str, List[str]]) -> None: + """Handle models loaded.""" + self.model_selector.set_models(models) + + # Try to select default model + if not self.model_selector.set_selected_model(self._default_model): + # Use first available + pass + + def _on_models_error(self, error: str) -> None: + """Handle models load error.""" + self.model_selector.set_error(f"Failed: {error[:30]}...") + + def _on_model_changed(self, model: str) -> None: + """Handle model selection change.""" + self._core.set_model(model) + + def _on_reasoning_changed(self, choice: str) -> None: + """Handle reasoning effort selection change.""" + # Map UI values to API values + effort_map = { + "Auto": None, # Don't send the parameter + "Low": "low", + "Medium": "medium", + "High": "high", + } + effort = effort_map.get(choice) + self._core.set_reasoning_effort(effort) + + def _on_send(self) -> None: + """Handle send button click.""" + if self._placeholder_active: + return + + content = self.input_text.get("1.0", "end").strip() + if not content: + return + + # Clear input + self.input_text.delete("1.0", "end") + + # Send message + if self._core.send_message(content): + self._set_streaming_state(True) + + def _on_new_session(self) -> None: + """Handle new session button click.""" + self._core.new_session() + self.message_view.clear() + self.checkpoint_dropdown.set_checkpoints([]) + self._set_streaming_state(False) + self.status_label.configure(text="") + + def _on_escape(self, event=None) -> None: + """Handle escape key.""" + if self._core.session.is_streaming: + self._core.cancel() + self._set_streaming_state(False) + self.status_label.configure(text="Cancelled") + else: + # Clear input + self.input_text.delete("1.0", "end") + self._show_placeholder() + + def _on_rollback(self, checkpoint_id: str) -> None: + """Handle checkpoint rollback.""" + if self._core.rollback_to_checkpoint(checkpoint_id): + # Rebuild message display + self.message_view.clear() + for msg in self._core.session.messages: + self.message_view.add_message(msg) + + # Update checkpoints + self.checkpoint_dropdown.set_checkpoints( + self._core.checkpoint_manager.get_checkpoints() + ) + + self.status_label.configure(text="Rolled back") + + # ========================================================================= + # AI Core Callbacks + # ========================================================================= + + def _on_thinking_chunk(self, chunk: str) -> None: + """Handle thinking chunk from AI.""" + self.message_view.append_thinking(chunk) + + def _on_content_chunk(self, chunk: str) -> None: + """Handle content chunk from AI.""" + self.message_view.append_content(chunk) + + def _on_tool_start(self, tool_call: ToolCall) -> None: + """Handle tool execution start.""" + frame = self.message_view.add_tool_call(tool_call) + self._pending_tool_frames[tool_call.id] = frame + self.status_label.configure(text=f"Executing: {tool_call.name}...") + + def _on_tool_result(self, tool_call: ToolCall, result: ToolResult) -> None: + """Handle tool execution result.""" + if tool_call.id in self._pending_tool_frames: + frame = self._pending_tool_frames.pop(tool_call.id) + self.message_view.update_tool_result(frame, result) + + if result.success: + self.status_label.configure(text=f"✓ {tool_call.name}") + else: + self.status_label.configure( + text=f"✗ {tool_call.name}: {result.message[:30]}..." + ) + + def _on_message_complete(self, message: Message) -> None: + """Handle complete message.""" + if message.role == "user": + self.message_view.add_user_message(message.content or "", message.timestamp) + # Start AI message for streaming + self.message_view.start_ai_message() + elif message.role == "assistant": + # Finish streaming message + self.message_view.finish_ai_message() + + def _on_error(self, error: str, is_retryable: bool) -> None: + """Handle error.""" + self._set_streaming_state(False) + + self.message_view.add_error( + message=error, + is_retryable=is_retryable, + on_retry=self._on_retry if is_retryable else None, + on_cancel=lambda: self.status_label.configure(text=""), + ) + + self.status_label.configure(text="Error occurred") + + def _on_retry(self) -> None: + """Handle retry.""" + if self._core.retry_last(): + self._set_streaming_state(True) + + def _on_stream_complete(self) -> None: + """Handle stream completion.""" + self._set_streaming_state(False) + + # Update checkpoints + self.checkpoint_dropdown.set_checkpoints( + self._core.checkpoint_manager.get_checkpoints() + ) + + self.status_label.configure(text="") + + # ========================================================================= + # State Management + # ========================================================================= + + def _set_streaming_state(self, is_streaming: bool) -> None: + """Update UI for streaming state.""" + if is_streaming: + self.send_btn.configure( + text="Cancel", fg_color=BG_SECONDARY, command=self._on_escape + ) + self.new_session_btn.configure(state="disabled") + self.status_label.configure(text="Thinking...") + else: + self.send_btn.configure( + text="Send →", fg_color=ACCENT_BLUE, command=self._on_send + ) + self.new_session_btn.configure(state="normal") + + def _on_close(self) -> None: + """Handle window close.""" + # Cancel any streaming + if self._core.session.is_streaming: + self._core.cancel() + + # Clean up + self._core.close() + + # Destroy window + self.destroy() diff --git a/src/proxy_app/ai_assistant/ui/checkpoint_ui.py b/src/proxy_app/ai_assistant/ui/checkpoint_ui.py new file mode 100644 index 0000000..19df555 --- /dev/null +++ b/src/proxy_app/ai_assistant/ui/checkpoint_ui.py @@ -0,0 +1,482 @@ +""" +Checkpoint UI Components. + +Provides a dropdown/popup for viewing and rolling back to checkpoints. +""" + +import customtkinter as ctk +import tkinter as tk +from typing import Callable, List, Optional + +from ..checkpoint import Checkpoint +from .styles import ( + ACCENT_BLUE, + ACCENT_RED, + ACCENT_YELLOW, + BG_HOVER, + BG_PRIMARY, + BG_SECONDARY, + BG_TERTIARY, + BORDER_COLOR, + FONT_FAMILY, + FONT_SIZE_NORMAL, + FONT_SIZE_SMALL, + HIGHLIGHT_BG, + TEXT_MUTED, + TEXT_PRIMARY, + TEXT_SECONDARY, + get_font, +) + + +class CheckpointDropdown(ctk.CTkFrame): + """ + Dropdown button that opens a checkpoint selection popup. + + Shows the list of checkpoints and allows rolling back to any point. + """ + + def __init__( + self, parent, on_rollback: Optional[Callable[[str], None]] = None, **kwargs + ): + """ + Initialize the checkpoint dropdown. + + Args: + parent: Parent widget + on_rollback: Callback when rollback is confirmed (receives checkpoint_id) + **kwargs: Additional frame arguments + """ + super().__init__(parent, fg_color="transparent", **kwargs) + + self._on_rollback = on_rollback + self._checkpoints: List[Checkpoint] = [] + self._popup: Optional[CheckpointPopup] = None + + self._create_widgets() + + def _create_widgets(self) -> None: + """Create the UI widgets.""" + # Dropdown button with undo icon + self.button = ctk.CTkButton( + self, + text="↩️", + font=get_font("normal"), + fg_color=BG_SECONDARY, + hover_color=BG_HOVER, + text_color=TEXT_PRIMARY, + border_width=1, + border_color=BORDER_COLOR, + width=40, + height=28, + command=self._toggle_popup, + ) + self.button.pack(side="left") + + # Count badge (hidden when 0) + self.badge = ctk.CTkLabel( + self, + text="0", + font=get_font("small"), + fg_color=ACCENT_BLUE, + text_color=TEXT_PRIMARY, + corner_radius=10, + width=20, + height=20, + ) + # Initially hidden + + def _toggle_popup(self) -> None: + """Toggle the checkpoint popup.""" + if self._popup is not None and self._popup.winfo_exists(): + self._popup.destroy() + self._popup = None + else: + self._show_popup() + + def _show_popup(self) -> None: + """Show the checkpoint selection popup.""" + if not self._checkpoints: + # Show a simple message if no checkpoints + self._popup = CheckpointPopup( + self, + checkpoints=[], + on_rollback=self._handle_rollback, + on_close=self._close_popup, + ) + else: + self._popup = CheckpointPopup( + self, + checkpoints=self._checkpoints, + on_rollback=self._handle_rollback, + on_close=self._close_popup, + ) + + # Position below the button + x = self.button.winfo_rootx() + y = self.button.winfo_rooty() + self.button.winfo_height() + 4 + self._popup.geometry(f"+{x}+{y}") + + def _handle_rollback(self, checkpoint_id: str) -> None: + """Handle rollback confirmation.""" + self._close_popup() + if self._on_rollback: + self._on_rollback(checkpoint_id) + + def _close_popup(self) -> None: + """Close the popup.""" + if self._popup is not None: + self._popup.destroy() + self._popup = None + + def set_checkpoints(self, checkpoints: List[Checkpoint]) -> None: + """ + Update the list of checkpoints. + + Args: + checkpoints: List of Checkpoint objects + """ + self._checkpoints = checkpoints + + # Update badge + count = len(checkpoints) + if count > 0: + self.badge.configure(text=str(count)) + self.badge.pack(side="left", padx=(4, 0)) + else: + self.badge.pack_forget() + + # Update popup if open + if self._popup is not None and self._popup.winfo_exists(): + self._popup.update_checkpoints(checkpoints) + + @property + def checkpoint_count(self) -> int: + """Get the number of checkpoints.""" + return len(self._checkpoints) + + +class CheckpointPopup(ctk.CTkToplevel): + """ + Popup window for checkpoint selection. + + Shows a list of checkpoints with timestamps and descriptions, + allows selecting one for rollback. + """ + + def __init__( + self, + parent, + checkpoints: List[Checkpoint], + on_rollback: Callable[[str], None], + on_close: Callable[[], None], + **kwargs, + ): + """ + Initialize the popup. + + Args: + parent: Parent widget + checkpoints: List of checkpoints to display + on_rollback: Callback when rollback is confirmed + on_close: Callback when popup is closed + """ + super().__init__(parent, **kwargs) + + self._checkpoints = checkpoints + self._on_rollback = on_rollback + self._on_close = on_close + self._selected_id: Optional[str] = None + + # Window configuration + self.title("Checkpoints") + self.configure(fg_color=BG_PRIMARY) + self.overrideredirect(True) # No window decorations + self.attributes("-topmost", True) + + # Bind focus loss + self.bind("", self._on_focus_out) + self.bind("", lambda e: self._close()) + + self._create_widgets() + + # Focus the window + self.focus_force() + + def _create_widgets(self) -> None: + """Create the UI widgets.""" + # Main container with border + self.container = ctk.CTkFrame( + self, + fg_color=BG_SECONDARY, + border_width=1, + border_color=BORDER_COLOR, + corner_radius=8, + ) + self.container.pack(fill="both", expand=True, padx=2, pady=2) + + # Header + header = ctk.CTkFrame(self.container, fg_color="transparent") + header.pack(fill="x", padx=8, pady=(8, 4)) + + ctk.CTkLabel( + header, + text="Checkpoints", + font=get_font("large", bold=True), + text_color=TEXT_PRIMARY, + ).pack(side="left") + + ctk.CTkButton( + header, + text="×", + font=get_font("large"), + fg_color="transparent", + hover_color=BG_HOVER, + text_color=TEXT_MUTED, + width=24, + height=24, + command=self._close, + ).pack(side="right") + + # Scrollable list container + self.list_frame = ctk.CTkScrollableFrame( + self.container, + fg_color=BG_TERTIARY, + corner_radius=4, + height=250, + width=350, + ) + self.list_frame.pack(fill="both", expand=True, padx=8, pady=4) + + # Populate list + self._populate_list() + + # Buttons + button_frame = ctk.CTkFrame(self.container, fg_color="transparent") + button_frame.pack(fill="x", padx=8, pady=(4, 8)) + + self.rollback_btn = ctk.CTkButton( + button_frame, + text="Rollback to Selected", + font=get_font("normal"), + fg_color=ACCENT_YELLOW, + hover_color="#d4a910", + text_color=BG_PRIMARY, + height=28, + state="disabled", + command=self._confirm_rollback, + ) + self.rollback_btn.pack(side="left", fill="x", expand=True, padx=(0, 4)) + + ctk.CTkButton( + button_frame, + text="Cancel", + font=get_font("normal"), + fg_color=BG_TERTIARY, + hover_color=BG_HOVER, + text_color=TEXT_PRIMARY, + border_width=1, + border_color=BORDER_COLOR, + height=28, + width=70, + command=self._close, + ).pack(side="right") + + def _populate_list(self) -> None: + """Populate the checkpoint list.""" + # Clear existing items + for widget in self.list_frame.winfo_children(): + widget.destroy() + + if not self._checkpoints: + ctk.CTkLabel( + self.list_frame, + text="No checkpoints yet.\n\nCheckpoints are created when\nthe AI makes changes.", + font=get_font("normal"), + text_color=TEXT_MUTED, + justify="center", + ).pack(pady=40) + return + + # Current state indicator + current_item = CheckpointItem( + self.list_frame, + text="Current State", + subtext="No changes to undo", + is_current=True, + on_select=lambda: self._select(None), + ) + current_item.pack(fill="x", pady=(0, 4)) + + # Add separator + sep = ctk.CTkFrame(self.list_frame, fg_color=BORDER_COLOR, height=1) + sep.pack(fill="x", pady=4) + + # Add checkpoints (newest first) + for checkpoint in reversed(self._checkpoints): + item = CheckpointItem( + self.list_frame, + text=checkpoint.get_display_text(), + subtext=self._get_checkpoint_subtext(checkpoint), + checkpoint_id=checkpoint.id, + on_select=lambda cid=checkpoint.id: self._select(cid), + ) + item.pack(fill="x", pady=2) + + def _get_checkpoint_subtext(self, checkpoint: Checkpoint) -> str: + """Get subtext for a checkpoint.""" + if checkpoint.tool_calls: + first = checkpoint.tool_calls[0] + return f"→ {first.message}" + return checkpoint.description + + def _select(self, checkpoint_id: Optional[str]) -> None: + """Select a checkpoint.""" + self._selected_id = checkpoint_id + + # Update button state + if checkpoint_id is None: + self.rollback_btn.configure(state="disabled") + else: + self.rollback_btn.configure(state="normal") + + def _confirm_rollback(self) -> None: + """Confirm and execute rollback.""" + if self._selected_id: + self._on_rollback(self._selected_id) + + def _close(self) -> None: + """Close the popup.""" + self._on_close() + + def _on_focus_out(self, event) -> None: + """Handle focus loss.""" + # Check if focus went to a child widget + if event.widget == self: + # Small delay to allow button clicks to register + self.after(100, self._check_focus) + + def _check_focus(self) -> None: + """Check if we should close after focus loss.""" + try: + focused = self.focus_get() + if focused is None or not str(focused).startswith(str(self)): + self._close() + except tk.TclError: + self._close() + + def update_checkpoints(self, checkpoints: List[Checkpoint]) -> None: + """Update the checkpoint list.""" + self._checkpoints = checkpoints + self._selected_id = None + self._populate_list() + self.rollback_btn.configure(state="disabled") + + +class CheckpointItem(ctk.CTkFrame): + """ + Individual checkpoint item in the list. + """ + + def __init__( + self, + parent, + text: str, + subtext: str = "", + checkpoint_id: Optional[str] = None, + is_current: bool = False, + on_select: Optional[Callable[[], None]] = None, + **kwargs, + ): + """ + Initialize the checkpoint item. + + Args: + parent: Parent widget + text: Main text (timestamp + tool) + subtext: Description text + checkpoint_id: ID of the checkpoint (None for current state) + is_current: Whether this is the current state indicator + on_select: Callback when selected + """ + super().__init__( + parent, + fg_color=HIGHLIGHT_BG if is_current else "transparent", + corner_radius=4, + cursor="hand2", + **kwargs, + ) + + self._checkpoint_id = checkpoint_id + self._is_current = is_current + self._is_selected = False + self._on_select = on_select + + self._create_widgets(text, subtext) + self._bind_events() + + def _create_widgets(self, text: str, subtext: str) -> None: + """Create the UI widgets.""" + # Radio indicator + self.radio = ctk.CTkLabel( + self, + text="○" if not self._is_current else "●", + font=get_font("normal"), + text_color=ACCENT_BLUE if self._is_current else TEXT_MUTED, + width=20, + ) + self.radio.pack(side="left", padx=(8, 4), pady=6) + + # Text container + text_frame = ctk.CTkFrame(self, fg_color="transparent") + text_frame.pack(side="left", fill="x", expand=True, pady=4) + + # Main text + self.text_label = ctk.CTkLabel( + text_frame, + text=text, + font=get_font("normal"), + text_color=TEXT_PRIMARY, + anchor="w", + ) + self.text_label.pack(fill="x") + + # Subtext + if subtext: + self.subtext_label = ctk.CTkLabel( + text_frame, + text=subtext, + font=get_font("small"), + text_color=TEXT_MUTED, + anchor="w", + ) + self.subtext_label.pack(fill="x") + + def _bind_events(self) -> None: + """Bind mouse events.""" + widgets = [self, self.radio, self.text_label] + if hasattr(self, "subtext_label"): + widgets.append(self.subtext_label) + + for widget in widgets: + widget.bind("", self._on_click) + widget.bind("", self._on_enter) + widget.bind("", self._on_leave) + + def _on_click(self, event=None) -> None: + """Handle click.""" + if not self._is_current and self._on_select: + self._is_selected = True + self.radio.configure(text="●", text_color=ACCENT_BLUE) + self.configure(fg_color=HIGHLIGHT_BG) + self._on_select() + + def _on_enter(self, event=None) -> None: + """Handle mouse enter.""" + if not self._is_current and not self._is_selected: + self.configure(fg_color=BG_HOVER) + + def _on_leave(self, event=None) -> None: + """Handle mouse leave.""" + if not self._is_current and not self._is_selected: + self.configure(fg_color="transparent") diff --git a/src/proxy_app/ai_assistant/ui/message_view.py b/src/proxy_app/ai_assistant/ui/message_view.py new file mode 100644 index 0000000..bd10627 --- /dev/null +++ b/src/proxy_app/ai_assistant/ui/message_view.py @@ -0,0 +1,586 @@ +""" +Message Display Widget for the AI Assistant. + +Provides a scrollable view for displaying conversation messages, +including user messages, AI responses, thinking sections, and tool results. +""" + +import customtkinter as ctk +import tkinter as tk +from dataclasses import dataclass +from datetime import datetime +from typing import Callable, List, Optional + +from ..core import Message +from ..tools import ToolCall, ToolResult +from .styles import ( + ACCENT_BLUE, + ACCENT_GREEN, + ACCENT_RED, + AI_MESSAGE_BG, + BG_HOVER, + BG_PRIMARY, + BG_SECONDARY, + BG_TERTIARY, + BORDER_COLOR, + CORNER_RADIUS, + ERROR_BG, + FONT_FAMILY, + FONT_SIZE_NORMAL, + FONT_SIZE_SMALL, + MESSAGE_PADDING, + MESSAGE_SPACING, + TEXT_MUTED, + TEXT_PRIMARY, + TEXT_SECONDARY, + THINKING_BG, + THINKING_TEXT, + TOOL_BG, + TOOL_FAILURE_COLOR, + TOOL_SUCCESS_COLOR, + USER_MESSAGE_BG, + get_font, +) +from .thinking import ThinkingSection + + +@dataclass +class StreamingState: + """State for the currently streaming message.""" + + thinking_text: str = "" + content_text: str = "" + is_streaming: bool = False + thinking_widget: Optional[ThinkingSection] = None + content_widget: Optional[ctk.CTkLabel] = None + message_frame: Optional[ctk.CTkFrame] = None + + +class MessageView(ctk.CTkFrame): + """ + Scrollable message display for conversation history. + + Features: + - User messages (right-aligned, accent background) + - AI messages (left-aligned) + - Collapsible thinking sections + - Tool execution display with results + - Streaming text support + - Auto-scroll to bottom + """ + + def __init__(self, parent, **kwargs): + """ + Initialize the message view. + + Args: + parent: Parent widget + **kwargs: Additional frame arguments + """ + super().__init__(parent, fg_color=BG_TERTIARY, **kwargs) + + self._streaming = StreamingState() + self._message_widgets: List[ctk.CTkFrame] = [] + + self._create_widgets() + + def _create_widgets(self) -> None: + """Create the UI widgets.""" + # Scrollable container + self.scroll_frame = ctk.CTkScrollableFrame( + self, + fg_color=BG_TERTIARY, + corner_radius=0, + ) + self.scroll_frame.pack(fill="both", expand=True) + + # Configure scroll frame grid + self.scroll_frame.grid_columnconfigure(0, weight=1) + + # Welcome message + self._show_welcome() + + def _show_welcome(self) -> None: + """Show the welcome message.""" + welcome = ctk.CTkFrame(self.scroll_frame, fg_color="transparent") + welcome.pack(fill="x", pady=20) + + ctk.CTkLabel( + welcome, + text="AI Assistant", + font=get_font("title", bold=True), + text_color=TEXT_PRIMARY, + ).pack() + + ctk.CTkLabel( + welcome, + text="Ask me to help configure your model filters.\nI can add rules, explain model statuses, and more.", + font=get_font("normal"), + text_color=TEXT_SECONDARY, + justify="center", + ).pack(pady=(8, 0)) + + self._welcome_widget = welcome + + def _hide_welcome(self) -> None: + """Hide the welcome message.""" + if hasattr(self, "_welcome_widget") and self._welcome_widget: + self._welcome_widget.destroy() + self._welcome_widget = None + + def add_user_message( + self, content: str, timestamp: Optional[datetime] = None + ) -> None: + """ + Add a user message to the display. + + Args: + content: Message content + timestamp: Optional timestamp + """ + self._hide_welcome() + + # Container (right-aligned) + container = ctk.CTkFrame(self.scroll_frame, fg_color="transparent") + container.pack(fill="x", pady=(MESSAGE_SPACING, 0), padx=MESSAGE_PADDING) + + # Message bubble (right side) + bubble = ctk.CTkFrame( + container, + fg_color=USER_MESSAGE_BG, + corner_radius=CORNER_RADIUS, + ) + bubble.pack(side="right", anchor="e") + + # Content + label = ctk.CTkLabel( + bubble, + text=content, + font=get_font("normal"), + text_color=TEXT_PRIMARY, + wraplength=400, + justify="left", + anchor="w", + ) + label.pack(padx=MESSAGE_PADDING, pady=MESSAGE_PADDING) + + # Timestamp + if timestamp: + time_str = timestamp.strftime("%H:%M") + time_label = ctk.CTkLabel( + container, + text=time_str, + font=get_font("small"), + text_color=TEXT_MUTED, + ) + time_label.pack(side="right", padx=(0, 8)) + + self._message_widgets.append(container) + self._scroll_to_bottom() + + def start_ai_message(self) -> None: + """Start a new AI message (for streaming).""" + self._hide_welcome() + + # Container + container = ctk.CTkFrame(self.scroll_frame, fg_color="transparent") + container.pack(fill="x", pady=(MESSAGE_SPACING, 0), padx=MESSAGE_PADDING) + + # Message bubble (left side) + bubble = ctk.CTkFrame( + container, + fg_color=AI_MESSAGE_BG, + corner_radius=CORNER_RADIUS, + ) + bubble.pack(side="left", anchor="w", fill="x", expand=True) + + # Thinking section (initially hidden, created when needed) + self._streaming = StreamingState( + is_streaming=True, + message_frame=bubble, + ) + + self._message_widgets.append(container) + self._scroll_to_bottom() + + def append_thinking(self, chunk: str) -> None: + """ + Append thinking content to the current streaming message. + + Args: + chunk: Thinking text chunk + """ + if not self._streaming.is_streaming: + return + + self._streaming.thinking_text += chunk + + # Create thinking widget if not exists + if self._streaming.thinking_widget is None: + self._streaming.thinking_widget = ThinkingSection( + self._streaming.message_frame, + initial_expanded=True, + ) + self._streaming.thinking_widget.pack( + fill="x", padx=MESSAGE_PADDING, pady=(MESSAGE_PADDING, 0) + ) + + self._streaming.thinking_widget.append_text(chunk) + self._scroll_to_bottom() + + def append_content(self, chunk: str) -> None: + """ + Append content to the current streaming message. + + Args: + chunk: Content text chunk + """ + if not self._streaming.is_streaming: + return + + self._streaming.content_text += chunk + + # Auto-collapse thinking when content arrives + if ( + self._streaming.thinking_widget + and self._streaming.thinking_widget.is_expanded + ): + self._streaming.thinking_widget.auto_collapse() + + # Create or update content widget + if self._streaming.content_widget is None: + self._streaming.content_widget = ctk.CTkLabel( + self._streaming.message_frame, + text=self._streaming.content_text, + font=get_font("normal"), + text_color=TEXT_PRIMARY, + wraplength=500, + justify="left", + anchor="w", + ) + self._streaming.content_widget.pack( + fill="x", padx=MESSAGE_PADDING, pady=MESSAGE_PADDING + ) + else: + self._streaming.content_widget.configure(text=self._streaming.content_text) + + self._scroll_to_bottom() + + def finish_ai_message(self) -> None: + """Finish the current streaming AI message.""" + self._streaming.is_streaming = False + self._streaming = StreamingState() + + def add_ai_message( + self, + content: Optional[str], + thinking: Optional[str] = None, + timestamp: Optional[datetime] = None, + ) -> None: + """ + Add a complete AI message (non-streaming). + + Args: + content: Message content + thinking: Optional thinking content + timestamp: Optional timestamp + """ + self._hide_welcome() + + # Container + container = ctk.CTkFrame(self.scroll_frame, fg_color="transparent") + container.pack(fill="x", pady=(MESSAGE_SPACING, 0), padx=MESSAGE_PADDING) + + # Message bubble + bubble = ctk.CTkFrame( + container, + fg_color=AI_MESSAGE_BG, + corner_radius=CORNER_RADIUS, + ) + bubble.pack(side="left", anchor="w", fill="x", expand=True) + + # Thinking section (collapsed by default) + if thinking: + thinking_section = ThinkingSection(bubble, initial_expanded=False) + thinking_section.set_text(thinking) + thinking_section.pack( + fill="x", padx=MESSAGE_PADDING, pady=(MESSAGE_PADDING, 0) + ) + + # Content + if content: + label = ctk.CTkLabel( + bubble, + text=content, + font=get_font("normal"), + text_color=TEXT_PRIMARY, + wraplength=500, + justify="left", + anchor="w", + ) + label.pack(fill="x", padx=MESSAGE_PADDING, pady=MESSAGE_PADDING) + + self._message_widgets.append(container) + self._scroll_to_bottom() + + def add_tool_call( + self, tool_call: ToolCall, result: Optional[ToolResult] = None + ) -> ctk.CTkFrame: + """ + Add a tool call display. + + Args: + tool_call: The tool call + result: Optional result (if already executed) + + Returns: + The tool frame widget (for updating with result later) + """ + # Container + container = ctk.CTkFrame(self.scroll_frame, fg_color="transparent") + container.pack(fill="x", pady=(4, 0), padx=MESSAGE_PADDING) + + # Tool bubble + bubble = ctk.CTkFrame( + container, + fg_color=TOOL_BG, + corner_radius=6, + ) + bubble.pack(side="left", anchor="w") + + # Header with icon + header = ctk.CTkFrame(bubble, fg_color="transparent") + header.pack(fill="x", padx=8, pady=(6, 2)) + + # Icon (pending, success, or failure) + if result is None: + icon = "⋯" + icon_color = TEXT_MUTED + elif result.success: + icon = "✓" + icon_color = TOOL_SUCCESS_COLOR + else: + icon = "✗" + icon_color = TOOL_FAILURE_COLOR + + icon_label = ctk.CTkLabel( + header, + text=icon, + font=get_font("normal", bold=True), + text_color=icon_color, + width=16, + ) + icon_label.pack(side="left") + + # Tool name and arguments + args_str = ", ".join(f'{k}="{v}"' for k, v in tool_call.arguments.items()) + tool_text = f"{tool_call.name}({args_str})" + + ctk.CTkLabel( + header, + text=tool_text, + font=get_font("small", monospace=True), + text_color=TEXT_SECONDARY, + ).pack(side="left", padx=(4, 0)) + + # Result message + if result: + result_label = ctk.CTkLabel( + bubble, + text=f"→ {result.message}", + font=get_font("small"), + text_color=TEXT_MUTED, + anchor="w", + ) + result_label.pack(fill="x", padx=8, pady=(0, 6)) + + # Store references for updating + bubble._icon_label = icon_label + bubble._result_container = bubble + + self._message_widgets.append(container) + self._scroll_to_bottom() + + return bubble + + def update_tool_result(self, tool_frame: ctk.CTkFrame, result: ToolResult) -> None: + """ + Update a tool call display with its result. + + Args: + tool_frame: The tool frame returned by add_tool_call + result: The tool result + """ + if hasattr(tool_frame, "_icon_label"): + if result.success: + tool_frame._icon_label.configure( + text="✓", text_color=TOOL_SUCCESS_COLOR + ) + else: + tool_frame._icon_label.configure( + text="✗", text_color=TOOL_FAILURE_COLOR + ) + + # Add result message + result_label = ctk.CTkLabel( + tool_frame, + text=f"→ {result.message}", + font=get_font("small"), + text_color=TEXT_MUTED, + anchor="w", + ) + result_label.pack(fill="x", padx=8, pady=(0, 6)) + + self._scroll_to_bottom() + + def add_error( + self, + message: str, + is_retryable: bool = True, + on_retry: Optional[Callable[[], None]] = None, + on_cancel: Optional[Callable[[], None]] = None, + ) -> ctk.CTkFrame: + """ + Add an error display. + + Args: + message: Error message + is_retryable: Whether retry is possible + on_retry: Retry callback + on_cancel: Cancel callback + + Returns: + The error frame (for removal on retry success) + """ + # Container + container = ctk.CTkFrame(self.scroll_frame, fg_color="transparent") + container.pack(fill="x", pady=(MESSAGE_SPACING, 0), padx=MESSAGE_PADDING) + + # Error bubble + bubble = ctk.CTkFrame( + container, + fg_color=ERROR_BG, + corner_radius=CORNER_RADIUS, + border_width=1, + border_color=ACCENT_RED, + ) + bubble.pack(side="left", anchor="w", fill="x", expand=True) + + # Icon and title + header = ctk.CTkFrame(bubble, fg_color="transparent") + header.pack(fill="x", padx=MESSAGE_PADDING, pady=(MESSAGE_PADDING, 4)) + + ctk.CTkLabel( + header, + text="⚠", + font=get_font("large"), + text_color=ACCENT_RED, + ).pack(side="left") + + ctk.CTkLabel( + header, + text="Error", + font=get_font("normal", bold=True), + text_color=ACCENT_RED, + ).pack(side="left", padx=(8, 0)) + + # Message + ctk.CTkLabel( + bubble, + text=message, + font=get_font("normal"), + text_color=TEXT_PRIMARY, + wraplength=450, + justify="left", + anchor="w", + ).pack(fill="x", padx=MESSAGE_PADDING) + + # Buttons + if is_retryable or on_cancel: + btn_frame = ctk.CTkFrame(bubble, fg_color="transparent") + btn_frame.pack(fill="x", padx=MESSAGE_PADDING, pady=(8, MESSAGE_PADDING)) + + if is_retryable and on_retry: + ctk.CTkButton( + btn_frame, + text="Retry", + font=get_font("small"), + fg_color=ACCENT_BLUE, + hover_color="#3a8eef", + height=26, + width=60, + command=lambda: self._handle_retry(container, on_retry), + ).pack(side="left", padx=(0, 8)) + + if on_cancel: + ctk.CTkButton( + btn_frame, + text="Cancel", + font=get_font("small"), + fg_color=BG_TERTIARY, + hover_color=BG_HOVER, + border_width=1, + border_color=BORDER_COLOR, + height=26, + width=60, + command=on_cancel, + ).pack(side="left") + + self._message_widgets.append(container) + self._scroll_to_bottom() + + return container + + def _handle_retry( + self, error_frame: ctk.CTkFrame, on_retry: Callable[[], None] + ) -> None: + """Handle retry button click.""" + error_frame.destroy() + if error_frame in self._message_widgets: + self._message_widgets.remove(error_frame) + on_retry() + + def remove_error(self, error_frame: ctk.CTkFrame) -> None: + """Remove an error display.""" + if error_frame.winfo_exists(): + error_frame.destroy() + if error_frame in self._message_widgets: + self._message_widgets.remove(error_frame) + + def add_message(self, message: Message) -> None: + """ + Add a Message object to the display. + + Args: + message: The Message to display + """ + if message.role == "user": + self.add_user_message(message.content or "", message.timestamp) + elif message.role == "assistant": + if message.tool_calls: + # Show AI message with tool calls + if message.content or message.reasoning_content: + self.add_ai_message( + message.content, message.reasoning_content, message.timestamp + ) + for tc in message.tool_calls: + self.add_tool_call(tc, tc.result) + else: + self.add_ai_message( + message.content, message.reasoning_content, message.timestamp + ) + # Tool messages are typically displayed as part of tool_calls + + def clear(self) -> None: + """Clear all messages.""" + for widget in self._message_widgets: + if widget.winfo_exists(): + widget.destroy() + self._message_widgets.clear() + self._streaming = StreamingState() + self._show_welcome() + + def _scroll_to_bottom(self) -> None: + """Scroll to the bottom of the message list.""" + self.scroll_frame.update_idletasks() + self.scroll_frame._parent_canvas.yview_moveto(1.0) diff --git a/src/proxy_app/ai_assistant/ui/model_selector.py b/src/proxy_app/ai_assistant/ui/model_selector.py new file mode 100644 index 0000000..ec7f36b --- /dev/null +++ b/src/proxy_app/ai_assistant/ui/model_selector.py @@ -0,0 +1,207 @@ +""" +Grouped Model Selector Widget. + +Dropdown for selecting LLM models, grouped by provider. +""" + +import customtkinter as ctk +from typing import Callable, Dict, List, Optional + +from .styles import ( + ACCENT_BLUE, + BG_HOVER, + BG_SECONDARY, + BG_TERTIARY, + BORDER_COLOR, + FONT_FAMILY, + FONT_SIZE_NORMAL, + TEXT_PRIMARY, + TEXT_SECONDARY, + get_font, +) + + +class ModelSelector(ctk.CTkFrame): + """ + Dropdown for selecting LLM models, grouped by provider. + + Features: + - Models grouped by provider (openai, gemini, anthropic, etc.) + - Search/filter capability (future) + - Displays current selection + - Callback on selection change + """ + + def __init__( + self, parent, on_model_changed: Optional[Callable[[str], None]] = None, **kwargs + ): + """ + Initialize the model selector. + + Args: + parent: Parent widget + on_model_changed: Callback when model selection changes + **kwargs: Additional frame arguments + """ + super().__init__(parent, fg_color="transparent", **kwargs) + + self._on_model_changed = on_model_changed + self._models: Dict[str, List[str]] = {} # provider -> [model_ids] + self._flat_models: List[str] = [] # All model IDs + self._current_model: Optional[str] = None + + self._create_widgets() + + def _create_widgets(self) -> None: + """Create the UI widgets.""" + # Label + self.label = ctk.CTkLabel( + self, + text="Model:", + font=get_font("normal"), + text_color=TEXT_SECONDARY, + ) + self.label.pack(side="left", padx=(0, 8)) + + # Dropdown + self.dropdown = ctk.CTkComboBox( + self, + values=["Loading..."], + font=get_font("normal"), + dropdown_font=get_font("normal"), + fg_color=BG_TERTIARY, + border_color=BORDER_COLOR, + button_color=BG_SECONDARY, + button_hover_color=BG_HOVER, + dropdown_fg_color=BG_SECONDARY, + dropdown_hover_color=BG_HOVER, + text_color=TEXT_PRIMARY, + dropdown_text_color=TEXT_PRIMARY, + width=280, + state="readonly", + command=self._on_selection, + ) + self.dropdown.pack(side="left") + self.dropdown.set("Loading...") + + # Bind mousewheel to dropdown for scrolling through options + self._bind_mousewheel() + + def set_models(self, models: Dict[str, List[str]]) -> None: + """ + Set the available models. + + Args: + models: Dict of provider -> list of model IDs + """ + self._models = models + self._flat_models = [] + + # Build flat list with group headers + display_values = [] + + for provider in sorted(models.keys()): + provider_models = models[provider] + if not provider_models: + continue + + # Add models (full provider/model format) + for model_id in sorted(provider_models): + # Model ID might already include provider prefix + if "/" in model_id: + display_values.append(model_id) + self._flat_models.append(model_id) + else: + full_id = f"{provider}/{model_id}" + display_values.append(full_id) + self._flat_models.append(full_id) + + if display_values: + self.dropdown.configure(values=display_values, state="readonly") + + # Keep current selection if still valid + if self._current_model and self._current_model in self._flat_models: + self.dropdown.set(self._current_model) + else: + # Select first model + self._current_model = display_values[0] + self.dropdown.set(display_values[0]) + else: + self.dropdown.configure(values=["No models available"], state="disabled") + self.dropdown.set("No models available") + self._current_model = None + + def _on_selection(self, choice: str) -> None: + """Handle dropdown selection.""" + if choice and choice != "Loading..." and choice != "No models available": + self._current_model = choice + if self._on_model_changed: + self._on_model_changed(choice) + + def get_selected_model(self) -> Optional[str]: + """Get the currently selected model ID.""" + return self._current_model + + def set_selected_model(self, model_id: str) -> bool: + """ + Set the selected model. + + Args: + model_id: The model ID to select + + Returns: + True if model was found and selected + """ + if model_id in self._flat_models: + self._current_model = model_id + self.dropdown.set(model_id) + return True + return False + + def set_loading(self) -> None: + """Set the dropdown to loading state.""" + self.dropdown.configure(values=["Loading..."], state="disabled") + self.dropdown.set("Loading...") + + def set_error(self, message: str = "Failed to load models") -> None: + """Set the dropdown to error state.""" + self.dropdown.configure(values=[message], state="disabled") + self.dropdown.set(message) + + @property + def has_models(self) -> bool: + """Check if models are loaded.""" + return bool(self._flat_models) + + def _bind_mousewheel(self) -> None: + """Bind mousewheel events to cycle through models when dropdown is focused.""" + + def on_mousewheel(event): + if not self._flat_models: + return + + current = self._current_model + if current not in self._flat_models: + return + + current_idx = self._flat_models.index(current) + + # Scroll direction (Windows uses event.delta, Linux uses event.num) + if event.delta > 0 or event.num == 4: + # Scroll up - previous model + new_idx = max(0, current_idx - 1) + else: + # Scroll down - next model + new_idx = min(len(self._flat_models) - 1, current_idx + 1) + + if new_idx != current_idx: + new_model = self._flat_models[new_idx] + self._current_model = new_model + self.dropdown.set(new_model) + if self._on_model_changed: + self._on_model_changed(new_model) + + # Bind to the dropdown widget (works when hovering over it) + self.dropdown.bind("", on_mousewheel) # Windows + self.dropdown.bind("", on_mousewheel) # Linux scroll up + self.dropdown.bind("", on_mousewheel) # Linux scroll down diff --git a/src/proxy_app/ai_assistant/ui/styles.py b/src/proxy_app/ai_assistant/ui/styles.py new file mode 100644 index 0000000..c4c3d91 --- /dev/null +++ b/src/proxy_app/ai_assistant/ui/styles.py @@ -0,0 +1,188 @@ +""" +UI Styles and Constants for the AI Assistant. + +Imports colors from model_filter_gui.py for consistency +and defines assistant-specific styles. +""" + +# Import base colors from model_filter_gui +# Using the same color scheme for visual consistency +from ...model_filter_gui import ( + ACCENT_BLUE, + ACCENT_GREEN, + ACCENT_RED, + ACCENT_YELLOW, + BG_HOVER, + BG_PRIMARY, + BG_SECONDARY, + BG_TERTIARY, + BORDER_COLOR, + FONT_FAMILY, + FONT_SIZE_LARGE, + FONT_SIZE_NORMAL, + FONT_SIZE_SMALL, + FONT_SIZE_TITLE, + HIGHLIGHT_BG, + TEXT_MUTED, + TEXT_PRIMARY, + TEXT_SECONDARY, +) + +# Re-export base colors +__all__ = [ + # Base colors (from model_filter_gui) + "BG_PRIMARY", + "BG_SECONDARY", + "BG_TERTIARY", + "BG_HOVER", + "TEXT_PRIMARY", + "TEXT_SECONDARY", + "TEXT_MUTED", + "ACCENT_BLUE", + "ACCENT_GREEN", + "ACCENT_RED", + "ACCENT_YELLOW", + "BORDER_COLOR", + "HIGHLIGHT_BG", + "FONT_FAMILY", + "FONT_SIZE_SMALL", + "FONT_SIZE_NORMAL", + "FONT_SIZE_LARGE", + "FONT_SIZE_TITLE", + # Assistant-specific + "USER_MESSAGE_BG", + "AI_MESSAGE_BG", + "THINKING_BG", + "THINKING_TEXT", + "TOOL_BG", + "TOOL_SUCCESS_COLOR", + "TOOL_FAILURE_COLOR", + "ERROR_BG", + "MESSAGE_SPACING", + "MESSAGE_PADDING", + "CORNER_RADIUS", + "INPUT_MIN_HEIGHT", + "INPUT_MAX_HEIGHT", +] + +# ============================================================================ +# Assistant-Specific Colors +# ============================================================================ + +# Message backgrounds +USER_MESSAGE_BG = "#2a3f5f" # Blue-tinted, right-aligned +AI_MESSAGE_BG = BG_SECONDARY # Same as card backgrounds + +# Thinking section +THINKING_BG = "#1a1a2a" # Slightly darker +THINKING_TEXT = TEXT_MUTED # Muted text for thinking + +# Tool execution +TOOL_BG = "#1e2838" # Subtle background for tool displays +TOOL_SUCCESS_COLOR = ACCENT_GREEN # Checkmark color +TOOL_FAILURE_COLOR = ACCENT_RED # X color + +# Error display +ERROR_BG = "#3d2020" # Dark red tint for error messages + +# ============================================================================ +# Layout Constants +# ============================================================================ + +# Message display +MESSAGE_SPACING = 12 # Vertical spacing between messages +MESSAGE_PADDING = 12 # Internal padding for message boxes +CORNER_RADIUS = 8 # Border radius for message boxes + +# Input area +INPUT_MIN_HEIGHT = 55 # Minimum height (about 2 lines) +INPUT_MAX_HEIGHT = 200 # Maximum height before scrolling + +# ============================================================================ +# Font Configurations +# ============================================================================ + + +def get_font( + size: str = "normal", bold: bool = False, monospace: bool = False +) -> tuple: + """ + Get a font tuple for CTk widgets. + + Args: + size: "small", "normal", "large", or "title" + bold: Whether to use bold weight + monospace: Whether to use monospace font + + Returns: + Tuple of (family, size, weight) + """ + family = "Consolas" if monospace else FONT_FAMILY + + sizes = { + "small": FONT_SIZE_SMALL, + "normal": FONT_SIZE_NORMAL, + "large": FONT_SIZE_LARGE, + "title": FONT_SIZE_TITLE, + } + font_size = sizes.get(size, FONT_SIZE_NORMAL) + + if bold: + return (family, font_size, "bold") + return (family, font_size) + + +# ============================================================================ +# Widget Style Helpers +# ============================================================================ + + +def apply_button_style(button, style: str = "primary") -> None: + """ + Apply a predefined style to a CTkButton. + + Args: + button: The CTkButton to style + style: "primary", "secondary", "danger", or "ghost" + """ + styles = { + "primary": { + "fg_color": ACCENT_BLUE, + "hover_color": "#3a8eef", + "text_color": TEXT_PRIMARY, + }, + "secondary": { + "fg_color": BG_SECONDARY, + "hover_color": BG_HOVER, + "text_color": TEXT_PRIMARY, + "border_width": 1, + "border_color": BORDER_COLOR, + }, + "danger": { + "fg_color": ACCENT_RED, + "hover_color": "#c0392b", + "text_color": TEXT_PRIMARY, + }, + "ghost": { + "fg_color": "transparent", + "hover_color": BG_HOVER, + "text_color": TEXT_SECONDARY, + }, + "success": { + "fg_color": ACCENT_GREEN, + "hover_color": "#27ae60", + "text_color": TEXT_PRIMARY, + }, + } + + if style in styles: + button.configure(**styles[style]) + + +def get_scrollbar_style() -> dict: + """Get the style configuration for scrollbars.""" + return { + "button_color": BG_HOVER, + "button_hover_color": ACCENT_BLUE, + "fg_color": BG_TERTIARY, + } diff --git a/src/proxy_app/ai_assistant/ui/thinking.py b/src/proxy_app/ai_assistant/ui/thinking.py new file mode 100644 index 0000000..d747dd5 --- /dev/null +++ b/src/proxy_app/ai_assistant/ui/thinking.py @@ -0,0 +1,227 @@ +""" +Collapsible Thinking Section Widget. + +Displays the AI's thinking/reasoning content in a collapsible panel. +Auto-collapses when content starts arriving, can be manually toggled. +""" + +import customtkinter as ctk +from typing import Optional + +from .styles import ( + BG_TERTIARY, + BORDER_COLOR, + FONT_FAMILY, + FONT_SIZE_SMALL, + TEXT_MUTED, + TEXT_SECONDARY, + THINKING_BG, + THINKING_TEXT, + get_font, +) + + +class ThinkingSection(ctk.CTkFrame): + """ + Collapsible section for displaying AI thinking/reasoning. + + Features: + - Click header to expand/collapse + - Auto-collapse when content starts arriving + - Streaming text support + - Muted styling to distinguish from main content + """ + + def __init__( + self, + parent, + initial_expanded: bool = True, + max_collapsed_preview: int = 100, + **kwargs, + ): + """ + Initialize the thinking section. + + Args: + parent: Parent widget + initial_expanded: Whether to start expanded + max_collapsed_preview: Max characters to show when collapsed + **kwargs: Additional frame arguments + """ + super().__init__(parent, fg_color=THINKING_BG, corner_radius=6, **kwargs) + + self._expanded = initial_expanded + self._max_preview = max_collapsed_preview + self._full_text = "" + self._auto_collapsed = False + + self._create_widgets() + + def _create_widgets(self) -> None: + """Create the UI widgets.""" + # Header (clickable) + self.header = ctk.CTkFrame(self, fg_color="transparent", cursor="hand2") + self.header.pack(fill="x", padx=8, pady=(6, 0)) + + # Expand/collapse indicator + self.indicator = ctk.CTkLabel( + self.header, + text="▼" if self._expanded else "▶", + font=get_font("small"), + text_color=TEXT_MUTED, + width=16, + ) + self.indicator.pack(side="left") + + # Title + self.title_label = ctk.CTkLabel( + self.header, + text="Thinking", + font=get_font("small"), + text_color=TEXT_MUTED, + ) + self.title_label.pack(side="left", padx=(4, 0)) + + # Preview text (shown when collapsed) + self.preview_label = ctk.CTkLabel( + self.header, + text="", + font=get_font("small"), + text_color=TEXT_MUTED, + anchor="w", + ) + + # Content container + self.content_frame = ctk.CTkFrame(self, fg_color="transparent") + if self._expanded: + self.content_frame.pack(fill="both", expand=True, padx=8, pady=(4, 8)) + + # Text display - use CTkLabel for auto-sizing (no scrollbar) + self.text_display = ctk.CTkLabel( + self.content_frame, + text="", + font=get_font("small"), + text_color=THINKING_TEXT, + wraplength=450, # Wrap text to fit container + justify="left", + anchor="nw", + ) + self.text_display.pack(fill="both", expand=True) + + # Bind click events + self.header.bind("", self._toggle) + self.indicator.bind("", self._toggle) + self.title_label.bind("", self._toggle) + self.preview_label.bind("", self._toggle) + + # Bind mousewheel to stop propagation when over thinking content + self._bind_mousewheel_capture() + + def _toggle(self, event=None) -> None: + """Toggle expanded/collapsed state.""" + self._expanded = not self._expanded + self._auto_collapsed = False + self._update_display() + + def _update_display(self) -> None: + """Update the display based on current state.""" + if self._expanded: + self.indicator.configure(text="▼") + self.preview_label.pack_forget() + self.content_frame.pack(fill="both", expand=True, padx=8, pady=(4, 8)) + else: + self.indicator.configure(text="▶") + self.content_frame.pack_forget() + # Show preview + preview = self._full_text[: self._max_preview] + if len(self._full_text) > self._max_preview: + preview += "..." + preview = preview.replace("\n", " ") + self.preview_label.configure(text=f" - {preview}" if preview else "") + self.preview_label.pack(side="left", padx=(8, 0), fill="x", expand=True) + + def append_text(self, text: str) -> None: + """ + Append text to the thinking content. + + Args: + text: Text chunk to append + """ + self._full_text += text + + # Update label text + self.text_display.configure(text=self._full_text) + + # Update preview if collapsed + if not self._expanded: + self._update_display() + + def set_text(self, text: str) -> None: + """ + Set the full thinking text. + + Args: + text: Complete thinking text + """ + self._full_text = text + self.text_display.configure(text=text) + self._update_display() + + def auto_collapse(self) -> None: + """Auto-collapse when content starts arriving.""" + if self._expanded and not self._auto_collapsed: + self._expanded = False + self._auto_collapsed = True + self._update_display() + + def expand(self) -> None: + """Expand the section.""" + if not self._expanded: + self._expanded = True + self._auto_collapsed = False + self._update_display() + + def collapse(self) -> None: + """Collapse the section.""" + if self._expanded: + self._expanded = False + self._update_display() + + def clear(self) -> None: + """Clear all content.""" + self._full_text = "" + self.text_display.configure(text="") + self._update_display() + + @property + def is_expanded(self) -> bool: + """Check if currently expanded.""" + return self._expanded + + @property + def text(self) -> str: + """Get the full thinking text.""" + return self._full_text + + @property + def has_content(self) -> bool: + """Check if there is any thinking content.""" + return bool(self._full_text.strip()) + + def _bind_mousewheel_capture(self) -> None: + """Bind mousewheel events to prevent scroll propagation when expanded.""" + + def on_mousewheel(event): + # Only capture scroll if expanded and has content + if self._expanded and self._full_text: + # Stop event propagation - don't scroll parent + return "break" + # Allow propagation to parent scroll + return None + + # Bind to all child widgets + widgets = [self, self.content_frame, self.text_display, self.header] + for widget in widgets: + widget.bind("", on_mousewheel) # Windows + widget.bind("", on_mousewheel) # Linux scroll up + widget.bind("", on_mousewheel) # Linux scroll down diff --git a/src/proxy_app/model_filter_gui.py b/src/proxy_app/model_filter_gui.py new file mode 100644 index 0000000..16e70db --- /dev/null +++ b/src/proxy_app/model_filter_gui.py @@ -0,0 +1,3655 @@ +""" +Model Filter GUI - Visual editor for model ignore/whitelist rules. + +A CustomTkinter application that provides a friendly interface for managing +which models are available per provider through ignore lists and whitelists. + +Features: +- Two synchronized model lists showing all fetched models and their filtered status +- Color-coded rules with visual association to affected models +- Real-time filtering preview as you type patterns +- Click interactions to highlight rule-model relationships +- Right-click context menus for quick actions +- Comprehensive help documentation +""" + +import customtkinter as ctk +from tkinter import Menu +import asyncio +import threading +import os +import re +from pathlib import Path +from dataclasses import dataclass, field +from typing import List, Dict, Tuple, Optional, Callable, Set +from dotenv import load_dotenv, set_key, unset_key + + +# ════════════════════════════════════════════════════════════════════════════════ +# CONSTANTS & CONFIGURATION +# ════════════════════════════════════════════════════════════════════════════════ + +# Window settings +WINDOW_TITLE = "Model Filter Configuration" +WINDOW_DEFAULT_SIZE = "1000x750" +WINDOW_MIN_WIDTH = 600 +WINDOW_MIN_HEIGHT = 400 + +# Color scheme (dark mode) +BG_PRIMARY = "#1a1a2e" # Main background +BG_SECONDARY = "#16213e" # Card/panel background +BG_TERTIARY = "#0f0f1a" # Input fields, lists +BG_HOVER = "#1f2b47" # Hover state +BORDER_COLOR = "#2a2a4a" # Subtle borders +TEXT_PRIMARY = "#e8e8e8" # Main text +TEXT_SECONDARY = "#a0a0a0" # Muted text +TEXT_MUTED = "#666680" # Very muted text +ACCENT_BLUE = "#4a9eff" # Primary accent +ACCENT_GREEN = "#2ecc71" # Success/normal +ACCENT_RED = "#e74c3c" # Danger/ignore +ACCENT_YELLOW = "#f1c40f" # Warning + +# Status colors +NORMAL_COLOR = "#2ecc71" # Green - models not affected by any rule +HIGHLIGHT_BG = "#2a3a5a" # Background for highlighted items + +# Ignore rules - warm color progression (reds/oranges) +IGNORE_COLORS = [ + "#e74c3c", # Bright red + "#c0392b", # Dark red + "#e67e22", # Orange + "#d35400", # Dark orange + "#f39c12", # Gold + "#e91e63", # Pink + "#ff5722", # Deep orange + "#f44336", # Material red + "#ff6b6b", # Coral + "#ff8a65", # Light deep orange +] + +# Whitelist rules - cool color progression (blues/teals) +WHITELIST_COLORS = [ + "#3498db", # Blue + "#2980b9", # Dark blue + "#1abc9c", # Teal + "#16a085", # Dark teal + "#9b59b6", # Purple + "#8e44ad", # Dark purple + "#00bcd4", # Cyan + "#2196f3", # Material blue + "#64b5f6", # Light blue + "#4dd0e1", # Light cyan +] + +# Font configuration +FONT_FAMILY = "Segoe UI" +FONT_SIZE_SMALL = 11 +FONT_SIZE_NORMAL = 12 +FONT_SIZE_LARGE = 14 +FONT_SIZE_TITLE = 16 +FONT_SIZE_HEADER = 20 + + +# ════════════════════════════════════════════════════════════════════════════════ +# DATA CLASSES +# ════════════════════════════════════════════════════════════════════════════════ + + +@dataclass +class FilterRule: + """Represents a single filter rule (ignore or whitelist pattern).""" + + pattern: str + color: str + rule_type: str # 'ignore' or 'whitelist' + affected_count: int = 0 + affected_models: List[str] = field(default_factory=list) + + def __hash__(self): + return hash((self.pattern, self.rule_type)) + + def __eq__(self, other): + if not isinstance(other, FilterRule): + return False + return self.pattern == other.pattern and self.rule_type == other.rule_type + + +@dataclass +class ModelStatus: + """Status information for a single model.""" + + model_id: str + status: str # 'normal', 'ignored', 'whitelisted' + color: str + affecting_rule: Optional[FilterRule] = None + + @property + def display_name(self) -> str: + """Get the model name without provider prefix for display.""" + if "/" in self.model_id: + return self.model_id.split("/", 1)[1] + return self.model_id + + @property + def provider(self) -> str: + """Extract provider from model ID.""" + if "/" in self.model_id: + return self.model_id.split("/")[0] + return "" + + +# ════════════════════════════════════════════════════════════════════════════════ +# FILTER ENGINE +# ════════════════════════════════════════════════════════════════════════════════ + + +class FilterEngine: + """ + Core filtering logic with rule management. + + Handles pattern matching, rule storage, and status calculation. + Tracks changes for save/discard functionality. + Uses caching for performance with large model lists. + """ + + def __init__(self): + self.ignore_rules: List[FilterRule] = [] + self.whitelist_rules: List[FilterRule] = [] + self._ignore_color_index = 0 + self._whitelist_color_index = 0 + self._original_ignore_patterns: Set[str] = set() + self._original_whitelist_patterns: Set[str] = set() + self._current_provider: Optional[str] = None + + # Caching for performance + self._status_cache: Dict[str, ModelStatus] = {} + self._available_count_cache: Optional[Tuple[int, int]] = None + self._cache_valid: bool = False + + def _invalidate_cache(self): + """Mark cache as stale (call when rules change).""" + self._status_cache.clear() + self._available_count_cache = None + self._cache_valid = False + + def reset(self): + """Clear all rules and reset state.""" + self.ignore_rules.clear() + self.whitelist_rules.clear() + self._ignore_color_index = 0 + self._whitelist_color_index = 0 + self._original_ignore_patterns.clear() + self._original_whitelist_patterns.clear() + self._invalidate_cache() + + def _get_next_ignore_color(self) -> str: + """Get next color for ignore rules (cycles through palette).""" + color = IGNORE_COLORS[self._ignore_color_index % len(IGNORE_COLORS)] + self._ignore_color_index += 1 + return color + + def _get_next_whitelist_color(self) -> str: + """Get next color for whitelist rules (cycles through palette).""" + color = WHITELIST_COLORS[self._whitelist_color_index % len(WHITELIST_COLORS)] + self._whitelist_color_index += 1 + return color + + def add_ignore_rule(self, pattern: str) -> Optional[FilterRule]: + """Add a new ignore rule. Returns the rule if added, None if duplicate.""" + pattern = pattern.strip() + if not pattern: + return None + + # Check for duplicates + for rule in self.ignore_rules: + if rule.pattern == pattern: + return None + + rule = FilterRule( + pattern=pattern, color=self._get_next_ignore_color(), rule_type="ignore" + ) + self.ignore_rules.append(rule) + self._invalidate_cache() + return rule + + def add_whitelist_rule(self, pattern: str) -> Optional[FilterRule]: + """Add a new whitelist rule. Returns the rule if added, None if duplicate.""" + pattern = pattern.strip() + if not pattern: + return None + + # Check for duplicates + for rule in self.whitelist_rules: + if rule.pattern == pattern: + return None + + rule = FilterRule( + pattern=pattern, + color=self._get_next_whitelist_color(), + rule_type="whitelist", + ) + self.whitelist_rules.append(rule) + self._invalidate_cache() + return rule + + def remove_ignore_rule(self, pattern: str) -> bool: + """Remove an ignore rule by pattern. Returns True if removed.""" + for i, rule in enumerate(self.ignore_rules): + if rule.pattern == pattern: + self.ignore_rules.pop(i) + self._invalidate_cache() + return True + return False + + def remove_whitelist_rule(self, pattern: str) -> bool: + """Remove a whitelist rule by pattern. Returns True if removed.""" + for i, rule in enumerate(self.whitelist_rules): + if rule.pattern == pattern: + self.whitelist_rules.pop(i) + self._invalidate_cache() + return True + return False + + def _pattern_matches(self, model_id: str, pattern: str) -> bool: + """ + Check if a pattern matches a model ID. + + Supports: + - Exact match: "gpt-4" matches only "gpt-4" + - Prefix wildcard: "gpt-4*" matches "gpt-4", "gpt-4-turbo", etc. + - Match all: "*" matches everything + """ + # Extract model name without provider prefix + if "/" in model_id: + provider_model_name = model_id.split("/", 1)[1] + else: + provider_model_name = model_id + + if pattern == "*": + return True + elif pattern.endswith("*"): + prefix = pattern[:-1] + return provider_model_name.startswith(prefix) or model_id.startswith(prefix) + else: + # Exact match against full ID or provider model name + return model_id == pattern or provider_model_name == pattern + + def pattern_is_covered_by(self, new_pattern: str, existing_pattern: str) -> bool: + """ + Check if new_pattern is already covered by existing_pattern. + + A pattern A is covered by pattern B if every model that would match A + would also match B. + + Examples: + - "gpt-4" is covered by "gpt-4*" (prefix covers exact) + - "gpt-4-turbo" is covered by "gpt-4*" (prefix covers longer) + - "gpt-4*" is covered by "gpt-*" (broader prefix covers narrower) + - Anything is covered by "*" (match-all covers everything) + - "gpt-4" is covered by "gpt-4" (exact duplicate) + """ + # Exact duplicate + if new_pattern == existing_pattern: + return True + + # Existing is wildcard-all - covers everything + if existing_pattern == "*": + return True + + # If existing is a prefix wildcard + if existing_pattern.endswith("*"): + existing_prefix = existing_pattern[:-1] + + # New is exact match - check if it starts with existing prefix + if not new_pattern.endswith("*"): + return new_pattern.startswith(existing_prefix) + + # New is also a prefix wildcard - check if new prefix starts with existing + new_prefix = new_pattern[:-1] + return new_prefix.startswith(existing_prefix) + + # Existing is exact match - only covers exact duplicate (already handled) + return False + + def is_pattern_covered(self, new_pattern: str, rule_type: str) -> bool: + """ + Check if a new pattern is already covered by any existing rule of the same type. + """ + rules = self.ignore_rules if rule_type == "ignore" else self.whitelist_rules + for rule in rules: + if self.pattern_is_covered_by(new_pattern, rule.pattern): + return True + return False + + def get_covered_patterns(self, new_pattern: str, rule_type: str) -> List[str]: + """ + Get list of existing patterns that would be covered (made redundant) + by adding new_pattern. + + Used for smart merge: when adding a broader pattern, remove the + narrower patterns it covers. + """ + rules = self.ignore_rules if rule_type == "ignore" else self.whitelist_rules + covered = [] + for rule in rules: + if self.pattern_is_covered_by(rule.pattern, new_pattern): + # The existing rule would be covered by the new pattern + covered.append(rule.pattern) + return covered + + def _compute_status(self, model_id: str) -> ModelStatus: + """ + Compute the status of a model based on current rules (no caching). + + Priority: Whitelist > Ignore > Normal + """ + # Check whitelist first (takes priority) + for rule in self.whitelist_rules: + if self._pattern_matches(model_id, rule.pattern): + return ModelStatus( + model_id=model_id, + status="whitelisted", + color=rule.color, + affecting_rule=rule, + ) + + # Then check ignore + for rule in self.ignore_rules: + if self._pattern_matches(model_id, rule.pattern): + return ModelStatus( + model_id=model_id, + status="ignored", + color=rule.color, + affecting_rule=rule, + ) + + # Default: normal + return ModelStatus( + model_id=model_id, status="normal", color=NORMAL_COLOR, affecting_rule=None + ) + + def get_model_status(self, model_id: str) -> ModelStatus: + """Get status for a model (uses cache if available).""" + if model_id in self._status_cache: + return self._status_cache[model_id] + return self._compute_status(model_id) + + def _rebuild_cache(self, models: List[str]): + """Rebuild the entire status cache in one efficient pass.""" + self._status_cache.clear() + + # Reset rule counts + for rule in self.ignore_rules + self.whitelist_rules: + rule.affected_count = 0 + rule.affected_models = [] + + available = 0 + for model_id in models: + status = self._compute_status(model_id) + self._status_cache[model_id] = status + + if status.affecting_rule: + status.affecting_rule.affected_count += 1 + status.affecting_rule.affected_models.append(model_id) + + if status.status != "ignored": + available += 1 + + self._available_count_cache = (available, len(models)) + self._cache_valid = True + + def get_all_statuses(self, models: List[str]) -> List[ModelStatus]: + """Get status for all models (rebuilds cache if invalid).""" + if not self._cache_valid: + self._rebuild_cache(models) + return [self._status_cache.get(m, self._compute_status(m)) for m in models] + + def update_affected_counts(self, models: List[str]): + """Update the affected_count and affected_models for all rules.""" + # This now just ensures cache is valid - counts are updated in _rebuild_cache + if not self._cache_valid: + self._rebuild_cache(models) + + def get_available_count(self, models: List[str]) -> Tuple[int, int]: + """Returns (available_count, total_count) from cache.""" + if not self._cache_valid: + self._rebuild_cache(models) + return self._available_count_cache or (0, 0) + + def preview_pattern( + self, pattern: str, rule_type: str, models: List[str] + ) -> List[str]: + """ + Preview which models would be affected by a pattern without adding it. + Returns list of affected model IDs. + """ + affected = [] + pattern = pattern.strip() + if not pattern: + return affected + + for model_id in models: + if self._pattern_matches(model_id, pattern): + affected.append(model_id) + + return affected + + def load_from_env(self, provider: str): + """Load ignore/whitelist rules for a provider from environment.""" + self.reset() + self._current_provider = provider + load_dotenv(override=True) + + # Load ignore list + ignore_key = f"IGNORE_MODELS_{provider.upper()}" + ignore_value = os.getenv(ignore_key, "") + if ignore_value: + patterns = [p.strip() for p in ignore_value.split(",") if p.strip()] + for pattern in patterns: + self.add_ignore_rule(pattern) + self._original_ignore_patterns = set(patterns) + + # Load whitelist + whitelist_key = f"WHITELIST_MODELS_{provider.upper()}" + whitelist_value = os.getenv(whitelist_key, "") + if whitelist_value: + patterns = [p.strip() for p in whitelist_value.split(",") if p.strip()] + for pattern in patterns: + self.add_whitelist_rule(pattern) + self._original_whitelist_patterns = set(patterns) + + def save_to_env(self, provider: str) -> bool: + """ + Save current rules to .env file. + Returns True if successful. + """ + env_path = Path.cwd() / ".env" + + try: + ignore_key = f"IGNORE_MODELS_{provider.upper()}" + whitelist_key = f"WHITELIST_MODELS_{provider.upper()}" + + # Save ignore patterns + ignore_patterns = [rule.pattern for rule in self.ignore_rules] + if ignore_patterns: + set_key(str(env_path), ignore_key, ",".join(ignore_patterns)) + else: + # Remove the key if no patterns + unset_key(str(env_path), ignore_key) + + # Save whitelist patterns + whitelist_patterns = [rule.pattern for rule in self.whitelist_rules] + if whitelist_patterns: + set_key(str(env_path), whitelist_key, ",".join(whitelist_patterns)) + else: + unset_key(str(env_path), whitelist_key) + + # Update original state + self._original_ignore_patterns = set(ignore_patterns) + self._original_whitelist_patterns = set(whitelist_patterns) + + return True + except Exception as e: + print(f"Error saving to .env: {e}") + return False + + def has_unsaved_changes(self) -> bool: + """Check if current rules differ from saved state.""" + current_ignore = set(rule.pattern for rule in self.ignore_rules) + current_whitelist = set(rule.pattern for rule in self.whitelist_rules) + + return ( + current_ignore != self._original_ignore_patterns + or current_whitelist != self._original_whitelist_patterns + ) + + def discard_changes(self): + """Reload rules from environment, discarding unsaved changes.""" + if self._current_provider: + self.load_from_env(self._current_provider) + + +# ════════════════════════════════════════════════════════════════════════════════ +# MODEL FETCHER +# ════════════════════════════════════════════════════════════════════════════════ + +# Global cache for fetched models (persists across provider switches) +_model_cache: Dict[str, List[str]] = {} + + +class ModelFetcher: + """ + Handles async model fetching from providers. + + Runs fetching in a background thread to avoid blocking the GUI. + Includes caching to avoid refetching on every provider switch. + """ + + @staticmethod + def get_cached_models(provider: str) -> Optional[List[str]]: + """Get cached models for a provider, if available.""" + return _model_cache.get(provider) + + @staticmethod + def clear_cache(provider: Optional[str] = None): + """Clear model cache. If provider specified, only clear that provider.""" + if provider: + _model_cache.pop(provider, None) + else: + _model_cache.clear() + + @staticmethod + def get_available_providers() -> List[str]: + """Get list of providers that have credentials configured.""" + providers = set() + load_dotenv(override=True) + + # Scan environment for API keys (handles numbered keys like GEMINI_API_KEY_1) + for key in os.environ: + if "_API_KEY" in key and "PROXY_API_KEY" not in key: + # Extract provider: NVIDIA_NIM_API_KEY_1 -> nvidia_nim + provider = key.split("_API_KEY")[0].lower() + providers.add(provider) + + # Check for OAuth providers + oauth_dir = Path("oauth_creds") + if oauth_dir.exists(): + for file in oauth_dir.glob("*_oauth_*.json"): + provider = file.name.split("_oauth_")[0] + providers.add(provider) + + return sorted(list(providers)) + + @staticmethod + def _find_credential(provider: str) -> Optional[str]: + """Find a credential for a provider (handles numbered keys).""" + load_dotenv(override=True) + provider_upper = provider.upper() + + # Try exact match first (e.g., GEMINI_API_KEY) + exact_key = f"{provider_upper}_API_KEY" + if os.getenv(exact_key): + return os.getenv(exact_key) + + # Look for numbered keys (e.g., GEMINI_API_KEY_1, NVIDIA_NIM_API_KEY_1) + for key, value in os.environ.items(): + if key.startswith(f"{provider_upper}_API_KEY") and value: + return value + + # Check for OAuth credentials + oauth_dir = Path("oauth_creds") + if oauth_dir.exists(): + oauth_files = list(oauth_dir.glob(f"{provider}_oauth_*.json")) + if oauth_files: + return str(oauth_files[0]) + + return None + + @staticmethod + async def _fetch_models_async(provider: str) -> Tuple[List[str], Optional[str]]: + """ + Async implementation of model fetching. + Returns: (models_list, error_message_or_none) + """ + try: + import httpx + from rotator_library.providers import PROVIDER_PLUGINS + + # Get credential + credential = ModelFetcher._find_credential(provider) + if not credential: + return [], f"No credentials found for '{provider}'" + + # Get provider class + provider_class = PROVIDER_PLUGINS.get(provider.lower()) + if not provider_class: + return [], f"Unknown provider: '{provider}'" + + # Fetch models + async with httpx.AsyncClient(timeout=30.0) as client: + instance = provider_class() + models = await instance.get_models(credential, client) + return models, None + + except ImportError as e: + return [], f"Import error: {e}" + except Exception as e: + return [], f"Failed to fetch: {str(e)}" + + @staticmethod + def fetch_models( + provider: str, + on_success: Callable[[List[str]], None], + on_error: Callable[[str], None], + on_start: Optional[Callable[[], None]] = None, + force_refresh: bool = False, + ): + """ + Fetch models in a background thread. + + Args: + provider: Provider name (e.g., 'openai', 'gemini') + on_success: Callback with list of model IDs + on_error: Callback with error message + on_start: Optional callback when fetching starts + force_refresh: If True, bypass cache and fetch fresh + """ + # Check cache first (unless force refresh) + if not force_refresh: + cached = ModelFetcher.get_cached_models(provider) + if cached is not None: + on_success(cached) + return + + def run_fetch(): + if on_start: + on_start() + + try: + # Run async fetch in new event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + models, error = loop.run_until_complete( + ModelFetcher._fetch_models_async(provider) + ) + # Clean up any pending tasks to avoid warnings + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + if pending: + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + if error: + on_error(error) + else: + # Cache the results + _model_cache[provider] = models + on_success(models) + + except Exception as e: + on_error(str(e)) + + thread = threading.Thread(target=run_fetch, daemon=True) + thread.start() + + +# ════════════════════════════════════════════════════════════════════════════════ +# HELP WINDOW +# ════════════════════════════════════════════════════════════════════════════════ + + +class HelpWindow(ctk.CTkToplevel): + """ + Modal help popup with comprehensive filtering documentation. + Uses CTkTextbox for proper scrolling with dark theme styling. + """ + + def __init__(self, parent): + super().__init__(parent) + + self.title("Help - Model Filtering") + self.geometry("700x600") + self.minsize(600, 500) + + # Make modal + self.transient(parent) + self.grab_set() + + # Configure appearance + self.configure(fg_color=BG_PRIMARY) + + # Build content + self._create_content() + + # Center on parent + self.update_idletasks() + x = parent.winfo_x() + (parent.winfo_width() - self.winfo_width()) // 2 + y = parent.winfo_y() + (parent.winfo_height() - self.winfo_height()) // 2 + self.geometry(f"+{x}+{y}") + + # Focus + self.focus_force() + + # Bind escape to close + self.bind("", lambda e: self.destroy()) + + def _create_content(self): + """Build the help content using CTkTextbox for proper scrolling.""" + # Main container + main_frame = ctk.CTkFrame(self, fg_color="transparent") + main_frame.pack(fill="both", expand=True, padx=20, pady=(20, 10)) + + # Use CTkTextbox - CustomTkinter's styled text widget with built-in scrolling + self.text_box = ctk.CTkTextbox( + main_frame, + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + fg_color=BG_SECONDARY, + text_color=TEXT_SECONDARY, + corner_radius=8, + wrap="word", + activate_scrollbars=True, + ) + self.text_box.pack(fill="both", expand=True) + + # Configure text tags for formatting + # Access the underlying tk.Text widget for tag configuration + text_widget = self.text_box._textbox + + text_widget.tag_configure( + "title", + font=(FONT_FAMILY, FONT_SIZE_HEADER, "bold"), + foreground=TEXT_PRIMARY, + spacing1=5, + spacing3=15, + ) + text_widget.tag_configure( + "section_title", + font=(FONT_FAMILY, FONT_SIZE_LARGE, "bold"), + foreground=ACCENT_BLUE, + spacing1=20, + spacing3=8, + ) + text_widget.tag_configure( + "separator", + font=(FONT_FAMILY, 6), + foreground=BORDER_COLOR, + spacing3=5, + ) + text_widget.tag_configure( + "content", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + foreground=TEXT_SECONDARY, + spacing1=2, + spacing3=5, + lmargin1=5, + lmargin2=5, + ) + + # Insert content + self._insert_help_content() + + # Make read-only by disabling + self.text_box.configure(state="disabled") + + # Bind mouse wheel for faster scrolling on the internal canvas + self.text_box.bind("", self._on_mousewheel) + # Also bind on the textbox's internal widget + self.text_box._textbox.bind("", self._on_mousewheel) + + # Close button at bottom + btn_frame = ctk.CTkFrame(self, fg_color="transparent") + btn_frame.pack(fill="x", padx=20, pady=(10, 15)) + + close_btn = ctk.CTkButton( + btn_frame, + text="Got it!", + font=(FONT_FAMILY, FONT_SIZE_NORMAL, "bold"), + fg_color=ACCENT_BLUE, + hover_color="#3a8aee", + height=40, + width=120, + command=self.destroy, + ) + close_btn.pack() + + def _on_mousewheel(self, event): + """Handle mouse wheel with faster scrolling.""" + # CTkTextbox uses _textbox internally + self.text_box._textbox.yview_scroll(-1 * (event.delta // 40), "units") + return "break" + + def _insert_help_content(self): + """Insert all help text with formatting.""" + # Access internal text widget for inserting with tags + text_widget = self.text_box._textbox + + # Title + text_widget.insert("end", "📖 Model Filtering Guide\n", "title") + + # Sections with emojis + sections = [ + ( + "🎯 Overview", + """Model filtering allows you to control which models are available through your proxy for each provider. + +• Use the IGNORE list to block specific models +• Use the WHITELIST to ensure specific models are always available +• Whitelist ALWAYS takes priority over Ignore""", + ), + ( + "⚖️ Filtering Priority", + """When a model is checked, the following order is used: + +1. WHITELIST CHECK + If the model matches any whitelist pattern → AVAILABLE + (Whitelist overrides everything else) + +2. IGNORE CHECK + If the model matches any ignore pattern → BLOCKED + +3. DEFAULT + If no patterns match → AVAILABLE""", + ), + ( + "✏️ Pattern Syntax", + """Three types of patterns are supported: + +EXACT MATCH + Pattern: gpt-4 + Matches: only "gpt-4", nothing else + +PREFIX WILDCARD + Pattern: gpt-4* + Matches: "gpt-4", "gpt-4-turbo", "gpt-4-preview", etc. + +MATCH ALL + Pattern: * + Matches: every model for this provider""", + ), + ( + "💡 Common Patterns", + """BLOCK ALL, ALLOW SPECIFIC: + Ignore: * + Whitelist: gpt-4o, gpt-4o-mini + Result: Only gpt-4o and gpt-4o-mini available + +BLOCK PREVIEW MODELS: + Ignore: *-preview, *-preview* + Result: All preview variants blocked + +BLOCK SPECIFIC SERIES: + Ignore: o1*, dall-e* + Result: All o1 and DALL-E models blocked + +ALLOW ONLY LATEST: + Ignore: * + Whitelist: *-latest + Result: Only models ending in "-latest" available""", + ), + ( + "🖱️ Interface Guide", + """PROVIDER DROPDOWN + Select which provider to configure + +MODEL LISTS + • Left list: All fetched models (unfiltered) + • Right list: Same models with colored status + • Green = Available (normal) + • Red/Orange tones = Blocked (ignored) + • Blue/Teal tones = Whitelisted + +SEARCH BOX + Filter both lists to find specific models quickly + +CLICKING MODELS + • Left-click: Highlight the rule affecting this model + • Right-click: Context menu with quick actions + +CLICKING RULES + • Highlights all models affected by that rule + • Shows which models will be blocked/allowed + +RULE INPUT (Merge Mode) + • Enter patterns separated by commas + • Only adds patterns not covered by existing rules + • Press Add or Enter to create rules + +IMPORT BUTTON (Replace Mode) + • Replaces ALL existing rules with imported ones + • Paste comma-separated patterns + +DELETE RULES + • Click the × button on any rule to remove it""", + ), + ( + "⌨️ Keyboard Shortcuts", + """Ctrl+S Save changes +Ctrl+R Refresh models from provider +Ctrl+F Focus search box +F1 Open this help window +Escape Clear search / Close dialogs""", + ), + ( + "💾 Saving Changes", + """Changes are saved to your .env file in this format: + + IGNORE_MODELS_OPENAI=pattern1,pattern2* + WHITELIST_MODELS_OPENAI=specific-model + +Click "Save" to persist changes, or "Discard" to revert. +Closing the window with unsaved changes will prompt you.""", + ), + ] + + for section_title, content in sections: + text_widget.insert("end", f"\n{section_title}\n", "section_title") + text_widget.insert("end", "─" * 50 + "\n", "separator") + text_widget.insert("end", content.strip() + "\n", "content") + + +# ════════════════════════════════════════════════════════════════════════════════ +# CUSTOM DIALOG +# ════════════════════════════════════════════════════════════════════════════════ + + +class UnsavedChangesDialog(ctk.CTkToplevel): + """Modal dialog for unsaved changes confirmation.""" + + def __init__(self, parent): + super().__init__(parent) + + self.result: Optional[str] = None # 'save', 'discard', 'cancel' + + self.title("Unsaved Changes") + self.geometry("400x180") + self.resizable(False, False) + + # Make modal + self.transient(parent) + self.grab_set() + + # Configure appearance + self.configure(fg_color=BG_PRIMARY) + + # Build content + self._create_content() + + # Center on parent + self.update_idletasks() + x = parent.winfo_x() + (parent.winfo_width() - self.winfo_width()) // 2 + y = parent.winfo_y() + (parent.winfo_height() - self.winfo_height()) // 2 + self.geometry(f"+{x}+{y}") + + # Focus + self.focus_force() + + # Bind escape to cancel + self.bind("", lambda e: self._on_cancel()) + + # Handle window close + self.protocol("WM_DELETE_WINDOW", self._on_cancel) + + def _create_content(self): + """Build dialog content.""" + # Icon and message + msg_frame = ctk.CTkFrame(self, fg_color="transparent") + msg_frame.pack(fill="x", padx=30, pady=(25, 15)) + + icon = ctk.CTkLabel( + msg_frame, text="⚠️", font=(FONT_FAMILY, 32), text_color=ACCENT_YELLOW + ) + icon.pack(side="left", padx=(0, 15)) + + text_frame = ctk.CTkFrame(msg_frame, fg_color="transparent") + text_frame.pack(side="left", fill="x", expand=True) + + title = ctk.CTkLabel( + text_frame, + text="Unsaved Changes", + font=(FONT_FAMILY, FONT_SIZE_LARGE, "bold"), + text_color=TEXT_PRIMARY, + anchor="w", + ) + title.pack(anchor="w") + + subtitle = ctk.CTkLabel( + text_frame, + text="You have unsaved filter changes.\nWhat would you like to do?", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + text_color=TEXT_SECONDARY, + anchor="w", + justify="left", + ) + subtitle.pack(anchor="w") + + # Buttons + btn_frame = ctk.CTkFrame(self, fg_color="transparent") + btn_frame.pack(fill="x", padx=30, pady=(10, 25)) + + cancel_btn = ctk.CTkButton( + btn_frame, + text="Cancel", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + fg_color=BG_SECONDARY, + hover_color=BG_HOVER, + border_width=1, + border_color=BORDER_COLOR, + width=100, + command=self._on_cancel, + ) + cancel_btn.pack(side="right", padx=(10, 0)) + + discard_btn = ctk.CTkButton( + btn_frame, + text="Discard", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + fg_color=ACCENT_RED, + hover_color="#c0392b", + width=100, + command=self._on_discard, + ) + discard_btn.pack(side="right", padx=(10, 0)) + + save_btn = ctk.CTkButton( + btn_frame, + text="Save", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + fg_color=ACCENT_GREEN, + hover_color="#27ae60", + width=100, + command=self._on_save, + ) + save_btn.pack(side="right") + + def _on_save(self): + self.result = "save" + self.destroy() + + def _on_discard(self): + self.result = "discard" + self.destroy() + + def _on_cancel(self): + self.result = "cancel" + self.destroy() + + def show(self) -> Optional[str]: + """Show dialog and return result.""" + self.wait_window() + return self.result + + +class ImportRulesDialog(ctk.CTkToplevel): + """Modal dialog for importing rules from comma-separated text.""" + + def __init__(self, parent, rule_type: str): + super().__init__(parent) + + self.result: Optional[List[str]] = None + self.rule_type = rule_type + + title_text = ( + "Import Ignore Rules" if rule_type == "ignore" else "Import Whitelist Rules" + ) + self.title(title_text) + self.geometry("500x300") + self.minsize(400, 250) + + # Make modal + self.transient(parent) + self.grab_set() + + # Configure appearance + self.configure(fg_color=BG_PRIMARY) + + # Build content + self._create_content() + + # Center on parent + self.update_idletasks() + x = parent.winfo_x() + (parent.winfo_width() - self.winfo_width()) // 2 + y = parent.winfo_y() + (parent.winfo_height() - self.winfo_height()) // 2 + self.geometry(f"+{x}+{y}") + + # Focus + self.focus_force() + self.text_box.focus_set() + + # Bind escape to cancel + self.bind("", lambda e: self._on_cancel()) + + # Handle window close + self.protocol("WM_DELETE_WINDOW", self._on_cancel) + + def _create_content(self): + """Build dialog content.""" + # Instructions at TOP + instruction_frame = ctk.CTkFrame(self, fg_color="transparent") + instruction_frame.pack(fill="x", padx=20, pady=(15, 10)) + + instruction = ctk.CTkLabel( + instruction_frame, + text="Paste comma-separated patterns below (will REPLACE all existing rules):", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + text_color=TEXT_PRIMARY, + anchor="w", + ) + instruction.pack(anchor="w") + + example = ctk.CTkLabel( + instruction_frame, + text="Example: gpt-4*, claude-3*, model-name", + font=(FONT_FAMILY, FONT_SIZE_SMALL), + text_color=TEXT_MUTED, + anchor="w", + ) + example.pack(anchor="w") + + # Buttons at BOTTOM - pack BEFORE textbox to reserve space + btn_frame = ctk.CTkFrame(self, fg_color="transparent", height=50) + btn_frame.pack(side="bottom", fill="x", padx=20, pady=(10, 15)) + btn_frame.pack_propagate(False) + + cancel_btn = ctk.CTkButton( + btn_frame, + text="Cancel", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + fg_color=BG_SECONDARY, + hover_color=BG_HOVER, + border_width=1, + border_color=BORDER_COLOR, + width=100, + height=32, + command=self._on_cancel, + ) + cancel_btn.pack(side="right", padx=(10, 0)) + + import_btn = ctk.CTkButton( + btn_frame, + text="Replace All", + font=(FONT_FAMILY, FONT_SIZE_NORMAL, "bold"), + fg_color=ACCENT_BLUE, + hover_color="#3a8aee", + width=110, + height=32, + command=self._on_import, + ) + import_btn.pack(side="right") + + # Text box fills MIDDLE space - pack LAST + self.text_box = ctk.CTkTextbox( + self, + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + fg_color=BG_TERTIARY, + border_color=BORDER_COLOR, + border_width=1, + text_color=TEXT_PRIMARY, + wrap="word", + ) + self.text_box.pack(fill="both", expand=True, padx=20, pady=(0, 0)) + + # Bind Ctrl+Enter to import + self.text_box.bind("", lambda e: self._on_import()) + + def _on_import(self): + """Parse and return the patterns.""" + text = self.text_box.get("1.0", "end").strip() + if text: + # Parse comma-separated patterns + patterns = [p.strip() for p in text.split(",") if p.strip()] + self.result = patterns + else: + self.result = [] + self.destroy() + + def _on_cancel(self): + self.result = None + self.destroy() + + def show(self) -> Optional[List[str]]: + """Show dialog and return list of patterns, or None if cancelled.""" + self.wait_window() + return self.result + + +class ImportResultDialog(ctk.CTkToplevel): + """Simple dialog showing import results.""" + + def __init__(self, parent, added: int, skipped: int, is_replace: bool = False): + super().__init__(parent) + + self.title("Import Complete") + self.geometry("380x160") + self.resizable(False, False) + + # Make modal + self.transient(parent) + self.grab_set() + + # Configure appearance + self.configure(fg_color=BG_PRIMARY) + + # Build content + self._create_content(added, skipped, is_replace) + + # Center on parent + self.update_idletasks() + x = parent.winfo_x() + (parent.winfo_width() - self.winfo_width()) // 2 + y = parent.winfo_y() + (parent.winfo_height() - self.winfo_height()) // 2 + self.geometry(f"+{x}+{y}") + + # Focus + self.focus_force() + + # Bind escape and enter to close + self.bind("", lambda e: self.destroy()) + self.bind("", lambda e: self.destroy()) + + def _create_content(self, added: int, skipped: int, is_replace: bool): + """Build dialog content.""" + # Icon and message + msg_frame = ctk.CTkFrame(self, fg_color="transparent") + msg_frame.pack(fill="x", padx=30, pady=(25, 15)) + + icon = ctk.CTkLabel( + msg_frame, + text="✅" if added > 0 else "ℹ️", + font=(FONT_FAMILY, 28), + text_color=ACCENT_GREEN if added > 0 else ACCENT_BLUE, + ) + icon.pack(side="left", padx=(0, 15)) + + text_frame = ctk.CTkFrame(msg_frame, fg_color="transparent") + text_frame.pack(side="left", fill="x", expand=True) + + # Title text differs based on mode + if is_replace: + if added > 0: + added_text = f"Replaced with {added} rule{'s' if added != 1 else ''}" + else: + added_text = "All rules cleared" + else: + if added > 0: + added_text = f"Added {added} rule{'s' if added != 1 else ''}" + else: + added_text = "No new rules added" + + title = ctk.CTkLabel( + text_frame, + text=added_text, + font=(FONT_FAMILY, FONT_SIZE_LARGE, "bold"), + text_color=TEXT_PRIMARY, + anchor="w", + ) + title.pack(anchor="w") + + # Subtitle for skipped/duplicates + if skipped > 0: + skip_text = f"{skipped} duplicate{'s' if skipped != 1 else ''} skipped" + subtitle = ctk.CTkLabel( + text_frame, + text=skip_text, + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + text_color=TEXT_MUTED, + anchor="w", + ) + subtitle.pack(anchor="w") + + # OK button + btn_frame = ctk.CTkFrame(self, fg_color="transparent") + btn_frame.pack(fill="x", padx=30, pady=(0, 20)) + + ok_btn = ctk.CTkButton( + btn_frame, + text="OK", + font=(FONT_FAMILY, FONT_SIZE_NORMAL), + fg_color=ACCENT_BLUE, + hover_color="#3a8aee", + width=80, + command=self.destroy, + ) + ok_btn.pack(side="right") + + +# ════════════════════════════════════════════════════════════════════════════════ +# TOOLTIP +# ════════════════════════════════════════════════════════════════════════════════ + + +class ToolTip: + """Simple tooltip implementation for CustomTkinter widgets.""" + + def __init__(self, widget, text: str, delay: int = 500): + self.widget = widget + self.text = text + self.delay = delay + self.tooltip_window = None + self.after_id = None + + widget.bind("", self._schedule_show) + widget.bind("", self._hide) + widget.bind("