Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ noise ([#58](https://github.com/ctrltz/meegsim/pull/58))
- A method for setting phase-phase coupling by adding noise to the shifted copy of input waveform ([#71](https://github.com/ctrltz/meegsim/pull/71))
- Function to convert the sources to mne.Label ([#73](https://github.com/ctrltz/meegsim/pull/73))
- Quick list-like access to the simulated sources ([#82](https://github.com/ctrltz/meegsim/pull/82))
- Control over the amplitude envelope of the coupled waveform ([#87](https://github.com/ctrltz/meegsim/pull/87))

### Changed

Expand Down
4 changes: 2 additions & 2 deletions docs/api/coupling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ Coupling methods
.. autosummary::
:toctree: ../generated/

ppc_shifted_copy_with_noise
ppc_constant_phase_shift
ppc_von_mises
constant_phase_shift
ppc_shifted_copy_with_noise
119 changes: 97 additions & 22 deletions src/meegsim/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,48 @@
from scipy.stats import vonmises
from scipy.signal import butter, filtfilt, hilbert

from meegsim._check import check_numeric
from meegsim._check import check_numeric, check_option
from meegsim.snr import get_variance, amplitude_adjustment_factor
from meegsim.utils import normalize_variance
from meegsim.waveform import narrowband_oscillation, white_noise


def constant_phase_shift(waveform, sfreq, phase_lag, m=1, n=1, random_state=None):
def _get_envelope(waveform, envelope, sfreq, fmin=None, fmax=None, random_state=None):
check_option(
"the amplitude envelope of the coupled waveform", envelope, ["same", "random"]
)
if not np.iscomplexobj(waveform):
waveform = hilbert(waveform)

if envelope == "same":
return np.abs(waveform)

if fmin is None or fmax is None:
raise ValueError(
"Frequency limits are required for generating the envelope of the coupled waveform"
)
times = np.arange(waveform.size) / sfreq
random_waveform = narrowband_oscillation(
1, times, fmin=fmin, fmax=fmax, random_state=random_state
)
random_waveform = hilbert(random_waveform)

# TODO: here we could also mix original and random envelope with different
# values of SNR to achieve smooth control over the resulting envelope correlation
return np.abs(random_waveform)


def ppc_constant_phase_shift(
waveform,
sfreq,
phase_lag,
fmin=None,
fmax=None,
envelope="random",
m=1,
n=1,
random_state=None,
):
"""
Generate a time series that is phase coupled to the input time series with
a constant phase lag.
Expand All @@ -32,21 +67,32 @@ def constant_phase_shift(waveform, sfreq, phase_lag, m=1, n=1, random_state=None
The input signal to be processed. It can be a real or complex time series.

sfreq : float
Sampling frequency of the signal, in Hz. This argument is not used in this
function but is accepted for consistency with other coupling methods.
Sampling frequency of the signal, in Hz.

phase_lag : float
Constant phase lag to apply to the waveform in radians.

m : int, optional
envelope : str, {"same", "random"}
Controls the amplitude envelope of the coupled waveform to be either randomly
generated (default) or to be the same as the envelope of the input waveform.

fmin : float, optional
Lower cutoff frequency for the oscillation that gives rise to the random
amplitude envelope (only if the ``envelope`` is set to ``"random"``).

fmax : float, optional
Upper cutoff frequency for the oscillation that gives rise to the random
amplitude envelope (only if the ``envelope`` is set to ``"random"``).

m : float, optional
Multiplier for the base frequency of the output oscillation, default is 1.

n : int, optional
n : float, optional
Multiplier for the base frequency of the input oscillation, default is 1.

random_state : None, optional
This parameter is accepted for consistency with other coupling functions
but not used since no randomness is involved.
Random state can be fixed to provide reproducible results if the envelope
is generated randomly. If not set, the results may differ between function calls.

Returns
-------
Expand All @@ -56,16 +102,36 @@ def constant_phase_shift(waveform, sfreq, phase_lag, m=1, n=1, random_state=None
if not np.iscomplexobj(waveform):
waveform = hilbert(waveform)

waveform_amp = np.abs(waveform)
waveform_amp = _get_envelope(waveform, envelope, sfreq, fmin, fmax, random_state)
waveform_angle = np.angle(waveform)
waveform_coupled = waveform_amp * np.exp(
1j * m / n * waveform_angle + 1j * phase_lag
waveform_coupled = np.real(
waveform_amp * np.exp(1j * m / n * waveform_angle + 1j * phase_lag)
)
if envelope == "same":
return normalize_variance(waveform_coupled)

# NOTE: if the envelope was modified, we filter the result again in the target
# frequency range to suppress possible distortions due to merging amplitude
# envelope and phase from different time series
b, a = butter(
N=2, Wn=np.array([m / n * fmin, m / n * fmax]) / sfreq * 2, btype="bandpass"
)
return normalize_variance(np.real(waveform_coupled))
waveform_coupled = filtfilt(b, a, waveform_coupled)

return normalize_variance(waveform_coupled)


def ppc_von_mises(
waveform, sfreq, phase_lag, kappa, fmin, fmax, m=1, n=1, random_state=None
waveform,
sfreq,
phase_lag,
kappa,
fmin,
fmax,
envelope="random",
m=1,
n=1,
random_state=None,
):
"""
Generate a time series that is phase coupled to the input time series with
Expand Down Expand Up @@ -102,10 +168,14 @@ def ppc_von_mises(
fmax: float
Upper cutoff frequency of the base frequency harmonic (in Hz).

m : int, optional
envelope : str, {"same", "random"}
Controls the amplitude envelope of the coupled waveform to be either randomly
generated (default) or to be the same as the envelope of the input waveform.

m : float, optional
Multiplier for the base frequency of the output oscillation, default is 1.

n : int, optional
n : float, optional
Multiplier for the base frequency of the input oscillation, default is 1.

random_state : None (default) or int
Expand All @@ -121,23 +191,26 @@ def ppc_von_mises(
if not np.iscomplexobj(waveform):
waveform = hilbert(waveform)

waveform_amp = np.abs(waveform)
waveform_amp = _get_envelope(waveform, envelope, sfreq, fmin, fmax, random_state)
waveform_angle = np.angle(waveform)
n_samples = len(waveform)
n_samples = waveform.size

ph_distr = vonmises.rvs(
kappa, loc=phase_lag, size=n_samples, random_state=random_state
)
tmp_waveform = np.real(
waveform_coupled = np.real(
waveform_amp * np.exp(1j * m / n * waveform_angle + 1j * ph_distr)
)

# NOTE: we filter the result again in the target frequency range to suppress
# possible distortions due to separate adjustment of the phase and amplitude
# of the coupled time series
b, a = butter(
N=2, Wn=np.array([m / n * fmin, m / n * fmax]) / sfreq * 2, btype="bandpass"
)
tmp_waveform = filtfilt(b, a, tmp_waveform)
waveform_coupled = waveform_amp * np.exp(1j * np.angle(hilbert(tmp_waveform)))
waveform_coupled = filtfilt(b, a, waveform_coupled)

return normalize_variance(np.real(waveform_coupled))
return normalize_variance(waveform_coupled)


def _shifted_copy_with_noise(
Expand All @@ -148,7 +221,9 @@ def _shifted_copy_with_noise(
waveform and (2) mixing it with noise to achieve a desired level of signal-to-noise
ratio, which determines the resulting phase-phase and amplitude-amplitude coupling.
"""
shifted_waveform = constant_phase_shift(waveform, sfreq, phase_lag)
shifted_waveform = ppc_constant_phase_shift(
waveform, sfreq, phase_lag, envelope="same"
)
signal_var = get_variance(shifted_waveform, sfreq, fmin, fmax, filter=True)

# NOTE: to make coupling band-limited (substantial only in the band of interest),
Expand Down
18 changes: 7 additions & 11 deletions src/meegsim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings

from mne.io.constants import FIFF
from scipy.special import i1, i0


logger = logging.getLogger("meegsim")
Expand Down Expand Up @@ -79,15 +78,16 @@ def normalize_variance(data):

Returns
-------
data: array
data_norm: array
Normalized time series. The variance of each row is equal to 1.
"""
# NOTE: make a copy to keep the original waveform intact
data_norm = data.copy()
if data_norm.ndim == 1:
return data_norm / np.std(data_norm)

if data.ndim == 1:
return data / np.std(data)

data /= np.std(data, axis=-1)[:, np.newaxis]
return data
data_norm /= np.std(data_norm, axis=-1)[:, np.newaxis]
return data_norm


def _extract_hemi(src):
Expand Down Expand Up @@ -195,10 +195,6 @@ def unpack_vertices(vertices_lists):
return unpacked_vertices


def theoretical_plv(kappa):
return i1(kappa) / i0(kappa)


def vertices_to_mne(vertices, src):
"""
Convert the vertices to the MNE format (list of lists).
Expand Down
Loading