Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ export type StreamEvent =
| { type: 'error'; error: Error }
| { type: 'metadata'; data: unknown };

export type VoiceStreamEvent =
| StreamEvent
| { type: 'transcription'; transcription: string }
| { type: 'audio'; audioData: string; audioFormat?: string };

export type StreamHandler = (event: StreamEvent) => void | Promise<void>;
export type VoiceStreamHandler = (
event: VoiceStreamEvent
) => void | Promise<void>;

export interface StreamResponse {
abort: () => void;
Expand Down Expand Up @@ -177,6 +185,11 @@ export interface ProviderImplementation<
handler: StreamHandler
) => StreamResponse;
voiceLLM: (params: VoiceParams, config: TConfig) => Promise<VoiceLLMResponse>;
voiceStreamLLM?: (
params: VoiceParams,
config: TConfig,
handler: VoiceStreamHandler
) => StreamResponse;
handleResponse: (response: Response) => Promise<LLMResponse>;
handleStreamResponse: (chunk: string) => StreamEvent;
}
Expand Down
150 changes: 148 additions & 2 deletions packages/cedar-os/src/store/agentConnection/agentConnectionSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ import type {
ResponseProcessor,
ResponseProcessorRegistry,
StructuredResponseType,
VoiceStreamHandler,
VoiceStreamEvent,
} from '@/store/agentConnection/AgentConnectionTypes';

import { useCedarStore } from '@/store/CedarStore';
import { getCedarState } from '@/store/CedarStore';
import { sanitizeJson } from '@/utils/sanitizeJson';
Expand Down Expand Up @@ -90,8 +93,12 @@ export interface AgentConnectionSlice {
handler: StreamHandler
) => StreamResponse;

// Voice LLM method
// Voice LLM methods
voiceLLM: (params: VoiceParams) => Promise<VoiceLLMResponse>;
voiceStreamLLM: (
params: VoiceParams,
handler: VoiceStreamHandler
) => StreamResponse;

// High-level methods that use callLLM/streamLLM
sendMessage: (params?: SendMessageParams) => Promise<void>;
Expand Down Expand Up @@ -125,7 +132,7 @@ export interface AgentConnectionSlice {
// Create a typed version of the slice that knows about the provider
export type TypedAgentConnectionSlice<T extends ProviderConfig> = Omit<
AgentConnectionSlice,
'callLLM' | 'streamLLM' | 'callLLMStructured' | 'voiceLLM'
'callLLM' | 'streamLLM' | 'callLLMStructured' | 'voiceLLM' | 'voiceStreamLLM'
> & {
callLLM: (params: GetParamsForConfig<T>) => Promise<LLMResponse>;
callLLMStructured: (
Expand All @@ -136,6 +143,10 @@ export type TypedAgentConnectionSlice<T extends ProviderConfig> = Omit<
handler: StreamHandler
) => StreamResponse;
voiceLLM: (params: VoiceParams) => Promise<VoiceLLMResponse>;
voiceStreamLLM: (
params: VoiceParams,
handler: VoiceStreamHandler
) => StreamResponse;
};

export const improvePrompt = async (
Expand Down Expand Up @@ -443,6 +454,141 @@ export const createAgentConnectionSlice: StateCreator<
}
},

// Voice streaming LLM method
voiceStreamLLM: (params: VoiceParams, handler: VoiceStreamHandler) => {
const config = get().providerConfig;
if (!config) {
throw new Error('No LLM provider configured');
}

// Augment params for Mastra provider to include resourceId & threadId
let voiceParams: VoiceParams = params;
if (config.provider === 'mastra') {
const resourceId = getCedarState('userId') as string | undefined;
const threadId = getCedarState('threadId') as string | undefined;
voiceParams = {
...params,
resourceId,
threadId,
} as typeof voiceParams;
}

// Log the stream start
const streamId = get().logStreamStart(voiceParams, config.provider);

const provider = getProviderImplementation(config);
const abortController = new AbortController();

set({ currentAbortController: abortController, isStreaming: true });

// Wrap the handler to log stream events
const wrappedHandler: VoiceStreamHandler = (event: VoiceStreamEvent) => {
if (event.type === 'chunk') {
get().logStreamChunk(streamId, event.content);
} else if (event.type === 'done') {
get().logStreamEnd(streamId, event.completedItems);
} else if (event.type === 'error') {
get().logAgentError(streamId, event.error);
} else if (event.type === 'object') {
get().logStreamObject(streamId, event.object);
} else if (event.type === 'transcription') {
get().logStreamChunk(streamId, event.transcription);
} else if (event.type === 'audio') {
// Log audio event (could be enhanced with audio-specific logging)
get().logStreamChunk(
streamId,
`[AUDIO: ${event.audioFormat || 'unknown'}]`
);
}
handler(event);
};

// Check if provider has voiceStreamLLM, fallback to voiceLLM if not
const providerImplementation = provider as unknown as Record<
string,
unknown
>;
if (typeof providerImplementation.voiceStreamLLM === 'function') {
// Provider supports voice streaming
const originalResponse = providerImplementation.voiceStreamLLM(
voiceParams as unknown as never,
config as never,
wrappedHandler
);

// Wrap the completion to update state when done
const wrappedCompletion = originalResponse.completion.finally(() => {
set({ isStreaming: false, currentAbortController: null });
});

return {
abort: () => {
originalResponse.abort();
abortController.abort();
},
completion: wrappedCompletion,
};
} else {
// Fallback to non-streaming voiceLLM for backward compatibility
const completion = (async () => {
try {
const response = await provider.voiceLLM(
voiceParams as unknown as never,
config as never
);

// Simulate streaming events for compatibility
if (response.content) {
wrappedHandler({ type: 'chunk', content: response.content });
}
if (response.audioData || response.audioUrl) {
wrappedHandler({
type: 'audio',
audioData: response.audioData || response.audioUrl || '',
audioFormat: response.audioFormat,
});
Comment on lines +544 to +549
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Audio data fallback logic uses audioData || audioUrl || '' which could result in empty string for audio data if both are undefined

}
if (response.object) {
wrappedHandler({
type: 'object',
object: Array.isArray(response.object)
? response.object
: [response.object],
});
}

// Send done event
const completedItems: (string | StructuredResponseType)[] = [];
if (response.content) completedItems.push(response.content);
if (response.object) {
const objects = Array.isArray(response.object)
? response.object
: [response.object];
completedItems.push(...objects);
}

wrappedHandler({ type: 'done', completedItems });
} catch (error) {
if (error instanceof Error && error.name !== 'AbortError') {
wrappedHandler({ type: 'error', error });
}
}
})();

// Wrap the completion to update state when done
const wrappedCompletion = completion.finally(() => {
set({ isStreaming: false, currentAbortController: null });
});

return {
abort: () => {
abortController.abort();
},
completion: wrappedCompletion,
};
}
},

// Handle LLM response
handleLLMResponse: async (itemsToProcess) => {
const state = get();
Expand Down
Loading
Loading