Skip to content

Commit f489f02

Browse files
committed
Make gcvit window size ratio based to improve resolution changing support #1449. Change default init to original.
1 parent c45c6ee commit f489f02

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

timm/models/gcvit.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
3131
from .fx_features import register_notrace_function
3232
from .helpers import build_model_with_cfg, named_apply
33-
from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \
34-
ClassifierHead, LayerNorm2d, _assert
33+
from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\
34+
get_attn, get_act_layer, get_norm_layer, _assert
3535
from .registry import register_model
3636
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
3737

@@ -321,7 +321,7 @@ def __init__(
321321
depth: int,
322322
num_heads: int,
323323
feat_size: Tuple[int, int],
324-
window_size: int,
324+
window_size: Tuple[int, int],
325325
downsample: bool = True,
326326
global_norm: bool = False,
327327
stage_norm: bool = False,
@@ -347,8 +347,9 @@ def __init__(
347347
else:
348348
self.downsample = nn.Identity()
349349
self.feat_size = feat_size
350+
window_size = to_2tuple(window_size)
350351

351-
feat_levels = int(math.log2(min(feat_size) / window_size))
352+
feat_levels = int(math.log2(min(feat_size) / min(window_size)))
352353
self.global_block = FeatureBlock(dim, feat_levels)
353354
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
354355

@@ -400,7 +401,8 @@ def __init__(
400401
num_classes: int = 1000,
401402
global_pool: str = 'avg',
402403
img_size: Tuple[int, int] = 224,
403-
window_size: Tuple[int, ...] = (7, 7, 14, 7),
404+
window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
405+
window_size: Tuple[int, ...] = None,
404406
embed_dim: int = 64,
405407
depths: Tuple[int, ...] = (3, 4, 19, 5),
406408
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
@@ -411,7 +413,7 @@ def __init__(
411413
proj_drop_rate: float = 0.,
412414
attn_drop_rate: float = 0.,
413415
drop_path_rate: float = 0.,
414-
weight_init='vit',
416+
weight_init='',
415417
act_layer: str = 'gelu',
416418
norm_layer: str = 'layernorm2d',
417419
norm_layer_cl: str = 'layernorm',
@@ -429,6 +431,11 @@ def __init__(
429431
self.drop_rate = drop_rate
430432
num_stages = len(depths)
431433
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
434+
if window_size is not None:
435+
window_size = to_ntuple(num_stages)(window_size)
436+
else:
437+
assert window_ratio is not None
438+
window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
432439

433440
self.stem = Stem(
434441
in_chs=in_chans,
@@ -480,7 +487,7 @@ def _init_weights(self, module, name, scheme='vit'):
480487
nn.init.zeros_(module.bias)
481488
else:
482489
if isinstance(module, nn.Linear):
483-
trunc_normal_tf_(module.weight, std=.02)
490+
nn.init.normal_(module.weight, std=.02)
484491
if module.bias is not None:
485492
nn.init.zeros_(module.bias)
486493

@@ -490,7 +497,6 @@ def no_weight_decay(self):
490497
k for k, _ in self.named_parameters()
491498
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
492499

493-
494500
@torch.jit.ignore
495501
def group_matcher(self, coarse=False):
496502
matcher = dict(
@@ -567,7 +573,6 @@ def gcvit_small(pretrained=False, **kwargs):
567573
model_kwargs = dict(
568574
depths=(3, 4, 19, 5),
569575
num_heads=(3, 6, 12, 24),
570-
window_size=(7, 7, 14, 7),
571576
embed_dim=96,
572577
mlp_ratio=2,
573578
layer_scale=1e-5,
@@ -580,7 +585,6 @@ def gcvit_base(pretrained=False, **kwargs):
580585
model_kwargs = dict(
581586
depths=(3, 4, 19, 5),
582587
num_heads=(4, 8, 16, 32),
583-
window_size=(7, 7, 14, 7),
584588
embed_dim=128,
585589
mlp_ratio=2,
586590
layer_scale=1e-5,

0 commit comments

Comments
 (0)