diff --git a/seas/domains.py b/seas/domains.py index b4351d6..fd7a913 100644 --- a/seas/domains.py +++ b/seas/domains.py @@ -15,7 +15,9 @@ from seas.ica import rebuild_mean_roi_timecourse, filter_mean from seas.rois import make_mask from seas.colormaps import save_colorbar, REGION_COLORMAP, DEFAULT_COLORMAP +from seas.signalanalysis import butterworth +from skimage.morphology import remove_small_objects def get_domain_map(components: dict, blur: int = 21, @@ -95,7 +97,7 @@ def get_domain_map(components: dict, blur += 1 eigenbrain = np.empty(shape) - eigenbrain[:] = np.NAN + eigenbrain[:] = np.nan for index in range(eig_vec.shape[1]): @@ -115,7 +117,7 @@ def get_domain_map(components: dict, if roimask is not None: domain_ROIs = np.empty(shape) - domain_ROIs[:] = np.NAN + domain_ROIs[:] = np.nan domain_ROIs.flat[maskind] = domain_ROIs_vector else: @@ -619,54 +621,41 @@ def write_frame(frame): print('Saving Colorbar to:' + cbarpath) save_colorbar(scale, cbarpath, colormap=colormap) - def threshold_by_domains(components: dict, blur: int = 1, - min_size_ratio: float = 0.1, - map_only: bool = True, - apply_filter_mean: bool = True, - max_loops: int = 2, - ignore_small: bool = True): + min_mask_size: int = 64, + thresh_type: str = 'max', + thresh_param: float = None): ''' - Creates a domain map from extracted independent components. A pixelwise maximum projection of the blurred signal components is taken through the n_components axis, to create a flattened representation of where a domain was maximally significant across the cortical surface. Components with multiple noncontiguous significant regions are counted as two distinct domains. + Function based on modified get_domain_map(). Thresholds ICs using a variety of methods for selective rebuild. Arguments: components: - The dictionary of components returned from seas.ica.project. Domains are most interesting if artifacts has already been assigned through seas.gui.run_gui. + The dictionary of components returned from seas.ica.project. ROIs are most interesting if artifacts has already been assigned through seas.gui.run_gui. blur: - An odd integer kernel Gaussian blur to run before segmenting. Domains look smoother with larger blurs, but you can lose some smaller domains. - map_only: - If true, compute the map only, do not rebuild time courses under each domain. - apply_filter_mean: - Whether to compute the filtered mean when calculating ROI rebuild timecourses. - min_size_ratio: - The minimum size ratio of the mean component size to allow for a component. If a the size of a component is under (min_size_ratio x mean_domain_size), and the next most significant domain over the pixel would result in a larger size domain, this next domain is chosen. - max_loops: - The number of times to check if the next most significant domain would result in a larger domain size. To entirely disable this, set max_loops to 0. - ignore_small: - If True, assign undersize domains that were not reassigned during max_loops to np.nan. + An odd integer kernel Gaussian blur to run before segmenting. ROIs look smoother with larger blurs, but you can lose some smaller domains. + min_mask_size: + An integer determining the minimum ROIs passed from each thresholded IC. + thresh_type: + A string used to determine IC threshold method. Choose from either 'max', 'z-score' or 'percentile'. + thresh_param: + A float used to determine the parameter for the given thresh_type. For 'z-score', this is the z-score threshold (eg; 2.0 for 2std). For 'percentile' this is the percentile used to threshold (eg; 95th percentile = 0.95). Returns: output: a dictionary containing the results of the operation, containing the following keys domain_blur: The Gaussian blur value used when generating the map - component_assignment: - A map showing the index of which *component* was maximally significant over a given pixel. Here, - This is in contrast to the domain map, where each domain is a unique integer. - domain_ROIs: - The computed np.array domain map (x,y). Each domain is represented by a unique integer, and represents a discrete continuous unit. Values that are masked, or where large enough domains were not detected are set to np.nan. - - if not map_only, the following are also included in the output dictionary: - ROI_timecourses: - The time courses rebuilt from the video under each ROI. The frame mean is not included in this calculation, and must be re-added from mean_filtered. - mean_filtered: - The frame mean, filtered by the default method. + eig_vec: + The thresholded eigenvectors (ICs). + thresh_masks: + The boolean masks used to threshold eig_vec. ''' print('\nExtracting Domain ROIs\n-----------------------') output = {} output['domain_blur'] = blur eig_vec = components['eig_vec'].copy() + shape = components['shape'] shape = (shape[1], shape[2]) @@ -690,36 +679,70 @@ def threshold_by_domains(components: dict, print('no noise components found') signal_indices = np.where(artifact_components == 0)[0] # eig_vec = eig_vec[:, signal_indices] # Don't change number of ICs, we're updating back to dict + + mask = np.zeros_like(eig_vec, dtype=bool) + match thresh_type: + case 'max': + # Return indices across each eig_vec (loading vector for component) where loading is max + threshold_ROIs_vector = np.argmax(np.abs(eig_vec), axis=1) + # Then threshold by clearing eig_vec outside of max indices + mask[np.arange(eig_vec.shape[0]), threshold_ROIs_vector] = True + case 'z-score': + mean_ROIs_vector = np.nanmean(eig_vec, axis=0) + std_ROIs_vector = np.nanstd(eig_vec, axis=0) + z_ROIs_vector = (eig_vec - mean_ROIs_vector)/std_ROIs_vector + for i in np.arange(eig_vec.shape[0]): + mask[i, :] = np.abs(z_ROIs_vector[i]) > thresh_param + case 'percentile': + flipped = components['flipped'] + # Flip ICs where necessary using flipped from dict + flipped_threshold_vec = np.multiply(flipped, eig_vec) + # Calculate 95 percentile cutoff for each IC + cutoff_vector = np.percentile(flipped, thresh_param, axis=0) + # Mask for all values above cutoff + for i in np.arange(eig_vec.shape[0]): + mask[i, :] = flipped_threshold_vec[i] > cutoff_vector[i] + case _: + print("Threshold type is neither max nor percentile.") + + # Filter small mask ROIs and smooth using blur if blur: print('blurring domains...') assert type(blur) is int, 'blur was not valid' if blur % 2 != 1: blur += 1 + eigenmask = np.zeros(shape, dtype=bool) eigenbrain = np.empty(shape) - eigenbrain[:] = np.NAN + eigenbrain[:] = np.nan - for index in range(eig_vec.shape[1]): + for index in range(mask.shape[1]): if roimask is not None: - eigenbrain.flat[maskind] = eig_vec.T[index] + eigenmask.flat[maskind] = mask.T[index] + # Remove small mask objects + filtered = remove_small_objects(eigenmask, min_size=min_mask_size, connectivity=1) + filtered_float = filtered.astype(np.float64) + eigenbrain.flat[maskind] = filtered_float.flat[maskind] + # Then blur blurred = cv2.GaussianBlur(eigenbrain, (blur, blur), 0) - eig_vec.T[index] = blurred.flat[maskind] + mask.T[index] = blurred.flat[maskind] else: - eigenbrain.flat = eig_vec.T[index] + eigenbrain.flat = mask.T[index] + filtered = remove_small_objects(eigenbrain, min_size=min_mask_size, connectivity=1) + filtered_float = filtered.astype(np.float64) + eigenbrain.flat[maskind] = filtered_float.flat blurred = cv2.GaussianBlur(eigenbrain, (blur, blur), 0) - eig_vec.T[index] = blurred.flat - - # This is the money section, return indices across each eig_vec (loading vector for component) where loading is max - domain_ROIs_vector = np.argmax(np.abs(eig_vec), axis=1) - # Then threshold by clearing eig_vec outside of max indices - mask = np.zeros_like(eig_vec, dtype=bool) - mask[np.arange(eig_vec.shape[0]), domain_ROIs_vector] = True - eig_vec[~mask] = 0 + mask.T[index] = blurred.flat + mask_bool = mask.astype(bool) + eig_vec[~mask_bool] = 0 + + output['thresh_masks'] = mask + # output['thresh_vec'] = eig_vec output['eig_vec'] = eig_vec - + return output # if blur: @@ -727,7 +750,7 @@ def threshold_by_domains(components: dict, # if roimask is not None: # domain_ROIs = np.empty(shape) - # domain_ROIs[:] = np.NAN + # domain_ROIs[:] = np.nan # domain_ROIs.flat[maskind] = domain_ROIs_vector # else: diff --git a/seas/experiment.py b/seas/experiment.py index 85124dd..20bfcc4 100644 --- a/seas/experiment.py +++ b/seas/experiment.py @@ -8,11 +8,13 @@ from seas.filemanager import sort_experiments, get_exp_span_string, read_yaml from seas.rois import roi_loader, make_mask, get_masked_region, insert_masked_region, draw_bounding_box from seas.hdf5manager import hdf5manager -from seas.ica import project, filter_mean +from seas.ica import project, filter_mean, rebuild_eigenbrain, threshold_by_domains, filter_components, threshold_components, rebuild from seas.signalanalysis import sort_noise, lag_n_autocorr from seas.waveletAnalysis import waveletAnalysis +from seas.domains import get_domain_map from typing import List +import tifffile as tif class Experiment: @@ -112,7 +114,7 @@ def __init__(self, if np.any(np.isnan(movie)): # If the video was already masked - roimask = np.zeros(movie[0].shape, dtype='uisnt8') + roimask = np.zeros(movie[0].shape, dtype='uint8') roimask[np.where(~np.isnan(movie[0]))] = 1 self.roimask = roimask @@ -461,3 +463,122 @@ def ica_project(self, f.print() return components + +def export_event_masks(components: dict, + outpath: str, + blur: int = 3, + thresh_type: str = 'z-score', + thresh_param: float = 7, + schematic: bool = False) -> None: + components_copy = components.copy() + threshold = threshold_by_domains(components_copy, + blur = blur, + thresh_type = thresh_type, + thresh_param = thresh_param, + schematic = schematic) + components_copy.update(threshold) + artifacts_bool = components_copy['artifact_components'].astype(bool) + event_components = components_copy['eig_vec'][:, ~artifacts_bool] + event_masks = rebuild_eigenbrain(event_components, + roimask = components_copy['roimask'], + bulk = True) + event_masks = np.where(np.abs(event_masks) > 0, 255, 0) + event_masks = np.where(np.mean(event_masks, axis = 0) == 255, 0, event_masks) + tif.imwrite(outpath, event_masks.astype(np.float32),imagej=True) + +def export_event_video(components: dict, + outpath: str, + artifact_components: np.ndarray = None, + t_start: int = None, + t_stop: int = None, + apply_mean_filter: bool = True, + cthresh: float = 2.0, + apply_masked_mean: bool = False, + filter_method: str = 'constant', + include_noise: bool = True) -> None: + components_copy = components.copy() + threshold = threshold_by_domains(components_copy, + blur = 3, + thresh_type = 'z-score', + thresh_param = 7) + eig_mix = filter_components(components_copy['eig_mix']) + eig_mix = threshold_components(eig_mix, + thresh_param = cthresh) + components_copy.update(threshold) + components_copy['eig_mix'] = eig_mix + rebuilt = rebuild(components_copy, + artifact_components = artifact_components, + t_start = t_start, + t_stop = t_stop, + apply_mean_filter = apply_mean_filter, + cthresh = cthresh, + apply_masked_mean = apply_masked_mean, + filter_method = filter_method, + include_noise = include_noise) + tif.imwrite(outpath, rebuilt.astype(np.float32), imagej=True) + +def sort_components(components: dict, sort_by_noise: bool = True): + eig_vec = components['eig_vec'] + eig_mix = components['eig_mix'] + lag1 = components['lag1'] + lag1_full = components['lag1_full'] + noise = components['noise_components'] + + if sort_by_noise: + ev_sort = np.argsort(lag1) # Sorting by lag1 auto-correlation + else: + ev_sort = np.argsort(eig_mix.std(axis=0)) # Sorting by timecourse standard deviation. + + eig_vec = eig_vec[:, ev_sort][:, ::-1] + eig_mix = eig_mix[:, ev_sort][:, ::-1] + lag1 = lag1[ev_sort][::-1] + lag1_full = lag1_full[ev_sort][::-1] + noise = noise[ev_sort][::-1] + + if 'artifact_components' in components: + artifacts = components['artifact_components'] + artifacts = artifacts[ev_sort][::-1] + components['artifact_components'] = artifacts + + # Save sorted values + components['eig_vec'] = eig_vec + components['eig_mix'] = eig_mix + components['lag1'] = lag1 + components['lag1_full'] = lag1_full + components['noise_components'] = noise + + # Derive from sorted values + components['timecourses'] = eig_mix.T + + # Recalculation calls (how PySEAS does it originally) + #noise, cutoff = sort_noise(eig_mix.T) + #components['cutoff'] = cutoff + #components['lag1'] = lag_n_autocorr(components['timecourses'], 1) + + # Recalculate domain map (doesn't work for some reason) + # domain_map = get_domain_map(components, map_only = False) + # components.update(domain_map) + + return components + +def flip_negative_components(components: dict): + n_components = components['n_components'] + eig_vec = components['eig_vec'] + eig_mix = components['eig_mix'] + + # Track component orientation and ensure positive spatial patterns + flipped = np.ones(n_components) + for i in range(n_components): + # Find the index of maximum absolute value + max_idx = np.argmax(np.abs(eig_vec[:, i])) + # If that maximum value is negative, flip the component + if eig_vec[max_idx, i] < 0: + eig_vec[:, i] *= -1 + eig_mix[:, i] *= -1 + flipped[i] = -1 + + components['flipped'] = flipped + components['eig_vec'] = eig_vec + components['eig_mix'] = eig_mix + + return components \ No newline at end of file diff --git a/seas/gui.py b/seas/gui.py index 76b214a..fb7cad4 100644 --- a/seas/gui.py +++ b/seas/gui.py @@ -115,11 +115,11 @@ def run_gui(components: dict, print('initializing artifact_components toggle') toggle = np.zeros((n_components,), dtype='uint8') - if 'flipped' in components: - flipped = components['flipped'] + # if 'flipped' in components: + # flipped = components['flipped'] - timecourses = timecourses * flipped[:, None] - eig_vec = eig_vec * flipped + # timecourses = timecourses * flipped[:, None] + # eig_vec = eig_vec * flipped if 'domain_ROIs' in components: domain_ROIs = components['domain_ROIs'] @@ -469,7 +469,7 @@ def update(self, component_id): if component_id is None: # Clear image. im = np.empty((x, y)) - im[:] = np.NAN + im[:] = np.nan self.imgplot = self.ax.imshow(im) self.canvas.draw() return () diff --git a/seas/ica.py b/seas/ica.py index a00721e..453b380 100644 --- a/seas/ica.py +++ b/seas/ica.py @@ -12,11 +12,18 @@ from seas.hdf5manager import hdf5manager from seas.video import rotate, save, rescale, play, scale_video +import cv2 +from skimage.morphology import remove_small_objects +from skimage import draw, measure +from scipy import ndimage +import tifffile as tif + def project(vector: np.ndarray, shape: Tuple[int, int, int], roimask: np.ndarray = None, n_components: int = None, + crop_excess_noise: bool = True, svd_multiplier: float = 5, calc_residuals: bool = True, max_iter: int = 1000): @@ -183,23 +190,26 @@ def project(vector: np.ndarray, reduced_n_components = int((noise.size - noise.sum()) * 1.25) print('reduced_n_components:', reduced_n_components) - - if reduced_n_components < n_components: - print('Cropping', n_components, 'to', reduced_n_components) - - ev_sort = np.argsort(eig_mix.std(axis=0)) - eig_vec = eig_vec[:, ev_sort][:, ::-1] - eig_mix = eig_mix[:, ev_sort][:, ::-1] - noise = noise[ev_sort][::-1] - - eig_vec = eig_vec[:, :reduced_n_components] - eig_mix = eig_mix[:, :reduced_n_components] - n_components = reduced_n_components - noise = noise[:reduced_n_components] - - components['lag1_full'] = components['lag1_full'][ev_sort][::-1] + + if crop_excess_noise: + if reduced_n_components < n_components: + print('Cropping', n_components, 'to', reduced_n_components) + + ev_sort = np.argsort(eig_mix.std(axis=0)) + eig_vec = eig_vec[:, ev_sort][:, ::-1] + eig_mix = eig_mix[:, ev_sort][:, ::-1] + noise = noise[ev_sort][::-1] + + eig_vec = eig_vec[:, :reduced_n_components] + eig_mix = eig_mix[:, :reduced_n_components] + n_components = reduced_n_components + noise = noise[:reduced_n_components] + + components['lag1_full'] = components['lag1_full'][ev_sort][::-1] + else: + print('Less than 75% signal. Not cropping excess noise.') else: - print('Less than 75% signal. Not cropping excess noise.') + print('Noise retention enabled. Not cropping excess noise.') components['noise_components'] = noise components['cutoff'] = cutoff @@ -261,7 +271,7 @@ def project(vector: np.ndarray, vector = vector.astype('float64') rebuilt = rebuild(components, artifact_components='none', - vector=True).T + apply_mean_filter=False).T rebuilt -= rebuilt.mean(axis=0) vector -= vector.mean(axis=0) @@ -298,13 +308,19 @@ def project(vector: np.ndarray, print('\n') return components - def rebuild(components: dict, artifact_components: np.ndarray = None, t_start: int = None, t_stop: int = None, apply_mean_filter: bool = True, - filter_method: str = 'wavelet', + mlow: float = 0.5, + mhigh: float = 1.0, + apply_component_filter: bool = False, + chigh: float = 1.0, + apply_component_threshold: bool = False, + cthresh: float = 2.0, + apply_masked_mean: bool = False, + filter_method: str = 'butterworth_highpass', fps: float = 7.5, include_noise: bool = True): ''' @@ -324,8 +340,24 @@ def rebuild(components: dict, The frame to stop rebuilding the movie at. If none is provided, the rebuilt movie ends at the last frame apply_mean_filter: Whether to apply a filter to the mean signal. - filter_method:; - The filter method to apply (see filter_mean function). + mlow: + A float determining the highpass cutoff for the mean filter, if used. + mhigh: + A float determining the lowpass cutoff for the mean filter, if used. + apply_component_filter: + Whether to apply a butterworth_lowpass filter to IC timecourses before rebuild. + chigh: + A float determining the lowpass cutoff for the component filter, if used. + apply_component_threshold: + Whether to apply a z-score threshold on the component timeseries. + cthresh: + A float determining the z-score threshold for the component threshold, if used. + apply_masked_mean: + If True, only re-adds the mean signal to pixels where at least one IC is defined. To be used for thresholded ICs. + filter_method: + The filter method to apply to the mean. Choose from 'butterworth_bandpass', 'butterworth_lowpass', 'butterworth_highpass', or 'constant'. Behaviour for 'wavelet' as yet undefined. + fps: + A float determining the fps for the source video. include_noise: Whether to include noise components when rebuilding. If noise_components should not be included in the rebuilt movie, set this to False @@ -341,6 +373,7 @@ def rebuild(components: dict, assert type(components) is dict, 'Components were not in format expected' eig_vec = components['eig_vec'] + eig_mix = components['eig_mix'] roimask = components['roimask'] shape = components['shape'] mean = components['mean'] @@ -358,7 +391,8 @@ def rebuild(components: dict, elif artifact_components == 'none': print('including all components') artifact_components = np.zeros(n_components) - elif ((not include_noise) and ('noise_components' in components.keys())): + + if ((not include_noise) and ('noise_components' in components.keys())): print('Not rebuilding noise components') artifact_components += components['noise_components'] artifact_components[np.where(artifact_components > 1)] = 1 @@ -384,7 +418,15 @@ def rebuild(components: dict, assert eig_vec[:,0].size == maskind[0].size, \ "Eigenvector size is not compatible with the masked region's size" - eig_mix = components['eig_mix'] + # Filter component timecourses + if apply_component_filter: + lpf_eig_mix = filter_components(eig_mix, fps=fps, high_cutoff=chigh) + eig_mix = lpf_eig_mix + + # Threshold component timecourses + if apply_component_threshold: + thresh_eig_mix = threshold_components(eig_mix, thresh_param=cthresh) + eig_mix = thresh_eig_mix if (t_start == None): t_start = 0 @@ -406,14 +448,35 @@ def rebuild(components: dict, data_r = np.dot(eig_vec[:, reconstruct_indices], eig_mix[t_start:t_stop, reconstruct_indices].T).T - if apply_mean_filter: - mean_filtered = filter_mean(mean, filter_method, fps=fps) - data_r += mean_filtered[t_start:t_stop, None] + if apply_masked_mean: + masks = components['thresh_masks'] + assert masks is not None, \ + "Masks have not been assigned to dictionary" + # Apply mean to masks only, zeroing unmasked pixels + if apply_mean_filter: + combined_mask = np.any(masks[:, reconstruct_indices], axis=1) + mean_to_add = np.zeros_like(data_r) + mean_filtered = filter_mean(mean, filter_method, low_cutoff=mlow, high_cutoff=mhigh, fps=fps) + mean_to_add[:, combined_mask] = mean_filtered[t_start:t_stop, None] + data_r += mean_to_add + else: + print('Not filtering mean') + combined_mask = np.any(masks[:, reconstruct_indices], axis=1) + mean_to_add = np.zeros_like(data_r) + mean_filtered = None + mean_to_add[:, combined_mask] = mean[t_start:t_stop, None] + data_r += mean_to_add else: - print('Not filtering mean') - mean_filtered = None - data_r += mean[t_start:t_stop, None] + # Run original readdition of mean + if apply_mean_filter: + mean_filtered = filter_mean(mean, filter_method, low_cutoff=mlow, high_cutoff=mhigh, fps=fps) + data_r += mean_filtered[t_start:t_stop, None] + + else: + print('Not filtering mean') + mean_filtered = None + data_r += mean[t_start:t_stop, None] print('Done!') @@ -508,6 +571,12 @@ def filter_mean(mean: np.ndarray, wavelet = waveletAnalysis(mean.astype('float64'), fps=fps) mean_filtered = wavelet.noiseFilter(upperPeriod=1 / low_cutoff) + elif filter_method == 'constant': + mean_template = np.zeros_like(mean) + meanest_mean = np.mean(mean) + mean_filtered = mean_template + meanest_mean + print('Mean set as constant: dfof = ' + str(meanest_mean)) + else: raise Exception("Filter method '" + str(filter_method)\ + "' not supported!\n\t Supported methods: butterworth, butterworth_bandpass, wavelet") @@ -515,6 +584,225 @@ def filter_mean(mean: np.ndarray, return mean_filtered +def filter_components(eig_mix: np.ndarray, + fps: float = 7.5, + high_cutoff: float = 0.5): + ''' + Applies a butterworth low pass filter to the IC timecourses. + + Arguments: + eig_mix: + The mixing matrix containing IC timecourses. + fps: + Sampling rate of the video. + high_cutoff: + The frequency cutoff to apply the low pass filter at. + + Returns: + lpf_eig_mix: The filtered IC timecourses reconstructed as the eig_mix matrix. + ''' + + print('Filtering component timecourses using butterworth_lowpass at '+ str(high_cutoff) +'Hz...') + timecourses = eig_mix.T + lpf_timecourses = np.zeros_like(timecourses) + for index in range(timecourses.shape[0]): + lpf_timecourses[index] = butterworth(timecourses[index], fps=fps, high=high_cutoff) + lpf_eig_mix = lpf_timecourses.T + + return lpf_eig_mix + +def threshold_components(eig_mix: np.ndarray, + thresh_param: float): + ''' + Applies a z-score threshold to the IC timecourses. + + Arguments: + eig_mix: + The mixing matrix containing IC timecourses. + thresh_param: + Z-score thresholding parameter (standard deviations). + + Returns: + thresh_eig_mix: The thresholded IC timecourses reconstructed as the eig_mix matrix. + ''' + + print('Thresholding component timecourses using z-score: >' + str(thresh_param) +'s.d.') + timecourses = eig_mix.T + thresh_timecourses = np.zeros_like(timecourses) + for index in range(timecourses.shape[0]): + timecourse = timecourses[index] + mean = np.mean(timecourse) + std = np.std(timecourse) + threshold = mean + thresh_param*std + timecourse[np.abs(timecourse) < np.abs(threshold)] = 0 + thresh_timecourses[index] = timecourse + thresh_eig_mix = thresh_timecourses.T + + return thresh_eig_mix + +def threshold_by_domains(components: dict, + blur: int = 1, + min_mask_size: int = 64, + thresh_type: str = 'max', + thresh_param: float = None, + schematic: bool = False): + ''' + Function based on modified get_domain_map(). Thresholds ICs using a variety of methods for selective rebuild. + + Arguments: + components: + The dictionary of components returned from seas.ica.project. ROIs are most interesting if artifacts has already been assigned through seas.gui.run_gui. + blur: + An odd integer kernel Gaussian blur to run before segmenting. ROIs look smoother with larger blurs, but you can lose some smaller domains. + min_mask_size: + An integer determining the minimum ROIs passed from each thresholded IC. + thresh_type: + A string used to determine IC threshold method. Choose from either 'max', 'z-score' or 'percentile'. + thresh_param: + A float used to determine the parameter for the given thresh_type. For 'z-score', this is the z-score threshold (eg; 2.0 for 2std). For 'percentile' this is the percentile used to threshold (eg; 95th percentile = 0.95). + + Returns: + output: a dictionary containing the results of the operation, containing the following keys + domain_blur: + The Gaussian blur value used when generating the map + eig_vec: + The thresholded eigenvectors (ICs). + thresh_masks: + The boolean masks used to threshold eig_vec. + ''' + print('\nExtracting Domain ROIs\n-----------------------') + output = {} + output['domain_blur'] = blur + + eig_vec = components['eig_vec'].copy() + + shape = components['shape'] + shape = (shape[1], shape[2]) + + if 'roimask' in components.keys() and components['roimask'] is not None: + roimask = components['roimask'] + maskind = np.where(roimask.flat == 1)[0] + else: + roimask = None + + if 'artifact_components' in components.keys(): + artifact_components = components['artifact_components'] + + print('Switching to signal indices only for domain detection') + + if 'noise_components' in components.keys(): + noise_components = components['noise_components'] + + signal_indices = np.where((artifact_components + + noise_components) == 0)[0] + else: + print('no noise components found') + signal_indices = np.where(artifact_components == 0)[0] + # eig_vec = eig_vec[:, signal_indices] # Don't change number of ICs, we're updating back to dict + + mask = np.zeros_like(eig_vec, dtype = bool) + + match thresh_type: + case 'max': + # Return indices across each eig_vec (loading vector for component) where loading is max + threshold_ROIs_vector = np.argmax(np.abs(eig_vec), axis=1) + # Then threshold by clearing eig_vec outside of max indices + mask[np.arange(eig_vec.shape[0]), threshold_ROIs_vector] = True + case 'z-score': + mean_ROIs_vector = np.nanmean(eig_vec, axis=0) + std_ROIs_vector = np.nanstd(eig_vec, axis=0) + z_ROIs_vector = (eig_vec - mean_ROIs_vector)/std_ROIs_vector + for i in np.arange(eig_vec.shape[0]): + abs_z = np.abs(z_ROIs_vector[i]) + mask[i, :] = abs_z > thresh_param + # event = abs_z[mask[i, :]] + # Deprecated but produced an interesting result + # if schematic and event.size != 0: + # schem_thresh = np.percentile(event, 75) + # mask[i, :] = abs_z > schem_thresh + case 'percentile': + flipped = components['flipped'] + # Flip ICs where necessary using flipped from dict + flipped_threshold_vec = np.multiply(flipped, eig_vec) + # Calculate 95 percentile cutoff for each IC + cutoff_vector = np.percentile(flipped, thresh_param, axis=0) + # Mask for all values above cutoff + for i in np.arange(eig_vec.shape[0]): + mask[i, :] = flipped_threshold_vec[i] > cutoff_vector[i] + case _: + print("Threshold type is neither max nor percentile.") + + # Filter small mask ROIs and smooth using blur + if blur: + print('blurring domains...') + assert type(blur) is int, 'blur was not valid' + if blur % 2 != 1: + blur += 1 + + eigenmask = np.zeros(shape, dtype=bool) + eigenbrain = np.empty(shape) + eigenbrain[:] = np.nan + + for index in range(mask.shape[1]): + + if roimask is not None: + eigenmask.flat[maskind] = mask.T[index] + # Remove small mask objects + filtered = remove_small_objects(eigenmask, min_size=min_mask_size, connectivity=1) + filtered_float = filtered.astype(np.float64) + eigenbrain.flat[maskind] = filtered_float.flat[maskind] + # Then blur + blurred = cv2.GaussianBlur(eigenbrain, (blur, blur), 0) + mask.T[index] = blurred.flat[maskind] + else: + eigenbrain.flat = mask.T[index] + filtered = remove_small_objects(eigenbrain, min_size=min_mask_size, connectivity=1) + filtered_float = filtered.astype(np.float64) + eigenbrain.flat[maskind] = filtered_float.flat + blurred = cv2.GaussianBlur(eigenbrain, (blur, blur), 0) + mask.T[index] = blurred.flat + + if schematic: + eigenmask = np.zeros(shape, dtype=np.uint8) + eigenbrain = np.empty(shape) + eigenbrain[:] = np.nan + + for i in range(mask.shape[1]): + event_schematic = np.zeros(shape, dtype=np.uint8) + eigenmask.flat[maskind] = mask.T[i] + eigenbrain.flat[maskind] = eig_vec.T[i] + # print("i is:", i) + # print(eigenmask) + if eigenmask.any(): + # tif.imwrite("/home/apluff/dev/test_data/eigenmasks/sub-070_eigenmask"+str(i)+".tif", eigenmask, imagej=True) + labelled, num_features = ndimage.label(eigenmask) + # print(labelled) + # print("labelled contains values:", np.unique(labelled)) + # print("num_features is:", num_features) + for j in range(1, num_features + 1): + centroid = ndimage.center_of_mass(eigenmask, + labels = labelled, + index = j) + # print("j is:", j) + # print("centroid is:", centroid) + int_centroid = tuple(int(x) for x in centroid) + event_size = np.sum(labelled, where = labelled == j)/j + schem_radius = int(np.sqrt(event_size/np.pi)) + rr, cc = draw.disk(int_centroid, schem_radius, shape = shape) + event_schematic[rr, cc] = 255 + # print("mask shape is:", mask.shape) + # print("event_schematics shape is:", event_schematic.shape) + mask.T[i] = event_schematic.flat[maskind] + + mask_bool = mask.astype(bool) + eig_vec[~mask_bool] = 0 + + output['thresh_masks'] = mask + # output['thresh_vec'] = eig_vec + output['eig_vec'] = eig_vec + + return output + def rebuild_mean_roi_timecourse(components: np.ndarray, mask: np.ndarray, include_zero: bool = True, @@ -637,7 +925,7 @@ def rebuild_eigenbrain(eig_vec: np.ndarray, else: eigenbrains = np.empty( (roimask.shape[0], roimask.shape[1], eig_vec.shape[1])) - eigenbrains[:] = np.NAN + eigenbrains[:] = np.nan eigenbrains[x, y, :] = eig_vec eigenbrains = np.swapaxes(eigenbrains, 0, 2) eigenbrains = np.swapaxes(eigenbrains, 1, 2) @@ -654,12 +942,11 @@ def rebuild_eigenbrain(eig_vec: np.ndarray, eigenbrain = eigenbrain.reshape(eigb_shape) else: eigenbrain = np.empty(roimask.shape) - eigenbrain[:] = np.NAN + eigenbrain[:] = np.nan eigenbrain.flat[maskind] = eig_vec.T[index] return eigenbrain - def filter_comparison(components: dict, downsample: int = 4, savepath: str = None,