1818from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1919from .helpers import build_model_with_cfg
2020from .registry import register_model
21- from .layers import ClassifierHead , DropPath , AvgPool2dSame , ScaledStdConv2d , get_act_layer , get_attn , make_divisible
21+ from .layers import ClassifierHead , DropPath , AvgPool2dSame , ScaledStdConv2d , get_act_layer , get_attn , make_divisible , get_act_fn
2222
2323
2424def _dcfg (url = '' , ** kwargs ):
@@ -40,17 +40,17 @@ def _dcfg(url='', **kwargs):
4040 'nf_regnet_b4' : _dcfg (url = '' , input_size = (3 , 320 , 320 )),
4141 'nf_regnet_b5' : _dcfg (url = '' , input_size = (3 , 384 , 384 )),
4242
43- 'nf_resnet26d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
44- 'nf_resnet50d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
45- 'nf_resnet101d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
43+ 'nf_resnet26 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
44+ 'nf_resnet50 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
45+ 'nf_resnet101 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
4646
47- 'nf_seresnet26d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
48- 'nf_seresnet50d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
49- 'nf_seresnet101d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
47+ 'nf_seresnet26 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
48+ 'nf_seresnet50 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
49+ 'nf_seresnet101 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
5050
51- 'nf_ecaresnet26d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
52- 'nf_ecaresnet50d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
53- 'nf_ecaresnet101d ' : _dcfg (url = '' , first_conv = 'stem.conv1 ' ),
51+ 'nf_ecaresnet26 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
52+ 'nf_ecaresnet50 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
53+ 'nf_ecaresnet101 ' : _dcfg (url = '' , first_conv = 'stem.conv ' ),
5454}
5555
5656
@@ -59,6 +59,7 @@ class NfCfg:
5959 depths : Tuple [int , int , int , int ]
6060 channels : Tuple [int , int , int , int ]
6161 alpha : float = 0.2
62+ gamma_in_act : bool = False
6263 stem_type : str = '3x3'
6364 stem_chs : Optional [int ] = None
6465 group_size : Optional [int ] = 8
@@ -84,68 +85,65 @@ class NfCfg:
8485 nf_regnet_b5 = NfCfg (depths = (3 , 7 , 14 , 14 ), channels = (80 , 168 , 336 , 704 ), num_features = 2048 ),
8586
8687 # ResNet (preact, D style deep stem/avg down) defs
87- nf_resnet26d = NfCfg (
88+ nf_resnet26 = NfCfg (
8889 depths = (2 , 2 , 2 , 2 ), channels = (256 , 512 , 1024 , 2048 ),
89- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
90+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
9091 act_layer = 'relu' , attn_layer = None ,),
91- nf_resnet50d = NfCfg (
92+ nf_resnet50 = NfCfg (
9293 depths = (3 , 4 , 6 , 3 ), channels = (256 , 512 , 1024 , 2048 ),
93- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
94+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
9495 act_layer = 'relu' , attn_layer = None ),
95- nf_resnet101d = NfCfg (
96+ nf_resnet101 = NfCfg (
9697 depths = (3 , 4 , 6 , 3 ), channels = (256 , 512 , 1024 , 2048 ),
97- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
98+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
9899 act_layer = 'relu' , attn_layer = None ),
99100
100101
101- nf_seresnet26d = NfCfg (
102+ nf_seresnet26 = NfCfg (
102103 depths = (2 , 2 , 2 , 2 ), channels = (256 , 512 , 1024 , 2048 ),
103- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
104+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
104105 act_layer = 'relu' , attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 0.25 )),
105- nf_seresnet50d = NfCfg (
106+ nf_seresnet50 = NfCfg (
106107 depths = (3 , 4 , 6 , 3 ), channels = (256 , 512 , 1024 , 2048 ),
107- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
108+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
108109 act_layer = 'relu' , attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 0.25 )),
109- nf_seresnet101d = NfCfg (
110+ nf_seresnet101 = NfCfg (
110111 depths = (3 , 4 , 6 , 3 ), channels = (256 , 512 , 1024 , 2048 ),
111- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
112+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
112113 act_layer = 'relu' , attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 0.25 )),
113114
114115
115- nf_ecaresnet26d = NfCfg (
116+ nf_ecaresnet26 = NfCfg (
116117 depths = (2 , 2 , 2 , 2 ), channels = (256 , 512 , 1024 , 2048 ),
117- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
118+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
118119 act_layer = 'relu' , attn_layer = 'eca' , attn_kwargs = dict ()),
119- nf_ecaresnet50d = NfCfg (
120+ nf_ecaresnet50 = NfCfg (
120121 depths = (3 , 4 , 6 , 3 ), channels = (256 , 512 , 1024 , 2048 ),
121- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
122+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
122123 act_layer = 'relu' , attn_layer = 'eca' , attn_kwargs = dict ()),
123- nf_ecaresnet101d = NfCfg (
124+ nf_ecaresnet101 = NfCfg (
124125 depths = (3 , 4 , 6 , 3 ), channels = (256 , 512 , 1024 , 2048 ),
125- stem_type = 'deep ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
126+ stem_type = '7x7_pool ' , stem_chs = 64 , width_factor = 1.0 , bottle_ratio = 0.25 , efficient = False , group_size = None ,
126127 act_layer = 'relu' , attn_layer = 'eca' , attn_kwargs = dict ()),
127128
128129)
129130
130- # class NormFreeSiLU(nn.Module):
131- # _K = 1. / 0.5595
132- # def __init__(self, inplace=False):
133- # super().__init__()
134- # self.inplace = inplace
135- #
136- # def forward(self, x):
137- # return F.silu(x, inplace=self.inplace) * self._K
138- #
139- #
140- # class NormFreeReLU(nn.Module):
141- # _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
142- #
143- # def __init__(self, inplace=False):
144- # super().__init__()
145- # self.inplace = inplace
146- #
147- # def forward(self, x):
148- # return F.relu(x, inplace=self.inplace) * self._K
131+
132+ class GammaAct (nn .Module ):
133+ def __init__ (self , act_type = 'relu' , gamma : float = 1.0 , inplace = False ):
134+ super ().__init__ ()
135+ self .act_fn = get_act_fn (act_type )
136+ self .gamma = gamma
137+ self .inplace = inplace
138+
139+ def forward (self , x ):
140+ return self .gamma * self .act_fn (x , inplace = self .inplace )
141+
142+
143+ def act_with_gamma (act_type , gamma : float = 1. ):
144+ def _create (inplace = False ):
145+ return GammaAct (act_type , gamma = gamma , inplace = inplace )
146+ return _create
149147
150148
151149class DownsampleAvg (nn .Module ):
@@ -178,10 +176,9 @@ def __init__(
178176 out_chs = out_chs or in_chs
179177 # EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
180178 mid_chs = make_divisible (in_chs * bottle_ratio if efficient else out_chs * bottle_ratio , ch_div )
181- groups = 1
182- if group_size is not None :
183- # NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
184- groups = mid_chs // group_size
179+ groups = 1 if group_size is None else mid_chs // group_size
180+ if group_size and group_size % ch_div == 0 :
181+ mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
185182 self .alpha = alpha
186183 self .beta = beta
187184 self .attn_gain = attn_gain
@@ -229,10 +226,11 @@ def forward(self, x):
229226
230227
231228def create_stem (in_chs , out_chs , stem_type = '' , conv_layer = None ):
229+ stem_stride = 2
232230 stem = OrderedDict ()
233- assert stem_type in ('' , 'deep' , '3x3' , '7x7' )
231+ assert stem_type in ('' , 'deep' , '3x3' , '7x7' , 'deep_pool' , '3x3_pool' , '7x7_pool' )
234232 if 'deep' in stem_type :
235- # 3 deep 3x3 conv stack as in ResNet V1D models
233+ # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
236234 mid_chs = out_chs // 2
237235 stem ['conv1' ] = conv_layer (in_chs , mid_chs , kernel_size = 3 , stride = 2 )
238236 stem ['conv2' ] = conv_layer (mid_chs , mid_chs , kernel_size = 3 , stride = 1 )
@@ -244,12 +242,16 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
244242 # 7x7 stem conv as in ResNet
245243 stem ['conv' ] = conv_layer (in_chs , out_chs , kernel_size = 7 , stride = 2 )
246244
247- return nn .Sequential (stem )
245+ if 'pool' in stem_type :
246+ stem ['pool' ] = nn .MaxPool2d (3 , stride = 2 , padding = 1 )
247+ stem_stride = 4
248+
249+ return nn .Sequential (stem ), stem_stride
248250
249251
250252_nonlin_gamma = dict (
251- silu = .5595 ,
252- relu = (0.5 * (1. - 1. / math .pi )) ** 0.5 ,
253+ silu = 1. / .5595 ,
254+ relu = (0.5 * (1. - 1. / math .pi )) ** - 0.5 ,
253255 identity = 1.0
254256)
255257
@@ -264,9 +266,12 @@ class NormalizerFreeNet(nn.Module):
264266 the (preact) ResNet models described earlier in the paper.
265267
266268 There are a few differences:
267- * channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts
269+ * channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
270+ this changes channel dim and param counts slightly from the paper models
268271 * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
269272 impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
273+ * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
274+ apply it in each activation. This is slightly slower, and yields slightly different results.
270275 * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
271276 for what it is/does. Approx 8-10% throughput loss.
272277 """
@@ -275,29 +280,33 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
275280 super ().__init__ ()
276281 self .num_classes = num_classes
277282 self .drop_rate = drop_rate
278- act_layer = get_act_layer (cfg .act_layer )
279283 assert cfg .act_layer in _nonlin_gamma , f"Please add non-linearity constants for activation ({ cfg .act_layer } )."
280- conv_layer = partial (ScaledStdConv2d , bias = True , gain = True , gamma = _nonlin_gamma [cfg .act_layer ])
284+ if cfg .gamma_in_act :
285+ act_layer = act_with_gamma (cfg .act_layer , gamma = _nonlin_gamma [cfg .act_layer ])
286+ conv_layer = partial (ScaledStdConv2d , bias = True , gain = True )
287+ else :
288+ act_layer = get_act_layer (cfg .act_layer )
289+ conv_layer = partial (ScaledStdConv2d , bias = True , gain = True , gamma = _nonlin_gamma [cfg .act_layer ])
281290 attn_layer = partial (get_attn (cfg .attn_layer ), ** cfg .attn_kwargs ) if cfg .attn_layer else None
282291
283- self .feature_info = [] # FIXME fill out feature info
284-
285292 stem_chs = cfg .stem_chs or cfg .channels [0 ]
286293 stem_chs = make_divisible (stem_chs * cfg .width_factor , cfg .ch_div )
287- self .stem = create_stem (in_chans , stem_chs , cfg .stem_type , conv_layer = conv_layer )
294+ self .stem , stem_stride = create_stem (in_chans , stem_chs , cfg .stem_type , conv_layer = conv_layer )
288295
289- prev_chs = stem_chs
296+ self . feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4
290297 dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (cfg .depths )).split (cfg .depths )]
291- net_stride = 2
298+ prev_chs = stem_chs
299+ net_stride = stem_stride
292300 dilation = 1
293301 expected_var = 1.0
294302 stages = []
295303 for stage_idx , stage_depth in enumerate (cfg .depths ):
296- if net_stride >= output_stride :
297- dilation *= 2
304+ stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
305+ self .feature_info += [dict (
306+ num_chs = prev_chs , reduction = net_stride , module = f'stages.{ stage_idx } .0.act1' if stride == 2 else '' )]
307+ if net_stride >= output_stride and stride > 1 :
308+ dilation *= stride
298309 stride = 1
299- else :
300- stride = 2
301310 net_stride *= stride
302311 first_dilation = 1 if dilation in (1 , 2 ) else 2
303312
@@ -338,7 +347,10 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
338347 else :
339348 self .num_features = prev_chs
340349 self .final_conv = nn .Identity ()
350+ # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
341351 self .final_act = act_layer ()
352+ self .feature_info += [dict (num_chs = self .num_features , reduction = net_stride , module = 'final_act' )]
353+
342354 self .head = ClassifierHead (self .num_features , num_classes , pool_type = global_pool , drop_rate = self .drop_rate )
343355
344356 for n , m in self .named_modules ():
@@ -373,11 +385,14 @@ def forward(self, x):
373385
374386
375387def _create_normfreenet (variant , pretrained = False , ** kwargs ):
388+ model_cfg = model_cfgs [variant ]
376389 feature_cfg = dict (flatten_sequential = True )
377390 feature_cfg ['feature_cls' ] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
391+ if 'pool' in model_cfg .stem_type :
392+ feature_cfg ['out_indices' ] = (1 , 2 , 3 , 4 ) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
378393
379394 return build_model_with_cfg (
380- NormalizerFreeNet , variant , pretrained , model_cfg = model_cfgs [ variant ] , default_cfg = default_cfgs [variant ],
395+ NormalizerFreeNet , variant , pretrained , model_cfg = model_cfg , default_cfg = default_cfgs [variant ],
381396 feature_cfg = feature_cfg , ** kwargs )
382397
383398
@@ -412,30 +427,30 @@ def nf_regnet_b5(pretrained=False, **kwargs):
412427
413428
414429@register_model
415- def nf_resnet26d (pretrained = False , ** kwargs ):
416- return _create_normfreenet ('nf_resnet26d ' , pretrained = pretrained , ** kwargs )
430+ def nf_resnet26 (pretrained = False , ** kwargs ):
431+ return _create_normfreenet ('nf_resnet26 ' , pretrained = pretrained , ** kwargs )
417432
418433
419434@register_model
420- def nf_resnet50d (pretrained = False , ** kwargs ):
421- return _create_normfreenet ('nf_resnet50d ' , pretrained = pretrained , ** kwargs )
435+ def nf_resnet50 (pretrained = False , ** kwargs ):
436+ return _create_normfreenet ('nf_resnet50 ' , pretrained = pretrained , ** kwargs )
422437
423438
424439@register_model
425- def nf_seresnet26d (pretrained = False , ** kwargs ):
426- return _create_normfreenet ('nf_seresnet26d ' , pretrained = pretrained , ** kwargs )
440+ def nf_seresnet26 (pretrained = False , ** kwargs ):
441+ return _create_normfreenet ('nf_seresnet26 ' , pretrained = pretrained , ** kwargs )
427442
428443
429444@register_model
430- def nf_seresnet50d (pretrained = False , ** kwargs ):
431- return _create_normfreenet ('nf_seresnet50d ' , pretrained = pretrained , ** kwargs )
445+ def nf_seresnet50 (pretrained = False , ** kwargs ):
446+ return _create_normfreenet ('nf_seresnet50 ' , pretrained = pretrained , ** kwargs )
432447
433448
434449@register_model
435- def nf_ecaresnet26d (pretrained = False , ** kwargs ):
436- return _create_normfreenet ('nf_ecaresnet26d ' , pretrained = pretrained , ** kwargs )
450+ def nf_ecaresnet26 (pretrained = False , ** kwargs ):
451+ return _create_normfreenet ('nf_ecaresnet26 ' , pretrained = pretrained , ** kwargs )
437452
438453
439454@register_model
440- def nf_ecaresnet50d (pretrained = False , ** kwargs ):
441- return _create_normfreenet ('nf_ecaresnet50d ' , pretrained = pretrained , ** kwargs )
455+ def nf_ecaresnet50 (pretrained = False , ** kwargs ):
456+ return _create_normfreenet ('nf_ecaresnet50 ' , pretrained = pretrained , ** kwargs )
0 commit comments