Skip to content

Commit e0a5911

Browse files
authored
Merge pull request #1645 from rwightman/norm_mlp_classifier
Extract NormMlpClassifierHead from maxxvit.py
2 parents 29fda20 + b304208 commit e0a5911

File tree

5 files changed

+216
-137
lines changed

5 files changed

+216
-137
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
44
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
55
from .blur_pool import BlurPool2d
6-
from .classifier import ClassifierHead, create_classifier
6+
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
77
from .cond_conv2d import CondConv2d, get_condconv_initializer
88
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
99
set_layer_config

timm/layers/classifier.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5-
from torch import nn as nn
5+
from collections import OrderedDict
6+
from functools import partial
7+
from typing import Optional, Union, Callable
8+
9+
import torch
10+
import torch.nn as nn
611
from torch.nn import functional as F
712

813
from .adaptive_avgmax_pool import SelectAdaptivePool2d
14+
from .create_act import get_act_layer
15+
from .create_norm import get_norm_layer
916

1017

1118
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
@@ -38,7 +45,21 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
3845
class ClassifierHead(nn.Module):
3946
"""Classifier head w/ configurable global pooling and dropout."""
4047

41-
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
48+
def __init__(
49+
self,
50+
in_features: int,
51+
num_classes: int,
52+
pool_type: str = 'avg',
53+
drop_rate: float = 0.,
54+
use_conv: bool = False,
55+
):
56+
"""
57+
Args:
58+
in_features: The number of input features.
59+
num_classes: The number of classes for the final classifier layer (output).
60+
pool_type: Global pooling type, pooling disabled if empty string ('').
61+
drop_rate: Pre-classifier dropout rate.
62+
"""
4263
super(ClassifierHead, self).__init__()
4364
self.drop_rate = drop_rate
4465
self.in_features = in_features
@@ -65,3 +86,76 @@ def forward(self, x, pre_logits: bool = False):
6586
else:
6687
x = self.fc(x)
6788
return self.flatten(x)
89+
90+
91+
class NormMlpClassifierHead(nn.Module):
92+
93+
def __init__(
94+
self,
95+
in_features: int,
96+
num_classes: int,
97+
hidden_size: Optional[int] = None,
98+
pool_type: str = 'avg',
99+
drop_rate: float = 0.,
100+
norm_layer: Union[str, Callable] = 'layernorm2d',
101+
act_layer: Union[str, Callable] = 'tanh',
102+
):
103+
"""
104+
Args:
105+
in_features: The number of input features.
106+
num_classes: The number of classes for the final classifier layer (output).
107+
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
108+
pool_type: Global pooling type, pooling disabled if empty string ('').
109+
drop_rate: Pre-classifier dropout rate.
110+
norm_layer: Normalization layer type.
111+
act_layer: MLP activation layer type (only used if hidden_size is not None).
112+
"""
113+
super().__init__()
114+
self.drop_rate = drop_rate
115+
self.in_features = in_features
116+
self.hidden_size = hidden_size
117+
self.num_features = in_features
118+
self.use_conv = not pool_type
119+
norm_layer = get_norm_layer(norm_layer)
120+
act_layer = get_act_layer(act_layer)
121+
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
122+
123+
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
124+
self.norm = norm_layer(in_features)
125+
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
126+
if hidden_size:
127+
self.pre_logits = nn.Sequential(OrderedDict([
128+
('fc', linear_layer(in_features, hidden_size)),
129+
('act', act_layer()),
130+
]))
131+
self.num_features = hidden_size
132+
else:
133+
self.pre_logits = nn.Identity()
134+
self.drop = nn.Dropout(self.drop_rate)
135+
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
136+
137+
def reset(self, num_classes, global_pool=None):
138+
if global_pool is not None:
139+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
140+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
141+
self.use_conv = self.global_pool.is_identity()
142+
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
143+
if self.hidden_size:
144+
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
145+
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
146+
with torch.no_grad():
147+
new_fc = linear_layer(self.in_features, self.hidden_size)
148+
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
149+
new_fc.bias.copy_(self.pre_logits.fc.bias)
150+
self.pre_logits.fc = new_fc
151+
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
152+
153+
def forward(self, x, pre_logits: bool = False):
154+
x = self.global_pool(x)
155+
x = self.norm(x)
156+
x = self.flatten(x)
157+
x = self.pre_logits(x)
158+
if pre_logits:
159+
return x
160+
x = self.fc(x)
161+
return x

