Skip to content

Commit 57f8554

Browse files
committed
support gradient checkpoint in forward_intermediates
1 parent a0a30a6 commit 57f8554

25 files changed

+117
-45
lines changed

timm/models/beit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,10 @@ def forward_intermediates(
451451
else:
452452
blocks = self.blocks[:max_index + 1]
453453
for i, blk in enumerate(blocks):
454-
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
454+
if self.grad_checkpointing and not torch.jit.is_scripting():
455+
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
456+
else:
457+
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
455458
if i in take_indices:
456459
# normalize intermediates with final norm layer if enabled
457460
intermediates.append(self.norm(x) if norm else x)

timm/models/byobnet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from ._builder import build_model_with_cfg
4646
from ._features import feature_take_indices
47-
from ._manipulate import named_apply, checkpoint_seq
47+
from ._manipulate import checkpoint, checkpoint_seq, named_apply
4848
from ._registry import generate_default_cfgs, register_model
4949

5050
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
@@ -1384,7 +1384,10 @@ def forward_intermediates(
13841384
stages = self.stages[:max_index]
13851385
for stage in stages:
13861386
feat_idx += 1
1387-
x = stage(x)
1387+
if self.grad_checkpointing and not torch.jit.is_scripting():
1388+
x = checkpoint(stage, x)
1389+
else:
1390+
x = stage(x)
13881391
if not exclude_final_conv and feat_idx == last_idx:
13891392
# default feature_info for this model uses final_conv as the last feature output (if present)
13901393
x = self.final_conv(x)

timm/models/cait.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
1919
from ._builder import build_model_with_cfg
2020
from ._features import feature_take_indices
21-
from ._manipulate import checkpoint_seq
21+
from ._manipulate import checkpoint, checkpoint_seq
2222
from ._registry import register_model, generate_default_cfgs
2323

2424
__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
@@ -373,7 +373,10 @@ def forward_intermediates(
373373
else:
374374
blocks = self.blocks[:max_index + 1]
375375
for i, blk in enumerate(blocks):
376-
x = blk(x)
376+
if self.grad_checkpointing and not torch.jit.is_scripting():
377+
x = checkpoint(blk, x)
378+
else:
379+
x = blk(x)
377380
if i in take_indices:
378381
# normalize intermediates with final norm layer if enabled
379382
intermediates.append(self.norm(x) if norm else x)

timm/models/davit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ._builder import build_model_with_cfg
2626
from ._features import feature_take_indices
2727
from ._features_fx import register_notrace_function
28-
from ._manipulate import checkpoint_seq
28+
from ._manipulate import checkpoint, checkpoint_seq
2929
from ._registry import generate_default_cfgs, register_model
3030

3131
__all__ = ['DaVit']
@@ -671,7 +671,10 @@ def forward_intermediates(
671671
stages = self.stages[:max_index + 1]
672672

673673
for feat_idx, stage in enumerate(stages):
674-
x = stage(x)
674+
if self.grad_checkpointing and not torch.jit.is_scripting():
675+
x = checkpoint(stage, x)
676+
else:
677+
x = stage(x)
675678
if feat_idx in take_indices:
676679
if norm and feat_idx == last_idx:
677680
x_inter = self.norm_pre(x) # applying final norm to last intermediate

timm/models/efficientnet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,11 @@ def forward_intermediates(
210210
blocks = self.blocks
211211
else:
212212
blocks = self.blocks[:max_index]
213-
for blk in blocks:
214-
feat_idx += 1
215-
x = blk(x)
213+
for feat_idx, blk in enumerate(blocks, start=1):
214+
if self.grad_checkpointing and not torch.jit.is_scripting():
215+
x = checkpoint(blk, x)
216+
else:
217+
x = blk(x)
216218
if feat_idx in take_indices:
217219
intermediates.append(x)
218220

timm/models/efficientvit_mit.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ._builder import build_model_with_cfg
2020
from ._features import feature_take_indices
2121
from ._features_fx import register_notrace_module
22-
from ._manipulate import checkpoint_seq
22+
from ._manipulate import checkpoint, checkpoint_seq
2323
from ._registry import register_model, generate_default_cfgs
2424

2525

@@ -789,7 +789,10 @@ def forward_intermediates(
789789
stages = self.stages[:max_index + 1]
790790

791791
for feat_idx, stage in enumerate(stages):
792-
x = stage(x)
792+
if self.grad_checkpointing and not torch.jit.is_scripting():
793+
x = checkpoint(stage, x)
794+
else:
795+
x = stage(x)
793796
if feat_idx in take_indices:
794797
intermediates.append(x)
795798

@@ -943,7 +946,10 @@ def forward_intermediates(
943946
stages = self.stages[:max_index + 1]
944947

945948
for feat_idx, stage in enumerate(stages):
946-
x = stage(x)
949+
if self.grad_checkpointing and not torch.jit.is_scripting():
950+
x = checkpoint(stage, x)
951+
else:
952+
x = stage(x)
947953
if feat_idx in take_indices:
948954
intermediates.append(x)
949955

timm/models/efficientvit_msra.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
1919
from ._builder import build_model_with_cfg
2020
from ._features import feature_take_indices
21-
from ._manipulate import checkpoint_seq
21+
from ._manipulate import checkpoint, checkpoint_seq
2222
from ._registry import register_model, generate_default_cfgs
2323

2424

@@ -510,7 +510,10 @@ def forward_intermediates(
510510
stages = self.stages[:max_index + 1]
511511

512512
for feat_idx, stage in enumerate(stages):
513-
x = stage(x)
513+
if self.grad_checkpointing and not torch.jit.is_scripting():
514+
x = checkpoint(stage, x)
515+
else:
516+
x = stage(x)
514517
if feat_idx in take_indices:
515518
intermediates.append(x)
516519

timm/models/eva.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,10 @@ def forward_intermediates(
716716
else:
717717
blocks = self.blocks[:max_index + 1]
718718
for i, blk in enumerate(blocks):
719-
x = blk(x, rope=rot_pos_embed)
719+
if self.grad_checkpointing and not torch.jit.is_scripting():
720+
x = checkpoint(blk, x, rope=rot_pos_embed)
721+
else:
722+
x = blk(x, rope=rot_pos_embed)
720723
if i in take_indices:
721724
intermediates.append(self.norm(x) if norm else x)
722725

timm/models/hiera.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# --------------------------------------------------------
2525
import math
2626
from functools import partial
27-
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
27+
from typing import Dict, List, Optional, Tuple, Type, Union
2828

2929
import torch
3030
import torch.nn as nn
@@ -719,7 +719,10 @@ def forward_intermediates(
719719
else:
720720
blocks = self.blocks[:max_index + 1]
721721
for i, blk in enumerate(blocks):
722-
x = blk(x)
722+
if self.grad_checkpointing and not torch.jit.is_scripting():
723+
x = checkpoint(blk, x)
724+
else:
725+
x = blk(x)
723726
if i in take_indices:
724727
x_int = self.reroll(x, i, mask=mask)
725728
intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)

timm/models/levit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
3535
from ._builder import build_model_with_cfg
3636
from ._features import feature_take_indices
37-
from ._manipulate import checkpoint_seq
37+
from ._manipulate import checkpoint, checkpoint_seq
3838
from ._registry import generate_default_cfgs, register_model
3939

4040
__all__ = ['Levit']
@@ -671,7 +671,10 @@ def forward_intermediates(
671671
else:
672672
stages = self.stages[:max_index + 1]
673673
for feat_idx, stage in enumerate(stages):
674-
x = stage(x)
674+
if self.grad_checkpointing and not torch.jit.is_scripting():
675+
x = checkpoint(stage, x)
676+
else:
677+
x = stage(x)
675678
if feat_idx in take_indices:
676679
if self.use_conv:
677680
intermediates.append(x)

0 commit comments

Comments
 (0)