@@ -29,9 +29,9 @@ def _cfg(url='', **kwargs):
2929
3030default_cfgs = {
3131 'tresnet_m' : _cfg (
32- url = 'https://miil-public-eu.oss-eu-central-1.aliyuncs. com/model-zoo/ImageNet_21K_P/models/timm /tresnet_m_1k_miil_83_1.pth' ),
32+ url = 'https://github. com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet /tresnet_m_1k_miil_83_1-d236afcb .pth' ),
3333 'tresnet_m_miil_in21k' : _cfg (
34- url = 'https://miil-public-eu.oss-eu-central-1.aliyuncs. com/model-zoo/ImageNet_21K_P/models/timm /tresnet_m_miil_in21k.pth' , num_classes = 11221 ),
34+ url = 'https://github. com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet /tresnet_m_miil_in21k-901b6ed4 .pth' , num_classes = 11221 ),
3535 'tresnet_l' : _cfg (
3636 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth' ),
3737 'tresnet_xl' : _cfg (
@@ -44,7 +44,10 @@ def _cfg(url='', **kwargs):
4444 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth' ),
4545 'tresnet_xl_448' : _cfg (
4646 input_size = (3 , 448 , 448 ), pool_size = (14 , 14 ),
47- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth' )
47+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth' ),
48+
49+ 'tresnet_v2_l' : _cfg (
50+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_v2_83_9-f36e4445.pth' ),
4851}
4952
5053
@@ -99,7 +102,7 @@ def forward(self, x):
99102 if self .se is not None :
100103 out = self .se (out )
101104
102- out += shortcut
105+ out = out + shortcut
103106 out = self .relu (out )
104107 return out
105108
@@ -153,7 +156,16 @@ def forward(self, x):
153156
154157
155158class TResNet (nn .Module ):
156- def __init__ (self , layers , in_chans = 3 , num_classes = 1000 , width_factor = 1.0 , global_pool = 'fast' , drop_rate = 0. ):
159+ def __init__ (
160+ self ,
161+ layers ,
162+ in_chans = 3 ,
163+ num_classes = 1000 ,
164+ width_factor = 1.0 ,
165+ v2 = False ,
166+ global_pool = 'fast' ,
167+ drop_rate = 0. ,
168+ ):
157169 self .num_classes = num_classes
158170 self .drop_rate = drop_rate
159171 super (TResNet , self ).__init__ ()
@@ -163,15 +175,19 @@ def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, globa
163175 # TResnet stages
164176 self .inplanes = int (64 * width_factor )
165177 self .planes = int (64 * width_factor )
178+ if v2 :
179+ self .inplanes = self .inplanes // 8 * 8
180+ self .planes = self .planes // 8 * 8
181+
166182 conv1 = conv2d_iabn (in_chans * 16 , self .planes , stride = 1 , kernel_size = 3 )
167183 layer1 = self ._make_layer (
168- BasicBlock , self .planes , layers [0 ], stride = 1 , use_se = True , aa_layer = aa_layer ) # 56x56
184+ Bottleneck if v2 else BasicBlock , self .planes , layers [0 ], stride = 1 , use_se = True , aa_layer = aa_layer )
169185 layer2 = self ._make_layer (
170- BasicBlock , self .planes * 2 , layers [1 ], stride = 2 , use_se = True , aa_layer = aa_layer ) # 28x28
186+ Bottleneck if v2 else BasicBlock , self .planes * 2 , layers [1 ], stride = 2 , use_se = True , aa_layer = aa_layer )
171187 layer3 = self ._make_layer (
172- Bottleneck , self .planes * 4 , layers [2 ], stride = 2 , use_se = True , aa_layer = aa_layer ) # 14x14
188+ Bottleneck , self .planes * 4 , layers [2 ], stride = 2 , use_se = True , aa_layer = aa_layer )
173189 layer4 = self ._make_layer (
174- Bottleneck , self .planes * 8 , layers [3 ], stride = 2 , use_se = False , aa_layer = aa_layer ) # 7x7
190+ Bottleneck , self .planes * 8 , layers [3 ], stride = 2 , use_se = False , aa_layer = aa_layer )
175191
176192 # body
177193 self .body = nn .Sequential (OrderedDict ([
@@ -285,6 +301,12 @@ def tresnet_l(pretrained=False, **kwargs):
285301 return _create_tresnet ('tresnet_l' , pretrained = pretrained , ** model_kwargs )
286302
287303
304+ @register_model
305+ def tresnet_v2_l (pretrained = False , ** kwargs ):
306+ model_kwargs = dict (layers = [3 , 4 , 23 , 3 ], width_factor = 1.0 , v2 = True , ** kwargs )
307+ return _create_tresnet ('tresnet_v2_l' , pretrained = pretrained , ** model_kwargs )
308+
309+
288310@register_model
289311def tresnet_xl (pretrained = False , ** kwargs ):
290312 model_kwargs = dict (layers = [4 , 5 , 24 , 3 ], width_factor = 1.3 , ** kwargs )
0 commit comments