Skip to content

Commit 9fcc019

Browse files
authored
Merge pull request #1812 from seefun/master
add ViT for Segment-Anything Model
2 parents 960202c + e9373b1 commit 9fcc019

File tree

6 files changed

+648
-4
lines changed

6 files changed

+648
-4
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
4242
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
4343
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
44-
'eva_*', 'flexivit*', 'eva02*'
44+
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
4545
]
4646
NUM_NON_STD = len(NON_STD_FILTERS)
4747

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .patch_dropout import PatchDropout
3737
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
3838
from .pool2d_same import AvgPool2dSame, create_pool2d
39-
from .pos_embed import resample_abs_pos_embed
39+
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
4040
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
4141
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
4242
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \

timm/layers/patch_embed.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
flatten: bool = True,
3838
output_fmt: Optional[str] = None,
3939
bias: bool = True,
40+
strict_img_size: bool = True,
4041
):
4142
super().__init__()
4243
self.patch_size = to_2tuple(patch_size)
@@ -56,15 +57,26 @@ def __init__(
5657
# flatten spatial dim and transpose to channels last, kept for bwd compat
5758
self.flatten = flatten
5859
self.output_fmt = Format.NCHW
60+
self.strict_img_size = strict_img_size
5961

6062
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
6163
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
6264

6365
def forward(self, x):
6466
B, C, H, W = x.shape
6567
if self.img_size is not None:
66-
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
67-
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
68+
if self.strict_img_size:
69+
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
70+
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
71+
else:
72+
_assert(
73+
H % self.patch_size[0] == 0,
74+
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
75+
)
76+
_assert(
77+
W % self.patch_size[1] == 0,
78+
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
79+
)
6880

6981
x = self.proj(x)
7082
if self.flatten:

timm/layers/pos_embed.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,24 @@ def resample_abs_pos_embed(
5252
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
5353

5454
return posemb
55+
56+
57+
def resample_abs_pos_embed_nhwc(
58+
posemb,
59+
new_size: List[int],
60+
interpolation: str = 'bicubic',
61+
antialias: bool = True,
62+
verbose: bool = False,
63+
):
64+
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
65+
return posemb
66+
67+
# do the interpolation
68+
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
69+
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
70+
posemb = posemb.permute(0, 2, 3, 1)
71+
72+
if not torch.jit.is_scripting() and verbose:
73+
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
74+
75+
return posemb

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .vision_transformer import *
6161
from .vision_transformer_hybrid import *
6262
from .vision_transformer_relpos import *
63+
from .vision_transformer_sam import *
6364
from .volo import *
6465
from .vovnet import *
6566
from .xception import *

0 commit comments

Comments
 (0)