diff --git a/pydfc/data_loader.py b/pydfc/data_loader.py index 303b8bd..98f9cdf 100644 --- a/pydfc/data_loader.py +++ b/pydfc/data_loader.py @@ -10,6 +10,10 @@ import h5py import numpy as np +from nilearn import datasets +from nilearn.interfaces.fmriprep import load_confounds, load_confounds_strategy +from nilearn.maskers import NiftiLabelsMasker, NiftiSpheresMasker +from nilearn.plotting import find_parcellation_cut_coords from .dfc_utils import intersection, label2network from .time_series import TIME_SERIES @@ -150,13 +154,18 @@ def load_from_array(subj_id2load=None, **params): return BOLD -def nifti2array(nifti_file, confound_strategy="none", standardize=False, n_rois=100): +def extract_region_signals( + nifti_file, + masker_type="NiftiLabelsMasker", + confound_strategy="none", + standardize=False, + labels_img=None, + seeds=None, + radius=None, +): """ this function uses nilearn maskers to extract BOLD signals from nifti files - For now it only works with schaefer atlas, - but you can set the number of rois to extract - {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} returns a numpy array of shape (time, roi) and labels and locs of rois @@ -167,37 +176,38 @@ def nifti2array(nifti_file, confound_strategy="none", standardize=False, n_rois= 'no_motion_no_gsr': motion parameters are used and global signal regression is applied. - """ - from nilearn import datasets - from nilearn.interfaces.fmriprep import load_confounds - from nilearn.maskers import NiftiLabelsMasker - from nilearn.plotting import find_parcellation_cut_coords - - parc = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois) - atlas_filename = parc.maps - labels = parc.labels - # The list of labels does not contain ‘Background’ by default. - # To have proper indexing, you should either manually add ‘Background’ to the list of labels: - # Prepend background label - labels = np.insert(labels, 0, "Background") - - # extract locs - # test! - # check if order is the same as labels - locs, labels_ = find_parcellation_cut_coords( - atlas_filename, background_label=0, return_label_names=True - ) - - # create the masker for extracting time series - masker = NiftiLabelsMasker( - labels_img=atlas_filename, - labels=labels, - resampling_target="data", - standardize=standardize, - ) + 'simple': nilearn's simple preprocessing with + full motion and basic wm_csf + and high_pass - labels = np.delete(labels, 0) # remove the background label - labels = [label.decode() for label in labels] + For now it only works with NiftiLabelsMasker and NiftiSpheresMasker and not with NiftiMapsMasker + masker_type: "NiftiLabelsMasker" or "NiftiSpheresMasker" + """ + if masker_type == "NiftiSpheresMasker": + # check if seeds and radius are provided + if seeds is None or radius is None: + raise ValueError("For NiftiSpheresMasker, seeds and radius must be provided.") + # create the masker for extracting time series + masker = NiftiSpheresMasker( + seeds=seeds, + radius=radius, # radius in mm + standardize=standardize, + ) + elif masker_type == "NiftiLabelsMasker": + # check if labels_img is provided + if labels_img is None: + raise ValueError("For NiftiLabelsMasker, labels_img must be provided.") + # create the masker for extracting time series + masker = NiftiLabelsMasker( + labels_img=labels_img, + resampling_target="data", + standardize=standardize, + ) + else: + raise ValueError( + "masker_type must be 'NiftiLabelsMasker' or 'NiftiSpheresMasker', " + f"but got {masker_type}" + ) ### extract the timeseries if confound_strategy == "none": @@ -223,16 +233,146 @@ def nifti2array(nifti_file, confound_strategy="none", standardize=False, n_rois= time_series = masker.fit_transform( nifti_file, confounds=confounds_simple, sample_mask=sample_mask ) + elif confound_strategy == "simple": + confounds_simple, sample_mask = load_confounds_strategy( + nifti_file, denoise_strategy="simple" + ) + time_series = masker.fit_transform( + nifti_file, confounds=confounds_simple, sample_mask=sample_mask + ) + else: + raise ValueError( + "confound_strategy must be one of 'none', 'no_motion', 'no_motion_no_gsr', or 'simple', " + f"but got {confound_strategy}" + ) + + return time_series + + +def nifti2array( + nifti_file, + masker_type="NiftiLabelsMasker", + confound_strategy="none", + standardize=False, + n_rois=100, + labels_img=None, + seeds=None, + radius=None, + region_names=None, +): + """ + this function uses nilearn maskers to extract + BOLD signals from nifti files + + returns a numpy array of shape (time, roi) + and labels and locs of rois + + confound_strategy: + 'none': no confounds are used + 'no_motion': motion parameters are used + 'no_motion_no_gsr': motion parameters are used + and global signal regression + is applied. + 'simple': nilearn's simple preprocessing with + full motion and basic wm_csf + and high_pass + + For now it only works with NiftiLabelsMasker and NiftiSpheresMasker and not with NiftiMapsMasker + masker_type: "NiftiLabelsMasker" or "NiftiSpheresMasker" + if masker_type is "NiftiLabelsMasker", + labels_img must be provided or n_rois must be provided + if masker_type is "NiftiSpheresMasker", + seeds and radius must be provided + + Note: + when not using Schaefer atlas, make sure + that the labels_img/seeds and region_names are in the same order. + """ + if masker_type == "NiftiLabelsMasker": + if labels_img is None: + # in this case, we will use the schaefer atlas + # we use n_rois to determine the number of rois + assert n_rois in [ + 100, + 200, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + ], "n_rois must be one of {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000}" + # fetch the schaefer atlas + parc = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois) + labels_img = parc.maps + labels = parc.labels + labels = [label.decode() for label in labels] + else: + assert ( + region_names is not None + ), "region_names must be provided if labels_img is provided" + assert type(region_names) is list, "region_names must be a list of strings" + + labels = region_names + + # extract locs from labels_img + # check if order is the same as labels + locs, labels_ = find_parcellation_cut_coords( + labels_img, background_label=0, return_label_names=True + ) # numpy.ndarray of shape (n_labels, 3) + + elif masker_type == "NiftiSpheresMasker": + + # make sure seeds is a list of tuples (x, y, z) + assert seeds is not None, "seeds must be provided for NiftiSpheresMasker" + assert radius is not None, "radius must be provided for NiftiSpheresMasker" + assert type(seeds) is list, "seeds must be a list of tuples (x, y, z)" + assert all( + isinstance(seed, tuple) and len(seed) == 3 for seed in seeds + ), "seeds must be a list of tuples (x, y, z) with 3 elements each" + + locs = np.array(seeds) # seeds should be a list of tuples (x, y, z) + + assert ( + region_names is not None + ), "region_names must be provided if seeds are provided" + assert type(region_names) is list, "region_names must be a list of strings" + + labels = region_names + + else: + raise ValueError( + "masker_type must be 'NiftiLabelsMasker' or 'NiftiSpheresMasker', " + f"but got {masker_type}" + ) + + # extract the timeseries + time_series = extract_region_signals( + nifti_file=nifti_file, + masker_type=masker_type, + confound_strategy=confound_strategy, + standardize=standardize, + labels_img=labels_img, + seeds=seeds, + radius=radius, + ) return time_series, labels, locs def nifti2timeseries( nifti_file, - n_rois, Fs, subj_id, confound_strategy="none", + masker_type="NiftiLabelsMasker", + n_rois=100, + labels_img=None, + seeds=None, + radius=None, + region_names=None, standardize=False, TS_name=None, session=None, @@ -242,15 +382,50 @@ def nifti2timeseries( it uses nilearn maskers to extract ROI signals from nifti files and returns a TIME_SERIES object - For now it only works with schaefer atlas, - but you can set the number of rois to extract - {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} + Parameters + ---------- + nifti_file : str + path to the nifti file + Fs : float + sampling frequency of the data + subj_id : str + subject ID, must start with 'sub-' + confound_strategy : str, optional + strategy for confound regression, by default "none" + masker_type : str, optional + type of masker to use, by default "NiftiLabelsMasker" + n_rois : int, optional + number of regions of interest to extract, by default 100 + labels_img : str, optional + path to the labels image, by default None + seeds : list, optional + list of tuples (x, y, z) for NiftiSpheresMasker + by default None + radius : float, optional + radius in mm for NiftiSpheresMasker, by default None + region_names : list, optional + list of region names for NiftiLabelsMasker or NiftiSpheresMasker, + by default None + standardize : bool, optional + whether to standardize the time series, by default False + TS_name : str, optional + name of the time series, by default None + session : str, optional + session name, by default None + + For more information on confound_strategy, masker_type, and other parameters, + see the documentation of the nifti2array function. """ time_series, labels, locs = nifti2array( nifti_file=nifti_file, confound_strategy=confound_strategy, standardize=standardize, + masker_type=masker_type, n_rois=n_rois, + labels_img=labels_img, + seeds=seeds, + radius=radius, + region_names=region_names, ) assert type(locs) is np.ndarray, "locs must be a numpy array" @@ -280,8 +455,13 @@ def nifti2timeseries( def multi_nifti2timeseries( nifti_files_list, subj_id_list, - n_rois, Fs, + masker_type="NiftiLabelsMasker", + n_rois=100, + labels_img=None, + seeds=None, + radius=None, + region_names=None, confound_strategy="none", standardize=False, TS_name=None, @@ -295,10 +475,15 @@ def multi_nifti2timeseries( if BOLD_multi is None: BOLD_multi = nifti2timeseries( nifti_file=nifti_file, - n_rois=n_rois, - Fs=Fs, subj_id=subj_id, + Fs=Fs, confound_strategy=confound_strategy, + masker_type=masker_type, + n_rois=n_rois, + labels_img=labels_img, + seeds=seeds, + radius=radius, + region_names=region_names, standardize=standardize, TS_name=TS_name, session=session, @@ -307,10 +492,15 @@ def multi_nifti2timeseries( BOLD_multi.concat_ts( nifti2timeseries( nifti_file=nifti_file, - n_rois=n_rois, - Fs=Fs, subj_id=subj_id, + Fs=Fs, confound_strategy=confound_strategy, + masker_type=masker_type, + n_rois=n_rois, + labels_img=labels_img, + seeds=seeds, + radius=radius, + region_names=region_names, standardize=standardize, TS_name=TS_name, session=session, diff --git a/pydfc/dfc.py b/pydfc/dfc.py index e02e662..6cf595e 100644 --- a/pydfc/dfc.py +++ b/pydfc/dfc.py @@ -46,6 +46,9 @@ def __init__(self, measure=None): self.measure_ = measure self.FCSs_ = None # is a dict self.FCS_idx_ = None # is a dict + self.FCS_proba_ = ( + None # is a 2D numpy array of probabilities for each FCS at each time point + ) # info of the time series used for dFC estimation self.TS_info_ = None self.TR_array_ = None @@ -89,6 +92,14 @@ def FCSs(self): def FCS_idx(self): return self.FCS_idx_ + @property + def FCS_proba(self): + """ + FCS_proba is a 2D numpy array of probabilities for each FCS at each time point + shape = (n_time, n_states) + """ + return self.FCS_proba_ + # test this @property def FCS_idx_array(self): @@ -167,13 +178,14 @@ def dFC2dict(self, TRs=None): if type(TRs) is list: TRs = np.array(TRs) TRs = TRs.astype(int) + dFC_mat = self.get_dFC_mat(TRs=TRs) + dFC_dict = {} for k, TR in enumerate(TRs): dFC_dict[f"TR{TR}"] = dFC_mat[k, :, :] return dFC_dict - # test this def get_dFC_mat(self, TRs=None, num_samples=None): """ get dFC matrices corresponding to @@ -184,7 +196,6 @@ def get_dFC_mat(self, TRs=None, num_samples=None): return picked TRs if num_samples > len(TRs) -> picks all TRs """ - if TRs is None: TRs = self.TR_array @@ -199,7 +210,8 @@ def get_dFC_mat(self, TRs=None, num_samples=None): dFC_mat = list() for TR in TRs: - dFC_mat.append(self.FCSs[self.FCS_idx[f"TR{TR}"]]) + FC_mat = self.FCSs[self.FCS_idx[f"TR{TR}"]] + dFC_mat.append(FC_mat) dFC_mat = np.array(dFC_mat) @@ -230,10 +242,11 @@ def SWed_dFC_mat(self, W=None, n_overlap=None, tapered_window=False): return dFC_mat_new - def set_dFC(self, FCSs, FCS_idx=None, TS_info=None, TR_array=None): + def set_dFC(self, FCSs, FCS_idx=None, FCS_proba=None, TS_info=None, TR_array=None): """ - FCSs: a 3D numpy array of FC matrices with shape (n_time, n_regions, n_regions) - FCS_idx: a list of indices that correspond to each FC matrix in FCSs over time + FCSs: a 3D numpy array of FC matrices with shape (n_states, n_regions, n_regions), for state-free methods: (n_time, n_regions, n_regions) + FCS_idx: a list of indices that correspond to each FC matrix in FCSs over time, used for state-based methods. + FCS_proba: a 2D numpy array of probabilities for each FCS at each time point, shape = (n_time, n_states), used for state-based methods. """ if len(FCSs.shape) == 2: @@ -268,6 +281,21 @@ def set_dFC(self, FCSs, FCS_idx=None, TS_info=None, TR_array=None): assert np.sum(np.abs(np.sort(TR_array) - TR_array)) == 0.0, "TRs not sorted !" + if FCS_proba is not None and FCS_idx is not None: + assert FCS_proba.shape[0] == len( + FCS_idx + ), "FCS_proba shape does not match FCSs shape (n_time)." + assert ( + FCS_proba.shape[1] == FCSs.shape[0] + ), "FCS_proba shape does not match FCSs shape (n_states)." + assert np.allclose( + FCS_proba.sum(axis=1), 1 + ), "FCS_proba probabilities must sum to 1 for each time point." + assert len(TR_array) == FCS_proba.shape[0], ( + "TR_array length does not match FCS_proba shape (n_time). " + f"TR_array length: {len(TR_array)}, FCS_proba shape: {FCS_proba.shape}" + ) + # the input FCS_idx is ranged from 0 to len(FCS)-1 but we shift it to 1 to len(FCS) self.FCSs_ = {} for i, FCS in enumerate(FCSs): @@ -277,6 +305,8 @@ def set_dFC(self, FCSs, FCS_idx=None, TS_info=None, TR_array=None): for i, idx in enumerate(FCS_idx): self.FCS_idx_[f"TR{TR_array[i]}"] = f"FCS{idx + 1}" # "FCS" + str(idx + 1) + self.FCS_proba_ = FCS_proba + self.TS_info_ = TS_info self.n_regions_ = FCSs.shape[1] self.n_time_ = len(self.FCS_idx_) diff --git a/pydfc/dfc_methods/cap.py b/pydfc/dfc_methods/cap.py index 0ee51bf..ed0b21c 100644 --- a/pydfc/dfc_methods/cap.py +++ b/pydfc/dfc_methods/cap.py @@ -8,6 +8,7 @@ import time import numpy as np +from scipy.special import softmax from sklearn.cluster import KMeans from ..dfc import DFC @@ -151,12 +152,23 @@ def estimate_dFC(self, time_series): act_vecs = time_series.data.T Z = self.kmeans_.predict(act_vecs.astype(np.float32)) + # get distances from the cluster centers for each sample + distances = self.kmeans_.transform( + act_vecs.astype(np.float32) + ) # shape: (n_samples, n_clusters) = (n_time, n_states) + # normalize distances to semi probabilities + rel = -distances + rel = rel - rel.min(axis=1, keepdims=True) # shift min to 0 + rel = rel / rel.sum(axis=1, keepdims=True) # normalize + Z_proba = rel # shape: (n_samples, n_clusters) = (n_time, n_states) # record time self.set_dFC_assess_time(time.time() - tic) dFC = DFC(measure=self) - dFC.set_dFC(FCSs=self.FCS_, FCS_idx=Z, TS_info=time_series.info_dict) + dFC.set_dFC( + FCSs=self.FCS_, FCS_idx=Z, FCS_proba=Z_proba, TS_info=time_series.info_dict + ) return dFC diff --git a/pydfc/dfc_methods/continuous_hmm.py b/pydfc/dfc_methods/continuous_hmm.py index f0a8fe8..7082d3d 100644 --- a/pydfc/dfc_methods/continuous_hmm.py +++ b/pydfc/dfc_methods/continuous_hmm.py @@ -122,12 +122,18 @@ def estimate_dFC(self, time_series): tic = time.time() Z = self.hmm_model.predict(time_series.data.T) + # get pribabilities for each state for each time point + Z_proba = self.hmm_model.predict_proba( + time_series.data.T + ) # shape: (n_samples, n_components) = (n_time, n_states) # record time self.set_dFC_assess_time(time.time() - tic) dFC = DFC(measure=self) - dFC.set_dFC(FCSs=self.FCS_, FCS_idx=Z, TS_info=time_series.info_dict) + dFC.set_dFC( + FCSs=self.FCS_, FCS_idx=Z, FCS_proba=Z_proba, TS_info=time_series.info_dict + ) return dFC diff --git a/pydfc/dfc_methods/discrete_hmm.py b/pydfc/dfc_methods/discrete_hmm.py index 77e1a91..03bde05 100644 --- a/pydfc/dfc_methods/discrete_hmm.py +++ b/pydfc/dfc_methods/discrete_hmm.py @@ -183,6 +183,10 @@ def estimate_dFC(self, time_series): Obs_seq = FCC.FCS_idx_array.reshape(-1, 1) Z = self.hmm_model.predict(Obs_seq) + # get pribabilities for each state for each time point + Z_proba = self.hmm_model.predict_proba( + Obs_seq + ) # shape: (n_samples, n_components) = (n_time, n_states) # record time self.set_dFC_assess_time(time.time() - tic) @@ -191,6 +195,7 @@ def estimate_dFC(self, time_series): dFC.set_dFC( FCSs=self.FCS_, FCS_idx=Z, + FCS_proba=Z_proba, TS_info=time_series.info_dict, TR_array=FCC.TR_array, ) diff --git a/pydfc/dfc_methods/sliding_window_clustr.py b/pydfc/dfc_methods/sliding_window_clustr.py index 25d7386..6674087 100644 --- a/pydfc/dfc_methods/sliding_window_clustr.py +++ b/pydfc/dfc_methods/sliding_window_clustr.py @@ -8,7 +8,9 @@ import time import numpy as np +from scipy.special import softmax from sklearn.cluster import KMeans +from sklearn.preprocessing import StandardScaler from ..dfc import DFC from ..dfc_utils import KMeansCustom, dFC_mat2vec, dFC_vec2mat @@ -82,6 +84,10 @@ def __init__(self, **params): self.params["measure_name"] = "Clustering" self.params["is_state_based"] = True + if self.params["clstr_distance"] is None: + # Default clustering distance is euclidean + self.params["clstr_distance"] = "euclidean" + assert ( self.params["clstr_distance"] == "euclidean" or self.params["clstr_distance"] == "manhattan" @@ -239,9 +245,27 @@ def estimate_dFC(self, time_series): if self.params["clstr_distance"] == "manhattan": ########### Manhattan Clustering ############## Z = self.kmeans_.predict(F.astype(np.float32)) + # get distances from the cluster centers for each sample + distances = self.kmeans_.transform( + F.astype(np.float32) + ) # shape: (n_samples, n_clusters) + # normalize distances to semi probabilities + rel = -distances + rel = rel - rel.min(axis=1, keepdims=True) # shift min to 0 + rel = rel / rel.sum(axis=1, keepdims=True) # normalize + Z_proba = rel # shape: (n_samples, n_clusters) = (n_time, n_states) else: ########### Euclidean Clustering ############## Z = self.kmeans_.predict(F.astype(np.float32)) + # get distances from the cluster centers for each sample + distances = self.kmeans_.transform( + F.astype(np.float32) + ) # shape: (n_samples, n_clusters) + # normalize distances to semi probabilities + rel = -distances + rel = rel - rel.min(axis=1, keepdims=True) # shift min to 0 + rel = rel / rel.sum(axis=1, keepdims=True) # normalize + Z_proba = rel # shape: (n_samples, n_clusters) = (n_time, n_states) # record time self.set_dFC_assess_time(time.time() - tic) @@ -250,6 +274,7 @@ def estimate_dFC(self, time_series): dFC.set_dFC( FCSs=self.FCS_, FCS_idx=Z, + FCS_proba=Z_proba, TS_info=time_series.info_dict, TR_array=dFC_raw.TR_array, ) diff --git a/pydfc/dfc_methods/windowless.py b/pydfc/dfc_methods/windowless.py index 65a5b15..2c4e722 100644 --- a/pydfc/dfc_methods/windowless.py +++ b/pydfc/dfc_methods/windowless.py @@ -9,6 +9,7 @@ import numpy as np from ksvd import ApproximateKSVD +from sklearn.linear_model import orthogonal_mp_gram from ..dfc import DFC from ..time_series import TIME_SERIES @@ -110,6 +111,27 @@ def estimate_FCS(self, time_series): return self + def transform_proba(self, D, X, n_nonzero_coefs=None): + """ + returns the probability of each state for each time point + D: dictionary, shape = (n_states, n_regions) + X: time series data, shape = (n_time, n_regions) + n_nonzero_coefs: number of non-zero coefficients to use in orthogonal matching pursuit + Returns: + Z_proba: shape = (n_time, n_states) + """ + gram = D.dot(D.T) # shape: (n_features, n_features) = (n_states, n_states) + Xy = D.dot(X.T) # shape: (n_features, n_targets) = (n_states, n_time) + + if n_nonzero_coefs is None: + n_nonzero_coefs = D.shape[0] + + gamma = orthogonal_mp_gram(gram, Xy, n_nonzero_coefs=n_nonzero_coefs).T + + Z_proba = np.abs(gamma) / np.abs(gamma).sum(axis=1, keepdims=True) + + return Z_proba + def estimate_dFC(self, time_series): assert ( @@ -125,17 +147,36 @@ def estimate_dFC(self, time_series): # start timing tic = time.time() - gamma = self.aksvd.transform(time_series.data.T) + gamma = self.aksvd.transform(time_series.data.T) # shape: (n_time, n_states) Z = list() for i in range(time_series.n_time): Z.append(np.argwhere(gamma[i, :] != 0)[0, 0]) + # get probability for each state for each time point + Z_proba = self.transform_proba( + D=self.dictionary, + X=time_series.data.T, + n_nonzero_coefs=self.params["n_states"], + ) # shape: (n_targets, n_features) = (n_time, n_states) + + assert Z_proba.shape[0] == time_series.n_time, ( + "Z_proba shape does not match time_series.n_time. " + f"Z_proba shape: {Z_proba.shape}, time_series.n_time: {time_series.n_time}" + ) + + assert Z_proba.shape[1] == self.params["n_states"], ( + "Z_proba shape does not match n_states. " + f"Z_proba shape: {Z_proba.shape}, n_states: {self.params['n_states']}" + ) + # record time self.set_dFC_assess_time(time.time() - tic) dFC = DFC(measure=self) - dFC.set_dFC(FCSs=self.FCS_, FCS_idx=Z, TS_info=time_series.info_dict) + dFC.set_dFC( + FCSs=self.FCS_, FCS_idx=Z, FCS_proba=Z_proba, TS_info=time_series.info_dict + ) return dFC diff --git a/pydfc/dfc_utils.py b/pydfc/dfc_utils.py index e1a449c..6a351c3 100644 --- a/pydfc/dfc_utils.py +++ b/pydfc/dfc_utils.py @@ -460,11 +460,13 @@ def _custom_distance(self, p1, p2): return pairwise_distances([p1], [p2], metric=self.metric)[0][0] def _assign_clusters(self, X, centroids): - clusters = [] - for x in X: - distances = [self._custom_distance(x, c) for c in centroids] - clusters.append(np.argmin(distances)) - return clusters + if self.metric == "manhattan": + distances = np.abs(X[:, None, :] - centroids[None, :, :]).sum(axis=2) + elif self.metric == "euclidean": + distances = np.linalg.norm(X[:, None, :] - centroids[None, :, :], axis=2) + else: + distances = pairwise_distances(X, centroids, metric=self.metric) + return np.argmin(distances, axis=1) def _compute_centroids(self, X, labels): centroids = [] @@ -474,7 +476,6 @@ def _compute_centroids(self, X, labels): return np.array(centroids) def fit(self, X): - X = deepcopy(X) min_inertia = None best_centroids = None best_labels = None @@ -489,12 +490,18 @@ def fit(self, X): if np.allclose(centroids, new_centroids, atol=1e-6): break centroids = new_centroids - inertia = np.sum( - [ - self._custom_distance(x, centroids[label]) ** 2 - for x, label in zip(X, labels) - ] - ) + + if self.metric == "manhattan": + distances = np.abs(X - centroids[labels]).sum(axis=1) + elif self.metric == "euclidean": + distances = np.linalg.norm(X - centroids[labels], axis=1) + else: + distances = pairwise_distances( + X, centroids[labels], metric=self.metric + ).diagonal() + + inertia = np.sum(distances**2) + if min_inertia is None or inertia < min_inertia: min_inertia = inertia best_centroids = centroids @@ -506,9 +513,14 @@ def fit(self, X): return self def predict(self, X): - X = deepcopy(X) return self._assign_clusters(X, self.cluster_centers_) + def transform(self, X): + """ + Transform the data to cluster centers + """ + return pairwise_distances(X, self.cluster_centers_, metric=self.metric) + #################################################################################################### diff --git a/pydfc/time_series.py b/pydfc/time_series.py index 7eab737..bb323c6 100644 --- a/pydfc/time_series.py +++ b/pydfc/time_series.py @@ -83,6 +83,11 @@ def __init__( else: self.data_dict_[subj_id]["time_array"] = time_array + # make sure node_labels have same number of regions as self.n_regions_ + assert ( + len(node_labels) == self.n_regions_ + ), "node_labels must have the same length as the number of regions in data." + self.locs_ = locs self.node_labels_ = node_labels diff --git a/pyproject.toml b/pyproject.toml index 0b53649..c1d116a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ 'ksvd', 'matplotlib', 'networkx', - 'nilearn>=0.10.2,!=0.10.3', + 'nilearn==0.10.2', 'pycwt', 'seaborn', 'statsmodels' diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 4c1f3db..7142336 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -4,30 +4,36 @@ from pydfc.data_loader import nifti2timeseries -# @pytest.fixture(scope="session") -# def rest_file(tmp_path_factory): -# URL = "https://s3.amazonaws.com/openneuro.org/ds002785/derivatives/fmriprep/sub-0001/func/sub-0001_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz?versionId=UfCs4xtwIEPDgmb32qFbtMokl_jxLUKr" -# tmpdir = tmp_path_factory.mktemp("data") -# file_path = tmpdir / "sub-0001_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz" -# with httpx.stream("GET", URL) as response: -# with file_path.open("wb") as f: -# for chunk in response.iter_bytes(): -# f.write(chunk) -# -# return file_path - @pytest.fixture -def simulated_bold_data(tmp_path): - img = nb.Nifti1Image(np.random.rand(10, 10, 10, 100), np.eye(4)) - img.to_filename(tmp_path / "simulated_bold.nii.gz") - return tmp_path / "simulated_bold.nii.gz" +def simulated_bold_and_label(tmp_path): + # Simulated BOLD data + bold_data = np.random.rand(10, 10, 10, 100) + affine = np.eye(4) + bold_img = nb.Nifti1Image(bold_data, affine) + bold_file = tmp_path / "bold.nii.gz" + bold_img.to_filename(bold_file) + + # Simulated label image with 3 ROIs (labels 1, 2, 3) + labels = np.zeros((10, 10, 10), dtype=np.int32) + labels[1:4, 1:4, 1:4] = 1 + labels[5:7, 5:7, 5:7] = 2 + labels[7:9, 1:3, 1:3] = 3 + label_img = nb.Nifti1Image(labels, affine) + label_file = tmp_path / "labels.nii.gz" + label_img.to_filename(label_file) + + return str(bold_file), str(label_file) -def test_load(simulated_bold_data): - nifti2timeseries( - nifti_file=str(simulated_bold_data), - n_rois=100, +def test_load(simulated_bold_and_label): + bold_file, label_file = simulated_bold_and_label + ts = nifti2timeseries( + nifti_file=bold_file, + labels_img=label_file, + region_names=["1", "2", "3"], Fs=1 / 0.75, subj_id="sub-0001", ) + assert ts is not None + assert ts.n_regions == 3