|
| 1 | +""" ConvNext |
| 2 | +
|
| 3 | +Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf |
| 4 | +
|
| 5 | +Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below |
| 6 | +
|
| 7 | +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman |
| 8 | +""" |
| 9 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 10 | +# All rights reserved. |
| 11 | +# This source code is licensed under the MIT license |
| 12 | + |
| 13 | +from functools import partial |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | +import torch.nn.functional as F |
| 18 | + |
| 19 | +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| 20 | +from .fx_features import register_notrace_module |
| 21 | +from .helpers import named_apply, build_model_with_cfg |
| 22 | +from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp |
| 23 | +from .registry import register_model |
| 24 | + |
| 25 | + |
| 26 | +__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this |
| 27 | + |
| 28 | + |
| 29 | +def _cfg(url='', **kwargs): |
| 30 | + return { |
| 31 | + 'url': url, |
| 32 | + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), |
| 33 | + 'crop_pct': 0.875, 'interpolation': 'bicubic', |
| 34 | + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
| 35 | + 'first_conv': 'stem.0', 'classifier': 'head', |
| 36 | + **kwargs |
| 37 | + } |
| 38 | + |
| 39 | + |
| 40 | +default_cfgs = dict( |
| 41 | + convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"), |
| 42 | + convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"), |
| 43 | + convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), |
| 44 | + convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), |
| 45 | + |
| 46 | + convnext_tiny_hnf=_cfg(url='', classifier='head.fc'), |
| 47 | + |
| 48 | + convnext_base_in22k=_cfg( |
| 49 | + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), |
| 50 | + convnext_large_in22k=_cfg( |
| 51 | + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), |
| 52 | + convnext_xlarge_in22k=_cfg( |
| 53 | + url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), |
| 54 | +) |
| 55 | + |
| 56 | + |
| 57 | +def _is_contiguous(tensor: torch.Tensor) -> bool: |
| 58 | + # jit is oh so lovely :/ |
| 59 | + # if torch.jit.is_tracing(): |
| 60 | + # return True |
| 61 | + if torch.jit.is_scripting(): |
| 62 | + return tensor.is_contiguous() |
| 63 | + else: |
| 64 | + return tensor.is_contiguous(memory_format=torch.contiguous_format) |
| 65 | + |
| 66 | + |
| 67 | +@register_notrace_module |
| 68 | +class LayerNorm2d(nn.Module): |
| 69 | + r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, normalized_shape, eps=1e-6): |
| 73 | + super().__init__() |
| 74 | + self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| 75 | + self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| 76 | + self.eps = eps |
| 77 | + self.normalized_shape = (normalized_shape,) |
| 78 | + |
| 79 | + def forward(self, x) -> torch.Tensor: |
| 80 | + if _is_contiguous(x): |
| 81 | + return F.layer_norm( |
| 82 | + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) |
| 83 | + else: |
| 84 | + s, u = torch.var_mean(x, dim=1, keepdim=True) |
| 85 | + x = (x - u) * torch.rsqrt(s + self.eps) |
| 86 | + x = x * self.weight[:, None, None] + self.bias[:, None, None] |
| 87 | + return x |
| 88 | + |
| 89 | + |
| 90 | +class ConvNeXtBlock(nn.Module): |
| 91 | + """ ConvNeXt Block |
| 92 | + There are two equivalent implementations: |
| 93 | + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) |
| 94 | + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back |
| 95 | +
|
| 96 | + Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate |
| 97 | + choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear |
| 98 | + is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. |
| 99 | +
|
| 100 | + Args: |
| 101 | + dim (int): Number of input channels. |
| 102 | + drop_path (float): Stochastic depth rate. Default: 0.0 |
| 103 | + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. |
| 104 | + """ |
| 105 | + |
| 106 | + def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=True, mlp_ratio=4, norm_layer=None): |
| 107 | + super().__init__() |
| 108 | + norm_layer = norm_layer or (partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)) |
| 109 | + mlp_layer = ConvMlp if conv_mlp else Mlp |
| 110 | + self.use_conv_mlp = conv_mlp |
| 111 | + self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv |
| 112 | + self.norm = norm_layer(dim) |
| 113 | + self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) |
| 114 | + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None |
| 115 | + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| 116 | + |
| 117 | + def forward(self, x): |
| 118 | + shortcut = x |
| 119 | + x = self.conv_dw(x) |
| 120 | + if self.use_conv_mlp: |
| 121 | + x = self.norm(x) |
| 122 | + x = self.mlp(x) |
| 123 | + if self.gamma is not None: |
| 124 | + x.mul_(self.gamma.reshape(1, -1, 1, 1)) |
| 125 | + else: |
| 126 | + x = x.permute(0, 2, 3, 1) |
| 127 | + x = self.norm(x) |
| 128 | + x = self.mlp(x) |
| 129 | + if self.gamma is not None: |
| 130 | + x.mul_(self.gamma) |
| 131 | + x = x.permute(0, 3, 1, 2) |
| 132 | + x = self.drop_path(x) + shortcut |
| 133 | + return x |
| 134 | + |
| 135 | + |
| 136 | +class ConvNeXtStage(nn.Module): |
| 137 | + |
| 138 | + def __init__( |
| 139 | + self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=True, |
| 140 | + norm_layer=None, cl_norm_layer=None, cross_stage=False): |
| 141 | + super().__init__() |
| 142 | + |
| 143 | + if in_chs != out_chs or stride > 1: |
| 144 | + self.downsample = nn.Sequential( |
| 145 | + norm_layer(in_chs), |
| 146 | + nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), |
| 147 | + ) |
| 148 | + else: |
| 149 | + self.downsample = nn.Identity() |
| 150 | + |
| 151 | + dp_rates = dp_rates or [0.] * depth |
| 152 | + self.blocks = nn.Sequential(*[ConvNeXtBlock( |
| 153 | + dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, |
| 154 | + norm_layer=norm_layer if conv_mlp else cl_norm_layer) |
| 155 | + for j in range(depth)] |
| 156 | + ) |
| 157 | + |
| 158 | + def forward(self, x): |
| 159 | + x = self.downsample(x) |
| 160 | + x = self.blocks(x) |
| 161 | + return x |
| 162 | + |
| 163 | + |
| 164 | +class ConvNeXt(nn.Module): |
| 165 | + r""" ConvNeXt |
| 166 | + A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf |
| 167 | +
|
| 168 | + Args: |
| 169 | + in_chans (int): Number of input image channels. Default: 3 |
| 170 | + num_classes (int): Number of classes for classification head. Default: 1000 |
| 171 | + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] |
| 172 | + dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] |
| 173 | + drop_rate (float): Head dropout rate |
| 174 | + drop_path_rate (float): Stochastic depth rate. Default: 0. |
| 175 | + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. |
| 176 | + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. |
| 177 | + """ |
| 178 | + |
| 179 | + def __init__( |
| 180 | + self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, |
| 181 | + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=True, |
| 182 | + head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., |
| 183 | + ): |
| 184 | + super().__init__() |
| 185 | + assert output_stride == 32 |
| 186 | + if norm_layer is None: |
| 187 | + norm_layer = partial(LayerNorm2d, eps=1e-6) |
| 188 | + cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) |
| 189 | + else: |
| 190 | + assert conv_mlp,\ |
| 191 | + 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' |
| 192 | + cl_norm_layer = norm_layer |
| 193 | + |
| 194 | + partial(LayerNorm2d, eps=1e-6) |
| 195 | + self.num_classes = num_classes |
| 196 | + self.drop_rate = drop_rate |
| 197 | + self.feature_info = [] |
| 198 | + |
| 199 | + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 |
| 200 | + self.stem = nn.Sequential( |
| 201 | + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), |
| 202 | + norm_layer(dims[0]) |
| 203 | + ) |
| 204 | + |
| 205 | + self.stages = nn.Sequential() |
| 206 | + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] |
| 207 | + curr_stride = patch_size |
| 208 | + prev_chs = dims[0] |
| 209 | + stages = [] |
| 210 | + # 4 feature resolution stages, each consisting of multiple residual blocks |
| 211 | + for i in range(4): |
| 212 | + stride = 2 if i > 0 else 1 |
| 213 | + # FIXME support dilation / output_stride |
| 214 | + curr_stride *= stride |
| 215 | + out_chs = dims[i] |
| 216 | + stages.append(ConvNeXtStage( |
| 217 | + prev_chs, out_chs, stride=stride, |
| 218 | + depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, |
| 219 | + norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) |
| 220 | + ) |
| 221 | + prev_chs = out_chs |
| 222 | + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 |
| 223 | + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] |
| 224 | + self.stages = nn.Sequential(*stages) |
| 225 | + |
| 226 | + self.num_features = prev_chs |
| 227 | + if head_norm_first: |
| 228 | + # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights) |
| 229 | + self.norm = norm_layer(self.num_features) # final norm layer |
| 230 | + self.pool = None # global pool in ClassifierHead, pool == None being used to differentiate |
| 231 | + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) |
| 232 | + else: |
| 233 | + # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) |
| 234 | + self.pool = SelectAdaptivePool2d(pool_type=global_pool) |
| 235 | + # NOTE when cl_norm_layer != norm_layer we could flatten here and use cl, but makes no performance diff |
| 236 | + self.norm = norm_layer(self.num_features) |
| 237 | + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
| 238 | + |
| 239 | + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) |
| 240 | + |
| 241 | + def get_classifier(self): |
| 242 | + return self.head.fc if self.pool is None else self.head |
| 243 | + |
| 244 | + def reset_classifier(self, num_classes=0, global_pool='avg'): |
| 245 | + if self.pool is None: |
| 246 | + # norm -> global pool -> fc ordering |
| 247 | + self.head = ClassifierHead( |
| 248 | + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) |
| 249 | + else: |
| 250 | + # pool -> norm -> fc |
| 251 | + self.pool = SelectAdaptivePool2d(pool_type=global_pool) |
| 252 | + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
| 253 | + |
| 254 | + def forward_features(self, x): |
| 255 | + x = self.stem(x) |
| 256 | + x = self.stages(x) |
| 257 | + if self.pool is None: |
| 258 | + # standard head, norm -> spatial pool -> fc |
| 259 | + # ideally, last norm is within forward_features, but can only do so if norm precedes pooling |
| 260 | + x = self.norm(x) |
| 261 | + return x |
| 262 | + |
| 263 | + def forward(self, x): |
| 264 | + x = self.forward_features(x) |
| 265 | + if self.pool is not None: |
| 266 | + # ConvNeXt head, spatial pool -> norm -> fc |
| 267 | + # FIXME clean this up |
| 268 | + x = self.pool(x) |
| 269 | + x = self.norm(x) |
| 270 | + if not self.pool.is_identity(): |
| 271 | + x = x.flatten(1) |
| 272 | + if self.drop_rate > 0: |
| 273 | + x = F.dropout(x, self.drop_rate, self.training) |
| 274 | + x = self.head(x) |
| 275 | + return x |
| 276 | + |
| 277 | + |
| 278 | +def _init_weights(module, name=None, head_init_scale=1.0): |
| 279 | + if isinstance(module, nn.Conv2d): |
| 280 | + trunc_normal_(module.weight, std=.02) |
| 281 | + nn.init.constant_(module.bias, 0) |
| 282 | + elif isinstance(module, nn.Linear): |
| 283 | + trunc_normal_(module.weight, std=.02) |
| 284 | + nn.init.constant_(module.bias, 0) |
| 285 | + if name and '.head' in name: |
| 286 | + module.weight.data.mul_(head_init_scale) |
| 287 | + module.bias.data.mul_(head_init_scale) |
| 288 | + |
| 289 | + |
| 290 | +def checkpoint_filter_fn(state_dict, model): |
| 291 | + """ Remap FB checkpoints -> timm """ |
| 292 | + if 'model' in state_dict: |
| 293 | + state_dict = state_dict['model'] |
| 294 | + out_dict = {} |
| 295 | + import re |
| 296 | + for k, v in state_dict.items(): |
| 297 | + k = k.replace('downsample_layers.0.', 'stem.') |
| 298 | + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) |
| 299 | + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) |
| 300 | + k = k.replace('dwconv', 'conv_dw') |
| 301 | + k = k.replace('pwconv', 'mlp.fc') |
| 302 | + if v.ndim == 2 and 'head' not in k: |
| 303 | + model_shape = model.state_dict()[k].shape |
| 304 | + v = v.reshape(model_shape) |
| 305 | + out_dict[k] = v |
| 306 | + return out_dict |
| 307 | + |
| 308 | + |
| 309 | +def _create_convnext(variant, pretrained=False, **kwargs): |
| 310 | + model = build_model_with_cfg( |
| 311 | + ConvNeXt, variant, pretrained, |
| 312 | + default_cfg=default_cfgs[variant], |
| 313 | + pretrained_filter_fn=checkpoint_filter_fn, |
| 314 | + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), |
| 315 | + **kwargs) |
| 316 | + return model |
| 317 | + |
| 318 | + |
| 319 | +@register_model |
| 320 | +def convnext_tiny(pretrained=False, **kwargs): |
| 321 | + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) |
| 322 | + model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args) |
| 323 | + return model |
| 324 | + |
| 325 | + |
| 326 | +@register_model |
| 327 | +def convnext_tiny_hnf(pretrained=False, **kwargs): |
| 328 | + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs) |
| 329 | + model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) |
| 330 | + return model |
| 331 | + |
| 332 | + |
| 333 | +@register_model |
| 334 | +def convnext_small(pretrained=False, **kwargs): |
| 335 | + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) |
| 336 | + model = _create_convnext('convnext_small', pretrained=pretrained, **model_args) |
| 337 | + return model |
| 338 | + |
| 339 | + |
| 340 | +@register_model |
| 341 | +def convnext_base(pretrained=False, **kwargs): |
| 342 | + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) |
| 343 | + model = _create_convnext('convnext_base', pretrained=pretrained, **model_args) |
| 344 | + return model |
| 345 | + |
| 346 | + |
| 347 | +@register_model |
| 348 | +def convnext_large(pretrained=False, **kwargs): |
| 349 | + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) |
| 350 | + model = _create_convnext('convnext_large', pretrained=pretrained, **model_args) |
| 351 | + return model |
| 352 | + |
| 353 | + |
| 354 | +@register_model |
| 355 | +def convnext_base_in22k(pretrained=False, **kwargs): |
| 356 | + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) |
| 357 | + model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args) |
| 358 | + return model |
| 359 | + |
| 360 | + |
| 361 | +@register_model |
| 362 | +def convnext_large_in22k(pretrained=False, **kwargs): |
| 363 | + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) |
| 364 | + model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args) |
| 365 | + return model |
| 366 | + |
| 367 | + |
| 368 | +@register_model |
| 369 | +def convnext_xlarge_in22k(pretrained=False, **kwargs): |
| 370 | + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], conv_mlp=False, **kwargs) |
| 371 | + model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) |
| 372 | + return model |
| 373 | + |
| 374 | + |
| 375 | + |
0 commit comments