Skip to content

Commit 170a5b6

Browse files
seefunrwightman
authored andcommitted
add tinyvit
1 parent 983310d commit 170a5b6

File tree

5 files changed

+782
-2
lines changed

5 files changed

+782
-2
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .patch_dropout import PatchDropout
3737
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
3838
from .pool2d_same import AvgPool2dSame, create_pool2d
39-
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
39+
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, resample_relative_position_bias_table
4040
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
4141
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple
4242
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \

timm/layers/pos_embed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,38 @@ def resample_abs_pos_embed_nhwc(
7878
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
7979

8080
return posemb
81+
82+
83+
def resample_relative_position_bias_table(
84+
position_bias_table,
85+
new_size,
86+
interpolation: str = 'bicubic',
87+
antialias: bool = True,
88+
verbose: bool = False
89+
):
90+
"""
91+
Resample relative position bias table suggested in LeVit
92+
Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
93+
"""
94+
L1, nH1 = position_bias_table.size()
95+
L2, nH2 = new_size
96+
assert nH1 == nH2
97+
if L1 != L2:
98+
orig_dtype = position_bias_table.dtype
99+
position_bias_table = position_bias_table.float()
100+
# bicubic interpolate relative_position_bias_table if not match
101+
S1 = int(L1 ** 0.5)
102+
S2 = int(L2 ** 0.5)
103+
relative_position_bias_table_resized = F.interpolate(
104+
position_bias_table.permute(1, 0).view(1, nH1, S1, S1),
105+
size=(S2, S2),
106+
mode=interpolation,
107+
antialias=antialias)
108+
relative_position_bias_table_resized = \
109+
relative_position_bias_table_resized.view(nH2, L2).permute(1, 0)
110+
relative_position_bias_table_resized.to(orig_dtype)
111+
if not torch.jit.is_scripting() and verbose:
112+
_logger.info(f'Resized position bias: {L1, nH1} to {L2, nH2}.')
113+
return relative_position_bias_table_resized
114+
else:
115+
return position_bias_table

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from .swin_transformer import *
5959
from .swin_transformer_v2 import *
6060
from .swin_transformer_v2_cr import *
61+
from .tiny_vit import *
6162
from .tnt import *
6263
from .tresnet import *
6364
from .twins import *

timm/models/efficientvit_msra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn as nn
1616

1717
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18-
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
18+
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
1919
from ._builder import build_model_with_cfg
2020
from ._manipulate import checkpoint_seq
2121
from ._registry import register_model, generate_default_cfgs

0 commit comments

Comments
 (0)