Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions examples/eg004r__fitting_JR_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# whobpyt stuff
import whobpyt
from whobpyt.datatypes import par, Recording
from whobpyt.models.JansenRit import RNNJANSEN, ParamsJR
from whobpyt.models.JansenRit import RNNJANSEN, ParamsJR, JansenRit_np
from whobpyt.optimization.custom_cost_JR import CostsJR
from whobpyt.run import Model_fitting

Expand Down Expand Up @@ -168,4 +168,26 @@
ax[1].set_title('Test')
ax[2].plot(eeg_data.T)
ax[2].set_title('empirical')
plt.show()
plt.show()


# %%
# Modified JR Validation Model
# ---------------------------------------------------
#
# The modified JR model

val_sim_len = 20 # Simulation length in secs
model_validate = JansenRit_np(model.node_size, model.step_size, model.output_size, model.tr, model.sc, model.lm.detach().numpy(), model.dist.detach().numpy(), model.params)

state_hist, hE = model_validate.forward(external = u, hx = model_validate.createIC(ver = 0), hE = np.zeros((model.node_size,500)), sim_len=val_sim_len)
# %%
# Plot the EEG
plt.figure(figsize = (16, 8))
plt.title("M")
for n in range(model.node_size):
plt.plot(state_hist[0:2000, n, 0:1], label = "M Node = " + str(n)) # Plotting EEG window
#plt.plot(state_hist[0:200, n, 1:2] - state_hist[0:200, n, 2:3], label = "EEG Node = " + str(n)) # plotting E-I

plt.xlabel('Time Steps (multiply by step_size to get msec), step_size = ' + str(step_size))
#plt.legend()
3 changes: 2 additions & 1 deletion whobpyt/models/JansenRit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .jansen_rit import RNNJANSEN
from .ParamsJR import ParamsJR
from .ParamsJR import ParamsJR
from .jansen_rit_validate import JansenRit_np
165 changes: 165 additions & 0 deletions whobpyt/models/JansenRit/jansen_rit_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Simulate JR with numpy code for validation
# Sorenza Bastiaens
import numpy as np

class JansenRit_np():

def __init__(self, node_size, step_size, output_size, tr, sc, lm, dist, params):


# Initialize the JR Model
#
# INPUT
# num_regions: Int - Number of nodes in network to model
# params: Params_JR - The parameters that all nodes in the network will share
# Con_Mtx: Tensor [num_regions, num_regions] - With connectivity (eg. structural connectivity)
# step_size=0.1
self.step_size = step_size
self.tr = tr # tr ms (integration step 0.1 ms)
self.sc = sc # structural connectivity factor
self.node_size = node_size # num of ROI
self.output_size = output_size # num of EEG channels
self.params = params
self.lm = lm # leadfield matrix
self.dist = dist # distance between nodes
self.state_size = 6 #

def createIC(self, ver):
state_lb = -0.5
state_ub = 0.5

return np.random.uniform(state_lb, state_ub, (self.node_size, self.state_size))

def forward(self, external, hx, hE, sim_len):

# Runs the JR model

# Defining JR parameters as numpy
A = self.params.A.npValue()
a = self.params.a.npValue()
B = self.params.B.npValue()
b = self.params.b.npValue()
g = self.params.g.npValue()
c1 = self.params.c1.npValue()
c2 = self.params.c2.npValue()
c3 = self.params.c3.npValue()
c4 = self.params.c4.npValue()
std_in = self.params.std_in.npValue()
vmax = self.params.vmax.npValue()
v0 = self.params.v0.npValue()
r = self.params.r.npValue()
y0 = self.params.y0.npValue()
mu = self.params.mu.npValue()
k = self.params.k.npValue()
cy0 = self.params.cy0.npValue()
ki = self.params.ki.npValue()

g_f = self.params.g_f.npValue()
g_b = self.params.g_b.npValue()

# Sigmoid function
def sigmoid(x, vmax, v0, r):
return vmax / (1 + np.exp(r * (v0 - x)))

init_state = hx
sim_len = sim_len
step_size = self.step_size

state_hist = np.zeros((int(sim_len/step_size), self.node_size, 7))
M = init_state[:, 0:1]
E = init_state[:, 1:2]
I = init_state[:, 2:3]
Mv = init_state[:, 3:4]
Ev = init_state[:, 4:5]
Iv = init_state[:, 5:6]

