|
| 1 | +""" |
| 2 | +InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900 |
| 3 | +
|
| 4 | +Some code is borrowed from timm: https://github.com/huggingface/pytorch-image-models |
| 5 | +""" |
| 6 | + |
| 7 | +from functools import partial |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | + |
| 12 | +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| 13 | +from timm.layers import trunc_normal_, DropPath, to_2tuple |
| 14 | +from ._manipulate import checkpoint_seq |
| 15 | +from ._registry import register_model |
| 16 | + |
| 17 | + |
| 18 | +class InceptionDWConv2d(nn.Module): |
| 19 | + """ Inception depthweise convolution |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125): |
| 23 | + super().__init__() |
| 24 | + |
| 25 | + gc = int(in_channels * branch_ratio) # channel numbers of a convolution branch |
| 26 | + self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc) |
| 27 | + self.dwconv_w = nn.Conv2d( |
| 28 | + gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc) |
| 29 | + self.dwconv_h = nn.Conv2d( |
| 30 | + gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc) |
| 31 | + self.split_indexes = (in_channels - 3 * gc, gc, gc, gc) |
| 32 | + |
| 33 | + def forward(self, x): |
| 34 | + x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1) |
| 35 | + return torch.cat(( |
| 36 | + x_id, |
| 37 | + self.dwconv_hw(x_hw), |
| 38 | + self.dwconv_w(x_w), |
| 39 | + self.dwconv_h(x_h) |
| 40 | + ), dim=1, |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +class ConvMlp(nn.Module): |
| 45 | + """ MLP using 1x1 convs that keeps spatial dims |
| 46 | + copied from timm: https://github.com/huggingface/pytorch-image-models/blob/v0.6.11/timm/models/layers/mlp.py |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, |
| 51 | + norm_layer=None, bias=True, drop=0.): |
| 52 | + super().__init__() |
| 53 | + out_features = out_features or in_features |
| 54 | + hidden_features = hidden_features or in_features |
| 55 | + bias = to_2tuple(bias) |
| 56 | + |
| 57 | + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) |
| 58 | + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() |
| 59 | + self.act = act_layer() |
| 60 | + self.drop = nn.Dropout(drop) |
| 61 | + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) |
| 62 | + |
| 63 | + def forward(self, x): |
| 64 | + x = self.fc1(x) |
| 65 | + x = self.norm(x) |
| 66 | + x = self.act(x) |
| 67 | + x = self.drop(x) |
| 68 | + x = self.fc2(x) |
| 69 | + return x |
| 70 | + |
| 71 | + |
| 72 | +class MlpHead(nn.Module): |
| 73 | + """ MLP classification head |
| 74 | + """ |
| 75 | + |
| 76 | + def __init__( |
| 77 | + self, dim, num_classes=1000, mlp_ratio=3, act_layer=nn.GELU, |
| 78 | + norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., bias=True): |
| 79 | + super().__init__() |
| 80 | + hidden_features = int(mlp_ratio * dim) |
| 81 | + self.fc1 = nn.Linear(dim, hidden_features, bias=bias) |
| 82 | + self.act = act_layer() |
| 83 | + self.norm = norm_layer(hidden_features) |
| 84 | + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) |
| 85 | + self.drop = nn.Dropout(drop) |
| 86 | + |
| 87 | + def forward(self, x): |
| 88 | + x = x.mean((2, 3)) # global average pooling |
| 89 | + x = self.fc1(x) |
| 90 | + x = self.act(x) |
| 91 | + x = self.norm(x) |
| 92 | + x = self.drop(x) |
| 93 | + x = self.fc2(x) |
| 94 | + return x |
| 95 | + |
| 96 | + |
| 97 | +class MetaNeXtBlock(nn.Module): |
| 98 | + """ MetaNeXtBlock Block |
| 99 | + Args: |
| 100 | + dim (int): Number of input channels. |
| 101 | + drop_path (float): Stochastic depth rate. Default: 0.0 |
| 102 | + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__( |
| 106 | + self, |
| 107 | + dim, |
| 108 | + token_mixer=nn.Identity, |
| 109 | + norm_layer=nn.BatchNorm2d, |
| 110 | + mlp_layer=ConvMlp, |
| 111 | + mlp_ratio=4, |
| 112 | + act_layer=nn.GELU, |
| 113 | + ls_init_value=1e-6, |
| 114 | + drop_path=0., |
| 115 | + |
| 116 | + ): |
| 117 | + super().__init__() |
| 118 | + self.token_mixer = token_mixer(dim) |
| 119 | + self.norm = norm_layer(dim) |
| 120 | + self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer) |
| 121 | + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None |
| 122 | + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| 123 | + |
| 124 | + def forward(self, x): |
| 125 | + shortcut = x |
| 126 | + x = self.token_mixer(x) |
| 127 | + x = self.norm(x) |
| 128 | + x = self.mlp(x) |
| 129 | + if self.gamma is not None: |
| 130 | + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) |
| 131 | + x = self.drop_path(x) + shortcut |
| 132 | + return x |
| 133 | + |
| 134 | + |
| 135 | +class MetaNeXtStage(nn.Module): |
| 136 | + def __init__( |
| 137 | + self, |
| 138 | + in_chs, |
| 139 | + out_chs, |
| 140 | + ds_stride=2, |
| 141 | + depth=2, |
| 142 | + drop_path_rates=None, |
| 143 | + ls_init_value=1.0, |
| 144 | + token_mixer=nn.Identity, |
| 145 | + act_layer=nn.GELU, |
| 146 | + norm_layer=None, |
| 147 | + mlp_ratio=4, |
| 148 | + ): |
| 149 | + super().__init__() |
| 150 | + self.grad_checkpointing = False |
| 151 | + if ds_stride > 1: |
| 152 | + self.downsample = nn.Sequential( |
| 153 | + norm_layer(in_chs), |
| 154 | + nn.Conv2d(in_chs, out_chs, kernel_size=ds_stride, stride=ds_stride), |
| 155 | + ) |
| 156 | + else: |
| 157 | + self.downsample = nn.Identity() |
| 158 | + |
| 159 | + drop_path_rates = drop_path_rates or [0.] * depth |
| 160 | + stage_blocks = [] |
| 161 | + for i in range(depth): |
| 162 | + stage_blocks.append(MetaNeXtBlock( |
| 163 | + dim=out_chs, |
| 164 | + drop_path=drop_path_rates[i], |
| 165 | + ls_init_value=ls_init_value, |
| 166 | + token_mixer=token_mixer, |
| 167 | + act_layer=act_layer, |
| 168 | + norm_layer=norm_layer, |
| 169 | + mlp_ratio=mlp_ratio, |
| 170 | + )) |
| 171 | + in_chs = out_chs |
| 172 | + self.blocks = nn.Sequential(*stage_blocks) |
| 173 | + |
| 174 | + def forward(self, x): |
| 175 | + x = self.downsample(x) |
| 176 | + if self.grad_checkpointing and not torch.jit.is_scripting(): |
| 177 | + x = checkpoint_seq(self.blocks, x) |
| 178 | + else: |
| 179 | + x = self.blocks(x) |
| 180 | + return x |
| 181 | + |
| 182 | + |
| 183 | +class MetaNeXt(nn.Module): |
| 184 | + r""" MetaNeXt |
| 185 | + A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/pdf/2203.xxxxx.pdf |
| 186 | +
|
| 187 | + Args: |
| 188 | + in_chans (int): Number of input image channels. Default: 3 |
| 189 | + num_classes (int): Number of classes for classification head. Default: 1000 |
| 190 | + depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3) |
| 191 | + dims (tuple(int)): Feature dimension at each stage. Default: (96, 192, 384, 768) |
| 192 | + token_mixers: Token mixer function. Default: nn.Identity |
| 193 | + norm_layer: Normalziation layer. Default: nn.BatchNorm2d |
| 194 | + act_layer: Activation function for MLP. Default: nn.GELU |
| 195 | + mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3) |
| 196 | + head_fn: classifier head |
| 197 | + drop_rate (float): Head dropout rate |
| 198 | + drop_path_rate (float): Stochastic depth rate. Default: 0. |
| 199 | + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. |
| 200 | + """ |
| 201 | + |
| 202 | + def __init__( |
| 203 | + self, |
| 204 | + in_chans=3, |
| 205 | + num_classes=1000, |
| 206 | + depths=(3, 3, 9, 3), |
| 207 | + dims=(96, 192, 384, 768), |
| 208 | + token_mixers=nn.Identity, |
| 209 | + norm_layer=nn.BatchNorm2d, |
| 210 | + act_layer=nn.GELU, |
| 211 | + mlp_ratios=(4, 4, 4, 3), |
| 212 | + head_fn=MlpHead, |
| 213 | + drop_rate=0., |
| 214 | + drop_path_rate=0., |
| 215 | + ls_init_value=1e-6, |
| 216 | + **kwargs, |
| 217 | + ): |
| 218 | + super().__init__() |
| 219 | + |
| 220 | + num_stage = len(depths) |
| 221 | + if not isinstance(token_mixers, (list, tuple)): |
| 222 | + token_mixers = [token_mixers] * num_stage |
| 223 | + if not isinstance(mlp_ratios, (list, tuple)): |
| 224 | + mlp_ratios = [mlp_ratios] * num_stage |
| 225 | + |
| 226 | + self.num_classes = num_classes |
| 227 | + self.drop_rate = drop_rate |
| 228 | + self.stem = nn.Sequential( |
| 229 | + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), |
| 230 | + norm_layer(dims[0]) |
| 231 | + ) |
| 232 | + |
| 233 | + self.stages = nn.Sequential() |
| 234 | + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] |
| 235 | + stages = [] |
| 236 | + prev_chs = dims[0] |
| 237 | + # feature resolution stages, each consisting of multiple residual blocks |
| 238 | + for i in range(num_stage): |
| 239 | + out_chs = dims[i] |
| 240 | + stages.append(MetaNeXtStage( |
| 241 | + prev_chs, |
| 242 | + out_chs, |
| 243 | + ds_stride=2 if i > 0 else 1, |
| 244 | + depth=depths[i], |
| 245 | + drop_path_rates=dp_rates[i], |
| 246 | + ls_init_value=ls_init_value, |
| 247 | + act_layer=act_layer, |
| 248 | + token_mixer=token_mixers[i], |
| 249 | + norm_layer=norm_layer, |
| 250 | + mlp_ratio=mlp_ratios[i], |
| 251 | + )) |
| 252 | + prev_chs = out_chs |
| 253 | + self.stages = nn.Sequential(*stages) |
| 254 | + self.num_features = prev_chs |
| 255 | + self.head = head_fn(self.num_features, num_classes, drop=drop_rate) |
| 256 | + self.apply(self._init_weights) |
| 257 | + |
| 258 | + @torch.jit.ignore |
| 259 | + def set_grad_checkpointing(self, enable=True): |
| 260 | + for s in self.stages: |
| 261 | + s.grad_checkpointing = enable |
| 262 | + |
| 263 | + @torch.jit.ignore |
| 264 | + def no_weight_decay(self): |
| 265 | + return {'norm'} |
| 266 | + |
| 267 | + def forward_features(self, x): |
| 268 | + x = self.stem(x) |
| 269 | + x = self.stages(x) |
| 270 | + return x |
| 271 | + |
| 272 | + def forward_head(self, x): |
| 273 | + x = self.head(x) |
| 274 | + return x |
| 275 | + |
| 276 | + def forward(self, x): |
| 277 | + x = self.forward_features(x) |
| 278 | + x = self.forward_head(x) |
| 279 | + return x |
| 280 | + |
| 281 | + def _init_weights(self, m): |
| 282 | + if isinstance(m, (nn.Conv2d, nn.Linear)): |
| 283 | + trunc_normal_(m.weight, std=.02) |
| 284 | + if m.bias is not None: |
| 285 | + nn.init.constant_(m.bias, 0) |
| 286 | + |
| 287 | + |
| 288 | +def _cfg(url='', **kwargs): |
| 289 | + return { |
| 290 | + 'url': url, |
| 291 | + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), |
| 292 | + 'crop_pct': 0.875, 'interpolation': 'bicubic', |
| 293 | + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
| 294 | + 'first_conv': 'stem.0', 'classifier': 'head.fc', |
| 295 | + **kwargs |
| 296 | + } |
| 297 | + |
| 298 | + |
| 299 | +default_cfgs = dict( |
| 300 | + inception_next_tiny=_cfg( |
| 301 | + url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth', |
| 302 | + ), |
| 303 | + inception_next_small=_cfg( |
| 304 | + url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth', |
| 305 | + ), |
| 306 | + inception_next_base=_cfg( |
| 307 | + url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth', |
| 308 | + ), |
| 309 | + inception_next_base_384=_cfg( |
| 310 | + url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth', |
| 311 | + input_size=(3, 384, 384), crop_pct=1.0, |
| 312 | + ), |
| 313 | +) |
| 314 | + |
| 315 | + |
| 316 | +@register_model |
| 317 | +def inception_next_tiny(pretrained=False, **kwargs): |
| 318 | + model = MetaNeXt( |
| 319 | + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), |
| 320 | + token_mixers=InceptionDWConv2d, |
| 321 | + **kwargs |
| 322 | + ) |
| 323 | + model.default_cfg = default_cfgs['inception_next_tiny'] |
| 324 | + if pretrained: |
| 325 | + state_dict = torch.hub.load_state_dict_from_url( |
| 326 | + url=model.default_cfg['url'], map_location="cpu", check_hash=True) |
| 327 | + model.load_state_dict(state_dict) |
| 328 | + return model |
| 329 | + |
| 330 | + |
| 331 | +@register_model |
| 332 | +def inception_next_small(pretrained=False, **kwargs): |
| 333 | + model = MetaNeXt( |
| 334 | + depths=(3, 3, 27, 3), dims=(96, 192, 384, 768), |
| 335 | + token_mixers=InceptionDWConv2d, |
| 336 | + **kwargs |
| 337 | + ) |
| 338 | + model.default_cfg = default_cfgs['inception_next_small'] |
| 339 | + if pretrained: |
| 340 | + state_dict = torch.hub.load_state_dict_from_url( |
| 341 | + url=model.default_cfg['url'], map_location="cpu", check_hash=True) |
| 342 | + model.load_state_dict(state_dict) |
| 343 | + return model |
| 344 | + |
| 345 | + |
| 346 | +@register_model |
| 347 | +def inception_next_base(pretrained=False, **kwargs): |
| 348 | + model = MetaNeXt( |
| 349 | + depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024), |
| 350 | + token_mixers=InceptionDWConv2d, |
| 351 | + **kwargs |
| 352 | + ) |
| 353 | + model.default_cfg = default_cfgs['inception_next_base'] |
| 354 | + if pretrained: |
| 355 | + state_dict = torch.hub.load_state_dict_from_url( |
| 356 | + url=model.default_cfg['url'], map_location="cpu", check_hash=True) |
| 357 | + model.load_state_dict(state_dict) |
| 358 | + return model |
| 359 | + |
| 360 | + |
| 361 | +@register_model |
| 362 | +def inception_next_base_384(pretrained=False, **kwargs): |
| 363 | + model = MetaNeXt( |
| 364 | + depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], |
| 365 | + mlp_ratios=[4, 4, 4, 3], |
| 366 | + token_mixers=InceptionDWConv2d, |
| 367 | + **kwargs |
| 368 | + ) |
| 369 | + model.default_cfg = default_cfgs['inception_next_base_384'] |
| 370 | + if pretrained: |
| 371 | + state_dict = torch.hub.load_state_dict_from_url( |
| 372 | + url=model.default_cfg['url'], map_location="cpu", check_hash=True) |
| 373 | + model.load_state_dict(state_dict) |
| 374 | + return model |
0 commit comments