diff --git a/cmd/testai/main-testai.go b/cmd/testai/main-testai.go index eace7ca61a..8e8fcdb3eb 100644 --- a/cmd/testai/main-testai.go +++ b/cmd/testai/main-testai.go @@ -24,8 +24,9 @@ import ( var testSchemaJSON string const ( - DefaultAnthropicModel = "claude-sonnet-4-5" - DefaultOpenAIModel = "gpt-5.1" + DefaultAnthropicModel = "claude-sonnet-4-5" + DefaultOpenAIModel = "gpt-5.1" + DefaultOpenRouterModel = "mistralai/mistral-small-3.2-24b-instruct" ) // TestResponseWriter implements http.ResponseWriter and additional interfaces for testing @@ -113,7 +114,7 @@ func testOpenAI(ctx context.Context, model, message string, tools []uctypes.Tool } opts := &uctypes.AIOptsType{ - APIType: aiusechat.APIType_OpenAI, + APIType: uctypes.APIType_OpenAIResponses, APIToken: apiKey, Model: model, MaxTokens: 4096, @@ -155,6 +156,106 @@ func testOpenAI(ctx context.Context, model, message string, tools []uctypes.Tool } } +func testOpenAIComp(ctx context.Context, model, message string, tools []uctypes.ToolDefinition) { + apiKey := os.Getenv("OPENAI_APIKEY") + if apiKey == "" { + fmt.Println("Error: OPENAI_APIKEY environment variable not set") + os.Exit(1) + } + + opts := &uctypes.AIOptsType{ + APIType: uctypes.APIType_OpenAIChat, + APIToken: apiKey, + BaseURL: "https://api.openai.com/v1/chat/completions", + Model: model, + MaxTokens: 4096, + ThinkingLevel: uctypes.ThinkingLevelMedium, + } + + chatID := uuid.New().String() + + aiMessage := &uctypes.AIMessage{ + MessageId: uuid.New().String(), + Parts: []uctypes.AIMessagePart{ + { + Type: uctypes.AIMessagePartTypeText, + Text: message, + }, + }, + } + + fmt.Printf("Testing OpenAI Completions API with WaveAIPostMessageWrap, model: %s\n", model) + fmt.Printf("Message: %s\n", message) + fmt.Printf("Chat ID: %s\n", chatID) + fmt.Println("---") + + testWriter := &TestResponseWriter{} + sseHandler := sse.MakeSSEHandlerCh(testWriter, ctx) + defer sseHandler.Close() + + chatOpts := uctypes.WaveChatOpts{ + ChatId: chatID, + ClientId: uuid.New().String(), + Config: *opts, + Tools: tools, + SystemPrompt: []string{"You are a helpful assistant. Be concise and clear in your responses."}, + } + err := aiusechat.WaveAIPostMessageWrap(ctx, sseHandler, aiMessage, chatOpts) + if err != nil { + fmt.Printf("OpenAI Completions API streaming error: %v\n", err) + } +} + +func testOpenRouter(ctx context.Context, model, message string, tools []uctypes.ToolDefinition) { + apiKey := os.Getenv("OPENROUTER_APIKEY") + if apiKey == "" { + fmt.Println("Error: OPENROUTER_APIKEY environment variable not set") + os.Exit(1) + } + + opts := &uctypes.AIOptsType{ + APIType: uctypes.APIType_OpenAIChat, + APIToken: apiKey, + BaseURL: "https://openrouter.ai/api/v1/chat/completions", + Model: model, + MaxTokens: 4096, + ThinkingLevel: uctypes.ThinkingLevelMedium, + } + + chatID := uuid.New().String() + + aiMessage := &uctypes.AIMessage{ + MessageId: uuid.New().String(), + Parts: []uctypes.AIMessagePart{ + { + Type: uctypes.AIMessagePartTypeText, + Text: message, + }, + }, + } + + fmt.Printf("Testing OpenRouter with WaveAIPostMessageWrap, model: %s\n", model) + fmt.Printf("Message: %s\n", message) + fmt.Printf("Chat ID: %s\n", chatID) + fmt.Println("---") + + testWriter := &TestResponseWriter{} + sseHandler := sse.MakeSSEHandlerCh(testWriter, ctx) + defer sseHandler.Close() + + chatOpts := uctypes.WaveChatOpts{ + ChatId: chatID, + ClientId: uuid.New().String(), + Config: *opts, + Tools: tools, + SystemPrompt: []string{"You are a helpful assistant. Be concise and clear in your responses."}, + } + err := aiusechat.WaveAIPostMessageWrap(ctx, sseHandler, aiMessage, chatOpts) + if err != nil { + fmt.Printf("OpenRouter streaming error: %v\n", err) + } +} + func testAnthropic(ctx context.Context, model, message string, tools []uctypes.ToolDefinition) { apiKey := os.Getenv("ANTHROPIC_APIKEY") if apiKey == "" { @@ -163,7 +264,7 @@ func testAnthropic(ctx context.Context, model, message string, tools []uctypes.T } opts := &uctypes.AIOptsType{ - APIType: aiusechat.APIType_Anthropic, + APIType: uctypes.APIType_AnthropicMessages, APIToken: apiKey, Model: model, MaxTokens: 4096, @@ -217,33 +318,46 @@ func testT2(ctx context.Context) { testOpenAI(ctx, DefaultOpenAIModel, "what is 2+2+8, use the provider adder tool", tools) } +func testT3(ctx context.Context) { + testOpenAIComp(ctx, "gpt-4o", "what is 2+2? please be brief", nil) +} + func printUsage() { - fmt.Println("Usage: go run main-testai.go [--anthropic] [--tools] [--model ] [message]") + fmt.Println("Usage: go run main-testai.go [--anthropic|--openaicomp|--openrouter] [--tools] [--model ] [message]") fmt.Println("Examples:") fmt.Println(" go run main-testai.go 'What is 2+2?'") fmt.Println(" go run main-testai.go --model o4-mini 'What is 2+2?'") fmt.Println(" go run main-testai.go --anthropic 'What is 2+2?'") fmt.Println(" go run main-testai.go --anthropic --model claude-3-5-sonnet-20241022 'What is 2+2?'") + fmt.Println(" go run main-testai.go --openaicomp --model gpt-4o 'What is 2+2?'") + fmt.Println(" go run main-testai.go --openrouter 'What is 2+2?'") + fmt.Println(" go run main-testai.go --openrouter --model anthropic/claude-3.5-sonnet 'What is 2+2?'") fmt.Println(" go run main-testai.go --tools 'Help me configure GitHub Actions monitoring'") fmt.Println("") fmt.Println("Default models:") fmt.Printf(" OpenAI: %s\n", DefaultOpenAIModel) fmt.Printf(" Anthropic: %s\n", DefaultAnthropicModel) + fmt.Printf(" OpenAI Completions: gpt-4o\n") + fmt.Printf(" OpenRouter: %s\n", DefaultOpenRouterModel) fmt.Println("") fmt.Println("Environment variables:") fmt.Println(" OPENAI_APIKEY (for OpenAI models)") fmt.Println(" ANTHROPIC_APIKEY (for Anthropic models)") + fmt.Println(" OPENROUTER_APIKEY (for OpenRouter models)") } func main() { - var anthropic, tools, help, t1, t2 bool + var anthropic, openaicomp, openrouter, tools, help, t1, t2, t3 bool var model string flag.BoolVar(&anthropic, "anthropic", false, "Use Anthropic API instead of OpenAI") + flag.BoolVar(&openaicomp, "openaicomp", false, "Use OpenAI Completions API") + flag.BoolVar(&openrouter, "openrouter", false, "Use OpenRouter API") flag.BoolVar(&tools, "tools", false, "Enable GitHub Actions Monitor tools for testing") - flag.StringVar(&model, "model", "", fmt.Sprintf("AI model to use (defaults: %s for OpenAI, %s for Anthropic)", DefaultOpenAIModel, DefaultAnthropicModel)) + flag.StringVar(&model, "model", "", fmt.Sprintf("AI model to use (defaults: %s for OpenAI, %s for Anthropic, %s for OpenRouter)", DefaultOpenAIModel, DefaultAnthropicModel, DefaultOpenRouterModel)) flag.BoolVar(&help, "help", false, "Show usage information") flag.BoolVar(&t1, "t1", false, fmt.Sprintf("Run preset T1 test (%s with 'what is 2+2')", DefaultAnthropicModel)) flag.BoolVar(&t2, "t2", false, fmt.Sprintf("Run preset T2 test (%s with 'what is 2+2')", DefaultOpenAIModel)) + flag.BoolVar(&t3, "t3", false, "Run preset T3 test (OpenAI Completions API with gpt-4o)") flag.Parse() if help { @@ -262,11 +376,19 @@ func main() { testT2(ctx) return } + if t3 { + testT3(ctx) + return + } // Set default model based on API type if not provided if model == "" { if anthropic { model = DefaultAnthropicModel + } else if openaicomp { + model = "gpt-4o" + } else if openrouter { + model = DefaultOpenRouterModel } else { model = DefaultOpenAIModel } @@ -285,6 +407,10 @@ func main() { if anthropic { testAnthropic(ctx, model, message, toolDefs) + } else if openaicomp { + testOpenAIComp(ctx, model, message, toolDefs) + } else if openrouter { + testOpenRouter(ctx, model, message, toolDefs) } else { testOpenAI(ctx, model, message, toolDefs) } diff --git a/frontend/app/aipanel/aimode.tsx b/frontend/app/aipanel/aimode.tsx new file mode 100644 index 0000000000..d5ec9d3063 --- /dev/null +++ b/frontend/app/aipanel/aimode.tsx @@ -0,0 +1,116 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +import { atoms } from "@/app/store/global"; +import { cn, makeIconClass } from "@/util/util"; +import { useAtomValue } from "jotai"; +import { memo, useRef, useState } from "react"; +import { WaveAIModel } from "./waveai-model"; + +export const AIModeDropdown = memo(() => { + const model = WaveAIModel.getInstance(); + const aiMode = useAtomValue(model.currentAIMode); + const aiModeConfigs = useAtomValue(model.aiModeConfigs); + const rateLimitInfo = useAtomValue(atoms.waveAIRateLimitInfoAtom); + const [isOpen, setIsOpen] = useState(false); + const dropdownRef = useRef(null); + + const hasPremium = !rateLimitInfo || rateLimitInfo.unknown || rateLimitInfo.preq > 0; + const hideQuick = model.inBuilder && hasPremium; + + const sortedConfigs = Object.entries(aiModeConfigs) + .map(([mode, config]) => ({ mode, ...config })) + .sort((a, b) => { + const orderDiff = (a["display:order"] || 0) - (b["display:order"] || 0); + if (orderDiff !== 0) return orderDiff; + return (a["display:name"] || "").localeCompare(b["display:name"] || ""); + }) + .filter((config) => !(hideQuick && config.mode === "waveai@quick")); + + const handleSelect = (mode: string) => { + const config = aiModeConfigs[mode]; + if (!config) return; + if (!hasPremium && config["waveai:premium"]) { + return; + } + model.setAIMode(mode); + setIsOpen(false); + }; + + let currentMode = aiMode || "waveai@balanced"; + const currentConfig = aiModeConfigs[currentMode]; + if (currentConfig) { + if (!hasPremium && currentConfig["waveai:premium"]) { + currentMode = "waveai@quick"; + } + if (hideQuick && currentMode === "waveai@quick") { + currentMode = "waveai@balanced"; + } + } + + const displayConfig = aiModeConfigs[currentMode] || { + "display:name": "? Unknown", + "display:icon": "question", + }; + + return ( +
+ + + {isOpen && ( + <> +
setIsOpen(false)} /> +
+ {sortedConfigs.map((config, index) => { + const isFirst = index === 0; + const isLast = index === sortedConfigs.length - 1; + const isDisabled = !hasPremium && config["waveai:premium"]; + const isSelected = currentMode === config.mode; + return ( + + ); + })} +
+ + )} +
+ ); +}); + +AIModeDropdown.displayName = "AIModeDropdown"; diff --git a/frontend/app/aipanel/aipanel-contextmenu.ts b/frontend/app/aipanel/aipanel-contextmenu.ts index b7a7f718d4..05060b5e64 100644 --- a/frontend/app/aipanel/aipanel-contextmenu.ts +++ b/frontend/app/aipanel/aipanel-contextmenu.ts @@ -41,45 +41,45 @@ export async function handleWaveAIContextMenu(e: React.MouseEvent, showCopy: boo const rateLimitInfo = globalStore.get(atoms.waveAIRateLimitInfoAtom); const hasPremium = !rateLimitInfo || rateLimitInfo.unknown || rateLimitInfo.preq > 0; - const currentThinkingMode = rtInfo?.["waveai:thinkingmode"] ?? (hasPremium ? "balanced" : "quick"); + const currentAIMode = rtInfo?.["waveai:mode"] ?? (hasPremium ? "waveai@balanced" : "waveai@quick"); const defaultTokens = model.inBuilder ? 24576 : 4096; const currentMaxTokens = rtInfo?.["waveai:maxoutputtokens"] ?? defaultTokens; - const thinkingModeSubmenu: ContextMenuItem[] = [ + const aiModeSubmenu: ContextMenuItem[] = [ { label: "Quick (gpt-5-mini)", type: "checkbox", - checked: currentThinkingMode === "quick", + checked: currentAIMode === "waveai@quick", click: () => { RpcApi.SetRTInfoCommand(TabRpcClient, { oref: model.orefContext, - data: { "waveai:thinkingmode": "quick" }, + data: { "waveai:mode": "waveai@quick" }, }); }, }, { label: hasPremium ? "Balanced (gpt-5.1, low thinking)" : "Balanced (premium)", type: "checkbox", - checked: currentThinkingMode === "balanced", + checked: currentAIMode === "waveai@balanced", enabled: hasPremium, click: () => { if (!hasPremium) return; RpcApi.SetRTInfoCommand(TabRpcClient, { oref: model.orefContext, - data: { "waveai:thinkingmode": "balanced" }, + data: { "waveai:mode": "waveai@balanced" }, }); }, }, { label: hasPremium ? "Deep (gpt-5.1, full thinking)" : "Deep (premium)", type: "checkbox", - checked: currentThinkingMode === "deep", + checked: currentAIMode === "waveai@deep", enabled: hasPremium, click: () => { if (!hasPremium) return; RpcApi.SetRTInfoCommand(TabRpcClient, { oref: model.orefContext, - data: { "waveai:thinkingmode": "deep" }, + data: { "waveai:mode": "waveai@deep" }, }); }, }, @@ -164,8 +164,8 @@ export async function handleWaveAIContextMenu(e: React.MouseEvent, showCopy: boo } menu.push({ - label: "Thinking Mode", - submenu: thinkingModeSubmenu, + label: "AI Mode", + submenu: aiModeSubmenu, }); menu.push({ diff --git a/frontend/app/aipanel/aipanel.tsx b/frontend/app/aipanel/aipanel.tsx index 79ae04fcc1..062fc2f559 100644 --- a/frontend/app/aipanel/aipanel.tsx +++ b/frontend/app/aipanel/aipanel.tsx @@ -16,12 +16,12 @@ import { memo, useCallback, useEffect, useRef, useState } from "react"; import { useDrop } from "react-dnd"; import { formatFileSizeError, isAcceptableFile, validateFileSize } from "./ai-utils"; import { AIDroppedFiles } from "./aidroppedfiles"; +import { AIModeDropdown } from "./aimode"; import { AIPanelHeader } from "./aipanelheader"; import { AIPanelInput } from "./aipanelinput"; import { AIPanelMessages } from "./aipanelmessages"; import { AIRateLimitStrip } from "./airatelimitstrip"; import { TelemetryRequiredMessage } from "./telemetryrequired"; -import { ThinkingLevelDropdown } from "./thinkingmode"; import { WaveAIModel } from "./waveai-model"; const AIBlockMask = memo(() => { @@ -246,6 +246,8 @@ const AIPanelComponentInner = memo(() => { model.registerUseChatData(sendMessage, setMessages, status, stop); // console.log("AICHAT messages", messages); + (window as any).aichatmessages = messages; + (window as any).aichatstatus = status; const handleKeyDown = (waveEvent: WaveKeyboardEvent): boolean => { if (checkKeyPressed(waveEvent, "Cmd:k")) { @@ -498,7 +500,7 @@ const AIPanelComponentInner = memo(() => { onContextMenu={(e) => handleWaveAIContextMenu(e, true)} >
- +
{model.inBuilder ? : }
diff --git a/frontend/app/aipanel/aipanelmessages.tsx b/frontend/app/aipanel/aipanelmessages.tsx index a32e3936b4..3d3ae0d912 100644 --- a/frontend/app/aipanel/aipanelmessages.tsx +++ b/frontend/app/aipanel/aipanelmessages.tsx @@ -4,7 +4,7 @@ import { useAtomValue } from "jotai"; import { memo, useEffect, useRef } from "react"; import { AIMessage } from "./aimessage"; -import { ThinkingLevelDropdown } from "./thinkingmode"; +import { AIModeDropdown } from "./aimode"; import { WaveAIModel } from "./waveai-model"; interface AIPanelMessagesProps { @@ -45,13 +45,13 @@ export const AIPanelMessages = memo(({ messages, status, onContextMenu }: AIPane useEffect(() => { const wasStreaming = prevStatusRef.current === "streaming"; const isNowNotStreaming = status !== "streaming"; - + if (wasStreaming && isNowNotStreaming) { requestAnimationFrame(() => { scrollToBottom(); }); } - + prevStatusRef.current = status; }, [status]); @@ -62,7 +62,7 @@ export const AIPanelMessages = memo(({ messages, status, onContextMenu }: AIPane onContextMenu={onContextMenu} >
- +
{messages.map((message, index) => { const isLastMessage = index === messages.length - 1; diff --git a/frontend/app/aipanel/aitypes.ts b/frontend/app/aipanel/aitypes.ts index a1192ec7ed..cc3c73d224 100644 --- a/frontend/app/aipanel/aitypes.ts +++ b/frontend/app/aipanel/aitypes.ts @@ -4,14 +4,14 @@ import { ChatRequestOptions, FileUIPart, UIMessage, UIMessagePart } from "ai"; type WaveUIDataTypes = { - // pkg/aiusechat/uctypes/usechat-types.go UIMessageDataUserFile + // pkg/aiusechat/uctypes/uctypes.go UIMessageDataUserFile userfile: { filename: string; size: number; mimetype: string; previewurl?: string; }; - // pkg/aiusechat/uctypes/usechat-types.go UIMessageDataToolUse + // pkg/aiusechat/uctypes/uctypes.go UIMessageDataToolUse tooluse: { toolcallid: string; toolname: string; diff --git a/frontend/app/aipanel/thinkingmode.tsx b/frontend/app/aipanel/thinkingmode.tsx deleted file mode 100644 index 1e0fb76be7..0000000000 --- a/frontend/app/aipanel/thinkingmode.tsx +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2025, Command Line Inc. -// SPDX-License-Identifier: Apache-2.0 - -import { atoms } from "@/app/store/global"; -import { cn } from "@/util/util"; -import { useAtomValue } from "jotai"; -import { memo, useRef, useState } from "react"; -import { WaveAIModel } from "./waveai-model"; - -type ThinkingMode = "quick" | "balanced" | "deep"; - -interface ThinkingModeMetadata { - icon: string; - name: string; - desc: string; - premium: boolean; -} - -const ThinkingModeData: Record = { - quick: { - icon: "fa-bolt", - name: "Quick", - desc: "Fastest responses (gpt-5-mini)", - premium: false, - }, - balanced: { - icon: "fa-sparkles", - name: "Balanced", - desc: "Good mix of speed and accuracy\n(gpt-5.1 with minimal thinking)", - premium: true, - }, - deep: { - icon: "fa-lightbulb", - name: "Deep", - desc: "Slower but most capable\n(gpt-5.1 with full reasoning)", - premium: true, - }, -}; - -export const ThinkingLevelDropdown = memo(() => { - const model = WaveAIModel.getInstance(); - const thinkingMode = useAtomValue(model.thinkingMode); - const rateLimitInfo = useAtomValue(atoms.waveAIRateLimitInfoAtom); - const [isOpen, setIsOpen] = useState(false); - const dropdownRef = useRef(null); - - const hasPremium = !rateLimitInfo || rateLimitInfo.unknown || rateLimitInfo.preq > 0; - const hideQuick = model.inBuilder && hasPremium; - - const handleSelect = (mode: ThinkingMode) => { - const metadata = ThinkingModeData[mode]; - if (!hasPremium && metadata.premium) { - return; - } - model.setThinkingMode(mode); - setIsOpen(false); - }; - - let currentMode = (thinkingMode as ThinkingMode) || "balanced"; - const currentMetadata = ThinkingModeData[currentMode]; - if (!hasPremium && currentMetadata.premium) { - currentMode = "quick"; - } - if (hideQuick && currentMode === "quick") { - currentMode = "balanced"; - } - - return ( -
- - - {isOpen && ( - <> -
setIsOpen(false)} /> -
- {(Object.keys(ThinkingModeData) as ThinkingMode[]) - .filter((mode) => !(hideQuick && mode === "quick")) - .map((mode, index, filteredModes) => { - const metadata = ThinkingModeData[mode]; - const isFirst = index === 0; - const isLast = index === filteredModes.length - 1; - const isDisabled = !hasPremium && metadata.premium; - const isSelected = currentMode === mode; - return ( - - ); - })} -
- - )} -
- ); -}); - -ThinkingLevelDropdown.displayName = "ThinkingLevelDropdown"; diff --git a/frontend/app/aipanel/waveai-model.tsx b/frontend/app/aipanel/waveai-model.tsx index 7af0914e88..34e11ec5ce 100644 --- a/frontend/app/aipanel/waveai-model.tsx +++ b/frontend/app/aipanel/waveai-model.tsx @@ -57,7 +57,8 @@ export class WaveAIModel { widgetAccessAtom!: jotai.Atom; droppedFiles: jotai.PrimitiveAtom = jotai.atom([]); chatId!: jotai.PrimitiveAtom; - thinkingMode: jotai.PrimitiveAtom = jotai.atom("balanced"); + currentAIMode: jotai.PrimitiveAtom = jotai.atom("waveai@balanced"); + aiModeConfigs!: jotai.Atom>; errorMessage: jotai.PrimitiveAtom = jotai.atom(null) as jotai.PrimitiveAtom; modelAtom!: jotai.Atom; containerWidth: jotai.PrimitiveAtom = jotai.atom(0); @@ -82,6 +83,11 @@ export class WaveAIModel { const modelMetaAtom = getOrefMetaKeyAtom(this.orefContext, "waveai:model"); return get(modelMetaAtom) ?? "gpt-5.1"; }); + this.aiModeConfigs = jotai.atom((get) => { + const fullConfig = get(atoms.fullConfigAtom); + return fullConfig?.waveai ?? {}; + }); + this.widgetAccessAtom = jotai.atom((get) => { if (this.inBuilder) { @@ -337,11 +343,11 @@ export class WaveAIModel { }); } - setThinkingMode(mode: string) { - globalStore.set(this.thinkingMode, mode); + setAIMode(mode: string) { + globalStore.set(this.currentAIMode, mode); RpcApi.SetRTInfoCommand(TabRpcClient, { oref: this.orefContext, - data: { "waveai:thinkingmode": mode }, + data: { "waveai:mode": mode }, }); } @@ -359,8 +365,8 @@ export class WaveAIModel { } globalStore.set(this.chatId, chatIdValue); - const thinkingModeValue = rtInfo?.["waveai:thinkingmode"] ?? "balanced"; - globalStore.set(this.thinkingMode, thinkingModeValue); + const aiModeValue = rtInfo?.["waveai:mode"] ?? "waveai@balanced"; + globalStore.set(this.currentAIMode, aiModeValue); try { const chatData = await RpcApi.GetWaveAIChatCommand(TabRpcClient, { chatid: chatIdValue }); diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 8b80fe62af..d00e629b4a 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -13,6 +13,25 @@ declare global { data64: string; }; + // wconfig.AIModeConfigType + type AIModeConfigType = { + "display:name": string; + "display:order"?: number; + "display:icon": string; + "display:shortdesc"?: string; + "display:description": string; + "ai:apitype": string; + "ai:model": string; + "ai:thinkinglevel": string; + "ai:baseurl"?: string; + "ai:apiversion"?: string; + "ai:apitoken"?: string; + "ai:apitokensecretname"?: string; + "ai:capabilities"?: string[]; + "waveai:cloud"?: boolean; + "waveai:premium": boolean; + }; + // wshrpc.ActivityDisplayType type ActivityDisplayType = { width: number; @@ -750,6 +769,7 @@ declare global { termthemes: {[key: string]: TermThemeType}; connections: {[key: string]: ConnKeywords}; bookmarks: {[key: string]: WebBookmark}; + waveai: {[key: string]: AIModeConfigType}; configerrors: ConfigError[]; }; @@ -930,7 +950,7 @@ declare global { "builder:appid"?: string; "builder:env"?: {[key: string]: string}; "waveai:chatid"?: string; - "waveai:thinkingmode"?: string; + "waveai:mode"?: string; "waveai:maxoutputtokens"?: number; }; @@ -1240,7 +1260,7 @@ declare global { "waveai:requestdurms"?: number; "waveai:widgetaccess"?: boolean; "waveai:thinkinglevel"?: string; - "waveai:thinkingmode"?: string; + "waveai:mode"?: string; "waveai:feedback"?: "good" | "bad"; "waveai:action"?: string; $set?: TEventUserProps; diff --git a/go.mod b/go.mod index d6339d50a1..b165226881 100644 --- a/go.mod +++ b/go.mod @@ -81,6 +81,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/outrigdev/goid v0.3.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sirupsen/logrus v1.9.3 // indirect diff --git a/go.sum b/go.sum index fbc5bc2d2f..e44a38bfdd 100644 --- a/go.sum +++ b/go.sum @@ -146,6 +146,8 @@ github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuE github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/outrigdev/goid v0.3.0 h1:t/otQD3EXc45cLtQVPUnNgEyRaTQA4cPeu3qVcrsIws= +github.com/outrigdev/goid v0.3.0/go.mod h1:hEH7f27ypN/GHWt/7gvkRoFYR0LZizfUBIAbak4neVE= github.com/photostorm/pty v1.1.19-0.20230903182454-31354506054b h1:cLGKfKb1uk0hxI0Q8L83UAJPpeJ+gSpn3cCU/tjd3eg= github.com/photostorm/pty v1.1.19-0.20230903182454-31354506054b/go.mod h1:KO+FcPtyLAiRC0hJwreJVvfwc7vnNz77UxBTIGHdPVk= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= diff --git a/pkg/aiusechat/aiutil/aiutil.go b/pkg/aiusechat/aiutil/aiutil.go index fb9f8bb517..0fd4854469 100644 --- a/pkg/aiusechat/aiutil/aiutil.go +++ b/pkg/aiusechat/aiutil/aiutil.go @@ -5,6 +5,7 @@ package aiutil import ( "bytes" + "context" "crypto/sha256" "encoding/base64" "encoding/hex" @@ -12,9 +13,12 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wcore" + "github.com/wavetermdev/waveterm/pkg/web/sse" ) // ExtractXmlAttribute extracts an attribute value from an XML-like tag. @@ -180,3 +184,90 @@ func JsonEncodeRequestBody(reqBody any) (bytes.Buffer, error) { } return buf, nil } + +func IsOpenAIReasoningModel(model string) bool { + m := strings.ToLower(model) + return strings.HasPrefix(m, "o1") || + strings.HasPrefix(m, "o3") || + strings.HasPrefix(m, "o4") || + strings.HasPrefix(m, "gpt-5") || + strings.HasPrefix(m, "gpt-5.1") +} + +// CreateToolUseData creates a UIMessageDataToolUse from tool call information +func CreateToolUseData(toolCallID, toolName string, arguments string, chatOpts uctypes.WaveChatOpts) uctypes.UIMessageDataToolUse { + toolUseData := uctypes.UIMessageDataToolUse{ + ToolCallId: toolCallID, + ToolName: toolName, + Status: uctypes.ToolUseStatusPending, + } + + toolDef := chatOpts.GetToolDefinition(toolName) + if toolDef == nil { + toolUseData.Status = uctypes.ToolUseStatusError + toolUseData.ErrorMessage = "tool not found" + return toolUseData + } + + var parsedArgs any + if err := json.Unmarshal([]byte(arguments), &parsedArgs); err != nil { + toolUseData.Status = uctypes.ToolUseStatusError + toolUseData.ErrorMessage = fmt.Sprintf("failed to parse tool arguments: %v", err) + return toolUseData + } + + if toolDef.ToolCallDesc != nil { + toolUseData.ToolDesc = toolDef.ToolCallDesc(parsedArgs, nil, nil) + } + + if toolDef.ToolApproval != nil { + toolUseData.Approval = toolDef.ToolApproval(parsedArgs) + } + + if chatOpts.TabId != "" { + if argsMap, ok := parsedArgs.(map[string]any); ok { + if widgetId, ok := argsMap["widget_id"].(string); ok && widgetId != "" { + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() + fullBlockId, err := wcore.ResolveBlockIdFromPrefix(ctx, chatOpts.TabId, widgetId) + if err == nil { + toolUseData.BlockId = fullBlockId + } + } + } + } + + return toolUseData +} + + +// SendToolProgress sends tool progress updates via SSE if the tool has a progress descriptor +func SendToolProgress(toolCallID, toolName string, jsonData []byte, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, usePartialParse bool) { + toolDef := chatOpts.GetToolDefinition(toolName) + if toolDef == nil || toolDef.ToolProgressDesc == nil { + return + } + + var parsedJSON any + var err error + if usePartialParse { + parsedJSON, err = utilfn.ParsePartialJson(jsonData) + } else { + err = json.Unmarshal(jsonData, &parsedJSON) + } + if err != nil { + return + } + + statusLines, err := toolDef.ToolProgressDesc(parsedJSON) + if err != nil { + return + } + + progressData := &uctypes.UIMessageDataToolProgress{ + ToolCallId: toolCallID, + ToolName: toolName, + StatusLines: statusLines, + } + _ = sseHandler.AiMsgData("data-toolprogress", "progress-"+toolCallID, progressData) +} diff --git a/pkg/aiusechat/anthropic/anthropic-backend.go b/pkg/aiusechat/anthropic/anthropic-backend.go index 345d30bcdd..987b8c117e 100644 --- a/pkg/aiusechat/anthropic/anthropic-backend.go +++ b/pkg/aiusechat/anthropic/anthropic-backend.go @@ -56,7 +56,7 @@ func (m *anthropicChatMessage) GetUsage() *uctypes.AIUsage { } return &uctypes.AIUsage{ - APIType: "anthropic", + APIType: uctypes.APIType_AnthropicMessages, Model: m.Usage.Model, InputTokens: m.Usage.InputTokens, OutputTokens: m.Usage.OutputTokens, diff --git a/pkg/aiusechat/anthropic/anthropic-convertmessage.go b/pkg/aiusechat/anthropic/anthropic-convertmessage.go index 0daf9f99b9..e8a64f3246 100644 --- a/pkg/aiusechat/anthropic/anthropic-convertmessage.go +++ b/pkg/aiusechat/anthropic/anthropic-convertmessage.go @@ -171,7 +171,7 @@ func buildAnthropicHTTPRequest(ctx context.Context, msgs []anthropicInputMessage req.Header.Set("anthropic-version", apiVersion) req.Header.Set("accept", "text/event-stream") req.Header.Set("X-Wave-ClientId", chatOpts.ClientId) - req.Header.Set("X-Wave-APIType", "anthropic") + req.Header.Set("X-Wave-APIType", uctypes.APIType_AnthropicMessages) return req, nil } @@ -795,8 +795,8 @@ func ConvertToolResultsToAnthropicChatMessage(toolResults []uctypes.AIToolResult // ConvertAIChatToUIChat converts an AIChat to a UIChat for Anthropic func ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { - if aiChat.APIType != "anthropic" { - return nil, fmt.Errorf("APIType must be 'anthropic', got '%s'", aiChat.APIType) + if aiChat.APIType != uctypes.APIType_AnthropicMessages { + return nil, fmt.Errorf("APIType must be '%s', got '%s'", uctypes.APIType_AnthropicMessages, aiChat.APIType) } uiMessages := make([]uctypes.UIMessage, 0, len(aiChat.NativeMessages)) diff --git a/pkg/aiusechat/openai/openai-backend.go b/pkg/aiusechat/openai/openai-backend.go index cced0dd06d..eb3ac08ee2 100644 --- a/pkg/aiusechat/openai/openai-backend.go +++ b/pkg/aiusechat/openai/openai-backend.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "log" "net/http" "net/url" "strings" @@ -17,11 +16,11 @@ import ( "github.com/google/uuid" "github.com/launchdarkly/eventsource" + "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" "github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/util/logutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" - "github.com/wavetermdev/waveterm/pkg/wcore" "github.com/wavetermdev/waveterm/pkg/web/sse" ) @@ -150,7 +149,7 @@ func (m *OpenAIChatMessage) GetUsage() *uctypes.AIUsage { return nil } return &uctypes.AIUsage{ - APIType: "openai", + APIType: uctypes.APIType_OpenAIResponses, Model: m.Usage.Model, InputTokens: m.Usage.InputTokens, OutputTokens: m.Usage.OutputTokens, @@ -396,8 +395,7 @@ type openaiBlockState struct { } type openaiStreamingState struct { - blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming - toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key + blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming msgID string model string stepStarted bool @@ -407,7 +405,7 @@ type openaiStreamingState struct { // ---------- Public entrypoint ---------- -func UpdateToolUseData(chatId string, callId string, newToolUseData *uctypes.UIMessageDataToolUse) error { +func UpdateToolUseData(chatId string, callId string, newToolUseData uctypes.UIMessageDataToolUse) error { chat := chatstore.DefaultChatStore.Get(chatId) if chat == nil { return fmt.Errorf("chat not found: %s", chatId) @@ -422,7 +420,7 @@ func UpdateToolUseData(chatId string, callId string, newToolUseData *uctypes.UIM if chatMsg.FunctionCall != nil && chatMsg.FunctionCall.CallId == callId { updatedMsg := *chatMsg updatedFunctionCall := *chatMsg.FunctionCall - updatedFunctionCall.ToolUseData = newToolUseData + updatedFunctionCall.ToolUseData = &newToolUseData updatedMsg.FunctionCall = &updatedFunctionCall aiOpts := &uctypes.AIOptsType{ @@ -592,9 +590,8 @@ func parseOpenAIHTTPError(resp *http.Response) error { func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decoder *eventsource.Decoder, cont *uctypes.WaveContinueResponse, chatOpts uctypes.WaveChatOpts) (*uctypes.WaveStopReason, []*OpenAIChatMessage) { // Per-response state state := &openaiStreamingState{ - blockMap: map[string]*openaiBlockState{}, - toolUseData: map[string]*uctypes.UIMessageDataToolUse{}, - chatOpts: chatOpts, + blockMap: map[string]*openaiBlockState{}, + chatOpts: chatOpts, } var rtnStopReason *uctypes.WaveStopReason @@ -862,8 +859,7 @@ func handleOpenAIEvent( } if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse { st.partialJSON = append(st.partialJSON, []byte(ev.Delta)...) - toolDef := state.chatOpts.GetToolDefinition(st.toolName) - sendToolProgress(st, toolDef, sse, st.partialJSON, true) + aiutil.SendToolProgress(st.toolCallID, st.toolName, st.partialJSON, state.chatOpts, sse, true) } return nil, nil @@ -876,10 +872,7 @@ func handleOpenAIEvent( // Get the function call info from the block state if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse { - toolDef := state.chatOpts.GetToolDefinition(st.toolName) - toolUseData := createToolUseData(st.toolCallID, st.toolName, toolDef, ev.Arguments, state.chatOpts) - state.toolUseData[st.toolCallID] = toolUseData - sendToolProgress(st, toolDef, sse, []byte(ev.Arguments), false) + aiutil.SendToolProgress(st.toolCallID, st.toolName, []byte(ev.Arguments), state.chatOpts, sse, false) } return nil, nil @@ -936,76 +929,6 @@ func handleOpenAIEvent( } } -func sendToolProgress(st *openaiBlockState, toolDef *uctypes.ToolDefinition, sse *sse.SSEHandlerCh, jsonData []byte, usePartialParse bool) { - if toolDef == nil || toolDef.ToolProgressDesc == nil { - return - } - var parsedJSON any - var err error - if usePartialParse { - parsedJSON, err = utilfn.ParsePartialJson(jsonData) - } else { - err = json.Unmarshal(jsonData, &parsedJSON) - } - if err != nil { - return - } - statusLines, err := toolDef.ToolProgressDesc(parsedJSON) - if err != nil { - return - } - progressData := &uctypes.UIMessageDataToolProgress{ - ToolCallId: st.toolCallID, - ToolName: st.toolName, - StatusLines: statusLines, - } - _ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData) -} - -func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinition, arguments string, chatOpts uctypes.WaveChatOpts) *uctypes.UIMessageDataToolUse { - toolUseData := &uctypes.UIMessageDataToolUse{ - ToolCallId: toolCallID, - ToolName: toolName, - Status: uctypes.ToolUseStatusPending, - } - - if toolDef == nil { - toolUseData.Status = uctypes.ToolUseStatusError - toolUseData.ErrorMessage = "tool not found" - return toolUseData - } - - var parsedArgs any - if err := json.Unmarshal([]byte(arguments), &parsedArgs); err != nil { - toolUseData.Status = uctypes.ToolUseStatusError - toolUseData.ErrorMessage = fmt.Sprintf("failed to parse tool arguments: %v", err) - return toolUseData - } - - if toolDef.ToolCallDesc != nil { - toolUseData.ToolDesc = toolDef.ToolCallDesc(parsedArgs, nil, nil) - } - - if toolDef.ToolApproval != nil { - toolUseData.Approval = toolDef.ToolApproval(parsedArgs) - } - - if chatOpts.TabId != "" { - if argsMap, ok := parsedArgs.(map[string]any); ok { - if widgetId, ok := argsMap["widget_id"].(string); ok && widgetId != "" { - ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) - defer cancelFn() - fullBlockId, err := wcore.ResolveBlockIdFromPrefix(ctx, chatOpts.TabId, widgetId) - if err == nil { - toolUseData.BlockId = fullBlockId - } - } - } - } - - return toolUseData -} - // extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStreamingState) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) { var messageContent []OpenAIMessageContent @@ -1040,13 +963,6 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStream } } - // Attach UIToolUseData if available - if data, ok := state.toolUseData[outputItem.CallId]; ok { - toolCall.ToolUseData = data - } else { - log.Printf("AI no data-tooluse for %s (callid: %s)\n", outputItem.Id, outputItem.CallId) - } - toolCalls = append(toolCalls, toolCall) // Create separate FunctionCall message @@ -1054,18 +970,13 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStream if outputItem.Arguments != "" { argsStr = outputItem.Arguments } - var toolUseDataPtr *uctypes.UIMessageDataToolUse - if data, ok := state.toolUseData[outputItem.CallId]; ok { - toolUseDataPtr = data - } functionCallMsg := &OpenAIChatMessage{ MessageId: uuid.New().String(), FunctionCall: &OpenAIFunctionCallInput{ - Type: "function_call", - CallId: outputItem.CallId, - Name: outputItem.Name, - Arguments: argsStr, - ToolUseData: toolUseDataPtr, + Type: "function_call", + CallId: outputItem.CallId, + Name: outputItem.Name, + Arguments: argsStr, }, } messages = append(messages, functionCallMsg) diff --git a/pkg/aiusechat/openai/openai-convertmessage.go b/pkg/aiusechat/openai/openai-convertmessage.go index 156c635887..70b6f31aa6 100644 --- a/pkg/aiusechat/openai/openai-convertmessage.go +++ b/pkg/aiusechat/openai/openai-convertmessage.go @@ -299,7 +299,7 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes. req.Header.Set("X-Wave-ChatId", chatOpts.ChatId) } req.Header.Set("X-Wave-Version", wavebase.WaveVersion) - req.Header.Set("X-Wave-APIType", "openai") + req.Header.Set("X-Wave-APIType", uctypes.APIType_OpenAIResponses) req.Header.Set("X-Wave-RequestType", chatOpts.GetWaveRequestType()) return req, nil @@ -519,8 +519,8 @@ func (m *OpenAIChatMessage) convertToUIMessage() *uctypes.UIMessage { // ConvertAIChatToUIChat converts an AIChat to a UIChat for OpenAI func ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { - if aiChat.APIType != "openai" { - return nil, fmt.Errorf("APIType must be 'openai', got '%s'", aiChat.APIType) + if aiChat.APIType != uctypes.APIType_OpenAIResponses { + return nil, fmt.Errorf("APIType must be '%s', got '%s'", uctypes.APIType_OpenAIResponses, aiChat.APIType) } uiMessages := make([]uctypes.UIMessage, 0, len(aiChat.NativeMessages)) for i, nativeMsg := range aiChat.NativeMessages { diff --git a/pkg/aiusechat/openaichat/openaichat-backend.go b/pkg/aiusechat/openaichat/openaichat-backend.go new file mode 100644 index 0000000000..4eb6217421 --- /dev/null +++ b/pkg/aiusechat/openaichat/openaichat-backend.go @@ -0,0 +1,257 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package openaichat + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + "github.com/launchdarkly/eventsource" + "github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore" + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/web/sse" +) + +// RunChatStep executes a chat step using the chat completions API +func RunChatStep( + ctx context.Context, + sseHandler *sse.SSEHandlerCh, + chatOpts uctypes.WaveChatOpts, + cont *uctypes.WaveContinueResponse, +) (*uctypes.WaveStopReason, []*StoredChatMessage, *uctypes.RateLimitInfo, error) { + if sseHandler == nil { + return nil, nil, nil, errors.New("sse handler is nil") + } + + chat := chatstore.DefaultChatStore.Get(chatOpts.ChatId) + if chat == nil { + return nil, nil, nil, fmt.Errorf("chat not found: %s", chatOpts.ChatId) + } + + if chatOpts.Config.TimeoutMs > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(chatOpts.Config.TimeoutMs)*time.Millisecond) + defer cancel() + } + + // Convert stored messages to chat completions format + var messages []ChatRequestMessage + + // Add system prompt if provided + if len(chatOpts.SystemPrompt) > 0 { + messages = append(messages, ChatRequestMessage{ + Role: "system", + Content: strings.Join(chatOpts.SystemPrompt, "\n"), + }) + } + + // Convert native messages + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*StoredChatMessage) + if !ok { + return nil, nil, nil, fmt.Errorf("expected StoredChatMessage, got %T", genMsg) + } + messages = append(messages, *chatMsg.Message.clean()) + } + + req, err := buildChatHTTPRequest(ctx, messages, chatOpts) + if err != nil { + return nil, nil, nil, err + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, nil, nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, nil, nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + // Setup SSE if this is a new request (not a continuation) + if cont == nil { + if err := sseHandler.SetupSSE(); err != nil { + return nil, nil, nil, fmt.Errorf("failed to setup SSE: %w", err) + } + } + + // Stream processing + stopReason, assistantMsg, err := processChatStream(ctx, resp.Body, sseHandler, chatOpts, cont) + if err != nil { + return nil, nil, nil, err + } + + return stopReason, []*StoredChatMessage{assistantMsg}, nil, nil +} + +func processChatStream( + ctx context.Context, + body io.Reader, + sseHandler *sse.SSEHandlerCh, + chatOpts uctypes.WaveChatOpts, + cont *uctypes.WaveContinueResponse, +) (*uctypes.WaveStopReason, *StoredChatMessage, error) { + decoder := eventsource.NewDecoder(body) + var textBuilder strings.Builder + msgID := uuid.New().String() + textID := uuid.New().String() + var finishReason string + textStarted := false + var toolCallsInProgress []ToolCall + + if cont == nil { + _ = sseHandler.AiMsgStart(msgID) + } + _ = sseHandler.AiMsgStartStep() + + for { + if err := ctx.Err(); err != nil { + _ = sseHandler.AiMsgError("request cancelled") + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindCanceled, + ErrorType: "cancelled", + ErrorText: "request cancelled", + }, nil, err + } + + event, err := decoder.Decode() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + _ = sseHandler.AiMsgError(err.Error()) + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindError, + ErrorType: "stream", + ErrorText: err.Error(), + }, nil, fmt.Errorf("stream decode error: %w", err) + } + + data := event.Data() + if data == "[DONE]" { + break + } + + var chunk StreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + log.Printf("openaichat: failed to parse chunk: %v\n", err) + continue + } + + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0] + if choice.Delta.Content != "" { + if !textStarted { + _ = sseHandler.AiMsgTextStart(textID) + textStarted = true + } + textBuilder.WriteString(choice.Delta.Content) + _ = sseHandler.AiMsgTextDelta(textID, choice.Delta.Content) + } + + if len(choice.Delta.ToolCalls) > 0 { + for _, tcDelta := range choice.Delta.ToolCalls { + idx := tcDelta.Index + for len(toolCallsInProgress) <= idx { + toolCallsInProgress = append(toolCallsInProgress, ToolCall{}) + } + + tc := &toolCallsInProgress[idx] + if tcDelta.ID != "" { + tc.ID = tcDelta.ID + } + if tcDelta.Type != "" { + tc.Type = tcDelta.Type + } + if tcDelta.Function != nil { + if tcDelta.Function.Name != "" { + tc.Function.Name = tcDelta.Function.Name + } + if tcDelta.Function.Arguments != "" { + tc.Function.Arguments += tcDelta.Function.Arguments + } + } + } + } + + if choice.FinishReason != nil && *choice.FinishReason != "" { + finishReason = *choice.FinishReason + } + } + + stopKind := uctypes.StopKindDone + if finishReason == "length" { + stopKind = uctypes.StopKindMaxTokens + } else if finishReason == "tool_calls" { + stopKind = uctypes.StopKindToolUse + } + + var validToolCalls []ToolCall + for _, tc := range toolCallsInProgress { + if tc.ID != "" && tc.Function.Name != "" { + validToolCalls = append(validToolCalls, tc) + } + } + + var waveToolCalls []uctypes.WaveToolCall + if len(validToolCalls) > 0 { + for _, tc := range validToolCalls { + var inputJSON any + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &inputJSON); err != nil { + log.Printf("openaichat: failed to parse tool call arguments: %v\n", err) + continue + } + } + waveToolCalls = append(waveToolCalls, uctypes.WaveToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Input: inputJSON, + }) + } + } + + stopReason := &uctypes.WaveStopReason{ + Kind: stopKind, + RawReason: finishReason, + ToolCalls: waveToolCalls, + } + + assistantMsg := &StoredChatMessage{ + MessageId: msgID, + Message: ChatRequestMessage{ + Role: "assistant", + }, + } + + if len(validToolCalls) > 0 { + assistantMsg.Message.ToolCalls = validToolCalls + } else { + assistantMsg.Message.Content = textBuilder.String() + } + + if textStarted { + _ = sseHandler.AiMsgTextEnd(textID) + } + _ = sseHandler.AiMsgFinishStep() + if stopKind != uctypes.StopKindToolUse { + _ = sseHandler.AiMsgFinish(finishReason, nil) + } + + return stopReason, assistantMsg, nil +} diff --git a/pkg/aiusechat/openaichat/openaichat-convertmessage.go b/pkg/aiusechat/openaichat/openaichat-convertmessage.go new file mode 100644 index 0000000000..d26c57884f --- /dev/null +++ b/pkg/aiusechat/openaichat/openaichat-convertmessage.go @@ -0,0 +1,346 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package openaichat + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + + "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" + "github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore" + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/wavebase" +) + +const ( + OpenAIChatDefaultMaxTokens = 4096 +) + +// appendToLastUserMessage appends text to the last user message in the messages slice +func appendToLastUserMessage(messages []ChatRequestMessage, text string) { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + messages[i].Content += "\n\n" + text + break + } + } +} + +// convertToolDefinitions converts Wave ToolDefinitions to OpenAI format +// Only includes tools whose required capabilities are met +func convertToolDefinitions(waveTools []uctypes.ToolDefinition, capabilities []string) []ToolDefinition { + if len(waveTools) == 0 { + return nil + } + + openaiTools := make([]ToolDefinition, 0, len(waveTools)) + for _, waveTool := range waveTools { + if !waveTool.HasRequiredCapabilities(capabilities) { + continue + } + openaiTool := ToolDefinition{ + Type: "function", + Function: ToolFunctionDef{ + Name: waveTool.Name, + Description: waveTool.Description, + Parameters: waveTool.InputSchema, + }, + } + openaiTools = append(openaiTools, openaiTool) + } + return openaiTools +} + +// buildChatHTTPRequest creates an HTTP request for the OpenAI chat completions API +func buildChatHTTPRequest(ctx context.Context, messages []ChatRequestMessage, chatOpts uctypes.WaveChatOpts) (*http.Request, error) { + opts := chatOpts.Config + + if opts.Model == "" { + return nil, errors.New("opts.model is required") + } + if opts.BaseURL == "" { + return nil, errors.New("BaseURL is required") + } + + maxTokens := opts.MaxTokens + if maxTokens <= 0 { + maxTokens = OpenAIChatDefaultMaxTokens + } + + finalMessages := messages + if len(chatOpts.SystemPrompt) > 0 { + systemMessage := ChatRequestMessage{ + Role: "system", + Content: strings.Join(chatOpts.SystemPrompt, "\n\n"), + } + finalMessages = append([]ChatRequestMessage{systemMessage}, messages...) + } + + // injected data + if chatOpts.TabState != "" { + appendToLastUserMessage(finalMessages, chatOpts.TabState) + } + if chatOpts.PlatformInfo != "" { + appendToLastUserMessage(finalMessages, "\n"+chatOpts.PlatformInfo+"\n") + } + + reqBody := &ChatRequest{ + Model: opts.Model, + Messages: finalMessages, + Stream: true, + } + + if aiutil.IsOpenAIReasoningModel(opts.Model) { + reqBody.MaxCompletionTokens = maxTokens + } else { + reqBody.MaxTokens = maxTokens + } + + // Add tool definitions if tools capability is available and tools exist + var allTools []uctypes.ToolDefinition + if opts.HasCapability(uctypes.AICapabilityTools) { + allTools = append(allTools, chatOpts.Tools...) + allTools = append(allTools, chatOpts.TabTools...) + if len(allTools) > 0 { + reqBody.Tools = convertToolDefinitions(allTools, opts.Capabilities) + } + } + + if wavebase.IsDevMode() { + log.Printf("openaichat: model %s, messages: %d, tools: %d\n", opts.Model, len(messages), len(allTools)) + } + + buf, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, opts.BaseURL, bytes.NewReader(buf)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + if opts.APIToken != "" { + req.Header.Set("Authorization", "Bearer "+opts.APIToken) + } + req.Header.Set("Accept", "text/event-stream") + if chatOpts.ClientId != "" { + req.Header.Set("X-Wave-ClientId", chatOpts.ClientId) + } + if chatOpts.ChatId != "" { + req.Header.Set("X-Wave-ChatId", chatOpts.ChatId) + } + req.Header.Set("X-Wave-Version", wavebase.WaveVersion) + req.Header.Set("X-Wave-APIType", uctypes.APIType_OpenAIChat) + req.Header.Set("X-Wave-RequestType", chatOpts.GetWaveRequestType()) + + return req, nil +} + +// ConvertAIMessageToStoredChatMessage converts an AIMessage to StoredChatMessage +// These messages are ALWAYS role "user" +func ConvertAIMessageToStoredChatMessage(aiMsg uctypes.AIMessage) (*StoredChatMessage, error) { + if err := aiMsg.Validate(); err != nil { + return nil, fmt.Errorf("invalid AIMessage: %w", err) + } + + var textBuilder strings.Builder + firstText := true + for _, part := range aiMsg.Parts { + var partText string + + switch { + case part.Type == uctypes.AIMessagePartTypeText: + partText = part.Text + + case part.MimeType == "text/plain": + textData, err := aiutil.ExtractTextData(part.Data, part.URL) + if err != nil { + log.Printf("openaichat: error extracting text data for %s: %v\n", part.FileName, err) + continue + } + partText = aiutil.FormatAttachedTextFile(part.FileName, textData) + + case part.MimeType == "directory": + if len(part.Data) == 0 { + log.Printf("openaichat: directory listing part missing data for %s\n", part.FileName) + continue + } + partText = aiutil.FormatAttachedDirectoryListing(part.FileName, string(part.Data)) + + default: + continue + } + + if partText != "" { + if !firstText { + textBuilder.WriteString("\n\n") + } + textBuilder.WriteString(partText) + firstText = false + } + } + + return &StoredChatMessage{ + MessageId: aiMsg.MessageId, + Message: ChatRequestMessage{ + Role: "user", + Content: textBuilder.String(), + }, + }, nil +} + +// ConvertToolResultsToNativeChatMessage converts tool results to OpenAI tool messages +func ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { + if len(toolResults) == 0 { + return nil, nil + } + + messages := make([]uctypes.GenAIMessage, 0, len(toolResults)) + for _, toolResult := range toolResults { + var content string + if toolResult.ErrorText != "" { + content = fmt.Sprintf("Error: %s", toolResult.ErrorText) + } else { + content = toolResult.Text + } + + msg := &StoredChatMessage{ + MessageId: toolResult.ToolUseID, + Message: ChatRequestMessage{ + Role: "tool", + ToolCallID: toolResult.ToolUseID, + Name: toolResult.ToolName, + Content: content, + }, + } + messages = append(messages, msg) + } + + return messages, nil +} + +// ConvertAIChatToUIChat converts stored chat to UI format +func ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { + uiChat := &uctypes.UIChat{ + ChatId: aiChat.ChatId, + APIType: aiChat.APIType, + Model: aiChat.Model, + APIVersion: aiChat.APIVersion, + Messages: make([]uctypes.UIMessage, 0, len(aiChat.NativeMessages)), + } + + for _, genMsg := range aiChat.NativeMessages { + chatMsg, ok := genMsg.(*StoredChatMessage) + if !ok { + continue + } + + var parts []uctypes.UIMessagePart + + // Add text content if present + if chatMsg.Message.Content != "" { + parts = append(parts, uctypes.UIMessagePart{ + Type: "text", + Text: chatMsg.Message.Content, + }) + } + + // Add tool calls if present (assistant requesting tool use) + if len(chatMsg.Message.ToolCalls) > 0 { + for _, toolCall := range chatMsg.Message.ToolCalls { + if toolCall.Type != "function" { + continue + } + + // Only add if ToolUseData is available + if toolCall.ToolUseData != nil { + parts = append(parts, uctypes.UIMessagePart{ + Type: "data-tooluse", + ID: toolCall.ID, + Data: *toolCall.ToolUseData, + }) + } + } + } + + // Tool result messages (role "tool") are not converted to UIMessage + if chatMsg.Message.Role == "tool" && chatMsg.Message.ToolCallID != "" { + continue + } + + // Skip messages with no parts + if len(parts) == 0 { + continue + } + + uiMsg := uctypes.UIMessage{ + ID: chatMsg.MessageId, + Role: chatMsg.Message.Role, + Parts: parts, + } + + uiChat.Messages = append(uiChat.Messages, uiMsg) + } + + return uiChat, nil +} + +// GetFunctionCallInputByToolCallId searches for a tool call by ID in the chat history +func GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) *uctypes.AIFunctionCallInput { + for _, genMsg := range aiChat.NativeMessages { + chatMsg, ok := genMsg.(*StoredChatMessage) + if !ok { + continue + } + idx := chatMsg.Message.FindToolCallIndex(toolCallId) + if idx == -1 { + continue + } + toolCall := chatMsg.Message.ToolCalls[idx] + return &uctypes.AIFunctionCallInput{ + CallId: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + ToolUseData: toolCall.ToolUseData, + } + } + return nil +} + +// UpdateToolUseData updates the ToolUseData for a specific tool call in the chat history +func UpdateToolUseData(chatId string, callId string, newToolUseData uctypes.UIMessageDataToolUse) error { + chat := chatstore.DefaultChatStore.Get(chatId) + if chat == nil { + return fmt.Errorf("chat not found: %s", chatId) + } + + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*StoredChatMessage) + if !ok { + continue + } + idx := chatMsg.Message.FindToolCallIndex(callId) + if idx == -1 { + continue + } + updatedMsg := chatMsg.Copy() + updatedMsg.Message.ToolCalls[idx].ToolUseData = &newToolUseData + aiOpts := &uctypes.AIOptsType{ + APIType: chat.APIType, + Model: chat.Model, + APIVersion: chat.APIVersion, + } + return chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, updatedMsg) + } + + return fmt.Errorf("tool call with callId %s not found in chat %s", callId, chatId) +} diff --git a/pkg/aiusechat/openaichat/openaichat-types.go b/pkg/aiusechat/openaichat/openaichat-types.go new file mode 100644 index 0000000000..f0bcc41614 --- /dev/null +++ b/pkg/aiusechat/openaichat/openaichat-types.go @@ -0,0 +1,171 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package openaichat + +import ( + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" +) + +// OpenAI Chat Completions API types (simplified) + +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatRequestMessage `json:"messages"` + Stream bool `json:"stream"` + MaxTokens int `json:"max_tokens,omitempty"` // legacy + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // newer + Temperature float64 `json:"temperature,omitempty"` + Tools []ToolDefinition `json:"tools,omitempty"` // if you use tools + ToolChoice any `json:"tool_choice,omitempty"` // "auto", "none", or struct +} + +type ChatRequestMessage struct { + Role string `json:"role"` // "system","user","assistant","tool" + Content string `json:"content,omitempty"` // normal text messages + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message + ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool" + Name string `json:"name,omitempty"` // tool name on role:"tool" +} + +func (cm *ChatRequestMessage) clean() *ChatRequestMessage { + if len(cm.ToolCalls) == 0 { + return cm + } + rtn := *cm + rtn.ToolCalls = make([]ToolCall, len(cm.ToolCalls)) + for i, tc := range cm.ToolCalls { + rtn.ToolCalls[i] = *tc.clean() + } + return &rtn +} + +func (cm *ChatRequestMessage) FindToolCallIndex(toolCallId string) int { + for i, tc := range cm.ToolCalls { + if tc.ID == toolCallId { + return i + } + } + return -1 +} + +type ToolDefinition struct { + Type string `json:"type"` // "function" + Function ToolFunctionDef `json:"function"` +} + +type ToolFunctionDef struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` // or jsonschema struct +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` // "function" + Function ToolFunctionCall `json:"function"` + ToolUseData *uctypes.UIMessageDataToolUse `json:"toolusedata,omitempty"` // Internal field (must be cleaned before sending to API) +} + +func (tc *ToolCall) clean() *ToolCall { + if tc.ToolUseData == nil { + return tc + } + rtn := *tc + rtn.ToolUseData = nil + return &rtn +} + +type ToolFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` // raw JSON string +} + +type StreamChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []StreamChoice `json:"choices"` +} + +type StreamChoice struct { + Index int `json:"index"` + Delta ContentDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` // "stop", "length" | "tool_calls" | "content_filter" +} + +// This is the important part: +type ContentDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"` +} + +type ToolCallDelta struct { + Index int `json:"index"` + ID string `json:"id,omitempty"` // only on first chunk + Type string `json:"type,omitempty"` // "function" + Function *ToolFunctionDelta `json:"function,omitempty"` +} + +type ToolFunctionDelta struct { + Name string `json:"name,omitempty"` // only on first chunk + Arguments string `json:"arguments,omitempty"` // streamed, append across chunks +} + +// StoredChatMessage is the stored message type +type StoredChatMessage struct { + MessageId string `json:"messageid"` + Message ChatRequestMessage `json:"message"` + Usage *ChatUsage `json:"usage,omitempty"` +} + +type ChatUsage struct { + Model string `json:"model,omitempty"` + InputTokens int `json:"prompt_tokens,omitempty"` + OutputTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +func (m *StoredChatMessage) GetMessageId() string { + return m.MessageId +} + +func (m *StoredChatMessage) GetRole() string { + return m.Message.Role +} + +func (m *StoredChatMessage) GetUsage() *uctypes.AIUsage { + if m.Usage == nil { + return nil + } + return &uctypes.AIUsage{ + APIType: uctypes.APIType_OpenAIChat, + Model: m.Usage.Model, + InputTokens: m.Usage.InputTokens, + OutputTokens: m.Usage.OutputTokens, + } +} + +func (m *StoredChatMessage) Copy() *StoredChatMessage { + if m == nil { + return nil + } + copied := *m + if len(m.Message.ToolCalls) > 0 { + copied.Message.ToolCalls = make([]ToolCall, len(m.Message.ToolCalls)) + for i, tc := range m.Message.ToolCalls { + copied.Message.ToolCalls[i] = tc + if tc.ToolUseData != nil { + toolUseDataCopy := *tc.ToolUseData + copied.Message.ToolCalls[i].ToolUseData = &toolUseDataCopy + } + } + } + if m.Usage != nil { + usageCopy := *m.Usage + copied.Usage = &usageCopy + } + return &copied +} diff --git a/pkg/aiusechat/tools_readdir.go b/pkg/aiusechat/tools_readdir.go index 4b90d664c0..da7d568f84 100644 --- a/pkg/aiusechat/tools_readdir.go +++ b/pkg/aiusechat/tools_readdir.go @@ -6,6 +6,7 @@ package aiusechat import ( "fmt" "os" + "path/filepath" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/util/fileutil" @@ -63,6 +64,10 @@ func verifyReadDirInput(input any, toolUseData *uctypes.UIMessageDataToolUse) er return fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return fmt.Errorf("path must be absolute, got relative path: %s", params.Path) + } + fileInfo, err := os.Stat(expandedPath) if err != nil { return fmt.Errorf("failed to stat path: %w", err) @@ -81,6 +86,15 @@ func readDirCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, return nil, err } + expandedPath, err := wavebase.ExpandHomeDir(params.Path) + if err != nil { + return nil, fmt.Errorf("failed to expand path: %w", err) + } + + if !filepath.IsAbs(expandedPath) { + return nil, fmt.Errorf("path must be absolute, got relative path: %s", params.Path) + } + result, err := fileutil.ReadDir(params.Path, *params.MaxEntries) if err != nil { return nil, err @@ -118,7 +132,7 @@ func GetReadDirToolDefinition() uctypes.ToolDefinition { "properties": map[string]any{ "path": map[string]any{ "type": "string", - "description": "Path to the directory to read. Supports '~' for the user's home directory.", + "description": "Absolute path to the directory to read. Supports '~' for the user's home directory. Relative paths are not supported.", }, "max_entries": map[string]any{ "type": "integer", diff --git a/pkg/aiusechat/tools_readfile.go b/pkg/aiusechat/tools_readfile.go index 423333c831..eecc2385b6 100644 --- a/pkg/aiusechat/tools_readfile.go +++ b/pkg/aiusechat/tools_readfile.go @@ -208,6 +208,10 @@ func verifyReadTextFileInput(input any, toolUseData *uctypes.UIMessageDataToolUs return fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + if blocked, reason := isBlockedFile(expandedPath); blocked { return fmt.Errorf("access denied: potentially sensitive file: %s", reason) } @@ -237,6 +241,10 @@ func readTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) return nil, fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return nil, fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + if blocked, reason := isBlockedFile(expandedPath); blocked { return nil, fmt.Errorf("access denied: potentially sensitive file: %s", reason) } @@ -328,7 +336,7 @@ func GetReadTextFileToolDefinition() uctypes.ToolDefinition { "properties": map[string]any{ "filename": map[string]any{ "type": "string", - "description": "Path to the file to read. Supports '~' for the user's home directory.", + "description": "Absolute path to the file to read. Supports '~' for the user's home directory. Relative paths are not supported.", }, "origin": map[string]any{ "type": "string", diff --git a/pkg/aiusechat/tools_screenshot.go b/pkg/aiusechat/tools_screenshot.go index 4c924db292..9df5a18f0e 100644 --- a/pkg/aiusechat/tools_screenshot.go +++ b/pkg/aiusechat/tools_screenshot.go @@ -67,6 +67,7 @@ func GetCaptureScreenshotToolDefinition(tabId string) uctypes.ToolDefinition { "required": []string{"widget_id"}, "additionalProperties": false, }, + RequiredCapabilities: []string{uctypes.AICapabilityImages}, ToolCallDesc: func(input any, output any, toolUseData *uctypes.UIMessageDataToolUse) string { inputMap, ok := input.(map[string]any) if !ok { diff --git a/pkg/aiusechat/tools_writefile.go b/pkg/aiusechat/tools_writefile.go index 2c830fd64c..d554cfab09 100644 --- a/pkg/aiusechat/tools_writefile.go +++ b/pkg/aiusechat/tools_writefile.go @@ -112,6 +112,10 @@ func verifyWriteTextFileInput(input any, toolUseData *uctypes.UIMessageDataToolU return fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + contentsBytes := []byte(params.Contents) if utilfn.HasBinaryData(contentsBytes) { return fmt.Errorf("contents appear to contain binary data") @@ -137,6 +141,10 @@ func writeTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) return nil, fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return nil, fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + contentsBytes := []byte(params.Contents) if utilfn.HasBinaryData(contentsBytes) { return nil, fmt.Errorf("contents appear to contain binary data") @@ -184,7 +192,7 @@ func GetWriteTextFileToolDefinition() uctypes.ToolDefinition { "properties": map[string]any{ "filename": map[string]any{ "type": "string", - "description": "Path to the file to write. Supports '~' for the user's home directory.", + "description": "Absolute path to the file to write. Supports '~' for the user's home directory. Relative paths are not supported.", }, "contents": map[string]any{ "type": "string", @@ -247,6 +255,10 @@ func verifyEditTextFileInput(input any, toolUseData *uctypes.UIMessageDataToolUs return fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + _, err = validateTextFile(expandedPath, "edit", true) if err != nil { return err @@ -269,6 +281,10 @@ func EditTextFileDryRun(input any, fileOverride string) ([]byte, []byte, error) return nil, nil, fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return nil, nil, fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + _, err = validateTextFile(expandedPath, "edit", true) if err != nil { return nil, nil, err @@ -303,6 +319,10 @@ func editTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) return nil, fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return nil, fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + _, err = validateTextFile(expandedPath, "edit", true) if err != nil { return nil, err @@ -340,7 +360,7 @@ func GetEditTextFileToolDefinition() uctypes.ToolDefinition { "properties": map[string]any{ "filename": map[string]any{ "type": "string", - "description": "Path to the file to edit. Supports '~' for the user's home directory.", + "description": "Absolute path to the file to edit. Supports '~' for the user's home directory. Relative paths are not supported.", }, "edits": map[string]any{ "type": "array", @@ -422,6 +442,10 @@ func verifyDeleteTextFileInput(input any, toolUseData *uctypes.UIMessageDataTool return fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + _, err = validateTextFile(expandedPath, "delete", true) if err != nil { return err @@ -442,6 +466,10 @@ func deleteTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse return nil, fmt.Errorf("failed to expand path: %w", err) } + if !filepath.IsAbs(expandedPath) { + return nil, fmt.Errorf("path must be absolute, got relative path: %s", params.Filename) + } + _, err = validateTextFile(expandedPath, "delete", true) if err != nil { return nil, err @@ -476,7 +504,7 @@ func GetDeleteTextFileToolDefinition() uctypes.ToolDefinition { "properties": map[string]any{ "filename": map[string]any{ "type": "string", - "description": "Path to the file to delete. Supports '~' for the user's home directory.", + "description": "Absolute path to the file to delete. Supports '~' for the user's home directory. Relative paths are not supported.", }, }, "required": []string{"filename"}, diff --git a/pkg/aiusechat/uctypes/usechat-types.go b/pkg/aiusechat/uctypes/uctypes.go similarity index 84% rename from pkg/aiusechat/uctypes/usechat-types.go rename to pkg/aiusechat/uctypes/uctypes.go index 4154ceacf0..9fc6a73f53 100644 --- a/pkg/aiusechat/uctypes/usechat-types.go +++ b/pkg/aiusechat/uctypes/uctypes.go @@ -6,6 +6,7 @@ package uctypes import ( "fmt" "net/url" + "slices" "strings" ) @@ -14,6 +15,12 @@ const DefaultAnthropicModel = "claude-sonnet-4-5" const DefaultOpenAIModel = "gpt-5-mini" const PremiumOpenAIModel = "gpt-5.1" +const ( + APIType_AnthropicMessages = "anthropic-messages" + APIType_OpenAIResponses = "openai-responses" + APIType_OpenAIChat = "openai-chat" +) + type UseChatRequest struct { Messages []UIMessage `json:"messages"` } @@ -78,13 +85,14 @@ type UIMessageDataUserFile struct { // ToolDefinition represents a tool that can be used by the AI model type ToolDefinition struct { - Name string `json:"name"` - DisplayName string `json:"displayname,omitempty"` // internal field (cannot marshal to API, must be stripped) - Description string `json:"description"` - ShortDescription string `json:"shortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped) - ToolLogName string `json:"-"` // short name for telemetry (e.g., "term:getscrollback") - InputSchema map[string]any `json:"input_schema"` - Strict bool `json:"strict,omitempty"` + Name string `json:"name"` + DisplayName string `json:"displayname,omitempty"` // internal field (cannot marshal to API, must be stripped) + Description string `json:"description"` + ShortDescription string `json:"shortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped) + ToolLogName string `json:"-"` // short name for telemetry (e.g., "term:getscrollback") + InputSchema map[string]any `json:"input_schema"` + Strict bool `json:"strict,omitempty"` + RequiredCapabilities []string `json:"requiredcapabilities,omitempty"` ToolTextCallback func(any) (string, error) `json:"-"` ToolAnyCallback func(any, *UIMessageDataToolUse) (any, error) `json:"-"` // *UIMessageDataToolUse will NOT be nil @@ -114,6 +122,18 @@ func (td *ToolDefinition) Desc() string { return td.Description } +func (td *ToolDefinition) HasRequiredCapabilities(capabilities []string) bool { + if td == nil || len(td.RequiredCapabilities) == 0 { + return true + } + for _, reqCap := range td.RequiredCapabilities { + if !slices.Contains(capabilities, reqCap) { + return false + } + } + return true +} + //------------------ // Wave specific types, stop reasons, tool calls, config // these are used internally to coordinate the calls/steps @@ -125,9 +145,9 @@ const ( ) const ( - ThinkingModeQuick = "quick" - ThinkingModeBalanced = "balanced" - ThinkingModeDeep = "deep" + AIModeQuick = "waveai@quick" + AIModeBalanced = "waveai@balanced" + AIModeDeep = "waveai@deep" ) const ( @@ -136,6 +156,12 @@ const ( ToolUseStatusCompleted = "completed" ) +const ( + AICapabilityTools = "tools" + AICapabilityImages = "images" + AICapabilityPdfs = "pdfs" +) + const ( ApprovalNeedsApproval = "needs-approval" ApprovalUserApproved = "user-approved" @@ -144,6 +170,28 @@ const ( ApprovalAutoApproved = "auto-approved" ) +type AIModeConfig struct { + Mode string `json:"mode"` + DisplayName string `json:"display:name"` + DisplayOrder float64 `json:"display:order,omitempty"` + DisplayIcon string `json:"display:icon"` + APIType string `json:"apitype"` + Model string `json:"model"` + ThinkingLevel string `json:"thinkinglevel"` + BaseURL string `json:"baseurl,omitempty"` + WaveAICloud bool `json:"waveaicloud,omitempty"` + APIVersion string `json:"apiversion,omitempty"` + APIToken string `json:"apitoken,omitempty"` + APITokenSecretName string `json:"apitokensecretname,omitempty"` + Premium bool `json:"premium"` + Description string `json:"description"` + Capabilities []string `json:"capabilities,omitempty"` +} + +func (c *AIModeConfig) HasCapability(cap string) bool { + return slices.Contains(c.Capabilities, cap) +} + // when updating this struct, also modify frontend/app/aipanel/aitypes.ts WaveUIDataTypes.tooluse type UIMessageDataToolUse struct { ToolCallId string `json:"toolcallid"` @@ -206,17 +254,18 @@ type WaveContinueResponse struct { // Wave Specific AI opts for configuration type AIOptsType struct { - APIType string `json:"apitype,omitempty"` - Model string `json:"model"` - APIToken string `json:"apitoken"` - OrgID string `json:"orgid,omitempty"` - APIVersion string `json:"apiversion,omitempty"` - BaseURL string `json:"baseurl,omitempty"` - ProxyURL string `json:"proxyurl,omitempty"` - MaxTokens int `json:"maxtokens,omitempty"` - TimeoutMs int `json:"timeoutms,omitempty"` - ThinkingLevel string `json:"thinkinglevel,omitempty"` // ThinkingLevelLow, ThinkingLevelMedium, or ThinkingLevelHigh - ThinkingMode string `json:"thinkingmode,omitempty"` // quick, balanced, or deep + APIType string `json:"apitype,omitempty"` + Model string `json:"model"` + APIToken string `json:"apitoken"` + OrgID string `json:"orgid,omitempty"` + APIVersion string `json:"apiversion,omitempty"` + BaseURL string `json:"baseurl,omitempty"` + ProxyURL string `json:"proxyurl,omitempty"` + MaxTokens int `json:"maxtokens,omitempty"` + TimeoutMs int `json:"timeoutms,omitempty"` + ThinkingLevel string `json:"thinkinglevel,omitempty"` // ThinkingLevelLow, ThinkingLevelMedium, or ThinkingLevelHigh + AIMode string `json:"aimode,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` } func (opts AIOptsType) IsWaveProxy() bool { @@ -227,6 +276,10 @@ func (opts AIOptsType) IsPremiumModel() bool { return opts.Model == "gpt-5" || opts.Model == "gpt-5.1" || strings.Contains(opts.Model, "claude-sonnet") } +func (opts AIOptsType) HasCapability(cap string) bool { + return slices.Contains(opts.Capabilities, cap) +} + type AIChat struct { ChatId string `json:"chatid"` APIType string `json:"apitype"` @@ -262,7 +315,7 @@ type AIMetrics struct { RequestDuration int `json:"requestduration"` // ms WidgetAccess bool `json:"widgetaccess"` ThinkingLevel string `json:"thinkinglevel,omitempty"` - ThinkingMode string `json:"thinkingmode,omitempty"` + AIMode string `json:"aimode,omitempty"` } type AIFunctionCallInput struct { @@ -559,7 +612,7 @@ func AreModelsCompatible(apiType, model1, model2 string) bool { return true } - if apiType == "openai" { + if apiType == APIType_OpenAIResponses { gpt5Models := map[string]bool{ "gpt-5.1": true, "gpt-5": true, diff --git a/pkg/aiusechat/usechat-backend.go b/pkg/aiusechat/usechat-backend.go index adebb11282..528cd3af5c 100644 --- a/pkg/aiusechat/usechat-backend.go +++ b/pkg/aiusechat/usechat-backend.go @@ -9,6 +9,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/aiusechat/anthropic" "github.com/wavetermdev/waveterm/pkg/aiusechat/openai" + "github.com/wavetermdev/waveterm/pkg/aiusechat/openaichat" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/web/sse" ) @@ -28,7 +29,7 @@ type UseChatBackend interface { // UpdateToolUseData updates the tool use data for a specific tool call in the chat. // This is used to update the UI state for tool execution (approval status, results, etc.) - UpdateToolUseData(chatId string, toolCallId string, toolUseData *uctypes.UIMessageDataToolUse) error + UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error // ConvertToolResultsToNativeChatMessage converts tool execution results into native chat messages // that can be sent back to the AI backend. Returns a slice of messages (some backends may @@ -51,14 +52,17 @@ type UseChatBackend interface { // Compile-time interface checks var _ UseChatBackend = (*openaiResponsesBackend)(nil) +var _ UseChatBackend = (*openaiCompletionsBackend)(nil) var _ UseChatBackend = (*anthropicBackend)(nil) // GetBackendByAPIType returns the appropriate UseChatBackend implementation for the given API type func GetBackendByAPIType(apiType string) (UseChatBackend, error) { switch apiType { - case APIType_OpenAI: + case uctypes.APIType_OpenAIResponses: return &openaiResponsesBackend{}, nil - case APIType_Anthropic: + case uctypes.APIType_OpenAIChat: + return &openaiCompletionsBackend{}, nil + case uctypes.APIType_AnthropicMessages: return &anthropicBackend{}, nil default: return nil, fmt.Errorf("unsupported API type: %s", apiType) @@ -82,7 +86,7 @@ func (b *openaiResponsesBackend) RunChatStep( return stopReason, genMsgs, rateLimitInfo, err } -func (b *openaiResponsesBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData *uctypes.UIMessageDataToolUse) error { +func (b *openaiResponsesBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error { return openai.UpdateToolUseData(chatId, toolCallId, toolUseData) } @@ -119,6 +123,43 @@ func (b *openaiResponsesBackend) ConvertAIChatToUIChat(aiChat uctypes.AIChat) (* return openai.ConvertAIChatToUIChat(aiChat) } +// openaiCompletionsBackend implements UseChatBackend for OpenAI Completions API +type openaiCompletionsBackend struct{} + +func (b *openaiCompletionsBackend) RunChatStep( + ctx context.Context, + sseHandler *sse.SSEHandlerCh, + chatOpts uctypes.WaveChatOpts, + cont *uctypes.WaveContinueResponse, +) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, *uctypes.RateLimitInfo, error) { + stopReason, msgs, rateLimitInfo, err := openaichat.RunChatStep(ctx, sseHandler, chatOpts, cont) + var genMsgs []uctypes.GenAIMessage + for _, msg := range msgs { + genMsgs = append(genMsgs, msg) + } + return stopReason, genMsgs, rateLimitInfo, err +} + +func (b *openaiCompletionsBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error { + return openaichat.UpdateToolUseData(chatId, toolCallId, toolUseData) +} + +func (b *openaiCompletionsBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { + return openaichat.ConvertToolResultsToNativeChatMessage(toolResults) +} + +func (b *openaiCompletionsBackend) ConvertAIMessageToNativeChatMessage(message uctypes.AIMessage) (uctypes.GenAIMessage, error) { + return openaichat.ConvertAIMessageToStoredChatMessage(message) +} + +func (b *openaiCompletionsBackend) GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) *uctypes.AIFunctionCallInput { + return openaichat.GetFunctionCallInputByToolCallId(aiChat, toolCallId) +} + +func (b *openaiCompletionsBackend) ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { + return openaichat.ConvertAIChatToUIChat(aiChat) +} + // anthropicBackend implements UseChatBackend for Anthropic API type anthropicBackend struct{} @@ -132,7 +173,7 @@ func (b *anthropicBackend) RunChatStep( return stopReason, []uctypes.GenAIMessage{msg}, rateLimitInfo, err } -func (b *anthropicBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData *uctypes.UIMessageDataToolUse) error { +func (b *anthropicBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error { return fmt.Errorf("UpdateToolUseData not implemented for anthropic backend") } diff --git a/pkg/aiusechat/usechat-mode.go b/pkg/aiusechat/usechat-mode.go new file mode 100644 index 0000000000..fe5bd2d786 --- /dev/null +++ b/pkg/aiusechat/usechat-mode.go @@ -0,0 +1,43 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package aiusechat + +import ( + "fmt" + + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/wconfig" +) + +func resolveAIMode(requestedMode string, premium bool) (string, *wconfig.AIModeConfigType, error) { + mode := requestedMode + if mode == "" { + mode = uctypes.AIModeBalanced + } + + config, err := getAIModeConfig(mode) + if err != nil { + return "", nil, err + } + + if config.WaveAICloud && !premium { + mode = uctypes.AIModeQuick + config, err = getAIModeConfig(mode) + if err != nil { + return "", nil, err + } + } + + return mode, config, nil +} + +func getAIModeConfig(aiMode string) (*wconfig.AIModeConfigType, error) { + fullConfig := wconfig.GetWatcher().GetFullConfig() + config, ok := fullConfig.WaveAIModes[aiMode] + if !ok { + return nil, fmt.Errorf("invalid AI mode: %s", aiMode) + } + + return &config, nil +} diff --git a/pkg/aiusechat/usechat-prompts.go b/pkg/aiusechat/usechat-prompts.go new file mode 100644 index 0000000000..b8bcb7aa03 --- /dev/null +++ b/pkg/aiusechat/usechat-prompts.go @@ -0,0 +1,61 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package aiusechat + +import "strings" + +var SystemPromptText = strings.Join([]string{ + `You are Wave AI, an intelligent assistant embedded within Wave Terminal, a modern terminal application with graphical widgets.`, + `You appear as a pull-out panel on the left side of a tab, with the tab's widgets laid out on the right.`, + `Widget context is provided as informational only.`, + `Do NOT assume any API access or ability to interact with the widgets except via tools provided (note that some widgets may expose NO tools, so their context is informational only).`, +}, " ") + +var SystemPromptText_OpenAI = strings.Join([]string{ + `You are Wave AI, an assistant embedded in Wave Terminal (a terminal with graphical widgets).`, + `You appear as a pull-out panel on the left; widgets are on the right.`, + + // Capabilities & truthfulness + `Tools define your only capabilities. If a capability is not provided by a tool, you cannot do it. Never fabricate data or pretend to call tools. If you lack data or access, say so directly and suggest the next best step.`, + `Use read-only tools (capture_screenshot, read_text_file, read_dir, term_get_scrollback) automatically whenever they help answer the user's request. When a user clearly expresses intent to modify something (write/edit/delete files), call the corresponding tool directly.`, + + // Crisp behavior + `Be concise and direct. Prefer determinism over speculation. If a brief clarifying question eliminates guesswork, ask it.`, + + // Attached text files + `User-attached text files may appear inline as \ncontent\n.`, + `User-attached directories use the tag JSON DirInfo.`, + `If multiple attached files exist, treat each as a separate source file with its own file_name.`, + `When the user refers to these files, use their inline content directly; do NOT call any read_text_file or file-access tools to re-read them unless asked.`, + + // Output & formatting + `When presenting commands or any runnable multi-line code, always use fenced Markdown code blocks.`, + `Use an appropriate language hint after the opening fence (e.g., "bash" for shell commands, "go" for Go, "json" for JSON).`, + `For shell commands, do NOT prefix lines with "$" or shell prompts. Use placeholders in ALL_CAPS (e.g., PROJECT_ID) and explain them once after the block if needed.`, + "Reserve inline code (single backticks) for short references like command names (`grep`, `less`), flags, env vars, file paths, or tiny snippets not meant to be executed.", + `You may use Markdown (lists, tables, bold/italics) to improve readability.`, + `Never comment on or justify your formatting choices; just follow these rules.`, + `When generating code or command blocks, try to keep lines under ~100 characters wide where practical (soft wrap; do not break tokens mid-word). Favor indentation and short variable names to stay compact, but correctness always takes priority.`, + + // Safety & limits + `If a request would execute dangerous or destructive actions, warn briefly and provide a safer alternative.`, + `If output is very long, prefer a brief summary plus a copy-ready fenced block or offer a follow-up chunking strategy.`, + + `You can write and edit local files on disk using available tools, but you cannot execute shell commands, run code in the terminal, or read/write remote files.`, + `If the user asks you to execute commands or run code, or deal with remote files say that these features aren't available yet AND make sure to say that they are coming soon (stay tuned for updates).`, + `Instead, show them exactly what command or code they could copy-paste to run manually.`, + + // Final reminder + `You have NO API access to widgets or Wave unless provided via an explicit tool.`, +}, " ") + +var SystemPromptText_StrictToolAddOn = `## Tool Call Rules (STRICT) + +When you decide a file write/edit tool call is needed: + +- Output ONLY the tool call. +- Do NOT include any explanation, summary, or file content in the chat. +- Do NOT echo the file content before or after the tool call. +- After the tool call result is returned, respond ONLY with what the user directly asked for. If they did not ask to see the file content, do NOT show it. +` diff --git a/pkg/aiusechat/usechat.go b/pkg/aiusechat/usechat.go index e5866bcaf4..477b0001c4 100644 --- a/pkg/aiusechat/usechat.go +++ b/pkg/aiusechat/usechat.go @@ -12,13 +12,16 @@ import ( "net/http" "os" "os/user" + "regexp" "strings" "sync" "time" "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" "github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/secretstore" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata" "github.com/wavetermdev/waveterm/pkg/util/ds" @@ -32,14 +35,10 @@ import ( "github.com/wavetermdev/waveterm/pkg/wstore" ) -const ( - APIType_Anthropic = "anthropic" - APIType_OpenAI = "openai" -) - -const DefaultAPI = APIType_OpenAI +const DefaultAPI = uctypes.APIType_OpenAIResponses const DefaultMaxTokens = 4 * 1024 const BuilderMaxTokens = 24 * 1024 +const WaveAIEndpointEnvName = "WAVETERM_WAVEAI_ENDPOINT" var ( globalRateLimitInfo = &uctypes.RateLimitInfo{Unknown: true} @@ -49,112 +48,68 @@ var ( activeChats = ds.MakeSyncMap[bool]() // key is chatid ) -var SystemPromptText = strings.Join([]string{ - `You are Wave AI, an intelligent assistant embedded within Wave Terminal, a modern terminal application with graphical widgets.`, - `You appear as a pull-out panel on the left side of a tab, with the tab's widgets laid out on the right.`, - `Widget context is provided as informationa only.`, - `Do NOT assume any API access or ability to interact with the widgets except via tools provided (note that some widgets may expose NO tools, so their context is informational only).`, -}, " ") - -var SystemPromptText_OpenAI = strings.Join([]string{ - `You are Wave AI, an assistant embedded in Wave Terminal (a terminal with graphical widgets).`, - `You appear as a pull-out panel on the left; widgets are on the right.`, - - // Capabilities & truthfulness - `Tools define your only capabilities. If a capability is not provided by a tool, you cannot do it.`, - `Context from widgets is read-only unless a tool explicitly grants interaction.`, - `Never fabricate data. If you lack data or access, say so and offer the next best step (e.g., suggest enabling a tool).`, - - // Crisp behavior - `Be concise and direct. Prefer determinism over speculation. If a brief clarifying question eliminates guesswork, ask it.`, - - // Attached text files - `User-attached text files may appear inline as \ncontent\n.`, - `User-attached directories use the tag JSON DirInfo.`, - `If multiple attached files exist, treat each as a separate source file with its own file_name.`, - `When the user refers to these files, use their inline content directly; do NOT call any read_text_file or file-access tools to re-read them unless asked.`, - - // Output & formatting - `When presenting commands or any runnable multi-line code, always use fenced Markdown code blocks.`, - `Use an appropriate language hint after the opening fence (e.g., "bash" for shell commands, "go" for Go, "json" for JSON).`, - `For shell commands, do NOT prefix lines with "$" or shell prompts. Use placeholders in ALL_CAPS (e.g., PROJECT_ID) and explain them once after the block if needed.`, - "Reserve inline code (single backticks) for short references like command names (`grep`, `less`), flags, env vars, file paths, or tiny snippets not meant to be executed.", - `You may use Markdown (lists, tables, bold/italics) to improve readability.`, - `Never comment on or justify your formatting choices; just follow these rules.`, - `When generating code or command blocks, try to keep lines under ~100 characters wide where practical (soft wrap; do not break tokens mid-word). Favor indentation and short variable names to stay compact, but correctness always takes priority.`, - - // Safety & limits - `If a request would execute dangerous or destructive actions, warn briefly and provide a safer alternative.`, - `If output is very long, prefer a brief summary plus a copy-ready fenced block or offer a follow-up chunking strategy.`, - - `You can write and edit local files on disk using available tools, but you cannot execute shell commands, run code in the terminal, or read/write remote files.`, - `If the user asks you to execute commands or run code, or deal with remote files say that these features aren't available yet AND make sure to say that they are coming soon (stay tuned for updates).`, - `Instead, show them exactly what command or code they could copy-paste to run manually.`, - - // Final reminder - `You have NO API access to widgets or Wave unless provided via an explicit tool.`, -}, " ") - -func getWaveAISettings(premium bool, builderMode bool, rtInfo *waveobj.ObjRTInfo) (*uctypes.AIOptsType, error) { - baseUrl := uctypes.DefaultAIEndpoint - if os.Getenv("WAVETERM_WAVEAI_ENDPOINT") != "" { - baseUrl = os.Getenv("WAVETERM_WAVEAI_ENDPOINT") +func getSystemPrompt(apiType string, model string, isBuilder bool) []string { + if isBuilder { + return []string{} + } + basePrompt := SystemPromptText_OpenAI + modelLower := strings.ToLower(model) + needsStrictToolAddOn, _ := regexp.MatchString(`(?i)\b(mistral|o?llama|qwen|mixtral|yi|phi|deepseek)\b`, modelLower) + if needsStrictToolAddOn { + return []string{basePrompt, SystemPromptText_StrictToolAddOn} } + return []string{basePrompt} +} + +func getWaveAISettings(premium bool, builderMode bool, rtInfo waveobj.ObjRTInfo) (*uctypes.AIOptsType, error) { maxTokens := DefaultMaxTokens if builderMode { maxTokens = BuilderMaxTokens } - if rtInfo != nil && rtInfo.WaveAIMaxOutputTokens > 0 { + if rtInfo.WaveAIMaxOutputTokens > 0 { maxTokens = rtInfo.WaveAIMaxOutputTokens } - var thinkingMode string - if premium { - thinkingMode = uctypes.ThinkingModeBalanced - if rtInfo != nil && rtInfo.WaveAIThinkingMode != "" { - thinkingMode = rtInfo.WaveAIThinkingMode + aiMode, config, err := resolveAIMode(rtInfo.WaveAIMode, premium) + if err != nil { + return nil, err + } + apiToken := config.APIToken + if apiToken == "" && config.APITokenSecretName != "" { + secret, exists, err := secretstore.GetSecret(config.APITokenSecretName) + if err != nil { + return nil, fmt.Errorf("failed to retrieve secret %s: %w", config.APITokenSecretName, err) + } + if !exists || secret == "" { + return nil, fmt.Errorf("secret %s not found or empty", config.APITokenSecretName) + } + apiToken = secret + } + + var baseUrl string + if config.WaveAICloud { + baseUrl = uctypes.DefaultAIEndpoint + if os.Getenv(WaveAIEndpointEnvName) != "" { + baseUrl = os.Getenv(WaveAIEndpointEnvName) } + } else if config.BaseURL != "" { + baseUrl = config.BaseURL } else { - thinkingMode = uctypes.ThinkingModeQuick - } - if DefaultAPI == APIType_Anthropic { - thinkingLevel := uctypes.ThinkingLevelMedium - return &uctypes.AIOptsType{ - APIType: APIType_Anthropic, - Model: uctypes.DefaultAnthropicModel, - MaxTokens: maxTokens, - ThinkingLevel: thinkingLevel, - ThinkingMode: thinkingMode, - BaseURL: baseUrl, - }, nil - } else if DefaultAPI == APIType_OpenAI { - var model string - var thinkingLevel string - - switch thinkingMode { - case uctypes.ThinkingModeQuick: - model = uctypes.DefaultOpenAIModel - thinkingLevel = uctypes.ThinkingLevelLow - case uctypes.ThinkingModeBalanced: - model = uctypes.PremiumOpenAIModel - thinkingLevel = uctypes.ThinkingLevelLow - case uctypes.ThinkingModeDeep: - model = uctypes.PremiumOpenAIModel - thinkingLevel = uctypes.ThinkingLevelMedium - default: - model = uctypes.PremiumOpenAIModel - thinkingLevel = uctypes.ThinkingLevelLow - } - - return &uctypes.AIOptsType{ - APIType: APIType_OpenAI, - Model: model, - MaxTokens: maxTokens, - ThinkingLevel: thinkingLevel, - ThinkingMode: thinkingMode, - BaseURL: baseUrl, - }, nil - } - return nil, fmt.Errorf("invalid API type: %s", DefaultAPI) + return nil, fmt.Errorf("no BaseURL configured for AI mode %s", aiMode) + } + + opts := &uctypes.AIOptsType{ + APIType: config.APIType, + Model: config.Model, + MaxTokens: maxTokens, + ThinkingLevel: config.ThinkingLevel, + AIMode: aiMode, + BaseURL: baseUrl, + Capabilities: config.Capabilities, + } + if apiToken != "" { + opts.APIToken = apiToken + } + return opts, nil } func shouldUseChatCompletionsAPI(model string) bool { @@ -203,7 +158,7 @@ func GetGlobalRateLimit() *uctypes.RateLimitInfo { } func runAIChatStep(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseChatBackend, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, error) { - if chatOpts.Config.APIType == APIType_OpenAI && shouldUseChatCompletionsAPI(chatOpts.Config.Model) { + if chatOpts.Config.APIType == uctypes.APIType_OpenAIResponses && shouldUseChatCompletionsAPI(chatOpts.Config.Model) { return nil, nil, fmt.Errorf("Chat completions API not available (must use newer OpenAI models)") } stopReason, messages, rateLimitInfo, err := backend.RunChatStep(ctx, sseHandler, chatOpts, cont) @@ -236,7 +191,7 @@ func GetChatUsage(chat *uctypes.AIChat) uctypes.AIUsage { return usage } -func updateToolUseDataInChat(backend UseChatBackend, chatOpts uctypes.WaveChatOpts, toolCallID string, toolUseData *uctypes.UIMessageDataToolUse) { +func updateToolUseDataInChat(backend UseChatBackend, chatOpts uctypes.WaveChatOpts, toolCallID string, toolUseData uctypes.UIMessageDataToolUse) { if err := backend.UpdateToolUseData(chatOpts.ChatId, toolCallID, toolUseData); err != nil { log.Printf("failed to update tool use data in chat: %v\n", err) } @@ -276,7 +231,7 @@ func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCa } // ToolVerifyInput can modify the toolusedata. re-send it here. _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) - updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData) + updateToolUseDataInChat(backend, chatOpts, toolCall.ID, *toolCall.ToolUseData) } if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval { @@ -305,7 +260,7 @@ func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCa // this still happens here because we need to update the FE to say the tool call was approved _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) - updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData) + updateToolUseDataInChat(backend, chatOpts, toolCall.ID, *toolCall.ToolUseData) } toolCall.ToolUseData.RunTs = time.Now().UnixMilli() @@ -341,7 +296,7 @@ func processToolCall(backend UseChatBackend, toolCall uctypes.WaveToolCall, chat if toolCall.ToolUseData != nil { _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) - updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData) + updateToolUseDataInChat(backend, chatOpts, toolCall.ID, *toolCall.ToolUseData) } return result @@ -353,17 +308,27 @@ func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason defer activeToolMap.Delete(toolCall.ID) } - // Send all data-tooluse packets at the beginning - for _, toolCall := range stopReason.ToolCalls { - if toolCall.ToolUseData != nil { - log.Printf("AI data-tooluse %s\n", toolCall.ID) - _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData) - updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData) - if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval && chatOpts.RegisterToolApproval != nil { - chatOpts.RegisterToolApproval(toolCall.ID) + // Create and send all data-tooluse packets at the beginning + for i := range stopReason.ToolCalls { + toolCall := &stopReason.ToolCalls[i] + // Create toolUseData from the tool call input + var argsJSON string + if toolCall.Input != nil { + argsBytes, err := json.Marshal(toolCall.Input) + if err == nil { + argsJSON = string(argsBytes) } } + toolUseData := aiutil.CreateToolUseData(toolCall.ID, toolCall.Name, argsJSON, chatOpts) + stopReason.ToolCalls[i].ToolUseData = &toolUseData + log.Printf("AI data-tooluse %s\n", toolCall.ID) + _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, toolUseData) + updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolUseData) + if toolUseData.Approval == uctypes.ApprovalNeedsApproval && chatOpts.RegisterToolApproval != nil { + chatOpts.RegisterToolApproval(toolCall.ID) + } } + // At this point, all ToolCalls are guaranteed to have non-nil ToolUseData var toolResults []uctypes.AIToolResult for _, toolCall := range stopReason.ToolCalls { @@ -389,8 +354,8 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha stepNum := chatstore.DefaultChatStore.CountUserMessages(chatOpts.ChatId) metrics := &uctypes.AIMetrics{ - ChatId: chatOpts.ChatId, - StepNum: stepNum, + ChatId: chatOpts.ChatId, + StepNum: stepNum, Usage: uctypes.AIUsage{ APIType: chatOpts.Config.APIType, Model: chatOpts.Config.Model, @@ -398,7 +363,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha WidgetAccess: chatOpts.WidgetAccess, ToolDetail: make(map[string]int), ThinkingLevel: chatOpts.Config.ThinkingLevel, - ThinkingMode: chatOpts.Config.ThinkingMode, + AIMode: chatOpts.Config.AIMode, } firstStep := true var cont *uctypes.WaveContinueResponse @@ -419,7 +384,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha chatOpts.PlatformInfo = platformInfo } } - stopReason, rtnMessage, err := runAIChatStep(ctx, sseHandler, backend, chatOpts, cont) + stopReason, rtnMessages, err := runAIChatStep(ctx, sseHandler, backend, chatOpts, cont) metrics.RequestCount++ if chatOpts.Config.IsPremiumModel() { metrics.PremiumReqCount++ @@ -427,8 +392,8 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha if chatOpts.Config.IsWaveProxy() { metrics.ProxyReqCount++ } - if len(rtnMessage) > 0 { - usage := getUsage(rtnMessage) + if len(rtnMessages) > 0 { + usage := getUsage(rtnMessages) log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount) metrics.Usage.InputTokens += usage.InputTokens metrics.Usage.OutputTokens += usage.OutputTokens @@ -447,14 +412,14 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha _ = sseHandler.AiMsgFinish("", nil) break } - for _, msg := range rtnMessage { + for _, msg := range rtnMessages { if msg != nil { chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg) } } firstStep = false - if stopReason != nil && stopReason.Kind == uctypes.StopKindPremiumRateLimit && chatOpts.Config.APIType == APIType_OpenAI && chatOpts.Config.Model == uctypes.PremiumOpenAIModel { - log.Printf("Premium rate limit hit with gpt-5.1, switching to gpt-5-mini\n") + if stopReason != nil && stopReason.Kind == uctypes.StopKindPremiumRateLimit && chatOpts.Config.APIType == uctypes.APIType_OpenAIResponses && chatOpts.Config.Model == uctypes.PremiumOpenAIModel { + log.Printf("Premium rate limit hit with %s, switching to %s\n", uctypes.PremiumOpenAIModel, uctypes.DefaultOpenAIModel) cont = &uctypes.WaveContinueResponse{ Model: uctypes.DefaultOpenAIModel, ContinueFromKind: uctypes.StopKindPremiumRateLimit, @@ -597,7 +562,7 @@ func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) { WaveAIRequestDurMs: metrics.RequestDuration, WaveAIWidgetAccess: metrics.WidgetAccess, WaveAIThinkingLevel: metrics.ThinkingLevel, - WaveAIThinkingMode: metrics.ThinkingMode, + WaveAIMode: metrics.AIMode, }) _ = telemetry.RecordTEvent(ctx, event) } @@ -645,11 +610,14 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) { oref := waveobj.MakeORef(waveobj.OType_Builder, req.BuilderId) rtInfo = wstore.GetRTInfo(oref) } + if rtInfo == nil { + rtInfo = &waveobj.ObjRTInfo{} + } // Get WaveAI settings premium := shouldUsePremium() builderMode := req.BuilderId != "" - aiOpts, err := getWaveAISettings(premium, builderMode, rtInfo) + aiOpts, err := getWaveAISettings(premium, builderMode, *rtInfo) if err != nil { http.Error(w, fmt.Sprintf("WaveAI configuration error: %v", err), http.StatusInternalServerError) return @@ -673,15 +641,7 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) { BuilderId: req.BuilderId, BuilderAppId: req.BuilderAppId, } - if chatOpts.Config.APIType == APIType_OpenAI { - if chatOpts.BuilderId != "" { - chatOpts.SystemPrompt = []string{} - } else { - chatOpts.SystemPrompt = []string{SystemPromptText_OpenAI} - } - } else { - chatOpts.SystemPrompt = []string{SystemPromptText} - } + chatOpts.SystemPrompt = getSystemPrompt(chatOpts.Config.APIType, chatOpts.Config.Model, chatOpts.BuilderId != "") if req.TabId != "" { chatOpts.TabStateGenerator = func() (string, []uctypes.ToolDefinition, string, error) { diff --git a/pkg/telemetry/telemetrydata/telemetrydata.go b/pkg/telemetry/telemetrydata/telemetrydata.go index 79ec3d6941..7dd7bffdb9 100644 --- a/pkg/telemetry/telemetrydata/telemetrydata.go +++ b/pkg/telemetry/telemetrydata/telemetrydata.go @@ -147,7 +147,7 @@ type TEventProps struct { WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"` WaveAIThinkingLevel string `json:"waveai:thinkinglevel,omitempty"` - WaveAIThinkingMode string `json:"waveai:thinkingmode,omitempty"` + WaveAIMode string `json:"waveai:mode,omitempty"` WaveAIFeedback string `json:"waveai:feedback,omitempty" tstype:"\"good\" | \"bad\""` WaveAIAction string `json:"waveai:action,omitempty"` diff --git a/pkg/waveobj/objrtinfo.go b/pkg/waveobj/objrtinfo.go index ff88f7090c..77dadf9985 100644 --- a/pkg/waveobj/objrtinfo.go +++ b/pkg/waveobj/objrtinfo.go @@ -22,6 +22,6 @@ type ObjRTInfo struct { BuilderEnv map[string]string `json:"builder:env,omitempty"` WaveAIChatId string `json:"waveai:chatid,omitempty"` - WaveAIThinkingMode string `json:"waveai:thinkingmode,omitempty"` + WaveAIMode string `json:"waveai:mode,omitempty"` WaveAIMaxOutputTokens int `json:"waveai:maxoutputtokens,omitempty"` } diff --git a/pkg/wconfig/defaultconfig/waveai.json b/pkg/wconfig/defaultconfig/waveai.json new file mode 100644 index 0000000000..03e51f3e64 --- /dev/null +++ b/pkg/wconfig/defaultconfig/waveai.json @@ -0,0 +1,41 @@ +{ + "waveai@quick": { + "display:name": "Quick", + "display:order": -3, + "display:icon": "bolt", + "display:shortdesc": "gpt-5-mini", + "display:description": "Fastest responses (gpt-5-mini)", + "ai:apitype": "openai-responses", + "ai:model": "gpt-5-mini", + "ai:thinkinglevel": "low", + "ai:capabilities": ["tools", "images", "pdfs"], + "waveai:cloud": true, + "waveai:premium": false + }, + "waveai@balanced": { + "display:name": "Balanced", + "display:order": -2, + "display:icon": "sparkles", + "display:shortdesc": "gpt-5.1, low thinking", + "display:description": "Good mix of speed and accuracy\n(gpt-5.1 with minimal thinking)", + "ai:apitype": "openai-responses", + "ai:model": "gpt-5.1", + "ai:thinkinglevel": "low", + "ai:capabilities": ["tools", "images", "pdfs"], + "waveai:cloud": true, + "waveai:premium": true + }, + "waveai@deep": { + "display:name": "Deep", + "display:order": -1, + "display:icon": "lightbulb", + "display:shortdesc": "gpt-5.1, full thinking", + "display:description": "Slower but most capable\n(gpt-5.1 with full reasoning)", + "ai:apitype": "openai-responses", + "ai:model": "gpt-5.1", + "ai:thinkinglevel": "medium", + "ai:capabilities": ["tools", "images", "pdfs"], + "waveai:cloud": true, + "waveai:premium": true + } +} diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go index 4de30cbb6d..c493cf49d5 100644 --- a/pkg/wconfig/settingsconfig.go +++ b/pkg/wconfig/settingsconfig.go @@ -257,6 +257,24 @@ type WebBookmark struct { DisplayOrder float64 `json:"display:order,omitempty"` } +type AIModeConfigType struct { + DisplayName string `json:"display:name"` + DisplayOrder float64 `json:"display:order,omitempty"` + DisplayIcon string `json:"display:icon"` + DisplayShortDesc string `json:"display:shortdesc,omitempty"` + DisplayDescription string `json:"display:description"` + APIType string `json:"ai:apitype"` + Model string `json:"ai:model"` + ThinkingLevel string `json:"ai:thinkinglevel"` + BaseURL string `json:"ai:baseurl,omitempty"` + APIVersion string `json:"ai:apiversion,omitempty"` + APIToken string `json:"ai:apitoken,omitempty"` + APITokenSecretName string `json:"ai:apitokensecretname,omitempty"` + Capabilities []string `json:"ai:capabilities,omitempty"` + WaveAICloud bool `json:"waveai:cloud,omitempty"` + WaveAIPremium bool `json:"waveai:premium,omitempty"` +} + type FullConfigType struct { Settings SettingsType `json:"settings" merge:"meta"` MimeTypes map[string]MimeTypeConfigType `json:"mimetypes"` @@ -266,8 +284,10 @@ type FullConfigType struct { TermThemes map[string]TermThemeType `json:"termthemes"` Connections map[string]ConnKeywords `json:"connections"` Bookmarks map[string]WebBookmark `json:"bookmarks"` + WaveAIModes map[string]AIModeConfigType `json:"waveai"` ConfigErrors []ConfigError `json:"configerrors" configfile:"-"` } + type ConnKeywords struct { ConnWshEnabled *bool `json:"conn:wshenabled,omitempty"` ConnAskBeforeWshInstall *bool `json:"conn:askbeforewshinstall,omitempty"` diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 131ac51e60..6f6c2afc7a 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -580,12 +580,10 @@ func (ws *WshServer) EventReadHistoryCommand(ctx context.Context, data wshrpc.Co } func (ws *WshServer) SetConfigCommand(ctx context.Context, data wshrpc.MetaSettingsType) error { - log.Printf("SETCONFIG: %v\n", data) return wconfig.SetBaseConfigValue(data.MetaMapType) } func (ws *WshServer) SetConnectionsConfigCommand(ctx context.Context, data wshrpc.ConnConfigRequest) error { - log.Printf("SET CONNECTIONS CONFIG: %v\n", data) return wconfig.SetConnectionsConfigValue(data.Host, data.MetaMapType) }