Skip to content

Commit 17923a6

Browse files
committed
Add layer scale to hieradet
1 parent 47e6958 commit 17923a6

File tree

1 file changed

+21
-24
lines changed

1 file changed

+21
-24
lines changed

timm/models/hieradet_sam2.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from copy import deepcopy
23
from functools import partial
34
from typing import Callable, Dict, List, Optional, Tuple, Union
45

@@ -8,7 +9,7 @@
89
from torch.jit import Final
910

1011
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11-
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \
12+
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
1213
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
1314

1415
from ._builder import build_model_with_cfg
@@ -121,11 +122,12 @@ def __init__(
121122
dim_out: int,
122123
num_heads: int,
123124
mlp_ratio: float = 4.0,
124-
drop_path: float = 0.0,
125125
q_stride: Optional[Tuple[int, int]] = None,
126126
norm_layer: Union[nn.Module, str] = "LayerNorm",
127127
act_layer: Union[nn.Module, str] = "GELU",
128128
window_size: int = 0,
129+
init_values: Optional[float] = None,
130+
drop_path: float = 0.0,
129131
):
130132
super().__init__()
131133
norm_layer = get_norm_layer(norm_layer)
@@ -135,43 +137,38 @@ def __init__(
135137
self.dim = dim
136138
self.dim_out = dim_out
137139
self.q_stride = q_stride
140+
141+
if dim != dim_out:
142+
self.proj = nn.Linear(dim, dim_out)
143+
else:
144+
self.proj = nn.Identity()
145+
self.pool = None
138146
if self.q_stride:
139-
q_pool = nn.MaxPool2d(
147+
# note make a different instance for this Module so that it's not shared with attn module
148+
self.pool = nn.MaxPool2d(
140149
kernel_size=q_stride,
141150
stride=q_stride,
142151
ceil_mode=False,
143152
)
144-
else:
145-
q_pool = None
146153

147154
self.norm1 = norm_layer(dim)
148155
self.attn = MultiScaleAttention(
149156
dim,
150157
dim_out,
151158
num_heads=num_heads,
152-
q_pool=q_pool,
159+
q_pool=deepcopy(self.pool),
153160
)
154-
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
161+
self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
162+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
155163

156164
self.norm2 = norm_layer(dim_out)
157165
self.mlp = Mlp(
158166
dim_out,
159167
int(dim_out * mlp_ratio),
160168
act_layer=act_layer,
161169
)
162-
163-
if dim != dim_out:
164-
self.proj = nn.Linear(dim, dim_out)
165-
else:
166-
self.proj = nn.Identity()
167-
self.pool = None
168-
if self.q_stride:
169-
# note make a different instance for this Module so that it's not shared with attn module
170-
self.pool = nn.MaxPool2d(
171-
kernel_size=q_stride,
172-
stride=q_stride,
173-
ceil_mode=False,
174-
)
170+
self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity()
171+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
175172

176173
def forward(self, x: torch.Tensor) -> torch.Tensor:
177174
shortcut = x # B, H, W, C
@@ -206,9 +203,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
206203
x = window_unpartition(x, window_size, (Hp, Wp))
207204
x = x[:, :H, :W, :].contiguous() # unpad
208205

209-
x = shortcut + self.drop_path(x)
210-
211-
x = x + self.drop_path(self.mlp(self.norm2(x)))
206+
x = shortcut + self.drop_path1(self.ls1(x))
207+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
212208
return x
213209

214210

@@ -280,6 +276,7 @@ def __init__(
280276
16,
281277
20,
282278
),
279+
init_values: Optional[float] = None,
283280
weight_init: str = '',
284281
fix_init: bool = True,
285282
head_init_scale: float = 0.001,
@@ -628,7 +625,7 @@ def sam2_hiera_large(pretrained=False, **kwargs):
628625

629626
@register_model
630627
def hieradet_small(pretrained=False, **kwargs):
631-
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8))
628+
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
632629
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
633630

634631

0 commit comments

Comments
 (0)