timm/models/convnext.py

Lines changed: 98 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@
3939

4040
from collections import OrderedDict
4141
from functools import partial
42+
from typing import Callable, Optional, Tuple, Union
4243

4344
import torch
4445
import torch.nn as nn
4546

4647
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4748
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
4849
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
50+
from timm.layers import NormMlpClassifierHead, ClassifierHead
4951
from ._builder import build_model_with_cfg
5052
from ._manipulate import named_apply, checkpoint_seq
5153
from ._pretrained import generate_default_cfgs
@@ -188,48 +190,50 @@ class ConvNeXt(nn.Module):
188190

189191
def __init__(
190192
self,
191-
in_chans=3,
192-
num_classes=1000,
193-
global_pool='avg',
194-
output_stride=32,
195-
depths=(3, 3, 9, 3),
196-
dims=(96, 192, 384, 768),
197-
kernel_sizes=7,
198-
ls_init_value=1e-6,
199-
stem_type='patch',
200-
patch_size=4,
201-
head_init_scale=1.,
202-
head_norm_first=False,
203-
conv_mlp=False,
204-
conv_bias=True,
205-
use_grn=False,
206-
act_layer='gelu',
207-
norm_layer=None,
208-
norm_eps=None,
209-
drop_rate=0.,
210-
drop_path_rate=0.,
193+
in_chans: int = 3,
194+
num_classes: int = 1000,
195+
global_pool: str = 'avg',
196+
output_stride: int = 32,
197+
depths: Tuple[int, ...] = (3, 3, 9, 3),
198+
dims: Tuple[int, ...] = (96, 192, 384, 768),
199+
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
200+
ls_init_value: Optional[float] = 1e-6,
201+
stem_type: str = 'patch',
202+
patch_size: int = 4,
203+
head_init_scale: float = 1.,
204+
head_norm_first: bool = False,
205+
head_hidden_size: Optional[int] = None,
206+
conv_mlp: bool = False,
207+
conv_bias: bool = True,
208+
use_grn: bool = False,
209+
act_layer: Union[str, Callable] = 'gelu',
210+
norm_layer: Optional[Union[str, Callable]] = None,
211+
norm_eps: Optional[float] = None,
212+
drop_rate: float = 0.,
213+
drop_path_rate: float = 0.,
211214
):
212215
"""
213216
Args:
214-
in_chans (int): Number of input image channels (default: 3)
215-
num_classes (int): Number of classes for classification head (default: 1000)
216-
global_pool (str): Global pooling type (default: 'avg')
217-
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
218-
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
219-
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
220-
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
221-
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
222-
stem_type (str): Type of stem (default: 'patch')
223-
patch_size (int): Stem patch size for patch stem (default: 4)
224-
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
225-
head_norm_first (bool): Apply normalization before global pool + head (default: False)
226-
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
227-
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
228-
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
229-
act_layer (Union[str, nn.Module]): Activation Layer
230-
norm_layer (Union[str, nn.Module]): Normalization Layer
231-
drop_rate (float): Head dropout rate (default: 0.)
232-
drop_path_rate (float): Stochastic depth rate (default: 0.)
217+
in_chans: Number of input image channels.
218+
num_classes: Number of classes for classification head.
219+
global_pool: Global pooling type.
220+
output_stride: Output stride of network, one of (8, 16, 32).
221+
depths: Number of blocks at each stage.
222+
dims: Feature dimension at each stage.
223+
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
224+
ls_init_value: Init value for Layer Scale, disabled if None.
225+
stem_type: Type of stem.
226+
patch_size: Stem patch size for patch stem.
227+
head_init_scale: Init scaling value for classifier weights and biases.
228+
head_norm_first: Apply normalization before global pool + head.
229+
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
230+
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
231+
conv_bias: Use bias layers w/ all convolutions.
232+
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
233+
act_layer: Activation layer type.
234+
norm_layer: Normalization layer type.
235+
drop_rate: Head pre-classifier dropout rate.
236+
drop_path_rate: Stochastic depth drop rate.
233237
"""
234238
super().__init__()
235239
assert output_stride in (8, 16, 32)
@@ -307,14 +311,26 @@ def __init__(
307311

308312
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
309313
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
310-
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
311-
self.head = nn.Sequential(OrderedDict([
312-
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
313-
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
314-
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
315-
('drop', nn.Dropout(self.drop_rate)),
316-
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
317-
314+
if head_norm_first:
315+
assert not head_hidden_size
316+
self.norm_pre = norm_layer(self.num_features)
317+
self.head = ClassifierHead(
318+
self.num_features,
319+
num_classes,
320+
pool_type=global_pool,
321+
drop_rate=self.drop_rate,
322+
)
323+
else:
324+
self.norm_pre = nn.Identity()
325+
self.head = NormMlpClassifierHead(
326+
self.num_features,
327+
num_classes,
328+
hidden_size=head_hidden_size,
329+
pool_type=global_pool,
330+
drop_rate=self.drop_rate,
331+
norm_layer=norm_layer,
332+
act_layer='gelu',
333+
)
318334
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
319335

320336
@torch.jit.ignore
@@ -338,10 +354,7 @@ def get_classifier(self):
338354
return self.head.fc
339355

340356
def reset_classifier(self, num_classes=0, global_pool=None):
341-
if global_pool is not None:
342-
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
343-
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
344-
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
357+
self.head.reset(num_classes, global_pool=global_pool)
345358

346359
def forward_features(self, x):
347360
x = self.stem(x)
@@ -350,12 +363,7 @@ def forward_features(self, x):
350363
return x
351364

352365
def forward_head(self, x, pre_logits: bool = False):
353-
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
354-
x = self.head.global_pool(x)
355-
x = self.head.norm(x)
356-
x = self.head.flatten(x)
357-
x = self.head.drop(x)
358-
return x if pre_logits else self.head.fc(x)
366+
return self.head(x, pre_logits=pre_logits)
359367

360368
def forward(self, x):
361369
x = self.forward_features(x)
@@ -389,6 +397,11 @@ def checkpoint_filter_fn(state_dict, model):
389397
if 'visual.head.proj.weight' in state_dict:
390398
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
391399
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
400+
elif 'visual.head.mlp.fc1.weight' in state_dict:
401+
out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
402+
out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
403+
out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
404+
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
392405
return out_dict
393406

394407
import re
@@ -708,6 +721,22 @@ def _cfgv2(url='', **kwargs):
708721

709722
'convnextv2_small.untrained': _cfg(),
710723

724+
# CLIP weights, fine-tuned on in1k or in12k + in1k
725+
'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
726+
hf_hub_id='timm/',
727+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
728+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
729+
'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
730+
hf_hub_id='timm/',
731+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
732+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
733+
'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
734+
hf_hub_id='timm/',
735+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
736+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
737+
),
738+
739+
711740
# CLIP based weights, original image tower weights and fine-tunes
712741
'convnext_base.clip_laion2b': _cfg(
713742
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
@@ -734,6 +763,11 @@ def _cfgv2(url='', **kwargs):
734763
hf_hub_filename='open_clip_pytorch_model.bin',
735764
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
736765
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
766+
'convnext_large_mlp.clip_laion2b_augreg': _cfg(
767+
hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
768+
hf_hub_filename='open_clip_pytorch_model.bin',
769+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
770+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
737771
})
738772

739773

@@ -846,6 +880,13 @@ def convnext_large(pretrained=False, **kwargs):
846880
return model
847881

848882

883+
@register_model
884+
def convnext_large_mlp(pretrained=False, **kwargs):
885+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
886+
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
887+
return model
888+
889+
849890
@register_model
850891
def convnext_xlarge(pretrained=False, **kwargs):
851892
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)

0 commit comments

Comments
 (0)