Skip to content

Commit 95ec255

Browse files
committed
Finish timm mode api for efficientformer_v2, add grad checkpointing support to both efficientformers
1 parent 9d03c6f commit 95ec255

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

timm/models/efficientformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2121
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
2222
from ._builder import build_model_with_cfg
23+
from ._manipulate import checkpoint_seq
2324
from ._pretrained import generate_default_cfgs
2425
from ._registry import register_model
2526

@@ -335,7 +336,10 @@ def __init__(
335336

336337
def forward(self, x):
337338
x = self.downsample(x)
338-
x = self.blocks(x)
339+
if self.grad_checkpointing:
340+
x = checkpoint_seq(self.blocks, x)
341+
else:
342+
x = self.blocks(x)
339343
return x
340344

341345

timm/models/efficientformer_v2.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
2626
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
2727
from ._builder import build_model_with_cfg
28+
from ._manipulate import checkpoint_seq
2829
from ._pretrained import generate_default_cfgs
2930
from ._registry import register_model
3031

@@ -498,7 +499,10 @@ def __init__(
498499

499500
def forward(self, x):
500501
x = self.downsample(x)
501-
x = self.blocks(x)
502+
if self.grad_checkpointing:
503+
x = checkpoint_seq(self.blocks, x)
504+
else:
505+
x = self.blocks(x)
502506
return x
503507

504508

@@ -508,6 +512,7 @@ def __init__(
508512
depths,
509513
in_chans=3,
510514
img_size=224,
515+
global_pool='avg',
511516
embed_dims=None,
512517
downsamples=None,
513518
mlp_ratios=4,
@@ -522,7 +527,9 @@ def __init__(
522527
distillation=True,
523528
):
524529
super().__init__()
530+
assert global_pool in ('avg', '')
525531
self.num_classes = num_classes
532+
self.global_pool = global_pool
526533
self.feature_info = []
527534
img_size = to_2tuple(img_size)
528535
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
@@ -583,11 +590,49 @@ def init_weights(self, m):
583590
if m.bias is not None:
584591
nn.init.constant_(m.bias, 0)
585592

586-
def forward(self, x):
593+
@torch.jit.ignore
594+
def no_weight_decay(self):
595+
return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
596+
597+
@torch.jit.ignore
598+
def group_matcher(self, coarse=False):
599+
matcher = dict(
600+
stem=r'^stem', # stem and embed
601+
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
602+
)
603+
return matcher
604+
605+
@torch.jit.ignore
606+
def set_grad_checkpointing(self, enable=True):
607+
for s in self.stages:
608+
s.grad_checkpointing = enable
609+
610+
@torch.jit.ignore
611+
def get_classifier(self):
612+
return self.head, self.head_dist
613+
614+
def reset_classifier(self, num_classes, global_pool=None):
615+
self.num_classes = num_classes
616+
if global_pool is not None:
617+
self.global_pool = global_pool
618+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
619+
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
620+
621+
@torch.jit.ignore
622+
def set_distilled_training(self, enable=True):
623+
self.distilled_training = enable
624+
625+
def forward_features(self, x):
587626
x = self.stem(x)
588627
x = self.stages(x)
589628
x = self.norm(x)
590-
x = x.mean(dim=(2, 3))
629+
return x
630+
631+
def forward_head(self, x, pre_logits: bool = False):
632+
if self.global_pool == 'avg':
633+
x = x.mean(dim=(2, 3))
634+
if pre_logits:
635+
return x
591636
x, x_dist = self.head(x), self.head_dist(x)
592637
if self.distilled_training and self.training and not torch.jit.is_scripting():
593638
# only return separate classification predictions when training in distilled mode
@@ -596,6 +641,11 @@ def forward(self, x):
596641
# during standard train/finetune, inference average the classifier predictions
597642
return (x + x_dist) / 2
598643

644+
def forward(self, x):
645+
x = self.forward_features(x)
646+
x = self.forward_head(x)
647+
return x
648+
599649

600650
def _cfg(url='', **kwargs):
601651
return {

0 commit comments

Comments
 (0)