num_steps = int(sim_len/step_size)
dt = step_size
self.w_bb = np.zeros((self.node_size, self.node_size))
self.w_ff = np.zeros((self.node_size, self.node_size))
self.w_ll = np.zeros((self.node_size, self.node_size))
# Update the Laplacian based on the updated connection gains w_bb.
w_b = np.exp(self.w_bb) * np.array(self.sc)
w_n_b = w_b / np.linalg.norm(w_b)
self.sc_m_b = w_n_b
dg_b = -np.diag(np.sum(w_n_b, axis=1))

# Update the Laplacian based on the updated connection gains w_ff.
w_f = np.exp(self.w_ff) * np.array(self.sc)
w_n_f = w_f / np.linalg.norm(w_f)
self.sc_m_f = w_n_f
dg_f = -np.diag(np.sum(w_n_f, axis=1))

# Update the Laplacian based on the updated connection gains w_ll.
w_l = np.exp(self.w_ll) * np.array(self.sc)
w_n_l = (0.5 * (w_l + np.transpose(w_l, (1, 0)))) / np.linalg.norm(
0.5 * (w_l + np.transpose(w_l, (1, 0))))
self.sc_fitted = w_n_l
dg_l = -np.diag(np.sum(w_n_l, axis=1))



self.delays = (self.dist / mu).astype(int)

# TODO currently single node, need to add all the connections and make it multiple nodes
for i in range(num_steps):

# LEd is to include the delays from other nodes
# con_1 = 1
# Don't include boundaries so no k_lb for example and no m(x) stuff

# Basically rM inludes (LEd_l + 1 * torch.matmul(dg_l, M))
# Calculate the derivatives
# Lateral is P-P
# Forward is P-E
# Backward is P-I
Ed = np.zeros((self.node_size, self.node_size))
hE_new = hE.copy()
Ed = hE_new[1,self.delays] #hE_new.gather(1,self.delays)
LEd_b = np.reshape(np.sum(w_n_b * np.transpose(Ed, (1, 0)), 1), (self.node_size, 1)) # Not sure if this needs to be included in validation
LEd_f = np.reshape(np.sum(w_n_f * np.transpose(Ed, (1, 0)), 1), (self.node_size, 1))
LEd_l = np.reshape(np.sum(w_n_l * np.transpose(Ed, (1, 0)), 1), (self.node_size, 1))
u_tms = 200 # Need to only had within a certain time frame, test again with 0
rM = k * ki * u_tms + std_in*np.random.randn(self.node_size, 1) + g * (LEd_l + 1 * np.matmul(dg_l, M))
rE = std_in*np.random.randn(self.node_size, 1) + g_f * (LEd_f + 1 * np.matmul(dg_f, E - I))
rI = std_in*np.random.randn(self.node_size, 1) + g_b * (-LEd_b - 1 * np.matmul(dg_b, E - I))

dM = dt * Mv
dE = dt * Ev
dI = dt * Iv
dMv = dt * (A*a*( rM + sigmoid(vmax,v0,r, E - I))- (2*a*Mv) - (a**(2)*M)) # BE CAREGUL rM in code has the sigmoid so only take everything else from original code
dEv = dt * (A*a*(mu + rE + (c2*sigmoid(vmax,v0,r,(c1*M)))) - (2*a*Ev) - (a**(2)*E))
dIv = dt * (B*b*(rI + c4*sigmoid(vmax,v0,r,(c3*M))) - (2*b*Iv) - (b**(2)*I))

# Update the state
dM = dM.detach().numpy()
dE = dE.detach().numpy()
dI = dI.detach().numpy()
dMv = dMv.detach().numpy()
dEv = dEv.detach().numpy()
dIv = dIv.detach().numpy()
M = M + dM
E = E + dE
I = I + dI
Mv = Mv + dMv
Ev = Ev + dEv
Iv = Iv + dIv
hE = np.concatenate((M, hE[:, :-1]), axis=1) #np.cat([M, hE[:, :-1]], axis=1) # update placeholders for pyramidal buffer

state_hist[i, :, 0:1] = M
state_hist[i, :, 1:2] = E
state_hist[i, :, 2:3] = I
state_hist[i, :, 3:4] = Mv
state_hist[i, :, 4:5] = Ev
state_hist[i, :, 5:6] = Iv

# Capture the states at every step .
#lm_t = (self.lm.T / np.sqrt(self.lm ** 2).sum(1)).T
#self.lm_t = (lm_t - 1 / self.output_size * np.matmul(np.ones((1, self.output_size)), lm_t))
#temp = cy0 * np.matmul(self.lm_t, M[:self.node_size, :]) - 1 * y0
#state_hist[i, :, 6:7] = temp # eeg_window

# Should then downsample the state_hist to the sampling rate of the EEG
return state_hist, hE