1414import numpy as np
1515import pandas as pd
1616
17+ from ..constants import PREFIT_ADDITIONAL_DAYS
1718from .parameters import Parameters
1819
1920
@@ -68,31 +69,42 @@ def __init__(self, p: Parameters):
6869
6970 if p .mitigation_date is None :
7071 self .i_day = 0 # seed to the full length
71- raw = self .run_projection (p , [(self .beta , p .n_days )])
72+ raw = self .run_projection (p , [
73+ (self .beta , p .n_days + PREFIT_ADDITIONAL_DAYS )])
7274 self .i_day = i_day = int (get_argmin_ds (raw ["census_hospitalized" ], p .current_hospitalized ))
7375
74- self .raw = self .run_projection (p , self .gen_policy (p ))
76+ self .raw = self .run_projection (p , self .get_policies (p ))
7577
7678 logger .info ('Set i_day = %s' , i_day )
7779 else :
78- projections = {}
7980 best_i_day = - 1
8081 best_i_day_loss = float ('inf' )
81- for i_day in range (p .n_days ):
82- self .i_day = i_day
83- raw = self .run_projection (p , self .gen_policy (p ))
82+ for self .i_day in range (p .n_days + PREFIT_ADDITIONAL_DAYS ):
83+ mitigation_day = - (p .current_date - p .mitigation_date ).days
84+ if mitigation_day < - self .i_day :
85+ mitigation_day = - self .i_day
86+
87+ total_days = self .i_day + p .n_days + PREFIT_ADDITIONAL_DAYS
88+ pre_mitigation_days = self .i_day + mitigation_day
89+ post_mitigation_days = total_days - pre_mitigation_days
90+
91+ raw = self .run_projection (p , [
92+ (self .beta , pre_mitigation_days ),
93+ (self .beta_t , post_mitigation_days ),
94+ ]
95+ )
8496
8597 # Don't fit against results that put the peak before the present day
86- if raw ["census_hospitalized" ].argmax () < i_day :
98+ if raw ["census_hospitalized" ].argmax () < self . i_day :
8799 continue
88100
89- loss = get_loss (raw ["census_hospitalized" ][i_day ], p .current_hospitalized )
101+ loss = get_loss (raw ["census_hospitalized" ][self . i_day ], p .current_hospitalized )
90102 if loss < best_i_day_loss :
91103 best_i_day_loss = loss
92- best_i_day = i_day
93- self .raw = raw
104+ best_i_day = self .i_day
94105
95106 self .i_day = best_i_day
107+ self .raw = self .run_projection (p , self .get_policies (p ))
96108
97109 logger .info (
98110 'Estimated date_first_hospitalized: %s; current_date: %s; i_day: %s' ,
@@ -127,7 +139,7 @@ def __init__(self, p: Parameters):
127139 intrinsic_growth_rate = get_growth_rate (p .doubling_time )
128140 self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
129141 self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
130- self .raw = self .run_projection (p , self .gen_policy (p ))
142+ self .raw = self .run_projection (p , self .get_policies (p ))
131143
132144 self .population = p .population
133145 else :
@@ -196,7 +208,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
196208 self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
197209 self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
198210
199- raw = self .run_projection (p , self .gen_policy (p ))
211+ raw = self .run_projection (p , self .get_policies (p ))
200212
201213 # Skip values the would put the fit past peak
202214 peak_admits_day = raw ["admits_hospitalized" ].argmax ()
@@ -210,7 +222,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
210222 min_loss = pd .Series (losses ).argmin ()
211223 return min_loss
212224
213- def gen_policy (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
225+ def get_policies (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
214226 if p .mitigation_date is not None :
215227 mitigation_day = - (p .current_date - p .mitigation_date ).days
216228 else :
0 commit comments