diff --git a/packages/cedar-os/src/store/agentConnection/AgentConnectionTypes.ts b/packages/cedar-os/src/store/agentConnection/AgentConnectionTypes.ts index 027e0344..94fb32ad 100644 --- a/packages/cedar-os/src/store/agentConnection/AgentConnectionTypes.ts +++ b/packages/cedar-os/src/store/agentConnection/AgentConnectionTypes.ts @@ -50,7 +50,15 @@ export type StreamEvent = | { type: 'error'; error: Error } | { type: 'metadata'; data: unknown }; +export type VoiceStreamEvent = + | StreamEvent + | { type: 'transcription'; transcription: string } + | { type: 'audio'; audioData: string; audioFormat?: string }; + export type StreamHandler = (event: StreamEvent) => void | Promise; +export type VoiceStreamHandler = ( + event: VoiceStreamEvent +) => void | Promise; export interface StreamResponse { abort: () => void; @@ -177,6 +185,11 @@ export interface ProviderImplementation< handler: StreamHandler ) => StreamResponse; voiceLLM: (params: VoiceParams, config: TConfig) => Promise; + voiceStreamLLM?: ( + params: VoiceParams, + config: TConfig, + handler: VoiceStreamHandler + ) => StreamResponse; handleResponse: (response: Response) => Promise; handleStreamResponse: (chunk: string) => StreamEvent; } diff --git a/packages/cedar-os/src/store/agentConnection/agentConnectionSlice.ts b/packages/cedar-os/src/store/agentConnection/agentConnectionSlice.ts index be11a83d..62584414 100644 --- a/packages/cedar-os/src/store/agentConnection/agentConnectionSlice.ts +++ b/packages/cedar-os/src/store/agentConnection/agentConnectionSlice.ts @@ -19,7 +19,10 @@ import type { ResponseProcessor, ResponseProcessorRegistry, StructuredResponseType, + VoiceStreamHandler, + VoiceStreamEvent, } from '@/store/agentConnection/AgentConnectionTypes'; + import { useCedarStore } from '@/store/CedarStore'; import { getCedarState } from '@/store/CedarStore'; import { sanitizeJson } from '@/utils/sanitizeJson'; @@ -90,8 +93,12 @@ export interface AgentConnectionSlice { handler: StreamHandler ) => StreamResponse; - // Voice LLM method + // Voice LLM methods voiceLLM: (params: VoiceParams) => Promise; + voiceStreamLLM: ( + params: VoiceParams, + handler: VoiceStreamHandler + ) => StreamResponse; // High-level methods that use callLLM/streamLLM sendMessage: (params?: SendMessageParams) => Promise; @@ -125,7 +132,7 @@ export interface AgentConnectionSlice { // Create a typed version of the slice that knows about the provider export type TypedAgentConnectionSlice = Omit< AgentConnectionSlice, - 'callLLM' | 'streamLLM' | 'callLLMStructured' | 'voiceLLM' + 'callLLM' | 'streamLLM' | 'callLLMStructured' | 'voiceLLM' | 'voiceStreamLLM' > & { callLLM: (params: GetParamsForConfig) => Promise; callLLMStructured: ( @@ -136,6 +143,10 @@ export type TypedAgentConnectionSlice = Omit< handler: StreamHandler ) => StreamResponse; voiceLLM: (params: VoiceParams) => Promise; + voiceStreamLLM: ( + params: VoiceParams, + handler: VoiceStreamHandler + ) => StreamResponse; }; export const improvePrompt = async ( @@ -443,6 +454,141 @@ export const createAgentConnectionSlice: StateCreator< } }, + // Voice streaming LLM method + voiceStreamLLM: (params: VoiceParams, handler: VoiceStreamHandler) => { + const config = get().providerConfig; + if (!config) { + throw new Error('No LLM provider configured'); + } + + // Augment params for Mastra provider to include resourceId & threadId + let voiceParams: VoiceParams = params; + if (config.provider === 'mastra') { + const resourceId = getCedarState('userId') as string | undefined; + const threadId = getCedarState('threadId') as string | undefined; + voiceParams = { + ...params, + resourceId, + threadId, + } as typeof voiceParams; + } + + // Log the stream start + const streamId = get().logStreamStart(voiceParams, config.provider); + + const provider = getProviderImplementation(config); + const abortController = new AbortController(); + + set({ currentAbortController: abortController, isStreaming: true }); + + // Wrap the handler to log stream events + const wrappedHandler: VoiceStreamHandler = (event: VoiceStreamEvent) => { + if (event.type === 'chunk') { + get().logStreamChunk(streamId, event.content); + } else if (event.type === 'done') { + get().logStreamEnd(streamId, event.completedItems); + } else if (event.type === 'error') { + get().logAgentError(streamId, event.error); + } else if (event.type === 'object') { + get().logStreamObject(streamId, event.object); + } else if (event.type === 'transcription') { + get().logStreamChunk(streamId, event.transcription); + } else if (event.type === 'audio') { + // Log audio event (could be enhanced with audio-specific logging) + get().logStreamChunk( + streamId, + `[AUDIO: ${event.audioFormat || 'unknown'}]` + ); + } + handler(event); + }; + + // Check if provider has voiceStreamLLM, fallback to voiceLLM if not + const providerImplementation = provider as unknown as Record< + string, + unknown + >; + if (typeof providerImplementation.voiceStreamLLM === 'function') { + // Provider supports voice streaming + const originalResponse = providerImplementation.voiceStreamLLM( + voiceParams as unknown as never, + config as never, + wrappedHandler + ); + + // Wrap the completion to update state when done + const wrappedCompletion = originalResponse.completion.finally(() => { + set({ isStreaming: false, currentAbortController: null }); + }); + + return { + abort: () => { + originalResponse.abort(); + abortController.abort(); + }, + completion: wrappedCompletion, + }; + } else { + // Fallback to non-streaming voiceLLM for backward compatibility + const completion = (async () => { + try { + const response = await provider.voiceLLM( + voiceParams as unknown as never, + config as never + ); + + // Simulate streaming events for compatibility + if (response.content) { + wrappedHandler({ type: 'chunk', content: response.content }); + } + if (response.audioData || response.audioUrl) { + wrappedHandler({ + type: 'audio', + audioData: response.audioData || response.audioUrl || '', + audioFormat: response.audioFormat, + }); + } + if (response.object) { + wrappedHandler({ + type: 'object', + object: Array.isArray(response.object) + ? response.object + : [response.object], + }); + } + + // Send done event + const completedItems: (string | StructuredResponseType)[] = []; + if (response.content) completedItems.push(response.content); + if (response.object) { + const objects = Array.isArray(response.object) + ? response.object + : [response.object]; + completedItems.push(...objects); + } + + wrappedHandler({ type: 'done', completedItems }); + } catch (error) { + if (error instanceof Error && error.name !== 'AbortError') { + wrappedHandler({ type: 'error', error }); + } + } + })(); + + // Wrap the completion to update state when done + const wrappedCompletion = completion.finally(() => { + set({ isStreaming: false, currentAbortController: null }); + }); + + return { + abort: () => { + abortController.abort(); + }, + completion: wrappedCompletion, + }; + } + }, + // Handle LLM response handleLLMResponse: async (itemsToProcess) => { const state = get(); diff --git a/packages/cedar-os/src/store/agentConnection/providers/mastra.ts b/packages/cedar-os/src/store/agentConnection/providers/mastra.ts index b8f8b159..dc6eed52 100644 --- a/packages/cedar-os/src/store/agentConnection/providers/mastra.ts +++ b/packages/cedar-os/src/store/agentConnection/providers/mastra.ts @@ -3,11 +3,58 @@ import type { MastraParams, ProviderImplementation, StructuredParams, + VoiceParams, } from '@/store/agentConnection/AgentConnectionTypes'; import { handleEventStream } from '@/store/agentConnection/agentUtils'; type MastraConfig = InferProviderConfig<'mastra'>; +// Helper functions for voice methods +const createVoiceHeaders = (config: MastraConfig): Record => { + const headers: Record = {}; + + // Only add Authorization header if apiKey is provided + if (config.apiKey) { + headers.Authorization = `Bearer ${config.apiKey}`; + } + + return headers; +}; + +const resolveVoiceEndpoint = ( + voiceSettings: VoiceParams['voiceSettings'], + config: MastraConfig +): string => { + // Use the endpoint from voiceSettings if provided, otherwise use voiceRoute from config + const voiceEndpoint = voiceSettings.endpoint || config.voiceRoute || '/voice'; + return voiceEndpoint.startsWith('http') + ? voiceEndpoint + : `${config.baseURL}${voiceEndpoint}`; +}; + +const createVoiceFormData = (params: VoiceParams): FormData => { + const { audioData, voiceSettings, context, ...rest } = params; + + const formData = new FormData(); + formData.append('audio', audioData, 'recording.webm'); + formData.append('settings', JSON.stringify(voiceSettings)); + + if (context) { + formData.append('context', context); + } + + for (const [key, value] of Object.entries(rest)) { + if (value === undefined || value === null) continue; + if (typeof value === 'object') { + formData.append(key, JSON.stringify(value)); + } else { + formData.append(key, String(value)); + } + } + + return formData; +}; + export const mastraProvider: ProviderImplementation< MastraParams, MastraConfig @@ -135,37 +182,9 @@ export const mastraProvider: ProviderImplementation< }, voiceLLM: async (params, config) => { - const { audioData, voiceSettings, context, ...rest } = params; - - const headers: Record = {}; - - // Only add Authorization header if apiKey is provided - if (config.apiKey) { - headers.Authorization = `Bearer ${config.apiKey}`; - } - - // Use the endpoint from voiceSettings if provided, otherwise use voiceRoute from config - const voiceEndpoint = - voiceSettings.endpoint || config.voiceRoute || '/voice'; - const fullUrl = voiceEndpoint.startsWith('http') - ? voiceEndpoint - : `${config.baseURL}${voiceEndpoint}`; - - const formData = new FormData(); - formData.append('audio', audioData, 'recording.webm'); - formData.append('settings', JSON.stringify(voiceSettings)); - if (context) { - formData.append('context', context); - } - - for (const [key, value] of Object.entries(rest)) { - if (value === undefined || value === null) continue; - if (typeof value === 'object') { - formData.append(key, JSON.stringify(value)); - } else { - formData.append(key, String(value)); - } - } + const headers = createVoiceHeaders(config); + const fullUrl = resolveVoiceEndpoint(params.voiceSettings, config); + const formData = createVoiceFormData(params); const response = await fetch(fullUrl, { method: 'POST', @@ -211,6 +230,96 @@ export const mastraProvider: ProviderImplementation< } }, + voiceStreamLLM: (params, config, handler) => { + const abortController = new AbortController(); + + const completion = (async () => { + try { + const headers = createVoiceHeaders(config); + const baseUrl = resolveVoiceEndpoint(params.voiceSettings, config); + const streamUrl = `${baseUrl}/stream`; + const formData = createVoiceFormData(params); + + const response = await fetch(streamUrl, { + method: 'POST', + headers, + body: formData, + signal: abortController.signal, + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + // Use handleEventStream with voice-aware object handling + await handleEventStream(response, (event) => { + // Handle audio events that come through as object events + if (event.type === 'object' && event.object) { + const objects = Array.isArray(event.object) + ? event.object + : [event.object]; + + // Check if any of these objects are audio events + for (const obj of objects) { + if ( + obj && + typeof obj === 'object' && + 'type' in obj && + obj.type === 'audio' + ) { + // Transform Mastra audio object to VoiceStreamEvent + const audioObj = obj as { + type: 'audio'; + audioData?: string; + audioFormat?: string; + }; + if (audioObj.audioData) { + handler({ + type: 'audio', + audioData: audioObj.audioData, + audioFormat: audioObj.audioFormat, + }); + // Continue processing other objects in the array if any + continue; + } + } else if ( + obj && + typeof obj === 'object' && + 'type' in obj && + obj.type === 'transcription' + ) { + const transcriptionObj = obj as { + type: 'transcription'; + transcription: string; + }; + handler({ + type: 'transcription', + transcription: transcriptionObj.transcription, + }); + // Continue processing other objects in the array if any + continue; + } else { + handler(event); + } + } + } else { + // Pass through all other events (chunk, done, error, metadata) + handler(event); + } + }); + } catch (error) { + if (error instanceof Error && error.name !== 'AbortError') { + handler({ type: 'error', error }); + } + } + })(); + + return { + abort: () => abortController.abort(), + completion, + }; + }, + handleResponse: async (response) => { if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); diff --git a/packages/cedar-os/src/store/voice/voiceSlice.ts b/packages/cedar-os/src/store/voice/voiceSlice.ts index fad34ea8..0e97e01f 100644 --- a/packages/cedar-os/src/store/voice/voiceSlice.ts +++ b/packages/cedar-os/src/store/voice/voiceSlice.ts @@ -1,6 +1,7 @@ import { StateCreator } from 'zustand'; import type { CedarStore } from '@/store/CedarOSTypes'; import type { + VoiceStreamEvent, StructuredResponseType, VoiceLLMResponse, } from '@/store/agentConnection/AgentConnectionTypes'; @@ -26,6 +27,7 @@ export interface VoiceState { useBrowserTTS?: boolean; autoAddToMessages?: boolean; endpoint?: string; // Voice endpoint URL + stream?: boolean; // Enable streaming voice responses }; } @@ -71,6 +73,7 @@ const initialVoiceState: VoiceState = { volume: 1.0, useBrowserTTS: false, autoAddToMessages: true, // Default to true for automatic message integration + stream: false, // Default to false for backward compatibility }, }; @@ -216,16 +219,74 @@ export const createVoiceSlice: StateCreator = ( // Get the stringified additional context from the store const contextString = get().stringifyAdditionalContext(); - // Use the agent connection's voiceLLM method - const response = await get().voiceLLM({ + const voiceParams = { audioData, voiceSettings, context: contextString, prompt: '', - }); + }; - // Handle the response using the new handleLLMVoice function - await get().handleLLMVoice(response); + // Check if voice streaming is enabled + if (voiceSettings.stream) { + // Use streaming voice method - follow same pattern as streamLLM + const streamResponse = get().voiceStreamLLM( + voiceParams, + async (event: VoiceStreamEvent) => { + switch (event.type) { + case 'transcription': + // Pass transcription to handleLLMVoice + await get().handleLLMVoice({ + content: '', + transcription: event.transcription, + }); + break; + case 'chunk': + // Pass text content directly to handleLLMVoice + await get().handleLLMVoice({ content: event.content }); + break; + case 'audio': + // Pass complete audio response to handleLLMVoice + await get().handleLLMVoice({ + content: '', + audioData: event.audioData, + audioFormat: event.audioFormat, + }); + break; + case 'object': + // Pass structured objects to handleLLMVoice + await get().handleLLMVoice({ + content: '', + object: Array.isArray(event.object) + ? event.object + : [event.object], + }); + break; + case 'done': + // Stream completed - no additional processing needed + break; + case 'error': + console.error('Voice stream error:', event.error); + set({ + voiceError: + event.error.message || 'Voice streaming error occurred', + }); + break; + } + } + ); + + // Wait for streaming to complete + await streamResponse.completion; + } else { + // Use the non-streaming agent connection's voiceLLM method + const response = await get().voiceLLM(voiceParams); + + // Handle the response using the existing handleLLMVoice function + await get().handleLLMVoice(response); + } + + // Voice processing completed successfully (streaming or non-streaming) + get().setIsProcessing(false); } catch (error) { set({ voiceError: @@ -241,6 +302,7 @@ export const createVoiceSlice: StateCreator = ( try { set({ isSpeaking: false }); + let handled = false; // Handle audio playback (voice-specific) if (response.audioData && response.audioFormat) { @@ -251,8 +313,10 @@ export const createVoiceSlice: StateCreator = ( } const audioBuffer = bytes.buffer; await get().playAudioResponse(audioBuffer); + handled = true; } else if (response.audioUrl) { await get().playAudioResponse(response.audioUrl); + handled = true; } else if (response.content && voiceSettings.useBrowserTTS) { if ('speechSynthesis' in window) { const utterance = new SpeechSynthesisUtterance(response.content); @@ -265,6 +329,7 @@ export const createVoiceSlice: StateCreator = ( utterance.onend = () => set({ isSpeaking: false }); speechSynthesis.speak(utterance); + handled = true; } } @@ -280,20 +345,20 @@ export const createVoiceSlice: StateCreator = ( timestamp: new Date().toISOString(), }, }); + handled = true; } // Build items array for handleLLMResponse const items: (string | StructuredResponseType)[] = []; - // This should be fixed tbh. HandleLLMResponse should be able to handle this, but due to current streaming limitations. - // Add content if present - if (response.content) { + if (response.content && !handled) { items.push(response.content); + handled = true; } // Add object if present - cast to StructuredResponseType for compatibility - if (response.object) { + if (response.object && !handled) { if (Array.isArray(response.object)) { items.push(...response.object); } else { @@ -307,8 +372,7 @@ export const createVoiceSlice: StateCreator = ( await handleLLMResponse(items); } - // Set processing state to false when voice processing completes successfully - get().setIsProcessing(false); + // Note: processing state is now cleared after streaming/non-streaming completion in streamAudioToEndpoint } catch (error) { set({ voiceError: