diff --git a/line_cleaner.py b/line_cleaner.py index ec2f20f..0cec340 100644 --- a/line_cleaner.py +++ b/line_cleaner.py @@ -90,10 +90,10 @@ def clean_strain(times, data, srate, f0s, bandwidths, Twindow, mcmc_seed=None, r resample_rng_key, rk = random.split(resample_rng_key) ind = random.randint(rk, (1,), 0, pred_samples['line_re'].shape[0])[0] - data_freq_residual[sel] = data_freq_re[sel] - scale_factor*pred_samples['line_re'][ind,:] + 1j*(data_freq_im[sel] - scale_factor*pred_samples['line_im'][ind,:]) + line_model_re = scale_factor*pred_samples['line_re'][ind,:] + line_model_im = scale_factor*pred_samples['line_im'][ind,:] - data_freq_re = np.real(data_freq_residual) - data_freq_im = np.imag(data_freq_residual) + data_freq_residual[sel] = data_freq_re[sel] - line_model_re + 1j*(data_freq_im[sel] - line_model_im) if return_mcmcs: mcmcs.append(mcmc) @@ -104,7 +104,6 @@ def clean_strain(times, data, srate, f0s, bandwidths, Twindow, mcmc_seed=None, r times_residual = times[window==1] if return_mcmcs: - return (times_residual, data_residual, mcmcs, pred_sampless) + return (times_residual, data_residual, mcmcs, pred_sampless, line_model_im, line_model_re, sel, data_freq_re) else: return (times_residual, data_residual) - \ No newline at end of file