Skip to content

Commit 00c5be7

Browse files
authored
Merge pull request #2258 from huggingface/sbb2_vit_hiera_weights
Update Hiera model for abswin & add more in12k weights for hiera & vit
2 parents 076efef + 17923a6 commit 00c5be7

File tree

12 files changed

+941
-92
lines changed

12 files changed

+941
-92
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@
5252
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5353
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5454
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
55-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit',
55+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2'
5656
]
5757

5858
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
5959
NON_STD_FILTERS = [
6060
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6161
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
62-
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
62+
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
6363
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
6464
]
6565
NUM_NON_STD = len(NON_STD_FILTERS)

timm/layers/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .attention_pool import AttentionPoolLatent
66
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
77
from .blur_pool import BlurPool2d, create_aa
8-
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
8+
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
99
from .cond_conv2d import CondConv2d, get_condconv_initializer
1010
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
1111
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
@@ -29,6 +29,7 @@
2929
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
3030
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
3131
from .inplace_abn import InplaceAbn
32+
from .layer_scale import LayerScale, LayerScale2d
3233
from .linear import Linear
3334
from .mixed_conv2d import MixedConv2d
3435
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
@@ -56,4 +57,5 @@
5657
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
5758
from .trace_utils import _assert, _float_to_int
5859
from .typing import LayerType, PadType
59-
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
60+
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
61+
init_weight_jax, init_weight_vit

timm/layers/classifier.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def forward(self, x, pre_logits: bool = False):
134134

135135

