From cbc5a2692b56b3eb3ddc10611099227492289039 Mon Sep 17 00:00:00 2001 From: Harrison Siegel Date: Thu, 18 Apr 2024 18:18:53 -0400 Subject: [PATCH] Updated clean_strain() to return spike model when return_mcmcs is true, and removed extraneous lines of code --- line_cleaner.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/line_cleaner.py b/line_cleaner.py index ec2f20f..0cec340 100644 --- a/line_cleaner.py +++ b/line_cleaner.py @@ -90,10 +90,10 @@ def clean_strain(times, data, srate, f0s, bandwidths, Twindow, mcmc_seed=None, r resample_rng_key, rk = random.split(resample_rng_key) ind = random.randint(rk, (1,), 0, pred_samples['line_re'].shape[0])[0] - data_freq_residual[sel] = data_freq_re[sel] - scale_factor*pred_samples['line_re'][ind,:] + 1j*(data_freq_im[sel] - scale_factor*pred_samples['line_im'][ind,:]) + line_model_re = scale_factor*pred_samples['line_re'][ind,:] + line_model_im = scale_factor*pred_samples['line_im'][ind,:] - data_freq_re = np.real(data_freq_residual) - data_freq_im = np.imag(data_freq_residual) + data_freq_residual[sel] = data_freq_re[sel] - line_model_re + 1j*(data_freq_im[sel] - line_model_im) if return_mcmcs: mcmcs.append(mcmc) @@ -104,7 +104,6 @@ def clean_strain(times, data, srate, f0s, bandwidths, Twindow, mcmc_seed=None, r times_residual = times[window==1] if return_mcmcs: - return (times_residual, data_residual, mcmcs, pred_sampless) + return (times_residual, data_residual, mcmcs, pred_sampless, line_model_im, line_model_re, sel, data_freq_re) else: return (times_residual, data_residual) - \ No newline at end of file