diff --git a/synapse/cli/__main__.py b/synapse/cli/__main__.py index 965f5a1..24f9806 100755 --- a/synapse/cli/__main__.py +++ b/synapse/cli/__main__.py @@ -18,6 +18,7 @@ streaming, taps, settings, + upload, ) from synapse.utils.discover import find_device_by_name @@ -79,6 +80,7 @@ def main(): deploy.add_commands(subparsers) build.add_commands(subparsers) settings.add_commands(subparsers) + upload.add_commands(subparsers) args = parser.parse_args() # If we need to setup the device URI, do that now diff --git a/synapse/cli/discover.py b/synapse/cli/discover.py index 86a87c4..26882d3 100644 --- a/synapse/cli/discover.py +++ b/synapse/cli/discover.py @@ -14,11 +14,12 @@ def __init__(self): self.table = Table(show_lines=True, min_width=80) self.table.title = Spinner("dots") self.table.add_column("Name", justify="left") - self.table.add_column("Host", justify="right") + self.table.add_column("Host", justify="left") + self.table.add_column("Device Type", justify="left") def add_device(self, device): self.devices.append(device) - self.table.add_row(device.name, device.host) + self.table.add_row(device.name, device.host, device.device_type) def generate_layout(device_table): diff --git a/synapse/cli/upload.py b/synapse/cli/upload.py new file mode 100644 index 0000000..5f2a8a9 --- /dev/null +++ b/synapse/cli/upload.py @@ -0,0 +1,338 @@ +import os +import csv +import json +import subprocess +from pathlib import Path +import numpy as np +import h5py +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn + +def validate_hdf5_structure(filename, console): + """Validate that the HDF5 file matches the company-enforced structure.""" + try: + with h5py.File(filename, 'r') as f: + # Check root attributes + required_root_attrs = ['lsb_uv', 'sample_rate_hz', 'session_start_time'] + for attr in required_root_attrs: + if attr not in f.attrs: + console.print(f"[bold red]Missing root attribute: {attr}[/bold red]") + return False + + # Check acquisition group and datasets + if 'acquisition' not in f: + console.print("[bold red]Missing 'acquisition' group[/bold red]") + return False + + acq = f['acquisition'] + required_datasets = [ + 'ElectricalSeries', + 'sequence_number', + 'timestamp_ns', + 'unix_timestamp_ns' + ] + + for dataset in required_datasets: + if dataset not in acq: + console.print(f"[bold red]Missing dataset in acquisition: {dataset}[/bold red]") + return False + + # Check general/device structure + if 'general' not in f or 'device' not in f['general']: + console.print("[bold red]Missing 'general/device' group[/bold red]") + return False + + if 'device_type' not in f['general']['device'].attrs: + console.print("[bold red]Missing 'device_type' attribute in general/device[/bold red]") + return False + + # Check general/extracellular_ephys/electrodes + if ('general' not in f or + 'extracellular_ephys' not in f['general'] or + 'electrodes' not in f['general']['extracellular_ephys']): + console.print("[bold red]Missing 'general/extracellular_ephys/electrodes' structure[/bold red]") + return False + + if 'id' not in f['general']['extracellular_ephys']['electrodes']: + console.print("[bold red]Missing 'id' dataset in electrodes[/bold red]") + return False + + console.print("[green]✓ HDF5 structure validation passed[/green]") + return True + + except Exception as e: + console.print(f"[bold red]Error validating HDF5 file: {e}[/bold red]") + return False + +def extract_metadata(filename, console): + """ + Extract metadata from HDF5 file and compute derived values. + Returns a dictionary with duration_s and num_channels. + """ + console.print(f"\n[cyan]Extracting metadata...[/cyan]") + + try: + with h5py.File(filename, 'r') as f: + # Read sample_rate_hz from root attributes + sample_rate_hz = f.attrs['sample_rate_hz'] + + # Get number of elements in timestamp_ns + timestamp_ns = f['acquisition']['timestamp_ns'] + num_timestamps = len(timestamp_ns) + + # Get number of channel IDs + channel_ids = f['general']['extracellular_ephys']['electrodes']['id'] + num_channels = len(channel_ids) + + # Calculate duration in seconds + duration_s = num_timestamps / sample_rate_hz + + console.print(f"[dim]Sample rate: {sample_rate_hz} Hz[/dim]") + console.print(f"[dim]Number of timestamps: {num_timestamps:,}[/dim]") + console.print(f"[dim]Number of channels: {num_channels}[/dim]") + console.print(f"[dim]Calculated duration: {duration_s:.2f} s[/dim]") + + metadata = { + "duration_s": duration_s, + "num_channels": int(num_channels) + } + + console.print("[green]✓ Metadata extracted[/green]") + return metadata + + except Exception as e: + console.print(f"[bold red]Error extracting metadata: {e}[/bold red]") + import traceback + console.print(f"[dim]{traceback.format_exc()}[/dim]") + return None + +def compute_spike_statistics(filename, console, num_chunks=128, threshold_std=3.0): + """ + Compute aggregate spike statistics across all channels. + Returns total spike distribution across time chunks with squaring transformation. + """ + console.print(f"\n[cyan]Computing spike distribution with {num_chunks} time chunks for UI visualization...[/cyan]") + + try: + with h5py.File(filename, 'r') as f: + # Get the electrical series data + electrical_series = f['acquisition']['ElectricalSeries'] + channel_ids = f['general']['extracellular_ephys']['electrodes']['id'][:] + num_channels = len(channel_ids) + total_samples = electrical_series.shape[0] + + console.print(f"[dim]Total samples: {total_samples:,}[/dim]") + console.print(f"[dim]Number of channels: {num_channels}[/dim]") + console.print(f"[dim]Data shape: {electrical_series.shape}[/dim]") + + # Determine data layout + if len(electrical_series.shape) == 1: + # Data is 1D - likely interleaved channels + samples_per_channel = total_samples // num_channels + is_interleaved = True + else: + # Data is 2D [samples, channels] or [channels, samples] + is_interleaved = False + if electrical_series.shape[1] == num_channels: + samples_per_channel = electrical_series.shape[0] + channel_axis = 1 + else: + samples_per_channel = electrical_series.shape[1] + channel_axis = 0 + + samples_per_chunk = samples_per_channel // num_chunks + + # Less aggressive subsampling - read ~10% of each chunk + subsample_size = max(5000, samples_per_chunk // 10) + + console.print(f"[dim]Samples per channel: {samples_per_channel:,}[/dim]") + console.print(f"[dim]Samples per chunk: {samples_per_chunk:,}[/dim]") + console.print(f"[dim]Reading ~{subsample_size} samples per chunk for estimation[/dim]") + + # Initialize aggregate spike counts across all channels + aggregate_spike_counts = np.zeros(num_chunks, dtype=np.int64) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + console=console + ) as progress: + task = progress.add_task(f"Processing {num_channels} channels...", total=num_channels) + + for ch_idx, ch_id in enumerate(channel_ids): + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * samples_per_chunk + chunk_end = min((chunk_idx + 1) * samples_per_chunk, samples_per_channel) + + # Read a representative sample from the middle of this chunk + sample_start = chunk_start + samples_per_chunk // 2 - subsample_size // 2 + sample_end = sample_start + subsample_size + + # Ensure we don't go out of bounds + sample_start = max(chunk_start, min(sample_start, chunk_end - subsample_size)) + sample_end = min(chunk_end, sample_start + subsample_size) + + try: + # Read data efficiently based on layout + if is_interleaved: + # For 1D interleaved: read block and extract channel + start_idx = sample_start * num_channels + ch_idx + channel_data = electrical_series[start_idx:sample_end * num_channels:num_channels] + else: + # For 2D data: slice appropriately + if channel_axis == 1: + channel_data = electrical_series[sample_start:sample_end, ch_idx] + else: + channel_data = electrical_series[ch_idx, sample_start:sample_end] + + # Convert to float for processing + channel_data = channel_data.astype(np.float32) + + # Quick preprocessing + channel_data -= np.mean(channel_data) + std = np.std(channel_data) + + if std > 0: + threshold = threshold_std * std + # Count spikes in this sample + spike_count = np.sum(np.abs(channel_data) > threshold) + # Extrapolate to full chunk + scaling_factor = samples_per_chunk / subsample_size + aggregate_spike_counts[chunk_idx] += int(spike_count * scaling_factor) + + except Exception as e: + console.print(f"[yellow]Warning: Error reading chunk {chunk_idx} for channel {ch_id}: {e}[/yellow]") + + progress.update(task, advance=1) + + # Normalize by total spike count to get distribution + total_spikes = np.sum(aggregate_spike_counts) + if total_spikes > 0: + normalized_distribution = aggregate_spike_counts / total_spikes + else: + normalized_distribution = aggregate_spike_counts.astype(np.float64) + + console.print(f"[dim]Total spikes detected (approx): {total_spikes:,}[/dim]") + console.print(f"[dim]Distribution range before transform: [{normalized_distribution.min():.6f}, {normalized_distribution.max():.6f}][/dim]") + + # Apply squaring transformation to emphasize differences + squared_distribution = normalized_distribution ** 2 + + # Renormalize after squaring + squared_sum = np.sum(squared_distribution) + if squared_sum > 0: + final_distribution = (squared_distribution / squared_sum).tolist() + else: + final_distribution = squared_distribution.tolist() + + console.print(f"[dim]Distribution range after transform: [{min(final_distribution):.6f}, {max(final_distribution):.6f}][/dim]") + console.print("[green]✓ Spike distribution computed with squaring transformation[/green]") + + return final_distribution + + except Exception as e: + console.print(f"[bold red]Error computing spike statistics: {e}[/bold red]") + import traceback + console.print(f"[dim]{traceback.format_exc()}[/dim]") + return None + +def upload(args): + console = Console() + + if not args.filename: + console.print("[bold red]Error: No filename specified[/bold red]") + return + + # Check file extension + file_path = Path(args.filename) + if file_path.suffix.lower() not in ['.hdf5', '.h5']: + console.print(f"[bold red]Error: File must be .hdf5 or .h5 format (got {file_path.suffix})[/bold red]") + return + + if not os.path.exists(args.filename): + console.print(f"[bold red]Error: File '{args.filename}' does not exist[/bold red]") + return + + # Validate HDF5 structure + console.print(f"[cyan]Validating HDF5 structure...[/cyan]") + if not validate_hdf5_structure(args.filename, console): + console.print("[bold red]HDF5 validation failed. Upload aborted.[/bold red]") + return + + # Extract metadata and write to JSON + metadata = extract_metadata(args.filename, console) + json_path = None + + if metadata is None: + console.print("[bold yellow]Warning: Could not extract metadata[/bold yellow]") + else: + # Write metadata to JSON file + json_path = Path(str(file_path) + '.json') + try: + with open(json_path, 'w') as json_file: + json.dump(metadata, json_file, indent=2) + console.print(f"[green]✓ Metadata written to {json_path}[/green]") + except Exception as e: + console.print(f"[bold red]Error writing JSON file: {e}[/bold red]") + json_path = None + + # Compute spike statistics + spike_distribution = compute_spike_statistics(args.filename, console) + csv_path = None + + if spike_distribution is None: + console.print("[bold yellow]Warning: Could not compute spike statistics[/bold yellow]") + else: + # Write distribution to CSV file (single row with 128 values) + csv_path = file_path.with_suffix(file_path.suffix + '.csv') + try: + with open(csv_path, 'w', newline='') as csv_file: + writer = csv.writer(csv_file) + writer.writerow(spike_distribution) + console.print(f"[green]✓ Spike distribution written to {csv_path}[/green]") + except Exception as e: + console.print(f"[bold red]Error writing CSV file: {e}[/bold red]") + csv_path = None + + file_size = os.path.getsize(args.filename) + file_size_mb = file_size / (1024 * 1024) + + uri = args.uri + remote_host = f"scifi@{uri}" + remote_dir = "~/replay" + remote_path = f"{remote_host}:{remote_dir}" + mkdir_command = ["ssh", remote_host, f"mkdir -p {remote_dir}"] + + console.print(f"\n[cyan]Uploading file:[/cyan] {args.filename}") + console.print(f"[cyan]File size:[/cyan] {file_size_mb:.2f} MB") + console.print(f"[cyan]Destination:[/cyan] {remote_path}") + + try: + # Ensure ~/replay directory exists + console.print(f"\n[cyan]Ensuring directory exists: {remote_dir}[/cyan]") + subprocess.run(mkdir_command, check=True, capture_output=True, text=True) + console.print(f"[green]✓ Directory ready[/green]") + + # Collect all files to upload + files_to_upload = [args.filename] + if json_path and os.path.exists(json_path): + files_to_upload.append(str(json_path)) + if csv_path and os.path.exists(csv_path): + files_to_upload.append(str(csv_path)) + + # Upload all files in a single SCP command + console.print(f"\n[cyan]Uploading {len(files_to_upload)} file(s) to {remote_path}...[/cyan]") + console.print("[dim]You may be prompted for a password[/dim]\n") + scp_command = ["scp"] + files_to_upload + [remote_path] + subprocess.run(scp_command, check=True) + console.print(f"[bold green]✓ Successfully uploaded all files[/bold green]") + + except subprocess.CalledProcessError as e: + console.print(f"\n[bold red]✗ Upload failed[/bold red]") + +def add_commands(subparsers): + upload_parser = subparsers.add_parser("upload", help="Upload HDF5 recordings to your Synapse device") + upload_parser.add_argument("filename", type=str, help="Path to the HDF5 file (.hdf5 or .h5) to upload") + upload_parser.set_defaults(func=upload) \ No newline at end of file diff --git a/synapse/utils/discover.py b/synapse/utils/discover.py index f8f1003..b4cf467 100644 --- a/synapse/utils/discover.py +++ b/synapse/utils/discover.py @@ -15,6 +15,7 @@ class DeviceInfo: capability: str name: str serial: str + device_type: str def discover_iter(socket_timeout_sec=1, discovery_timeout_sec=DISCOVERY_TIMEOUT_SEC): @@ -40,11 +41,20 @@ def discover_iter(socket_timeout_sec=1, discovery_timeout_sec=DISCOVERY_TIMEOUT_ else: data = data.decode("ascii").split() if data[0] == "ID": - if len(data) != 5: - continue - _, serial, capability, port, name = data + # Backward compatability + serial = "unknown_serial" + capability = "unknown_capability" + port = 0 + name = "unknown_name" + device_type = "unknown_device" + + if len(data) == 5: + _, serial, capability, port, name = data + elif len(data) >= 6: + _, serial, capability, port, name, device_type, *_ = data + dev_info = DeviceInfo( - server[0], int(port), capability, name, serial + server[0], int(port), capability, name, serial, device_type ) if dev_info not in devices: devices.append(dev_info)