136136
class NormMlpClassifierHead(nn.Module):
137-
137+
""" A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
138+
"""
138139
def __init__(
139140
self,
140141
in_features: int,
@@ -204,3 +205,79 @@ def forward(self, x, pre_logits: bool = False):
204205
return x
205206
x = self.fc(x)
206207
return x
208+
209+
210+
class ClNormMlpClassifierHead(nn.Module):
211+
""" A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
212+
"""
213+
def __init__(
214+
self,
215+
in_features: int,
216+
num_classes: int,
217+
hidden_size: Optional[int] = None,
218+
pool_type: str = 'avg',
219+
drop_rate: float = 0.,
220+
norm_layer: Union[str, Callable] = 'layernorm',
221+
act_layer: Union[str, Callable] = 'gelu',
222+
input_fmt: str = 'NHWC',
223+
):
224+
"""
225+
Args:
226+
in_features: The number of input features.
227+
num_classes: The number of classes for the final classifier layer (output).
228+
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
229+
pool_type: Global pooling type, pooling disabled if empty string ('').
230+
drop_rate: Pre-classifier dropout rate.
231+
norm_layer: Normalization layer type.
232+
act_layer: MLP activation layer type (only used if hidden_size is not None).
233+
"""
234+
super().__init__()
235+
self.in_features = in_features
236+
self.hidden_size = hidden_size
237+
self.num_features = in_features
238+
assert pool_type in ('', 'avg', 'max', 'avgmax')
239+
self.pool_type = pool_type
240+
assert input_fmt in ('NHWC', 'NLC')
241+
self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
242+
norm_layer = get_norm_layer(norm_layer)
243+
act_layer = get_act_layer(act_layer)
244+
245+
self.norm = norm_layer(in_features)
246+
if hidden_size:
247+
self.pre_logits = nn.Sequential(OrderedDict([
248+
('fc', nn.Linear(in_features, hidden_size)),
249+
('act', act_layer()),
250+
]))
251+
self.num_features = hidden_size
252+
else:
253+
self.pre_logits = nn.Identity()
254+
self.drop = nn.Dropout(drop_rate)
255+
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
256+
257+
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
258+
if pool_type is not None:
259+
self.pool_type = pool_type
260+
if reset_other:
261+
self.pre_logits = nn.Identity()
262+
self.norm = nn.Identity()
263+
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
264+
265+
def _global_pool(self, x):
266+
if self.pool_type:
267+
if self.pool_type == 'avg':
268+
x = x.mean(dim=self.pool_dim)
269+
elif self.pool_type == 'max':
270+
x = x.amax(dim=self.pool_dim)
271+
elif self.pool_type == 'avgmax':
272+
x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
273+
return x
274+
275+
def forward(self, x, pre_logits: bool = False):
276+
x = self._global_pool(x)
277+
x = self.norm(x)
278+
x = self.pre_logits(x)
279+
x = self.drop(x)
280+
if pre_logits:
281+
return x
282+
x = self.fc(x)
283+
return x

timm/layers/create_act.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
9797
return None
9898
if isinstance(name, Callable):
9999
return name
100+
name = name.lower()
100101
if not (is_exportable() or is_scriptable()):
101102
# If not exporting or scripting the model, first look for a memory-efficient version with
102103
# custom autograd, then fallback
@@ -117,6 +118,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
117118
return name
118119
if not name:
119120
return None
121+
name = name.lower()
120122
if not (is_exportable() or is_scriptable()):
121123
if name in _ACT_LAYER_ME:
122124
return _ACT_LAYER_ME[name]

timm/layers/layer_scale.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class LayerScale(nn.Module):
6+
""" LayerScale on tensors with channels in last-dim.
7+
"""
8+
def __init__(
9+
self,
10+
dim: int,
11+
init_values: float = 1e-5,
12+
inplace: bool = False,
13+
) -> None:
14+
super().__init__()
15+
self.inplace = inplace
16+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
17+
18+
def forward(self, x: torch.Tensor) -> torch.Tensor:
19+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
20+
21+
22+
class LayerScale2d(nn.Module):
23+
""" LayerScale for tensors with torch 2D NCHW layout.
24+
"""
25+
def __init__(
26+
self,
27+
dim: int,
28+
init_values: float = 1e-5,
29+
inplace: bool = False,
30+
):
31+
super().__init__()
32+
self.inplace = inplace
33+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
34+
35+
def forward(self, x):
36+
gamma = self.gamma.view(1, -1, 1, 1)
37+
return x.mul_(gamma) if self.inplace else x * gamma
38+

timm/layers/weight_init.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import math
33
import warnings
4-
4+
from torch import nn
55
from torch.nn.init import _calculate_fan_in_and_fan_out
66

77

@@ -123,3 +123,45 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
123123

124124
def lecun_normal_(tensor):
125125
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
126+
127+
128+
def init_weight_vit(
129+
module: nn.Module,
130+
name: str,
131+
init_bias: float = 0.02,
132+
head_bias: float = 0.,
133+
classifier_name: str = 'head'
134+
):
135+
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
136+
if name.startswith(classifier_name):
137+
nn.init.zeros_(module.weight)
138+
nn.init.constant_(module.bias, head_bias)
139+
else:
140+
nn.init.trunc_normal_(module.weight, std=0.02)
141+
if isinstance(module, nn.Linear) and module.bias is not None:
142+
nn.init.constant_(module.bias, init_bias)
143+
elif hasattr(module, 'init_weights'):
144+
module.init_weights()
145+
146+
147+
def init_weight_jax(
148+
module: nn.Module,
149+
name: str,
150+
head_bias: float = 0.,
151+
classifier_name: str = 'head',
152+
):
153+
if isinstance(module, nn.Linear):
154+
if name.startswith(classifier_name):
155+
nn.init.zeros_(module.weight)
156+
nn.init.constant_(module.bias, head_bias)
157+
else:
158+
nn.init.xavier_uniform_(module.weight)
159+
if module.bias is not None:
160+
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
161+
elif isinstance(module, nn.Conv2d):
162+
lecun_normal_(module.weight)
163+
if module.bias is not None:
164+
nn.init.zeros_(module.bias)
165+
elif hasattr(module, 'init_weights'):
166+
module.init_weights()
167+

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .hardcorenas import *
2828
from .hgnet import *
2929
from .hiera import *
30+
from .hieradet_sam2 import *
3031
from .hrnet import *
3132
from .inception_next import *
3233
from .inception_resnet_v2 import *

timm/models/efficientnet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,8 +1290,12 @@ def _cfg(url='', **kwargs):
12901290
'efficientnet_b0.ra4_e3600_r224_in1k': _cfg(
12911291
hf_hub_id='timm/',
12921292
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
1293-
crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0
1294-
),
1293+
crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0),
1294+
'efficientnet_b1.ra4_e3600_r240_in1k': _cfg(
1295+
hf_hub_id='timm/',
1296+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
1297+
input_size=(3, 240, 240), crop_pct=0.9, pool_size=(8, 8),
1298+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
12951299
'efficientnet_b1.ft_in1k': _cfg(
12961300
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
12971301
hf_hub_id='timm/',

0 commit comments

Comments
 (0)