diff --git a/app/lib/backend/http/api/users.dart b/app/lib/backend/http/api/users.dart index 26262db021..37dbffe016 100644 --- a/app/lib/backend/http/api/users.dart +++ b/app/lib/backend/http/api/users.dart @@ -243,6 +243,18 @@ Future deletePerson(String personId) async { return response.statusCode == 204; } +Future deletePersonSpeechSample(String personId, int sampleIndex) async { + var response = await makeApiCall( + url: '${Env.apiBaseUrl}v1/users/people/$personId/speech-samples/$sampleIndex', + headers: {}, + method: 'DELETE', + body: '', + ); + if (response == null) return false; + debugPrint('deletePersonSpeechSample response: ${response.body}'); + return response.statusCode == 200; +} + Future getFollowUpQuestion({String conversationId = '0'}) async { var response = await makeApiCall( url: '${Env.apiBaseUrl}v1/joan/$conversationId/followup-question', diff --git a/app/lib/pages/settings/people.dart b/app/lib/pages/settings/people.dart index 38e0c3d145..de40c6ba75 100644 --- a/app/lib/pages/settings/people.dart +++ b/app/lib/pages/settings/people.dart @@ -8,7 +8,6 @@ import 'package:omi/providers/people_provider.dart'; import 'package:omi/providers/connectivity_provider.dart'; import 'package:omi/widgets/dialog.dart'; import 'package:omi/widgets/extensions/functions.dart'; -import 'package:just_audio/just_audio.dart'; import 'package:omi/utils/l10n_extensions.dart'; import 'package:provider/provider.dart'; @@ -161,7 +160,7 @@ class _UserPeoplePageState extends State<_UserPeoplePage> { ); } - Future _confirmDeleteSample(int peopleIdx, Person person, String url, PeopleProvider provider) async { + Future _confirmDeleteSample(int peopleIdx, Person person, int sampleIdx, PeopleProvider provider) async { final connectivityProvider = Provider.of(context, listen: false); if (!connectivityProvider.isConnected) { ConnectivityProvider.showNoInternetDialog(context); @@ -180,7 +179,7 @@ class _UserPeoplePageState extends State<_UserPeoplePage> { ); if (confirmed == true) { - provider.deletePersonSample(peopleIdx, url); + await provider.deletePersonSample(peopleIdx, sampleIdx); } } @@ -297,20 +296,11 @@ class _UserPeoplePageState extends State<_UserPeoplePage> { ), onPressed: () => provider.playPause(index, j, sample), ), - title: Text(index == 0 + title: Text(j == 0 ? context.l10n.speechProfile - : context.l10n.sampleNumber(index)), - onTap: () => _confirmDeleteSample(index, person, sample, provider), - subtitle: FutureBuilder( - future: AudioPlayer().setUrl(sample), - builder: (context, snapshot) { - if (snapshot.hasData) { - return Text(context.l10n.secondsCount(snapshot.data!.inSeconds)); - } else { - return Text(context.l10n.loadingDuration); - } - }, - ), + : context.l10n.sampleNumber(j)), + onTap: () => _confirmDeleteSample(index, person, j, provider), + subtitle: Text('Tap to delete'), )), ], ), diff --git a/app/lib/providers/people_provider.dart b/app/lib/providers/people_provider.dart index 1554f15854..d1b8fccd1c 100644 --- a/app/lib/providers/people_provider.dart +++ b/app/lib/providers/people_provider.dart @@ -1,5 +1,4 @@ import 'package:flutter/cupertino.dart'; -import 'package:omi/backend/http/api/speech_profile.dart'; import 'package:omi/backend/http/api/users.dart'; import 'package:omi/backend/preferences.dart'; import 'package:omi/backend/schema/person.dart'; @@ -106,21 +105,17 @@ class PeopleProvider extends BaseProvider { notifyListeners(); } - String _getFileNameFromUrl(String url) { - Uri uri = Uri.parse(url); - String fileName = uri.pathSegments.last; - return fileName.split('.').first; - } + Future deletePersonSample(int personIdx, int sampleIdx) async { + String personId = people[personIdx].id; - void deletePersonSample(int personIdx, String url) { - String name = _getFileNameFromUrl(url); - var parts = name.split('_segment_'); - String conversationId = parts[0]; - int segmentIdx = int.parse(parts[1]); - deleteProfileSample(conversationId, segmentIdx, personId: people[personIdx].id); - people[personIdx].speechSamples!.remove(url); - SharedPreferencesUtil().replaceCachedPerson(people[personIdx]); - notifyListeners(); + bool success = await deletePersonSpeechSample(personId, sampleIdx); + if (success) { + people[personIdx].speechSamples!.removeAt(sampleIdx); + SharedPreferencesUtil().replaceCachedPerson(people[personIdx]); + notifyListeners(); + } else { + debugPrint('Failed to delete speech sample at index: $sampleIdx'); + } } void deletePersonProvider(Person person) { diff --git a/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml b/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml index 258163c720..a09d67717b 100644 --- a/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml +++ b/backend/charts/backend-listen/dev_omi_backend_listen_values.yaml @@ -114,6 +114,8 @@ env: value: "http://34.172.155.20:80/v1/vad" - name: HOSTED_SPEECH_PROFILE_API_URL value: "http://34.172.155.20:80/v1/speaker-identification" + - name: HOSTED_SPEAKER_EMBEDDING_API_URL + value: "http://34.172.155.20:80" - name: PINECONE_API_KEY valueFrom: secretKeyRef: diff --git a/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml b/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml index e82ce287cb..6cf3b68fa6 100644 --- a/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml +++ b/backend/charts/backend-listen/prod_omi_backend_listen_values.yaml @@ -107,6 +107,8 @@ env: value: "http://172.16.128.101:8080/v1/vad" - name: HOSTED_SPEECH_PROFILE_API_URL value: "http://172.16.128.101:8080/v1/speaker-identification" + - name: HOSTED_SPEAKER_EMBEDDING_API_URL + value: "http://diarizer.omi.me:80" - name: PINECONE_API_KEY valueFrom: secretKeyRef: diff --git a/backend/charts/pusher/dev_omi_pusher_values.yaml b/backend/charts/pusher/dev_omi_pusher_values.yaml index 083c4eb867..5d5f2df4ab 100644 --- a/backend/charts/pusher/dev_omi_pusher_values.yaml +++ b/backend/charts/pusher/dev_omi_pusher_values.yaml @@ -106,6 +106,8 @@ env: value: "http://34.172.155.20:80/v1/vad" - name: HOSTED_SPEECH_PROFILE_API_URL value: "http://34.172.155.20:80/v1/speaker-identification" + - name: HOSTED_SPEAKER_EMBEDDING_API_URL + value: "http://34.172.155.20:80" - name: PINECONE_API_KEY valueFrom: secretKeyRef: diff --git a/backend/charts/pusher/prod_omi_pusher_values.yaml b/backend/charts/pusher/prod_omi_pusher_values.yaml index c62369931e..d3d8e9416f 100644 --- a/backend/charts/pusher/prod_omi_pusher_values.yaml +++ b/backend/charts/pusher/prod_omi_pusher_values.yaml @@ -111,6 +111,8 @@ env: value: "http://vad.omi.me:80/v1/vad" - name: HOSTED_SPEECH_PROFILE_API_URL value: "http://vad.omi.me:80/v1/speaker-identification" + - name: HOSTED_SPEAKER_EMBEDDING_API_URL + value: "http://diarizer.omi.me:80" - name: PINECONE_API_KEY valueFrom: secretKeyRef: diff --git a/backend/database/users.py b/backend/database/users.py index e1c1986cd5..668efdf699 100644 --- a/backend/database/users.py +++ b/backend/database/users.py @@ -100,6 +100,127 @@ def delete_person(uid: str, person_id: str): person_ref.delete() +def add_person_speech_sample(uid: str, person_id: str, sample_path: str, max_samples: int = 5) -> bool: + """ + Append speech sample path to person's speech_samples list. + Limits to max_samples to prevent unlimited growth. + + Args: + uid: User ID + person_id: Person ID + sample_path: GCS path to the speech sample + max_samples: Maximum number of samples to keep (default 5) + + Returns: + True if sample was added, False if limit reached + """ + person_ref = db.collection('users').document(uid).collection('people').document(person_id) + person_doc = person_ref.get() + + if not person_doc.exists: + return False + + person_data = person_doc.to_dict() + current_samples = person_data.get('speech_samples', []) + + # Check if we've hit the limit + if len(current_samples) >= max_samples: + return False + + person_ref.update( + { + 'speech_samples': firestore.ArrayUnion([sample_path]), + 'updated_at': datetime.now(timezone.utc), + } + ) + return True + + +def get_person_speech_samples_count(uid: str, person_id: str) -> int: + """Get the count of speech samples for a person.""" + person_ref = db.collection('users').document(uid).collection('people').document(person_id) + person_doc = person_ref.get() + + if not person_doc.exists: + return 0 + + person_data = person_doc.to_dict() + return len(person_data.get('speech_samples', [])) + + +def remove_person_speech_sample(uid: str, person_id: str, sample_path: str) -> bool: + """ + Remove a speech sample path from person's speech_samples list. + + Args: + uid: User ID + person_id: Person ID + sample_path: GCS path to remove + + Returns: + True if removed, False if person not found + """ + person_ref = db.collection('users').document(uid).collection('people').document(person_id) + person_doc = person_ref.get() + + if not person_doc.exists: + return False + + person_ref.update({ + 'speech_samples': firestore.ArrayRemove([sample_path]), + 'updated_at': datetime.now(timezone.utc), + }) + return True + + +def set_person_speaker_embedding(uid: str, person_id: str, embedding: list) -> bool: + """ + Store speaker embedding for a person. + + Args: + uid: User ID + person_id: Person ID + embedding: List of floats representing the speaker embedding + + Returns: + True if stored successfully, False if person not found + """ + person_ref = db.collection('users').document(uid).collection('people').document(person_id) + person_doc = person_ref.get() + + if not person_doc.exists: + return False + + person_ref.update( + { + 'speaker_embedding': embedding, + 'updated_at': datetime.now(timezone.utc), + } + ) + return True + + +def get_person_speaker_embedding(uid: str, person_id: str) -> Optional[list]: + """ + Get speaker embedding for a person. + + Args: + uid: User ID + person_id: Person ID + + Returns: + List of floats representing the embedding, or None if not found + """ + person_ref = db.collection('users').document(uid).collection('people').document(person_id) + person_doc = person_ref.get() + + if not person_doc.exists: + return None + + person_data = person_doc.to_dict() + return person_data.get('speaker_embedding') + + def delete_user_data(uid: str): user_ref = db.collection('users').document(uid) if not user_ref.get().exists: diff --git a/backend/routers/conversations.py b/backend/routers/conversations.py index fc8e5dee82..78f814d380 100644 --- a/backend/routers/conversations.py +++ b/backend/routers/conversations.py @@ -32,6 +32,7 @@ from utils.conversations.process_conversation import process_conversation, retrieve_in_progress_conversation from utils.conversations.search import search_conversations from utils.llm.conversation_processing import generate_summary_with_prompt +from utils.speaker_identification import extract_speaker_samples from utils.other import endpoints as auth from utils.other.storage import get_conversation_recording_if_exists from utils.app_integrations import trigger_external_integrations @@ -495,6 +496,7 @@ def set_assignee_conversation_segment( def assign_segments_bulk( conversation_id: str, data: BulkAssignSegmentsRequest, + background_tasks: BackgroundTasks, uid: str = Depends(auth.get_current_user_uid), ): conversation = _get_valid_conversation_by_id(uid, conversation_id) @@ -521,6 +523,17 @@ def assign_segments_bulk( conversations_db.update_conversation_segments( uid, conversation_id, [segment.dict() for segment in conversation.transcript_segments] ) + + # Trigger speaker sample extraction when assigning to a person + if data.assign_type == 'person_id' and value: + background_tasks.add_task( + extract_speaker_samples, + uid=uid, + person_id=value, + conversation_id=conversation_id, + segment_ids=data.segment_ids, + ) + return conversation diff --git a/backend/routers/pusher.py b/backend/routers/pusher.py index 4f8c0781b1..ea2d2c6372 100644 --- a/backend/routers/pusher.py +++ b/backend/routers/pusher.py @@ -3,6 +3,7 @@ import json import time from datetime import datetime, timezone +from typing import List from fastapi import APIRouter from fastapi.websockets import WebSocketDisconnect, WebSocket @@ -13,7 +14,11 @@ from database.redis_db import get_cached_user_geolocation from models.conversation import Conversation, ConversationStatus, Geolocation from utils.apps import is_audio_bytes_app_enabled -from utils.app_integrations import trigger_realtime_integrations, trigger_realtime_audio_bytes, trigger_external_integrations +from utils.app_integrations import ( + trigger_realtime_integrations, + trigger_realtime_audio_bytes, + trigger_external_integrations, +) from utils.conversations.location import get_google_maps_location from utils.conversations.process_conversation import process_conversation from utils.webhooks import ( @@ -22,9 +27,19 @@ get_audio_bytes_webhook_seconds, ) from utils.other.storage import upload_audio_chunk +from utils.speaker_identification import extract_speaker_samples router = APIRouter() +# Constants for speaker sample extraction +SPEAKER_SAMPLE_PROCESS_INTERVAL = 15.0 +SPEAKER_SAMPLE_MIN_AGE = 120.0 + +# Constants for private cloud sync +PRIVATE_CLOUD_SYNC_PROCESS_INTERVAL = 1.0 +PRIVATE_CLOUD_CHUNK_DURATION = 5.0 +PRIVATE_CLOUD_SYNC_MAX_RETRIES = 3 + async def _process_conversation_task(uid: str, conversation_id: str, language: str, websocket: WebSocket): """Process a conversation and send result back to _listen via websocket.""" @@ -32,10 +47,7 @@ async def _process_conversation_task(uid: str, conversation_id: str, language: s conversation_data = conversations_db.get_conversation(uid, conversation_id) if not conversation_data: # Send error response - response = { - "conversation_id": conversation_id, - "error": "conversation_not_found" - } + response = {"conversation_id": conversation_id, "error": "conversation_not_found"} data = bytearray() data.extend(struct.pack("I", 201)) data.extend(bytes(json.dumps(response), "utf-8")) @@ -43,7 +55,7 @@ async def _process_conversation_task(uid: str, conversation_id: str, language: s return conversation = Conversation(**conversation_data) - + if conversation.status != ConversationStatus.processing: conversations_db.update_conversation_status(uid, conversation.id, ConversationStatus.processing) conversation.status = ConversationStatus.processing @@ -56,12 +68,8 @@ async def _process_conversation_task(uid: str, conversation_id: str, language: s conversation.geolocation = get_google_maps_location(geolocation.latitude, geolocation.longitude) # Run blocking operations in thread pool to avoid blocking event loop - conversation = await asyncio.to_thread( - process_conversation, uid, language, conversation - ) - messages = await asyncio.to_thread( - trigger_external_integrations, uid, conversation - ) + conversation = await asyncio.to_thread(process_conversation, uid, language, conversation) + messages = await asyncio.to_thread(trigger_external_integrations, uid, conversation) except Exception as e: print(f"Error processing conversation: {e}", uid, conversation_id) conversations_db.set_conversation_as_discarded(uid, conversation.id) @@ -69,21 +77,15 @@ async def _process_conversation_task(uid: str, conversation_id: str, language: s messages = [] # Send success response back (minimal - transcribe will fetch from DB) - response = { - "conversation_id": conversation_id, - "success": True - } + response = {"conversation_id": conversation_id, "success": True} data = bytearray() data.extend(struct.pack("I", 201)) data.extend(bytes(json.dumps(response), "utf-8")) await websocket.send_bytes(data) - + except Exception as e: print(f"Error in _process_conversation_task: {e}", uid, conversation_id) - response = { - "conversation_id": conversation_id, - "error": str(e) - } + response = {"conversation_id": conversation_id, "error": str(e)} data = bytearray() data.extend(struct.pack("I", 201)) data.extend(bytes(json.dumps(response), "utf-8")) @@ -117,15 +119,111 @@ async def _websocket_util_trigger( audio_bytes_trigger_delay_seconds = 4 has_audio_apps_enabled = is_audio_bytes_app_enabled(uid) private_cloud_sync_enabled = users_db.get_user_private_cloud_sync_enabled(uid) - private_cloud_sync_delay_seconds = 5 - async def save_audio_chunk(chunk_data: bytes, uid: str, conversation_id: str, timestamp: float): - upload_audio_chunk(chunk_data, uid, conversation_id, timestamp) + # Queue for pending speaker sample extraction requests + speaker_sample_queue: List[dict] = [] + + # Queue for pending private cloud sync chunks + private_cloud_queue: List[dict] = [] + + async def process_private_cloud_queue(): + """Background task that processes private cloud sync uploads with retry logic.""" + nonlocal websocket_active, private_cloud_queue + + while websocket_active or len(private_cloud_queue) > 0: + await asyncio.sleep(PRIVATE_CLOUD_SYNC_PROCESS_INTERVAL) + + if not private_cloud_queue: + continue + + # Process all pending chunks + chunks_to_process = private_cloud_queue.copy() + private_cloud_queue = [] + + successful_conversation_ids = set() # Track conversations with successful uploads + + for chunk_info in chunks_to_process: + chunk_data = chunk_info['data'] + conv_id = chunk_info['conversation_id'] + timestamp = chunk_info['timestamp'] + retries = chunk_info.get('retries', 0) + + try: + await asyncio.to_thread(upload_audio_chunk, chunk_data, uid, conv_id, timestamp) + successful_conversation_ids.add(conv_id) + except Exception as e: + if retries < PRIVATE_CLOUD_SYNC_MAX_RETRIES: + # Re-queue with incremented retry count + chunk_info['retries'] = retries + 1 + private_cloud_queue.append(chunk_info) + print(f"Private cloud upload failed (retry {retries + 1}): {e}", uid, conv_id) + else: + print( + f"Private cloud upload failed after {PRIVATE_CLOUD_SYNC_MAX_RETRIES} retries, dropping chunk: {e}", + uid, + conv_id, + ) + + # Update audio_files for conversations with successful uploads + for conv_id in successful_conversation_ids: + try: + audio_files = await asyncio.to_thread(conversations_db.create_audio_files_from_chunks, uid, conv_id) + if audio_files: + await asyncio.to_thread( + conversations_db.update_conversation, + uid, + conv_id, + {'audio_files': [af.dict() for af in audio_files]}, + ) + except Exception as e: + print(f"Error updating audio files: {e}", uid, conv_id) + + async def process_speaker_sample_queue(): + """Background task that processes speaker sample extraction requests.""" + nonlocal websocket_active, speaker_sample_queue + + while websocket_active or len(speaker_sample_queue) > 0: + await asyncio.sleep(SPEAKER_SAMPLE_PROCESS_INTERVAL) + + if not speaker_sample_queue: + continue + + current_time = time.time() + + # Separate ready and pending requests + ready_requests = [] + pending_requests = [] + + for request in speaker_sample_queue: + if current_time - request['queued_at'] >= SPEAKER_SAMPLE_MIN_AGE: + ready_requests.append(request) + else: + pending_requests.append(request) + + # Keep pending requests in queue + speaker_sample_queue = pending_requests + + # Process ready requests (fire and forget) + for request in ready_requests: + person_id = request['person_id'] + conv_id = request['conversation_id'] + segment_ids = request['segment_ids'] + + try: + await extract_speaker_samples( + uid=uid, + person_id=person_id, + conversation_id=conv_id, + segment_ids=segment_ids, + sample_rate=sample_rate, + ) + except Exception as e: + print(f"Error extracting speaker samples: {e}", uid, conv_id) - # task async def receive_tasks(): nonlocal websocket_active nonlocal websocket_close_code + nonlocal speaker_sample_queue audiobuffer = bytearray() trigger_audiobuffer = bytearray() @@ -168,24 +266,49 @@ async def receive_tasks(): ) continue + # Speaker sample extraction request - queue for background processing + if header_type == 105: + res = json.loads(bytes(data[4:]).decode("utf-8")) + person_id = res.get('person_id') + conv_id = res.get('conversation_id') + segment_ids = res.get('segment_ids', []) + if person_id and conv_id and segment_ids: + print(f"Queued speaker sample request: person={person_id}, {len(segment_ids)} segments", uid) + speaker_sample_queue.append( + { + 'person_id': person_id, + 'conversation_id': conv_id, + 'segment_ids': segment_ids, + 'queued_at': time.time(), + } + ) + continue + # Audio bytes if header_type == 101: - audiobuffer.extend(data[4:]) - trigger_audiobuffer.extend(data[4:]) + # Parse: header(4) | timestamp(8 bytes double) | audio_data + buffer_start_timestamp = struct.unpack("d", data[4:12])[0] + audio_data = data[12:] + + audiobuffer.extend(audio_data) + trigger_audiobuffer.extend(audio_data) - # Private cloud sync + # Private cloud sync - queue chunks for background processing if private_cloud_sync_enabled and current_conversation_id: if private_cloud_chunk_start_time is None: - private_cloud_chunk_start_time = time.time() - - private_cloud_sync_buffer.extend(data[4:]) - # Save chunk every 5 seconds (sample_rate * 2 bytes per sample * 5 seconds) - if len(private_cloud_sync_buffer) >= sample_rate * 2 * private_cloud_sync_delay_seconds: - chunk_data = bytes(private_cloud_sync_buffer) - timestamp = private_cloud_chunk_start_time - conv_id = current_conversation_id - asyncio.run_coroutine_threadsafe( - save_audio_chunk(chunk_data, uid, conv_id, timestamp), loop + # Use timestamp from first buffer of this 5-second chunk + private_cloud_chunk_start_time = buffer_start_timestamp + + private_cloud_sync_buffer.extend(audio_data) + # Queue chunk every 5 seconds (sample_rate * 2 bytes per sample * 5 seconds) + if len(private_cloud_sync_buffer) >= sample_rate * 2 * PRIVATE_CLOUD_CHUNK_DURATION: + private_cloud_queue.append( + { + 'data': bytes(private_cloud_sync_buffer), + 'conversation_id': current_conversation_id, + 'timestamp': private_cloud_chunk_start_time, + 'retries': 0, + } ) private_cloud_sync_buffer = bytearray() private_cloud_chunk_start_time = None @@ -214,11 +337,24 @@ async def receive_tasks(): print(f'Could not process audio: error {e}') websocket_close_code = 1011 finally: + # Flush any remaining private cloud sync buffer before shutdown + if private_cloud_sync_enabled and current_conversation_id and len(private_cloud_sync_buffer) > 0: + private_cloud_queue.append( + { + 'data': bytes(private_cloud_sync_buffer), + 'conversation_id': current_conversation_id, + 'timestamp': private_cloud_chunk_start_time or time.time(), + 'retries': 0, + } + ) + print(f"Flushed final private cloud buffer: {len(private_cloud_sync_buffer)} bytes", uid) websocket_active = False try: receive_task = asyncio.create_task(receive_tasks()) - await asyncio.gather(receive_task) + speaker_sample_task = asyncio.create_task(process_speaker_sample_queue()) + private_cloud_task = asyncio.create_task(process_private_cloud_queue()) + await asyncio.gather(receive_task, speaker_sample_task, private_cloud_task) except Exception as e: print(f"Error during WebSocket operation: {e}") diff --git a/backend/routers/sync.py b/backend/routers/sync.py index 57b153751f..d0e684333c 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -28,6 +28,9 @@ get_or_create_merged_audio, get_merged_audio_signed_url, ) + +# Audio constants +AUDIO_SAMPLE_RATE = 16000 from utils import encryption from utils.stt.pre_recorded import deepgram_prerecorded, postprocess_words from utils.stt.vad import vad_is_empty @@ -102,7 +105,7 @@ def parse_range_header(range_header: str, file_size: int) -> tuple[int, int] | N # ********************************************** -def _precache_audio_file(uid: str, conversation_id: str, audio_file: dict): +def _precache_audio_file(uid: str, conversation_id: str, audio_file: dict, fill_gaps: bool = True): """Pre-cache a single audio file.""" try: audio_file_id = audio_file.get('id') @@ -116,6 +119,8 @@ def _precache_audio_file(uid: str, conversation_id: str, audio_file: dict): audio_file_id=audio_file_id, timestamps=timestamps, pcm_to_wav_func=pcm_to_wav, + fill_gaps=fill_gaps, + sample_rate=AUDIO_SAMPLE_RATE, ) print(f"Pre-cached audio file: {audio_file_id}") except Exception as e: @@ -310,11 +315,15 @@ def download_audio_file_endpoint( audio_file_id=audio_file_id, timestamps=audio_file['chunk_timestamps'], pcm_to_wav_func=pcm_to_wav, + fill_gaps=True, + sample_rate=AUDIO_SAMPLE_RATE, ) content_type = "audio/wav" extension = "wav" else: - audio_data = download_audio_chunks_and_merge(uid, conversation_id, audio_file['chunk_timestamps']) + audio_data = download_audio_chunks_and_merge( + uid, conversation_id, audio_file['chunk_timestamps'], fill_gaps=True, sample_rate=AUDIO_SAMPLE_RATE + ) content_type = "application/octet-stream" extension = "pcm" except FileNotFoundError: diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 6c44820488..07570c12ea 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -5,11 +5,13 @@ import struct import time import uuid +import wave from datetime import datetime, timedelta, timezone from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Callable import av +import numpy as np import opuslib # type: ignore import webrtcvad # type: ignore @@ -81,12 +83,79 @@ from utils.webhooks import get_audio_bytes_webhook_seconds from utils.onboarding import OnboardingHandler +from utils.stt.speaker_embedding import ( + extract_embedding_from_bytes, + compare_embeddings, + SPEAKER_MATCH_THRESHOLD, +) + + router = APIRouter() PUSHER_ENABLED = bool(os.getenv('HOSTED_PUSHER_API_URL')) +class AudioRingBuffer: + """Circular buffer storing last N seconds of PCM16 mono audio with timestamp tracking.""" + + def __init__(self, duration_seconds: float, sample_rate: int): + self.sample_rate = sample_rate + self.bytes_per_second = sample_rate * 2 # PCM16 mono + self.capacity = int(duration_seconds * self.bytes_per_second) + self.buffer = bytearray(self.capacity) + self.write_pos = 0 + self.total_bytes_written = 0 + self.last_write_timestamp: Optional[float] = None + + def write(self, data: bytes, timestamp: float): + """Append audio data with timestamp.""" + for byte in data: + self.buffer[self.write_pos] = byte + self.write_pos = (self.write_pos + 1) % self.capacity + self.total_bytes_written += len(data) + self.last_write_timestamp = timestamp + + def get_time_range(self) -> Optional[Tuple[float, float]]: + """Return (start_ts, end_ts) of audio currently in buffer.""" + if self.last_write_timestamp is None: + return None + bytes_in_buffer = min(self.total_bytes_written, self.capacity) + buffer_duration = bytes_in_buffer / self.bytes_per_second + return (self.last_write_timestamp - buffer_duration, self.last_write_timestamp) + + def extract(self, start_ts: float, end_ts: float) -> Optional[bytes]: + """Extract audio for absolute timestamp range.""" + time_range = self.get_time_range() + if time_range is None: + return None + + buffer_start_ts, buffer_end_ts = time_range + actual_start = max(start_ts, buffer_start_ts) + actual_end = min(end_ts, buffer_end_ts) + + if actual_start >= actual_end: + return None + + bytes_in_buffer = min(self.total_bytes_written, self.capacity) + buffer_logical_start = (self.write_pos - bytes_in_buffer) % self.capacity + + start_offset = int((actual_start - buffer_start_ts) * self.bytes_per_second) + end_offset = int((actual_end - buffer_start_ts) * self.bytes_per_second) + + # Ensure even number of bytes (PCM16) + length = ((end_offset - start_offset) // 2) * 2 + if length <= 0: + return None + + result = bytearray(length) + for i in range(length): + pos = (buffer_logical_start + start_offset + i) % self.capacity + result[i] = self.buffer[pos] + + return bytes(result) + + class CustomSttMode(str, Enum): disabled = "disabled" enabled = "enabled" @@ -253,10 +322,21 @@ async def _listen( # Initialize segment buffers early (before onboarding handler needs them) realtime_segment_buffers = [] realtime_photo_buffers: list[ConversationPhoto] = [] - + + # === Speaker Identification State === + RING_BUFFER_DURATION = 60.0 # seconds + SPEAKER_ID_MIN_AUDIO = 2.0 + SPEAKER_ID_TARGET_AUDIO = 4.0 + + audio_ring_buffer: Optional[AudioRingBuffer] = None + speaker_id_segment_queue: asyncio.Queue[dict] = asyncio.Queue(maxsize=100) + person_embeddings_cache: Dict[str, dict] = {} # person_id -> {embedding, name} + speaker_id_enabled = False # Will be set after private_cloud_sync_enabled is known + # Onboarding handler onboarding_handler: Optional[OnboardingHandler] = None if onboarding_mode: + async def send_onboarding_event(event: dict): if websocket_active and websocket.client_state == WebSocketState.CONNECTED: try: @@ -271,7 +351,7 @@ def onboarding_stream_transcript(segments: List[dict]): onboarding_handler = OnboardingHandler(uid, send_onboarding_event, onboarding_stream_transcript) asyncio.create_task(onboarding_handler.send_current_question()) - + locked_conversation_ids: Set[str] = set() speaker_to_person_map: Dict[int, Tuple[str, str]] = {} segment_person_assignment_map: Dict[str, str] = {} @@ -281,8 +361,6 @@ def onboarding_stream_transcript(segments: List[dict]): last_usage_record_timestamp: Optional[float] = None words_transcribed_since_last_record: int = 0 last_transcript_time: Optional[float] = None - seconds_to_trim = None - seconds_to_add = None current_conversation_id = None async def _record_usage_periodically(): @@ -417,6 +495,11 @@ async def send_heartbeat(): # Create or get conversation ID early for audio chunk storage private_cloud_sync_enabled = user_db.get_user_private_cloud_sync_enabled(uid) + # Enable speaker identification if not custom STT and private cloud sync is enabled + speaker_id_enabled = not use_custom_stt and private_cloud_sync_enabled + if speaker_id_enabled: + audio_ring_buffer = AudioRingBuffer(RING_BUFFER_DURATION, sample_rate) + # Conversation timeout (to process the conversation after x seconds of silence) # Max: 4h, min 2m conversation_creation_timeout = conversation_timeout @@ -492,8 +575,6 @@ def send_last_conversation(): # Create new stub conversation for next batch async def _create_new_in_progress_conversation(): - nonlocal seconds_to_trim - nonlocal seconds_to_add nonlocal current_conversation_id conversation_source = ConversationSource.omi @@ -561,8 +642,6 @@ async def _create_new_in_progress_conversation(): redis_db.set_conversation_meeting_id(new_conversation_id, detected_meeting_id) current_conversation_id = new_conversation_id - seconds_to_trim = None - seconds_to_add = None print(f"Created new stub conversation: {new_conversation_id}", uid, session_id) @@ -583,7 +662,6 @@ async def _process_conversation(conversation_id: str): # Process existing conversations async def _prepare_in_progess_conversations(): - nonlocal seconds_to_add nonlocal current_conversation_id if existing_conversation := retrieve_in_progress_conversation(uid): @@ -600,14 +678,8 @@ async def _prepare_in_progess_conversations(): # Continue with the existing conversation current_conversation_id = existing_conversation['id'] - started_at = datetime.fromisoformat(existing_conversation['started_at'].isoformat()) - seconds_to_add = ( - (datetime.now(timezone.utc) - started_at).total_seconds() - if existing_conversation['transcript_segments'] - else None - ) print( - f"Resuming conversation {current_conversation_id} with {(seconds_to_add if seconds_to_add else 0):.1f}s offset. Will timeout in {conversation_creation_timeout - seconds_since_last_segment:.1f}s", + f"Resuming conversation {current_conversation_id}. Will timeout in {conversation_creation_timeout - seconds_since_last_segment:.1f}s", uid, session_id, ) @@ -634,24 +706,14 @@ def _process_speaker_assigned_segments(transcript_segments: List[TranscriptSegme segment.person_id = person_id def _update_in_progress_conversation( - conversation_id: str, segments: List[TranscriptSegment], photos: List[ConversationPhoto], finished_at: datetime + conversation: Conversation, + segments: List[TranscriptSegment], + photos: List[ConversationPhoto], + finished_at: datetime, ): - """Update the current in-progress conversation with new segments/photos.""" - conversation_data = conversations_db.get_conversation(uid, conversation_id) - if not conversation_data: - print(f"Warning: conversation {conversation_id} not found", uid, session_id) - return None, (0, 0) - - conversation = Conversation(**conversation_data) starts, ends = (0, 0) if segments: - # If conversation has no segments yet but we're adding some, update started_at - if not conversation.transcript_segments: - started_at = finished_at - timedelta(seconds=max(0, segments[-1].end)) - conversations_db.update_conversation(uid, conversation.id, {'started_at': started_at}) - conversation.started_at = started_at - conversation.transcript_segments, (starts, ends) = TranscriptSegment.combine_segments( conversation.transcript_segments, segments ) @@ -928,16 +990,19 @@ async def transcript_consume(): # Audio bytes audio_buffers = bytearray() + audio_buffer_last_received: float = None # Track when last audio was received audio_bytes_enabled = ( bool(get_audio_bytes_webhook_seconds(uid)) or is_audio_bytes_app_enabled(uid) or private_cloud_sync_enabled ) - def audio_bytes_send(audio_bytes): - nonlocal audio_buffers + def audio_bytes_send(audio_bytes: bytes, received_at: float): + nonlocal audio_buffers, audio_buffer_last_received audio_buffers.extend(audio_bytes) + audio_buffer_last_received = received_at async def _audio_bytes_flush(auto_reconnect: bool = True): nonlocal audio_buffers + nonlocal audio_buffer_last_received nonlocal pusher_ws nonlocal pusher_connected nonlocal last_synced_conversation_id @@ -964,9 +1029,16 @@ async def _audio_bytes_flush(auto_reconnect: bool = True): # Send audio bytes if pusher_connected and pusher_ws and len(audio_buffers) > 0: try: - # 101|data + # Calculate buffer start time: + # buffer_start = last_received_time - buffer_duration + # buffer_duration = buffer_length_bytes / (sample_rate * 2 bytes per sample) + buffer_duration_seconds = len(audio_buffers) / (sample_rate * 2) + buffer_start_time = (audio_buffer_last_received or time.time()) - buffer_duration_seconds + + # 101|timestamp(8 bytes double)|audio_data data = bytearray() data.extend(struct.pack("I", 101)) + data.extend(struct.pack("d", buffer_start_time)) data.extend(audio_buffers.copy()) audio_buffers = bytearray() # reset await pusher_ws.send(data) @@ -1081,6 +1153,39 @@ async def close(code: int = 1000): if pusher_ws: await pusher_ws.close(code) + async def send_speaker_sample_request( + person_id: str, + conv_id: str, + segment_ids: List[str], + ): + """Send speaker sample extraction request to pusher with segment IDs.""" + nonlocal pusher_ws, pusher_connected + if not pusher_connected or not pusher_ws: + return + try: + data = bytearray() + data.extend(struct.pack("I", 105)) + data.extend( + bytes( + json.dumps( + { + "person_id": person_id, + "conversation_id": conv_id, + "segment_ids": segment_ids, + } + ), + "utf-8", + ) + ) + await pusher_ws.send(data) + print( + f"Sent speaker sample request to pusher: person={person_id}, {len(segment_ids)} segments", + uid, + session_id, + ) + except Exception as e: + print(f"Failed to send speaker sample request: {e}", uid, session_id) + def is_connected(): return pusher_connected @@ -1094,6 +1199,7 @@ def is_connected(): request_conversation_processing, pusher_receive, is_connected, + send_speaker_sample_request, ) transcript_send = None @@ -1105,6 +1211,7 @@ def is_connected(): request_conversation_processing = None pusher_receive = None pusher_is_connected = None + send_speaker_sample_request = None # Transcripts # @@ -1217,8 +1324,163 @@ async def conversation_lifecycle_manager(): await _process_conversation(current_conversation_id) await _create_new_in_progress_conversation() + async def speaker_identification_task(): + """Consume segment queue, accumulate per speaker, trigger match when ready.""" + nonlocal websocket_active, speaker_to_person_map + nonlocal person_embeddings_cache, audio_ring_buffer + + if not speaker_id_enabled: + return + + # Load person embeddings + try: + people = user_db.get_people(uid) + for person in people: + emb = person.get('speaker_embedding') + if emb: + person_embeddings_cache[person['id']] = { + 'embedding': np.array(emb, dtype=np.float32).reshape(1, -1), + 'name': person['name'], + } + print(f"Speaker ID: loaded {len(person_embeddings_cache)} person embeddings", uid, session_id) + except Exception as e: + print(f"Speaker ID: failed to load embeddings: {e}", uid, session_id) + return + + if not person_embeddings_cache: + print("Speaker ID: no stored embeddings, task disabled", uid, session_id) + return + + # Consume loop + while websocket_active: + try: + seg = await asyncio.wait_for(speaker_id_segment_queue.get(), timeout=2.0) + except asyncio.TimeoutError: + continue + + speaker_id = seg['speaker_id'] + + # Skip if already resolved + if speaker_id in speaker_to_person_map: + continue + + duration = seg['duration'] + if duration >= SPEAKER_ID_MIN_AUDIO: + asyncio.create_task(_match_speaker_embedding(speaker_id, seg)) + + print("Speaker ID task ended", uid, session_id) + + async def _match_speaker_embedding(speaker_id: int, segment: dict): + """Extract audio from ring buffer and match against stored embeddings.""" + nonlocal speaker_to_person_map, segment_person_assignment_map, audio_ring_buffer + + try: + seg_start = segment['abs_start'] + seg_end = segment['abs_end'] + duration = segment['duration'] + + if duration < SPEAKER_ID_MIN_AUDIO: + print(f"Speaker ID: segment too short ({duration:.1f}s)", uid, session_id) + return + + # Get buffer time range + time_range = audio_ring_buffer.get_time_range() + if time_range is None: + print(f"Speaker ID: buffer empty", uid, session_id) + return + + buffer_start_ts, buffer_end_ts = time_range + + # Calculate extraction range - stay within segment bounds, max 10 seconds from center + MAX_EXTRACT_DURATION = 10.0 + + if duration <= MAX_EXTRACT_DURATION: + # Segment fits within max duration, use full segment + extract_start = seg_start + extract_end = seg_end + else: + # Segment is longer than max, extract 10s from center + center = (seg_start + seg_end) / 2 + half_duration = MAX_EXTRACT_DURATION / 2 + extract_start = center - half_duration + extract_end = center + half_duration + + # Clamp to buffer availability + extract_start = max(buffer_start_ts, extract_start) + extract_end = min(buffer_end_ts, extract_end) + + if extract_end <= extract_start: + print(f"Speaker ID: no audio to extract", uid, session_id) + return + + # Extract only the needed bytes directly from ring buffer + pcm_data = audio_ring_buffer.extract(extract_start, extract_end) + if not pcm_data: + print(f"Speaker ID: failed to extract audio", uid, session_id) + return + + # Convert PCM to numpy for WAV encoding + samples = np.frombuffer(pcm_data, dtype=np.int16) + + # Convert PCM to WAV using av + output_buffer = io.BytesIO() + output_container = av.open(output_buffer, mode='w', format='wav') + output_stream = output_container.add_stream('pcm_s16le', rate=sample_rate) + output_stream.layout = 'mono' + + frame = av.AudioFrame.from_ndarray(samples.reshape(1, -1), format='s16', layout='mono') + frame.rate = sample_rate + + for packet in output_stream.encode(frame): + output_container.mux(packet) + for packet in output_stream.encode(): + output_container.mux(packet) + + output_container.close() + wav_bytes = output_buffer.getvalue() + + # Extract embedding (API call) + query_embedding = await asyncio.to_thread(extract_embedding_from_bytes, wav_bytes, "query.wav") + + # Find best match + best_match = None + best_distance = float('inf') + + for person_id, data in person_embeddings_cache.items(): + distance = compare_embeddings(query_embedding, data['embedding']) + if distance < best_distance: + best_distance = distance + best_match = (person_id, data['name']) + + if best_match and best_distance < SPEAKER_MATCH_THRESHOLD: + person_id, person_name = best_match + print( + f"Speaker ID: speaker {speaker_id} -> {person_name} (distance={best_distance:.3f})", uid, session_id + ) + + # Store for session consistency + speaker_to_person_map[speaker_id] = (person_id, person_name) + + # Auto-assign processed segment + segment_person_assignment_map[segment['id']] = person_id + + # Notify client + _send_message_event( + SpeakerLabelSuggestionEvent( + speaker_id=speaker_id, + person_id=person_id, + person_name=person_name, + segment_id=segment['id'], + ) + ) + else: + print(f"Speaker ID: speaker {speaker_id} no match (best={best_distance:.3f})", uid, session_id) + + except Exception as e: + print(f"Speaker ID: match error for speaker {speaker_id}: {e}", uid, session_id) + async def stream_transcript_process(): - nonlocal websocket_active, realtime_segment_buffers, realtime_photo_buffers, websocket, seconds_to_trim + nonlocal websocket_active, realtime_segment_buffers, realtime_photo_buffers, websocket nonlocal current_conversation_id, translation_enabled, speaker_to_person_map, suggested_segments, words_transcribed_since_last_record, last_transcript_time while websocket_active or len(realtime_segment_buffers) > 0 or len(realtime_photo_buffers) > 0: @@ -1235,22 +1497,43 @@ async def stream_transcript_process(): finished_at = datetime.now(timezone.utc) + # Get conversation + conversation_data = conversations_db.get_conversation(uid, current_conversation_id) + if not conversation_data: + print( + f"Warning: conversation {current_conversation_id} not found during segment processing", + uid, + session_id, + ) + continue + + # Guard first_audio_byte_timestamp must be set + if not first_audio_byte_timestamp: + print(f"Warning: first_audio_byte_timestamp not set, skipping segment processing", uid, session_id) + continue + transcript_segments = [] if segments_to_process: last_transcript_time = time.time() - if seconds_to_trim is None: - seconds_to_trim = segments_to_process[0]["start"] - - if seconds_to_add: - for i, segment in enumerate(segments_to_process): - segment["start"] += seconds_to_add - segment["end"] += seconds_to_add - segments_to_process[i] = segment - elif seconds_to_trim: - for i, segment in enumerate(segments_to_process): - segment["start"] -= seconds_to_trim - segment["end"] -= seconds_to_trim - segments_to_process[i] = segment + + # If conversation has no segments yet, set started_at based on when first speech occurred + if not conversation_data.get('transcript_segments'): + first_speech_timestamp = first_audio_byte_timestamp + segments_to_process[0]["start"] + new_started_at = datetime.fromtimestamp(first_speech_timestamp, tz=timezone.utc) + conversations_db.update_conversation(uid, current_conversation_id, {'started_at': new_started_at}) + conversation_data['started_at'] = new_started_at + + # Calculate unified time offset: audio stream start relative to conversation start + conversation_started_at = conversation_data['started_at'] + if isinstance(conversation_started_at, str): + conversation_started_at = datetime.fromisoformat(conversation_started_at) + time_offset = first_audio_byte_timestamp - conversation_started_at.timestamp() + + # Apply offset to all segments + for i, segment in enumerate(segments_to_process): + segment["start"] += time_offset + segment["end"] += time_offset + segments_to_process[i] = segment newly_processed_segments = [] for s in segments_to_process: @@ -1267,13 +1550,9 @@ async def stream_transcript_process(): current_session_segments[seg.id] = seg.speech_profile_processed transcript_segments, _ = TranscriptSegment.combine_segments([], newly_processed_segments) - if not current_conversation_id: - print("Warning: No current conversation ID", uid, session_id) - continue - - result = _update_in_progress_conversation( - current_conversation_id, transcript_segments, photos_to_process, finished_at - ) + # Update transcript segments + conversation = Conversation(**conversation_data) + result = _update_in_progress_conversation(conversation, transcript_segments, photos_to_process, finished_at) if not result or not result[0]: continue conversation, (starts, ends) = result @@ -1297,8 +1576,8 @@ async def stream_transcript_process(): if segment.person_id or segment.is_user or segment.id in suggested_segments: continue + # Session consistency speaker identification if speech_profile_complete.is_set(): - # Session consistency if segment.speaker_id in speaker_to_person_map: person_id, person_name = speaker_to_person_map[segment.speaker_id] _send_message_event( @@ -1312,6 +1591,31 @@ async def stream_transcript_process(): suggested_segments.add(segment.id) continue + # Embeding id speaker indentification + if speaker_id_enabled and person_embeddings_cache: + started_at_ts = conversation.started_at.timestamp() + if ( + segment.speaker_id is not None + and not segment.person_id + and not segment.is_user + and segment.speaker_id not in speaker_to_person_map + ): + try: + speaker_id_segment_queue.put_nowait( + { + 'id': segment.id, + 'speaker_id': segment.speaker_id, + 'abs_start': first_audio_byte_timestamp + + segment.start + - time_offset, # raw start/end + 'abs_end': first_audio_byte_timestamp + segment.end - time_offset, + 'duration': segment.end - segment.start, + 'text': segment.text, # TODO: remove + } + ) + except asyncio.QueueFull: + pass # Drop if queue is full + # Text-based detection detected_name = detect_speaker_from_text(segment.text) if detected_name: @@ -1389,7 +1693,7 @@ async def handle_image_chunk( async def receive_data(dg_socket, dg_profile_socket, soniox_sock, soniox_profile_sock, speechmatics_sock): nonlocal websocket_active, websocket_close_code, last_audio_received_time, last_activity_time, current_conversation_id nonlocal realtime_photo_buffers, speaker_to_person_map, first_audio_byte_timestamp, last_usage_record_timestamp - nonlocal soniox_profile_socket, deepgram_profile_socket + nonlocal soniox_profile_socket, deepgram_profile_socket, audio_ring_buffer timer_start = time.time() last_audio_received_time = timer_start @@ -1515,12 +1819,16 @@ async def close_soniox_profile(): ) continue + # Feed ring buffer for speaker identification + if audio_ring_buffer is not None: + audio_ring_buffer.write(data, last_audio_received_time) + if not use_custom_stt: stt_audio_buffer.extend(data) await flush_stt_buffer() if audio_bytes_send is not None: - audio_bytes_send(data) + audio_bytes_send(data, last_audio_received_time) elif message.get("text") is not None: try: @@ -1562,6 +1870,23 @@ async def close_soniox_profile(): print( f"Speaker {speaker_id} assigned to {person_name} ({person_id})", uid, session_id ) + + # Forward to pusher for speech sample extraction (non-blocking) + # Only for real people (not 'user') and when private cloud sync is enabled + if ( + person_id + and person_id != 'user' + and private_cloud_sync_enabled + and send_speaker_sample_request is not None + and current_conversation_id + ): + asyncio.create_task( + send_speaker_sample_request( + person_id=person_id, + conv_id=current_conversation_id, + segment_ids=segment_ids, + ) + ) else: print( "Speaker assignment ignored: no segment_ids or no speech-profile-processed segments.", @@ -1602,6 +1927,7 @@ async def close_soniox_profile(): request_conversation_processing, pusher_receive, pusher_is_connected, + send_speaker_sample_request, ) = create_pusher_task_handler() # Pusher connection @@ -1629,6 +1955,7 @@ async def close_soniox_profile(): record_usage_task = asyncio.create_task(_record_usage_periodically()) lifecycle_manager_task = asyncio.create_task(conversation_lifecycle_manager()) pending_conversations_task = asyncio.create_task(process_pending_conversations(timed_out_conversation_id)) + speaker_id_task = asyncio.create_task(speaker_identification_task()) _send_message_event(MessageServiceStatusEvent(status="ready")) @@ -1639,6 +1966,7 @@ async def close_soniox_profile(): record_usage_task, lifecycle_manager_task, pending_conversations_task, + speaker_id_task, ] + pusher_tasks # Add speech profile task to run concurrently (sends profile audio in background) @@ -1700,6 +2028,7 @@ async def close_soniox_profile(): realtime_segment_buffers.clear() realtime_photo_buffers.clear() image_chunks.clear() + person_embeddings_cache.clear() except NameError as e: # Variables might not be defined if an error occurred early print(f"Cleanup error (safe to ignore): {e}", uid, session_id) diff --git a/backend/routers/users.py b/backend/routers/users.py index 2afe4c1678..bd2bf19bbd 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -53,8 +53,9 @@ from utils.other import endpoints as auth from utils.other.storage import ( delete_all_conversation_recordings, - get_user_person_speech_samples, + get_speech_sample_signed_urls, delete_user_person_speech_samples, + delete_user_person_speech_sample, ) from utils.webhooks import webhook_first_time_setup @@ -242,7 +243,9 @@ def get_single_person( if not person: raise HTTPException(status_code=404, detail="Person not found") if include_speech_samples: - person['speech_samples'] = get_user_person_speech_samples(uid, person['id']) + # Convert stored GCS paths to signed URLs + stored_paths = person.get('speech_samples', []) + person['speech_samples'] = get_speech_sample_signed_urls(stored_paths) return person @@ -251,13 +254,10 @@ def get_all_people(include_speech_samples: bool = True, uid: str = Depends(auth. print('get_all_people', include_speech_samples) people = get_people(uid) if include_speech_samples: - - def single(person): - person['speech_samples'] = get_user_person_speech_samples(uid, person['id']) - - threads = [threading.Thread(target=single, args=(person,)) for person in people] - [t.start() for t in threads] - [t.join() for t in threads] + # Convert stored GCS paths to signed URLs for each person + for person in people: + stored_paths = person.get('speech_samples', []) + person['speech_samples'] = get_speech_sample_signed_urls(stored_paths) return people @@ -278,6 +278,36 @@ def delete_person_endpoint(person_id: str, uid: str = Depends(auth.get_current_u return {'status': 'ok'} +@router.delete('/v1/users/people/{person_id}/speech-samples/{sample_index}', tags=['v1']) +def delete_person_speech_sample_endpoint( + person_id: str, + sample_index: int, + uid: str = Depends(auth.get_current_user_uid), +): + """Delete a specific speech sample for a person by index.""" + person = get_person(uid, person_id) + if not person: + raise HTTPException(status_code=404, detail="Person not found") + + speech_samples = person.get('speech_samples', []) + if sample_index < 0 or sample_index >= len(speech_samples): + raise HTTPException(status_code=404, detail="Sample not found") + + path_to_delete = speech_samples[sample_index] + + # Extract filename from path for GCS deletion + filename = path_to_delete.split('/')[-1] + + # Delete from GCS + delete_user_person_speech_sample(uid, person_id, filename) + + # Remove from Firestore + from database.users import remove_person_speech_sample + remove_person_speech_sample(uid, person_id, path_to_delete) + + return {'status': 'ok'} + + # ********************************************************** # ************* RANDOM JOAN SPECIFIC FEATURES ************** # ********************************************************** diff --git a/backend/utils/other/storage.py b/backend/utils/other/storage.py index 70c27c342d..8089b9a8fa 100644 --- a/backend/utils/other/storage.py +++ b/backend/utils/other/storage.py @@ -141,6 +141,31 @@ def delete_user_person_speech_samples(uid: str, person_id: str) -> None: blob.delete() +def upload_person_speech_sample_from_bytes( + audio_bytes: bytes, + uid: str, + person_id: str, + sample_rate: int = 16000, +) -> str: + """Upload PCM audio bytes as WAV speech sample. Returns GCS path.""" + import uuid as uuid_module + + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) # 16-bit audio + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_bytes) + + bucket = storage_client.bucket(speech_profiles_bucket) + filename = f"{uuid_module.uuid4()}.wav" + path = f'{uid}/people_profiles/{person_id}/{filename}' + blob = bucket.blob(path) + blob.upload_from_string(wav_buffer.getvalue(), content_type='audio/wav') + + return path + + def get_user_people_ids(uid: str) -> List[str]: bucket = storage_client.bucket(speech_profiles_bucket) blobs = bucket.list_blobs(prefix=f'{uid}/people_profiles/') @@ -161,6 +186,27 @@ def get_user_person_speech_samples(uid: str, person_id: str, download: bool = Fa return [_get_signed_url(blob, 60) for blob in blobs] +def get_speech_sample_signed_urls(paths: List[str]) -> List[str]: + """ + Generate signed URLs for speech samples given their GCS paths. + Uses the paths stored in Firestore instead of listing GCS blobs. + + Args: + paths: List of GCS paths (e.g., '{uid}/people_profiles/{person_id}/{filename}') + + Returns: + List of signed URLs + """ + if not paths: + return [] + bucket = storage_client.bucket(speech_profiles_bucket) + signed_urls = [] + for path in paths: + blob = bucket.blob(path) + signed_urls.append(_get_signed_url(blob, 60)) + return signed_urls + + # ******************************************** # ************* POST PROCESSING ************** # ******************************************** @@ -356,7 +402,13 @@ def delete_conversation_audio_files(uid: str, conversation_id: str) -> None: blob.delete() -def download_audio_chunks_and_merge(uid: str, conversation_id: str, timestamps: List[float]) -> bytes: +def download_audio_chunks_and_merge( + uid: str, + conversation_id: str, + timestamps: List[float], + fill_gaps: bool = True, + sample_rate: int = 16000, +) -> bytes: """ Download and merge audio chunks on-demand, handling mixed encryption states. Downloads chunks in parallel. @@ -366,6 +418,9 @@ def download_audio_chunks_and_merge(uid: str, conversation_id: str, timestamps: uid: User ID conversation_id: Conversation ID timestamps: List of chunk timestamps to merge + fill_gaps: If True, insert silence (zero bytes) between chunks to maintain + continuous time-aligned audio. Default True. + sample_rate: Audio sample rate in Hz (default 16000) Returns: Merged audio bytes (PCM16) @@ -416,9 +471,39 @@ def download_single_chunk(timestamp: float) -> tuple[float, bytes | None]: # Merge chunks merged_data = bytearray() - for timestamp in timestamps: - if timestamp in chunk_results: - merged_data.extend(chunk_results[timestamp]) + + if fill_gaps and timestamps and chunk_results: + # Sort timestamps to ensure proper ordering + sorted_timestamps = sorted(timestamps) + first_timestamp = sorted_timestamps[0] + current_time = first_timestamp # Track current audio end time in seconds + + for timestamp in sorted_timestamps: + if timestamp not in chunk_results: + continue + + pcm_data = chunk_results[timestamp] + + # Calculate gap from current position to this chunk's start + gap_seconds = timestamp - current_time + if gap_seconds > 0: + # Insert silence: 16-bit mono = 2 bytes per sample + gap_samples = int(gap_seconds * sample_rate) + silence_bytes = bytes(gap_samples * 2) # Zero bytes for silence + merged_data.extend(silence_bytes) + print(f"Filled {gap_seconds:.3f}s gap ({len(silence_bytes)} bytes) before chunk at {timestamp}") + + merged_data.extend(pcm_data) + + # Update current time based on chunk duration + # PCM16 mono: 2 bytes per sample + chunk_duration = len(pcm_data) / (sample_rate * 2) + current_time = timestamp + chunk_duration + else: + # Original behavior - just concatenate without gap filling + for timestamp in timestamps: + if timestamp in chunk_results: + merged_data.extend(chunk_results[timestamp]) if not merged_data: raise FileNotFoundError(f"No chunks found for conversation {conversation_id}") @@ -432,7 +517,13 @@ def get_cached_merged_audio_path(uid: str, conversation_id: str, audio_file_id: def get_or_create_merged_audio( - uid: str, conversation_id: str, audio_file_id: str, timestamps: List[float], pcm_to_wav_func + uid: str, + conversation_id: str, + audio_file_id: str, + timestamps: List[float], + pcm_to_wav_func, + fill_gaps: bool = True, + sample_rate: int = 16000, ) -> tuple[bytes, bool]: """ Get merged audio from cache or create it. @@ -444,6 +535,8 @@ def get_or_create_merged_audio( audio_file_id: Audio file ID timestamps: List of chunk timestamps pcm_to_wav_func: Function to convert PCM to WAV + fill_gaps: If True, insert silence between chunks to maintain time alignment. Default True. + sample_rate: Audio sample rate in Hz (default 16000) Returns: Tuple of (audio_data_bytes, was_cached) @@ -475,7 +568,9 @@ def get_or_create_merged_audio( print(f"Cache miss, merging audio for: {cache_path}") # Download and merge chunks - pcm_data = download_audio_chunks_and_merge(uid, conversation_id, timestamps) + pcm_data = download_audio_chunks_and_merge( + uid, conversation_id, timestamps, fill_gaps=fill_gaps, sample_rate=sample_rate + ) # Convert to WAV wav_data = pcm_to_wav_func(pcm_data) @@ -549,7 +644,9 @@ def _pcm_to_wav(pcm_data: bytes, sample_rate: int = 16000, channels: int = 1) -> return wav_buffer.getvalue() -def precache_conversation_audio(uid: str, conversation_id: str, audio_files: list) -> None: +def precache_conversation_audio( + uid: str, conversation_id: str, audio_files: list, fill_gaps: bool = True, sample_rate: int = 16000 +) -> None: """ Pre-cache all audio files for a conversation in a background thread. @@ -557,6 +654,8 @@ def precache_conversation_audio(uid: str, conversation_id: str, audio_files: lis uid: User ID conversation_id: Conversation ID audio_files: List of audio file dicts with 'id' and 'chunk_timestamps' + fill_gaps: If True, insert silence between chunks to maintain time alignment. Default True. + sample_rate: Audio sample rate in Hz (default 16000) """ if not audio_files: return @@ -575,6 +674,8 @@ def _cache_single(af): audio_file_id=audio_file_id, timestamps=timestamps, pcm_to_wav_func=_pcm_to_wav, + fill_gaps=fill_gaps, + sample_rate=sample_rate, ) except Exception as e: print(f"[PRECACHE] Error caching audio file {af.get('id')}: {e}") diff --git a/backend/utils/speaker_identification.py b/backend/utils/speaker_identification.py index d7ba43c677..faa4754ee7 100644 --- a/backend/utils/speaker_identification.py +++ b/backend/utils/speaker_identification.py @@ -1,5 +1,115 @@ +import asyncio +import io import re -from typing import Optional +import wave +from typing import List, Optional + +import av +import numpy as np + +from database import conversations as conversations_db +from database import users as users_db +from utils.other.storage import ( + download_audio_chunks_and_merge, + upload_person_speech_sample_from_bytes, +) +from utils.stt.speaker_embedding import extract_embedding_from_bytes + + +def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: + """ + Convert PCM16 mono audio to WAV format bytes. + + Args: + pcm_data: Raw PCM16 mono audio bytes + sample_rate: Audio sample rate in Hz + + Returns: + WAV format bytes + """ + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(pcm_data) + return wav_buffer.getvalue() + + +def _trim_pcm_audio(pcm_data: bytes, sample_rate: int, start_sec: float, end_sec: float) -> bytes: + """ + Trim PCM16 mono audio using av for sample-accurate cutting. + + Args: + pcm_data: Raw PCM16 mono audio bytes + sample_rate: Audio sample rate in Hz + start_sec: Start time in seconds (relative to pcm_data start) + end_sec: End time in seconds (relative to pcm_data start) + + Returns: + Trimmed PCM16 mono audio bytes + """ + # Create WAV container for av to read + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(pcm_data) + wav_buffer.seek(0) + + # Use av to extract trimmed audio with sample-accurate boundaries + trimmed_samples = [] + with av.open(wav_buffer, mode='r') as container: + stream = container.streams.audio[0] + + for frame in container.decode(stream): + if frame.pts is None: + continue + + frame_time = float(frame.pts * stream.time_base) + frame_duration = frame.samples / sample_rate + frame_end_time = frame_time + frame_duration + + # Skip frames entirely before our start + if frame_end_time <= start_sec: + continue + # Stop once we're past the end + if frame_time >= end_sec: + break + + # Convert frame to numpy array + arr = frame.to_ndarray() + # For mono pcm_s16le, arr shape is (1, samples) + if arr.ndim == 2: + arr = arr[0] + + # Calculate which samples from this frame to include + frame_start_sample = 0 + frame_end_sample = len(arr) + + if frame_time < start_sec: + # Trim beginning of frame + skip_samples = int((start_sec - frame_time) * sample_rate) + frame_start_sample = skip_samples + + if frame_end_time > end_sec: + # Trim end of frame + keep_duration = end_sec - max(frame_time, start_sec) + frame_end_sample = frame_start_sample + int(keep_duration * sample_rate) + + if frame_start_sample < frame_end_sample: + trimmed_samples.append(arr[frame_start_sample:frame_end_sample]) + + if not trimmed_samples: + return b'' + + return np.concatenate(trimmed_samples).astype(np.int16).tobytes() + + +# Constants for speaker sample extraction +SPEAKER_SAMPLE_MIN_SEGMENT_DURATION = 10.0 +SPEAKER_SAMPLE_WINDOW_HALF = SPEAKER_SAMPLE_MIN_SEGMENT_DURATION / 2 # Language-specific patterns for speaker identification from text # Each pattern should have a capture group for the name. @@ -123,3 +233,205 @@ def detect_speaker_from_text(text: str) -> Optional[str]: if name and len(name) >= 2: return name.capitalize() return None + + +async def extract_speaker_samples( + uid: str, + person_id: str, + conversation_id: str, + segment_ids: List[str], + sample_rate: int = 16000, +): + """ + Extract speech samples from segments and store as speaker profiles. + Fetches conversation from DB to get started_at and segment details. + Processes each segment one by one, stops when sample limit reached. + """ + try: + # Check current sample count once + sample_count = users_db.get_person_speech_samples_count(uid, person_id) + if sample_count >= 1: + print(f"Person {person_id} already has {sample_count} samples, skipping", uid, conversation_id) + return + + # Fetch conversation to get started_at and segment details + conversation = conversations_db.get_conversation(uid, conversation_id) + if not conversation: + print(f"Conversation {conversation_id} not found", uid) + return + + started_at = conversation.get('started_at') + if not started_at: + print(f"Conversation {conversation_id} has no started_at", uid) + return + + started_at_ts = started_at.timestamp() if hasattr(started_at, 'timestamp') else float(started_at) + + # Build segment lookup from conversation's transcript_segments + conv_segments = conversation.get('transcript_segments', []) + segment_map = {s.get('id'): s for s in conv_segments if s.get('id')} + + # Get chunks from audio_files instead of storage listing + audio_files = conversation.get('audio_files', []) + if not audio_files: + print(f"No audio files found for {conversation_id}, skipping speaker sample extraction", uid) + return + + # Collect all chunk timestamps from audio files + all_timestamps = [] + for af in audio_files: + timestamps = af.get('chunk_timestamps', []) + all_timestamps.extend(timestamps) + + if not all_timestamps: + print(f"No chunk timestamps found for {conversation_id}, skipping speaker sample extraction", uid) + return + + # Build chunks list in expected format + chunks = [{'timestamp': ts} for ts in sorted(set(all_timestamps))] + + samples_added = 0 + max_samples_to_add = 1 - sample_count + + # Build ordered list with index lookup for expansion + ordered_segments = [s for s in conv_segments if s.get('id')] + segment_index_map = {s.get('id'): i for i, s in enumerate(ordered_segments)} + + for seg_id in segment_ids: + if samples_added >= max_samples_to_add: + break + + seg = segment_map.get(seg_id) + if not seg: + print(f"Segment {seg_id} not found in conversation", uid, conversation_id) + continue + + segment_start = seg.get('start') + segment_end = seg.get('end') + if segment_start is None or segment_end is None: + continue + + seg_duration = segment_end - segment_start + speaker_id = seg.get('speaker_id') + + # If segment is too short, try expanding to adjacent segments with same speaker + if seg_duration < SPEAKER_SAMPLE_MIN_SEGMENT_DURATION and speaker_id is not None: + seg_idx = segment_index_map.get(seg_id) + if seg_idx is not None: + i = seg_idx - 1 + while i >= 0: + prev_seg = ordered_segments[i] + if prev_seg.get('speaker_id') != speaker_id: + break + prev_start = prev_seg.get('start') + if prev_start is not None: + segment_start = min(segment_start, prev_start) + seg_duration = segment_end - segment_start + if seg_duration >= SPEAKER_SAMPLE_MIN_SEGMENT_DURATION: + print( + f"Expanded segment to {seg_duration:.1f}s by including adjacent segments", + uid, + conversation_id, + ) + break + i -= 1 + + if seg_duration < SPEAKER_SAMPLE_MIN_SEGMENT_DURATION: + print(f"Segment too short ({seg_duration:.1f}s) even after expansion, skipping", uid, conversation_id) + continue + + # Extract centered sample window (10 seconds max from center of segment) + seg_center = (segment_start + segment_end) / 2 + sample_start = max(segment_start, seg_center - SPEAKER_SAMPLE_WINDOW_HALF) + sample_end = min(segment_end, seg_center + SPEAKER_SAMPLE_WINDOW_HALF) + + # Calculate absolute timestamps using the sample window + abs_start = started_at_ts + sample_start + abs_end = started_at_ts + sample_end + + # Find relevant chunks + sorted_chunks = sorted(chunks, key=lambda c: c['timestamp']) + + # Find first chunk that starts at or before abs_start + first_idx = 0 + for i, chunk in enumerate(sorted_chunks): + if chunk['timestamp'] <= abs_start: + first_idx = i + else: + break + + # Collect from first_idx up to abs_end + relevant_timestamps = [] + for chunk in sorted_chunks[first_idx:]: + if chunk['timestamp'] <= abs_end: + relevant_timestamps.append(chunk['timestamp']) + else: + break + + if not relevant_timestamps: + print(f"No relevant chunks for segment {segment_start:.1f}-{segment_end:.1f}s", uid, conversation_id) + continue + + # Download, merge, and extract + merged = await asyncio.to_thread( + download_audio_chunks_and_merge, + uid, + conversation_id, + relevant_timestamps, + fill_gaps=True, + sample_rate=sample_rate, + ) + buffer_start = min(relevant_timestamps) + + # Use av for sample-accurate trimming + trim_start = abs_start - buffer_start + trim_end = abs_end - buffer_start + sample_audio = _trim_pcm_audio(merged, sample_rate, trim_start, trim_end) + + # Ensure minimum sample length (8 seconds) + min_sample_seconds = 8.0 + min_sample_bytes = int(sample_rate * min_sample_seconds * 2) + if len(sample_audio) < min_sample_bytes: + actual_seconds = len(sample_audio) / (sample_rate * 2) + print( + f"Sample too short ({actual_seconds:.1f}s), need {min_sample_seconds}s, skipping", + uid, + conversation_id, + ) + continue + + # Upload and store + path = await asyncio.to_thread( + upload_person_speech_sample_from_bytes, sample_audio, uid, person_id, sample_rate + ) + + success = users_db.add_person_speech_sample(uid, person_id, path) + if success: + samples_added += 1 + seg_text = seg.get('text', '')[:100] # Truncate to 100 chars + print( + f"Stored speech sample {samples_added} for person {person_id}: segment_id={seg_id}, file={path}, text={seg_text}", + uid, + conversation_id, + ) + + # Extract and store speaker embedding + try: + wav_bytes = _pcm_to_wav_bytes(sample_audio, sample_rate) + embedding = await asyncio.to_thread(extract_embedding_from_bytes, wav_bytes, "sample.wav") + # Convert numpy array to list for Firestore storage + embedding_list = embedding.flatten().tolist() + users_db.set_person_speaker_embedding(uid, person_id, embedding_list) + print( + f"Stored speaker embedding for person {person_id} (dim={len(embedding_list)})", + uid, + conversation_id, + ) + except Exception as emb_err: + print(f"Failed to extract/store speaker embedding: {emb_err}", uid, conversation_id) + else: + print(f"Failed to add speech sample for person {person_id}", uid, conversation_id) + break # Likely hit limit + + except Exception as e: + print(f"Error extracting speaker samples: {e}", uid, conversation_id) diff --git a/backend/utils/stt/speaker_embedding.py b/backend/utils/stt/speaker_embedding.py new file mode 100644 index 0000000000..fe193e9799 --- /dev/null +++ b/backend/utils/stt/speaker_embedding.py @@ -0,0 +1,176 @@ +import os +from typing import Optional, Tuple + +import numpy as np +import requests +from scipy.spatial.distance import cdist + +# Cosine distance threshold for speaker matching +# Based on VoxCeleb 1 test set EER of 2.8% +SPEAKER_MATCH_THRESHOLD = 0.35 + + +def _get_api_url() -> str: + """Get the speaker embedding API URL from environment.""" + url = os.getenv('HOSTED_SPEAKER_EMBEDDING_API_URL') + if not url: + raise ValueError("HOSTED_SPEAKER_EMBEDDING_API_URL environment variable not set") + return url + + +def extract_embedding(audio_path: str) -> np.ndarray: + """ + Extract speaker embedding from an audio file using hosted API. + + Args: + audio_path: Path to audio file (wav format recommended) + + Returns: + numpy array of shape (1, D) where D is embedding dimension + """ + api_url = _get_api_url() + + with open(audio_path, 'rb') as f: + files = {'file': (os.path.basename(audio_path), f, 'audio/wav')} + response = requests.post(f"{api_url}/v1/embedding", files=files) + response.raise_for_status() + + result = response.json() + + # Handle both formats: direct array or {"embedding": [...]} + if isinstance(result, list): + embedding = np.array(result, dtype=np.float32) + else: + embedding = np.array(result['embedding'], dtype=np.float32) + + # Ensure shape is (1, D) + if embedding.ndim == 1: + embedding = embedding.reshape(1, -1) + + return embedding + + +def extract_embedding_from_bytes(audio_data: bytes, filename: str = "audio.wav") -> np.ndarray: + """ + Extract speaker embedding from audio bytes using hosted API. + + Args: + audio_data: Raw audio bytes (wav format) + filename: Filename to use in the request + + Returns: + numpy array of shape (1, D) where D is embedding dimension + """ + api_url = _get_api_url() + + files = {'file': (filename, audio_data, 'audio/wav')} + response = requests.post(f"{api_url}/v1/embedding", files=files) + response.raise_for_status() + + result = response.json() + + # Handle both formats: direct array or {"embedding": [...]} + if isinstance(result, list): + embedding = np.array(result, dtype=np.float32) + else: + embedding = np.array(result['embedding'], dtype=np.float32) + + # Ensure shape is (1, D) + if embedding.ndim == 1: + embedding = embedding.reshape(1, -1) + + return embedding + + +def compare_embeddings(embedding1: np.ndarray, embedding2: np.ndarray) -> float: + """ + Compare two speaker embeddings using cosine distance. + + Args: + embedding1: First embedding array (1, D) + embedding2: Second embedding array (1, D) + + Returns: + Cosine distance (0.0 = identical, 2.0 = opposite) + Lower values indicate more similar speakers + """ + distance = cdist(embedding1, embedding2, metric="cosine")[0, 0] + return float(distance) + + +def is_same_speaker( + embedding1: np.ndarray, embedding2: np.ndarray, threshold: float = SPEAKER_MATCH_THRESHOLD +) -> Tuple[bool, float]: + """ + Determine if two embeddings belong to the same speaker. + + Args: + embedding1: First embedding array + embedding2: Second embedding array + threshold: Cosine distance threshold for matching + + Returns: + Tuple of (is_match, distance) + """ + distance = compare_embeddings(embedding1, embedding2) + return distance < threshold, distance + + +def embedding_to_bytes(embedding: np.ndarray) -> bytes: + """ + Serialize embedding to bytes for storage. + + Args: + embedding: numpy array embedding + + Returns: + Bytes representation of the embedding + """ + return embedding.astype(np.float32).tobytes() + + +def bytes_to_embedding(data: bytes, dim: int = 512) -> np.ndarray: + """ + Deserialize embedding from bytes. + + Args: + data: Bytes representation of embedding + dim: Embedding dimension (default 512 for pyannote/embedding) + + Returns: + numpy array of shape (1, D) + """ + embedding = np.frombuffer(data, dtype=np.float32) + return embedding.reshape(1, -1) + + +def find_best_match( + query_embedding: np.ndarray, candidate_embeddings: list[np.ndarray], threshold: float = SPEAKER_MATCH_THRESHOLD +) -> Optional[Tuple[int, float]]: + """ + Find the best matching speaker from a list of candidates. + + Args: + query_embedding: Embedding to match + candidate_embeddings: List of candidate embeddings + threshold: Maximum distance for a valid match + + Returns: + Tuple of (best_index, distance) or None if no match found + """ + if not candidate_embeddings: + return None + + best_idx = -1 + best_distance = float('inf') + + for idx, candidate in enumerate(candidate_embeddings): + distance = compare_embeddings(query_embedding, candidate) + if distance < best_distance: + best_distance = distance + best_idx = idx + + if best_distance < threshold: + return best_idx, best_distance + + return None