diff --git a/examples/eg__analysis_model.py b/examples/eg__analysis_model.py new file mode 100644 index 00000000..0838892e --- /dev/null +++ b/examples/eg__analysis_model.py @@ -0,0 +1,178 @@ +""" +.. _ex-modelanalysis: + +======================================================== +Modelling TMS-EEG evoked responses +======================================================== + +This example shows the analysis from model: + +1. model parameters +2. networks +3. neural states + +""" +# %% +# First we must import the necessary packages required for the example: + +# System-based packages +import os +import sys +sys.path.append('..') + + +# Whobpyt modules taken from the whobpyt package +import whobpyt +from whobpyt.datatypes import Parameter as par, Timeseries +from whobpyt.models.jansen_rit import JansenRitModel,JansenRitParams +from whobpyt.run import ModelFitting +from whobpyt.optimization.custom_cost_JR import CostsJR +from whobpyt.datasets.fetchers import fetch_egtmseeg + +# Python Packages used for processing and displaying given analytical data (supported for .mat and Google Drive files) +import numpy as np +import pandas as pd +import scipy.io +import gdown +import pickle +import warnings +warnings.filterwarnings('ignore') +import matplotlib.pyplot as plt # Plotting library (For Visualization) +import seaborn as sns + +import mne # Neuroimaging package + + + + + + +# %% +# load in a previously completed model fitting results object +full_run_fname = os.path.join(data_dir, 'Subject_1_low_voltage_fittingresults_stim_exp.pkl') +F = pickle.load(open(full_run_fname, 'rb')) + + +### get labels for Yeo 200 +url = 'https://raw.githubusercontent.com/ThomasYeoLab/CBIG/master/stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/Parcellations/MNI/Centroid_coordinates/Schaefer2018_200Parcels_7Networks_order_FSLMNI152_2mm.Centroid_RAS.csv' +atlas = pd.read_csv(url) +labels = atlas['ROI Name'] + +# get networks +nets = [label.split('_')[2] for label in labels] +net_names = np.unique(np.array(nets)) + +# %% +# 1. model parameters +# ----------------------------------- + +# %% +# Plots of parameter values over Training (check if converges) +fig, axs = plt.subplots(2,2, figsize = (12,8)) +paras = ['c1', 'c2', 'c3', 'c4'] +for i in range(len(paras)): + axs[i//2,i%2].plot(F.trainingStats.fit_params[paras[i]]) + axs[i//2, i%2].set_title(paras[i]) +plt.title("Select Variables Changing Over Training Epochs") + +# %% +# Plots of parameter values over Training (prior vs post) +fig, axs = plt.subplots(2,2, figsize = (12,8)) +paras = ['c1', 'c2', 'c3', 'c4'] +for i in range(len(paras)): + axs[i//2,i%2].hist(F.trainingStats.fit_params[paras[i]][:500], label='prior') + axs[i//2,i%2].hist(F.trainingStats.fit_params[paras[i]][-500:], label='post') + axs[i//2, i%2].set_title(paras[i]) +plt.title("Prior vs Post") + +# %% +# 2. Networks +# ----------------------------------- +fig, axs = plt.subplots(1,3, figsize = (12,8)) +networks_frommodels = ['p2p', 'p2e', 'p2i'] +sns.heatmap(F.model.w_p2p.detach().numpy(), cmap = 'bwr', center=0, ax=axs[0]) +axs[0].set_title(networks_frommodels[0]) +sns.heatmap(F.model.w_p2p.detach().numpy(), cmap = 'bwr', center=0, ax=axs[1]) +axs[1].set_title(networks_frommodels[1]) +sns.heatmap(F.model.w_p2p.detach().numpy(), cmap = 'bwr', center=0, ax=axs[2]) +axs[2].set_title(networks_frommodels[2]) + + +# %% +# 3. Neural states +# ----------------------------------- + +#### plot E response on each networks +fig, ax = plt.subplots(2,4, figsize=(12,10), sharey= True) +t = np.linspace(-0.1,0.3, 400) + +for i, net in enumerate(net_names): + mask = np.array(nets) == net + ax[i//4, i%4].plot(t, F.lastRec['E'].npTS()[mask,:].mean(0).T) + ax[i//4, i%4].set_title(net) +plt.suptitle('Test: E') +plt.show() + +### plot I response at each networks +fig, ax = plt.subplots(2,4, figsize=(12,10), sharey= True) +t = np.linspace(-0.1,0.3, 400) + +for i, net in enumerate(net_names): + mask = np.array(nets) == net + ax[i//4, i%4].plot(t, F.lastRec['I'].npTS()[mask,:].mean(0).T) + ax[i//4, i%4].set_title(net) +plt.suptitle('Test: I') +plt.show() + +### plot P response at each networks +fig, ax = plt.subplots(2,4, figsize=(12,10), sharey= True) +t = np.linspace(-0.1,0.3, 400) + +for i, net in enumerate(net_names): + mask = np.array(nets) == net + ax[i//4, i%4].plot(t, F.lastRec['P'].npTS()[mask,:].mean(0).T) + ax[i//4, i%4].set_title(net) +plt.suptitle('Test: P') +plt.show() + + +### plot phase of E at each network +j = complex(0,1) +fig, ax = plt.subplots(2,4, figsize=(12,10), sharey= True) +t = np.linspace(-0.1,0.3, 400) + +phase = np.angle(F.lastRec['E'].npTS()+j*F.lastRec['Ev'].npTS()) +for i, net in enumerate(net_names): + mask = np.array(nets) == net + ax[i//4, i%4].plot(t, phase[mask,:].mean(0).T) + ax[i//4, i%4].set_title(net) +plt.suptitle('Test: phase E') +plt.show() + +### plot I phase at each network +j = complex(0,1) +fig, ax = plt.subplots(2,4, figsize=(12,10), sharey= True) +t = np.linspace(-0.1,0.3, 400) + +phase = np.angle(F.lastRec['I'].npTS()+j*F.lastRec['Iv'].npTS()) +for i, net in enumerate(net_names): + mask = np.array(nets) == net + ax[i//4, i%4].plot(t, phase[mask,:].mean(0).T) + ax[i//4, i%4].set_title(net) +plt.suptitle('Test: phase I') +plt.show() + +### plot P phase at each network + +j = complex(0,1) +fig, ax = plt.subplots(2,4, figsize=(12,10), sharey= True) +t = np.linspace(-0.1,0.3, 400) + +phase = np.angle(F.lastRec['P'].npTS()+j*F.lastRec['Pv'].npTS()) +for i, net in enumerate(net_names): + mask = np.array(nets) == net + ax[i//4, i%4].plot(t, phase[mask,:].mean(0).T) + ax[i//4, i%4].set_title(net) +plt.suptitle('Test: phase P') +plt.show() + diff --git a/examples/eg__tmseeg_fq.py b/examples/eg__tmseeg_fq.py new file mode 100644 index 00000000..ee9736b7 --- /dev/null +++ b/examples/eg__tmseeg_fq.py @@ -0,0 +1,146 @@ +""" +.. _ex-tmseeg: + +======================================================== +Modelling TMS-EEG evoked responses +======================================================== + +This example shows how to organize the empirical eeg data, set-up JR model with user-defined learnable model +parameters and train model. After train how to test model with new inputs (noises) to generate simulated EEG. +Furethermore, show some analysis based on uncovered neural states from the model. + +""" +# %% +# First we must import the necessary packages required for the example: + +# System-based packages +import os +import sys +sys.path.append('..') + + +# Whobpyt modules taken from the whobpyt package +import whobpyt +from whobpyt.datatypes import Parameter as par, Timeseries +from whobpyt.models.linear_fq import LINEAR_FQ, ParamsLinearFreqs +from whobpyt.run import Model_fitting_fq +from whobpyt.optimization.cost_Freq import CostsFreqs +from whobpyt.datasets.fetchers import fetch_egtmseeg + +# Python Packages used for processing and displaying given analytical data (supported for .mat and Google Drive files) +import numpy as np +import pandas as pd +import scipy.io +import gdown +import pickle +import warnings +warnings.filterwarnings('ignore') +import matplotlib.pyplot as plt # Plotting library (For Visualization) + +import mne # Neuroimaging package + + + +# %% +# Download and load example data +data_dir = fetch_egtmseeg() + +# %% +# Load EEG data +eeg_file_name = os.path.join(data_dir, 'Subject_1_low_voltage.fif') +epoched = mne.read_epochs(eeg_file_name, verbose=False); +evoked = epoched.get_data() +eeg = np.concatenate(list(evoked), axis=1) + +# %% +# Load Atlas +atlas_file_name = os.path.join(data_dir, 'Schaefer2018_200Parcels_7Networks_order_FSLMNI152_2mm.Centroid_RAS.txt') +atlas = pd.read_csv(atlas_file_name) +labels = atlas['ROI Name'] +coords = np.array([atlas['R'], atlas['A'], atlas['S']]).T +conduction_velocity = 5 #in ms + +# %% +# Compute the distance matrix which is used to calculate delay between regions +dist = np.zeros((coords.shape[0], coords.shape[0])) + +for roi1 in range(coords.shape[0]): + for roi2 in range(coords.shape[0]): + dist[roi1, roi2] = np.sqrt(np.sum((coords[roi1,:] - coords[roi2,:])**2, axis=0)) + dist[roi1, roi2] = np.sqrt(np.sum((coords[roi1,:] - coords[roi2,:])**2, axis=0)) + + +# %% +# Load the stim weights matrix which encode where to inject the external input +stim_weights = np.load(os.path.join(data_dir, 'stim_weights.npy')) +stim_weights_thr = stim_weights.copy() +labels[np.where(stim_weights_thr>0)[0]] + +# %% +# Load the structural connectivity matrix +sc_file = os.path.join(data_dir, 'Schaefer2018_200Parcels_7Networks_count.csv') +sc_df = pd.read_csv(sc_file, header=None, sep=' ') +sc = sc_df.values +sc = np.log1p(sc) / np.linalg.norm(np.log1p(sc)) + +u_l, s_l, v_l = np.linalg.svd(sc) + +# %% +# Load the leadfield matrix +lm_file = os.path.join(data_dir, 'Subject_1_low_voltage_lf.npy') +lm = np.load(lm_file) +print(lm.shape) +ki0 =stim_weights_thr[:,np.newaxis] +delays = dist/conduction_velocity + +# %% +# define options for JR model: batch size integration step and sampling rate of the empirical eeg +# the number of regions in the parcellation and the number of channels + +node_size = sc.shape[0] +output_size = eeg.shape[0] + +sim_psd_source, sim_freqs_source = mne.time_frequency.psd_array_welch(eeg, sfreq=1000,fmin=1,fmax=50,n_fft=1900, n_per_seg=2000) + +psd_train={} +psd_train['fq'] =[] +psd_train['psd'] = [] +epochs_size = 1500 +for i in range(epochs_size): + fq_test_low =np.random.uniform(1,50, 500) + psd_train['fq'].append(np.sort(fq_test_low)) + sim_psd_test = [] + + for w in psd_train['fq'][i]: + + ind = np.where(sim_freqs_source > w)[0][0] + #print(ind) + #print(sim_freqs_source[ind-1], w, sim_freqs_source[ind]) + per = np.abs(w-sim_freqs_source[ind-1])/(np.abs(w-sim_freqs_source[ind])+np.abs(w-sim_freqs_source[ind-1])) + #print(sim_psd_source.T[ind-1][0], ((1-per)*per*sim_psd_source.T[ind-1] +per*sim_psd_source.T[ind])[0], sim_psd_source.T[ind][0]) + sim_psd_test.append(1*((1-per)*sim_psd_source.T[ind-1] +per*sim_psd_source.T[ind])) + + psd_train['psd'].append(np.array(sim_psd_test).T) + +psd_train['fq'] = np.array(psd_train['fq']) +psd_train['psd'] = np.array(psd_train['psd']) +lm_v = 0.01*np.random.randn(output_size,200) +params = ParamsLinearFreqs(mu = par(5,5, 0.5, True), g = par(100,100,1,True), eigvals= par(s_l, s_l, .1 * np.ones((node_size,1)), True),a = par(50, 50, 1, True), \ + b = par(20,20, 0.5,True), A = par(3, 3, 0.2, True), B = par(22), C1 = par(100, 100, 1, True), C2 = par(30, 30, 1, True),c = par(0.2, 0.2, 0.001, True), + lm=par(lm, lm, .1 * np.ones((output_size, node_size))+lm_v, True),std_in= par(1000000)) + + +# %% +# call model want to fit +model = LINEAR_FQ(params, node_size =node_size, output_size=output_size, sc_eigvecs =u_l, dist =dist) +# create objective function +ObjFun = CostsFreqs( model) + +# %% +# call model fit +F = Model_fitting_fq(psd_train, epochs_size, model, ObjFun) +# %% + + +F.train() +fq_test, psd_test = F.test(sim_freqs_source) \ No newline at end of file diff --git a/whobpyt/datasets/fetchers.py b/whobpyt/datasets/fetchers.py index 3321d390..75e002ed 100644 --- a/whobpyt/datasets/fetchers.py +++ b/whobpyt/datasets/fetchers.py @@ -171,6 +171,9 @@ def fetch_egtmseeg(dest_folder=None, redownload=False): if os.path.isdir(dest_folder) and redownload == True: os.system('rm -rf %s' %dest_folder) + + total_files = len(files_dict) + # If the folder does not exist, create it and download the files if not os.path.isdir(dest_folder): @@ -178,10 +181,13 @@ def fetch_egtmseeg(dest_folder=None, redownload=False): os.chdir(dest_folder) - dlcode = osf_folder_url - pull_file(dlcode, file_name, download_method='wget') - - os.chdir(cwd) + for file_code, file_name in files_dict.items(): + dlcode = osf_url_pfx + '/' + file_code + pull_file(dlcode, file_name, download_method='wget') + + + os.chdir(cwd) + return dest_folder diff --git a/whobpyt/models/jansen_rit/jansen_rit.py b/whobpyt/models/jansen_rit/jansen_rit.py index 67029554..09161dd2 100644 --- a/whobpyt/models/jansen_rit/jansen_rit.py +++ b/whobpyt/models/jansen_rit/jansen_rit.py @@ -437,27 +437,27 @@ def forward(self, external, hx, hE): # Run through the number of specified sample points for this window for i_window in range(self.TRs_per_window): + # Collect the delayed inputs: + + # i) index the history of E + Ed = pttranspose(hE.clone().gather(1,self.delays), 0, 1) + + # ii) multiply the past states by the connectivity weights matrix, and sum over rows + LEd_p2e = ptsum(w_n_f * Ed, 1) + LEd_p2i = -ptsum(w_n_b * Ed, 1) + LEd_p2p = ptsum(w_n_l * Ed, 1) + # iii) reshape for next step + LEd_p2e = ptreshape(LEd_p2e, (n_nodes, 1)) + LEd_p2i = ptreshape(LEd_p2i, (n_nodes, 1)) + LEd_p2p = ptreshape(LEd_p2p, (n_nodes, 1)) # For each sample point, run the model by solving the differential # equations for a defined number of integration steps, # and keep only the final activity state within this set of steps for step_i in range(self.steps_per_TR): - # Collect the delayed inputs: - - # i) index the history of E - Ed = pttranspose(hE.clone().gather(1,self.delays), 0, 1) - - # ii) multiply the past states by the connectivity weights matrix, and sum over rows - LEd_p2e = ptsum(w_n_f * Ed, 1) - LEd_p2i = -ptsum(w_n_b * Ed, 1) - LEd_p2p = ptsum(w_n_l * Ed, 1) - # iii) reshape for next step - LEd_p2e = ptreshape(LEd_p2e, (n_nodes, 1)) - LEd_p2i = ptreshape(LEd_p2i, (n_nodes, 1)) - LEd_p2p = ptreshape(LEd_p2p, (n_nodes, 1)) # iv) if specified, add the laplacian component (self-connections from diagonals) if self.use_laplacian: diff --git a/whobpyt/models/linear_fq/__init__.py b/whobpyt/models/linear_fq/__init__.py new file mode 100644 index 00000000..17ec521e --- /dev/null +++ b/whobpyt/models/linear_fq/__init__.py @@ -0,0 +1 @@ +from .linear_fq import LINEAR_FQ, ParamsLinearFreqs \ No newline at end of file diff --git a/whobpyt/models/linear_fq/linear_fq.py b/whobpyt/models/linear_fq/linear_fq.py new file mode 100644 index 00000000..686f92d9 --- /dev/null +++ b/whobpyt/models/linear_fq/linear_fq.py @@ -0,0 +1,146 @@ +# PyTorch stuff +import torch +from torch.nn.parameter import Parameter as ptParameter +from torch.nn import ReLU as ptReLU +from torch.linalg import norm as ptnorm +from torch import (tensor as pttensor, float32 as ptfloat32, sum as ptsum, exp as ptexp, diag as ptdiag, + transpose as pttranspose, zeros_like as ptzeros_like, int64 as ptint64, randn as ptrandn, + matmul as ptmatmul, tanh as pttanh, matmul as ptmatmul, reshape as ptreshape, sqrt as ptsqrt, + ones as ptones, cat as ptcat) + +# Numpy stuff +from numpy.random import uniform +from numpy import ones,zeros +import numpy as np + +# WhoBPyT stuff +from ...datatypes import AbstractNeuralModel, AbstractParams, Parameter as par +from ...functions.arg_type_check import method_arg_type_check# ... + + + + +class ParamsLinearFreqs(AbstractParams): + + def __init__(self, **kwargs): + """ + Initializes the ParamsLinearFreqs object. + + Args: + **kwargs: Keyword arguments for the model parameters. + + Returns: + None + """ + + super(ParamsLinearFreqs, self).__init__(**kwargs) + param = { + "std_in": par(0.1), + + "eigvals": par(1), + "mu": par(5) + } + for var in param: + setattr(self, var, param[var]) + + for var in kwargs: + setattr(self, var, kwargs[var]) + +class LINEAR_FQ(AbstractNeuralModel): + """ + A module for Robinson model from freqency to power spectrum + Attibutes + --------- + """ + model_name = "LINEAR_FQ" + + def __init__(self, params: ParamsLinearFreqs, node_size = 200, mode_size = 20, output_size = 64, sc_eigvecs =np.ones((200,200)), \ + dist =np.ones((200,200)), use_fit_gains=False, use_fit_lfm=False): + """ + Parameters + ---------- + + param from ParamJR + """ + super(LINEAR_FQ, self).__init__(params) + + self.params = params + self.node_size = node_size + self.mode_size = mode_size + self.output_size = node_size + self.use_fit_gains = use_fit_gains + self.use_fit_lfm = use_fit_lfm + self.sc_eigvecs = sc_eigvecs + self.dist = dist + + self.setModelParameters() + + + + + + + def forward(self, input): + """ + Forward step in simulating the EEG signal. + Parameters + ---------- + input: list of frequencey + + Outputs + ------- + next_state: pws with given frequence same size as input + + """ + # Generate the ReLU module + m = torch.nn.ReLU() + # define some constants + std_in = 0.00001 + m(self.params.std_in.value()) + g = 0.00001 + m(self.params.g.value()) + a = 0.00001 + m(self.params.a.value()) + b = 0.00001 + m(self.params.b.value()) + A = 0.00001 + m(self.params.A.value()) + B = 0.00001 + m(self.params.B.value()) + C2 = 0.00001 + m(self.params.C2.value()) + C1 = 0.00001 + m(self.params.C1.value()) + c = 0.00001 + m(self.params.c.value()) + n_mode = self.mode_size + eigvals = 0 + m(self.params.eigvals.value()[:n_mode])/m(self.params.eigvals.value()[:n_mode]).max() + mu = 0.02 + m(self.params.mu.value()) + dist = torch.tensor(self.dist, dtype=torch.float32) + u_sc = torch.tensor(self.sc_eigvecs, dtype=torch.float32) + lm = self.params.lm.value() + tau_mode = m(u_sc.T @ dist/mu @ u_sc) + sc = u_sc[:,:n_mode] @ torch.diag(eigvals[:,0]) @ (u_sc[:,:n_mode]).T + + + + next_state = [] + + + for i_fq in range(input.shape[0]): + #print(i_fq) + omega = input[i_fq] * 2*np.pi + j = complex(0, 1) # imaginary number + s = omega * j + tf_e = A*a/(s**2 +2*a*s +a**2 ) + tf_i = B*b/(s**2 +2*b*s +b**2 ) + tf_ei = (1-0*C1*tf_i)*tf_e/(1+ C1*C2*tf_e*tf_i) + tf_close = (1/(s+c))*tf_ei*torch.linalg.inv(1+g* torch.exp(-s*tau_mode)*sc*tf_ei) + + """lap = torch.diag((u_sc @ (torch.diag(eigvals[:,0]) ) @ u_sc.T).sum(1)) \ + - (u_sc @ (torch.diag(eigvals) ) @ u_sc.T) + u_l, d_l, l_l = torch.svd(lap)""" + tf = std_in * tf_close.sum(0)[:,np.newaxis] + + + + + #print(torch.abs(closed_loop_g)) + lm_n = lm/torch.sqrt((lm**2).sum()) + next_state.append(torch.abs(torch.matmul((lm_n + 0*j), tf))) + + + + + return torch.cat(next_state, dim=1) \ No newline at end of file diff --git a/whobpyt/optimization/__init__.py b/whobpyt/optimization/__init__.py index 90f1d21b..7e2ed22b 100644 --- a/whobpyt/optimization/__init__.py +++ b/whobpyt/optimization/__init__.py @@ -3,4 +3,5 @@ from .cost_FC import CostsFixedFC from .cost_Mean import CostsMean from .cost_PSD import CostsPSD -from .cost_PSD import CostsFixedPSD \ No newline at end of file +from .cost_PSD import CostsFixedPSD +from .cost_Freq import CostsFreqs \ No newline at end of file diff --git a/whobpyt/optimization/cost_Freq.py b/whobpyt/optimization/cost_Freq.py new file mode 100644 index 00000000..7f48d3b4 --- /dev/null +++ b/whobpyt/optimization/cost_Freq.py @@ -0,0 +1,66 @@ +import numpy as np # for numerical operations +import torch +from torch import (Tensor as ptTensor, reshape as ptreshape, mean as ptmean, matmul as ptmatmul, transpose as pttranspose, + diag as ptdiag, reciprocal as ptreciprocal, sqrt as ptsqrt, tril as pttril, ones_like as ptones_like, + zeros_like as ptzeros_like, greater as ptgreater, masked_select as ptmasked_select, sum as ptsum, + multiply as ptmultiply, log as ptlog, device as ptdevice) + +from ..datatypes import AbstractLoss +from ..datatypes import Parameter as par +from ..functions.arg_type_check import method_arg_type_check + +class CostsFreqs(AbstractLoss): + def __init__(self, model): + self.model = model + + + def loss(self, simData: dict, empData: ptTensor): + """ + Calculate the Pearson Correlation between the simFC and empFC. + From there, compute the probability and negative log-likelihood. + + Parameters + ---------- + simData: dict of tensor with node_size X datapoint + simulated EEG + empData: tensor with node_size X datapoint + empirical EEG + """ + method_arg_type_check(self.loss) # Check that the passed arguments (excluding self) abide by their expected data types + sim = simData + emp = empData + loss_main = ptsqrt(ptmean((ptlog(sim) - ptlog(emp)) ** 2)) # + model = self.model + + # define some constants + lb = 0.001 + + w_cost = 10 + + # define the relu function + m = torch.nn.ReLU() + + exclude_param = [] + if model.use_fit_gains: + exclude_param.append('gains_con') #TODO: Is this correct? + + + + + + loss_EI = 0 + loss_prior = [] + + variables_p = [a for a in dir(model.params) if (type(getattr(model.params, a)) == par)] + + for var_name in variables_p: + # print(var) + var = getattr(self.model.params, var_name) + if var.fit_hyper: + loss_prior.append(torch.sum((lb + m(var.prior_precision)) * \ + (m(var.val) - m(var.prior_mean)) ** 2) \ + + torch.sum(-torch.log(lb + m(var.prior_precision)))) + + # total loss + loss = 200 * w_cost * loss_main + 1 * sum(loss_prior) + 1 * loss_EI + return loss, loss_main \ No newline at end of file diff --git a/whobpyt/run/__init__.py b/whobpyt/run/__init__.py index 9ae93de3..6df7ed69 100644 --- a/whobpyt/run/__init__.py +++ b/whobpyt/run/__init__.py @@ -1,3 +1,4 @@ from .model_fitting import ModelFitting from .custom_fitting import FittingFNGFPG from .batch_fitting import FittingBatch +from .model_fitting_fq import Model_fitting_fq diff --git a/whobpyt/run/model_fitting_fq.py b/whobpyt/run/model_fitting_fq.py new file mode 100644 index 00000000..41581e3e --- /dev/null +++ b/whobpyt/run/model_fitting_fq.py @@ -0,0 +1,162 @@ +import numpy as np # for numerical operations +import torch +import torch.optim as optim +from ..datatypes import Timeseries as Recording # JG: rename this to just Timeseries +from ..datatypes import AbstractNeuralModel,AbstractFitting,AbstractLoss +from ..datatypes import TrainingStats +#from whobpyt.models.RWW.RWW_np import RWW_np #This should be removed and made general +from ..functions.arg_type_check import method_arg_type_check +import pickle +from sklearn.metrics.pairwise import cosine_similarity + +class Model_fitting_fq: + """ + Using ADAM and AutoGrad to fit JansenRit to empirical EEG + Attributes + ---------- + model: instance of class RNNJANSEN + forward model JansenRit + ts: array with num_tr x node_size + empirical EEG time-series + num_epoches: int + the times for repeating trainning + Methods: + train() + train model + test() + using the optimal model parater to simulate the BOLD + """ + + # from sklearn.metrics.pairwise import cosine_similarity + def __init__(self, psd, num_epoches, model: AbstractNeuralModel, cost: AbstractLoss,): + """ + Parameters + ---------- + model: instance of class RNNJANSEN + forward model JansenRit + ts: array with num_tr x node_size + empirical EEG time-series + num_epoches: int + the times for repeating trainning + """ + self.model = model + self.num_epoches = num_epoches + # self.u = u + """if ts.shape[1] != model.node_size: + print('ts is a matrix with the number of datapoint X the number of node') + else: + self.ts = ts""" + self.fq = torch.tensor(psd['fq'], dtype=torch.float32) + self.psd = torch.tensor(psd['psd'], dtype=torch.float32) + self.cost = cost + #placeholder for output(EEG and histoty of model parameters and loss) + self.trainingStats = TrainingStats(self.model) + + def save(self, filename): + with open(filename, 'wb') as f: + pickle.dump(self, f) + + def train(self, u= 0, learningrate: float = 0.05, lr_2ndLevel: float = 0.05, lr_scheduler: bool = False): + """ + Parameters + ---------- + None + Outputs: OutputRJ + """ + + # placeholders for the history of model parameters + + loss_main_th = 1000 + + method_arg_type_check(self.train, exclude = ['u', 'empRec']) # Check that the passed arguments (excluding self) abide by their expected data types + + # Define two different optimizers for each group + modelparameter_optimizer = optim.Adam(self.model.params_fitted['modelparameter'], lr=learningrate, eps=1e-7) + hyperparameter_optimizer = optim.Adam(self.model.params_fitted['hyperparameter'], lr=lr_2ndLevel, eps=1e-7) + + + + + loss_his = [] + + # define constant 1 tensor + + con_1 = torch.tensor(1.0, dtype=torch.float32) + + for i_epoch in range(self.num_epoches): + if (loss_main_th > 1e-10): + + + psd_target = self.psd[i_epoch % self.fq.shape[0]] + fq_target = self.fq[i_epoch % self.fq.shape[0]] + # Create placeholders for the simulated EEG E I M Ev Iv and Mv of entire time series. + + + + + # Reset the gradient to zeros after update model parameters. + hyperparameter_optimizer.zero_grad() + modelparameter_optimizer.zero_grad() + + + # Use the model.forward() function to update next state and get simulated EEG in this batch. + next_batch = self.model(fq_target) + + print(((torch.log(next_batch) - torch.log(psd_target))**2).mean()) + + #loss, loss_main = 1*self.cost.cost_eff(torch.log10(next_batch), torch.log10(psd_target),self.model) + + loss, loss_main = 1*self.cost.loss(next_batch, psd_target) + loss_main_th = loss_main.detach().numpy() + loss_his.append(loss.detach().numpy()) + # print('epoch: ', i_epoch, 'batch: ', i_batch, loss.detach().numpy()) + + # Calculate gradient using backward (backpropagation) method of the loss function. + loss.backward(retain_graph=True) + + # Optimize the model based on the gradient method in updating the model parameters. + hyperparameter_optimizer.step() + modelparameter_optimizer.step() + + # Put the updated model parameters into the history placeholders. + # sc_par.append(self.model.sc[mask].copy()) + trackedParam = {} + exclude_param = ['gains_con'] #This stores SC and LF which are saved seperately + if(self.model.track_params): + for par_name in self.model.track_params: + var = getattr(self.model.params, par_name) + if (var.fit_par): + trackedParam[par_name] = var.value().detach().cpu().numpy().copy() + if var.fit_hyper: + + trackedParam[par_name + "_prior_mean"] = var.prior_mean.detach().cpu().numpy().copy() + trackedParam[par_name + "_prior_precision"] = var.prior_precision.detach().cpu().numpy().copy() + for key, value in self.model.state_dict().items(): + if key not in exclude_param: + trackedParam[key] = value.detach().cpu().numpy().ravel().copy() + self.trainingStats.appendParam(trackedParam) + + + + self.trainingStats.appendLoss(loss_his) + print('epoch: ', i_epoch, loss.detach().numpy(), loss_main.detach().numpy()) + + + + + + + + def test(self, input): + """ + Parameters + ---------- + None + Outputs: OutputRJ + """ + + + fq_target = torch.tensor(input, dtype=torch.float32) + next_batch = self.model(fq_target) + + return fq_target, next_batch \ No newline at end of file