11import math
2+ from copy import deepcopy
23from functools import partial
34from typing import Callable , Dict , List , Optional , Tuple , Union
45
89from torch .jit import Final
910
1011from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
11- from timm .layers import PatchEmbed , Mlp , DropPath , ClNormMlpClassifierHead , PatchDropout , \
12+ from timm .layers import PatchEmbed , Mlp , DropPath , ClNormMlpClassifierHead , LayerScale , \
1213 get_norm_layer , get_act_layer , init_weight_jax , init_weight_vit , to_2tuple , use_fused_attn
1314
1415from ._builder import build_model_with_cfg
@@ -121,11 +122,12 @@ def __init__(
121122 dim_out : int ,
122123 num_heads : int ,
123124 mlp_ratio : float = 4.0 ,
124- drop_path : float = 0.0 ,
125125 q_stride : Optional [Tuple [int , int ]] = None ,
126126 norm_layer : Union [nn .Module , str ] = "LayerNorm" ,
127127 act_layer : Union [nn .Module , str ] = "GELU" ,
128128 window_size : int = 0 ,
129+ init_values : Optional [float ] = None ,
130+ drop_path : float = 0.0 ,
129131 ):
130132 super ().__init__ ()
131133 norm_layer = get_norm_layer (norm_layer )
@@ -135,43 +137,38 @@ def __init__(
135137 self .dim = dim
136138 self .dim_out = dim_out
137139 self .q_stride = q_stride
140+
141+ if dim != dim_out :
142+ self .proj = nn .Linear (dim , dim_out )
143+ else :
144+ self .proj = nn .Identity ()
145+ self .pool = None
138146 if self .q_stride :
139- q_pool = nn .MaxPool2d (
147+ # note make a different instance for this Module so that it's not shared with attn module
148+ self .pool = nn .MaxPool2d (
140149 kernel_size = q_stride ,
141150 stride = q_stride ,
142151 ceil_mode = False ,
143152 )
144- else :
145- q_pool = None
146153
147154 self .norm1 = norm_layer (dim )
148155 self .attn = MultiScaleAttention (
149156 dim ,
150157 dim_out ,
151158 num_heads = num_heads ,
152- q_pool = q_pool ,
159+ q_pool = deepcopy ( self . pool ) ,
153160 )
154- self .drop_path = DropPath (drop_path ) if drop_path > 0.0 else nn .Identity ()
161+ self .ls1 = LayerScale (dim_out , init_values ) if init_values is not None else nn .Identity ()
162+ self .drop_path1 = DropPath (drop_path ) if drop_path > 0.0 else nn .Identity ()
155163
156164 self .norm2 = norm_layer (dim_out )
157165 self .mlp = Mlp (
158166 dim_out ,
159167 int (dim_out * mlp_ratio ),
160168 act_layer = act_layer ,
161169 )
162-
163- if dim != dim_out :
164- self .proj = nn .Linear (dim , dim_out )
165- else :
166- self .proj = nn .Identity ()
167- self .pool = None
168- if self .q_stride :
169- # note make a different instance for this Module so that it's not shared with attn module
170- self .pool = nn .MaxPool2d (
171- kernel_size = q_stride ,
172- stride = q_stride ,
173- ceil_mode = False ,
174- )
170+ self .ls2 = LayerScale (dim_out , init_values ) if init_values is not None else nn .Identity ()
171+ self .drop_path2 = DropPath (drop_path ) if drop_path > 0.0 else nn .Identity ()
175172
176173 def forward (self , x : torch .Tensor ) -> torch .Tensor :
177174 shortcut = x # B, H, W, C
@@ -206,9 +203,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
206203 x = window_unpartition (x , window_size , (Hp , Wp ))
207204 x = x [:, :H , :W , :].contiguous () # unpad
208205
209- x = shortcut + self .drop_path (x )
210-
211- x = x + self .drop_path (self .mlp (self .norm2 (x )))
206+ x = shortcut + self .drop_path1 (self .ls1 (x ))
207+ x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
212208 return x
213209
214210
@@ -280,6 +276,7 @@ def __init__(
280276 16 ,
281277 20 ,
282278 ),
279+ init_values : Optional [float ] = None ,
283280 weight_init : str = '' ,
284281 fix_init : bool = True ,
285282 head_init_scale : float = 0.001 ,
@@ -628,7 +625,7 @@ def sam2_hiera_large(pretrained=False, **kwargs):
628625
629626@register_model
630627def hieradet_small (pretrained = False , ** kwargs ):
631- model_args = dict (stages = (1 , 2 , 11 , 2 ), global_att_blocks = (7 , 10 , 13 ), window_spec = (8 , 4 , 16 , 8 ))
628+ model_args = dict (stages = (1 , 2 , 11 , 2 ), global_att_blocks = (7 , 10 , 13 ), window_spec = (8 , 4 , 16 , 8 ), init_values = 1e-5 )
632629 return _create_hiera_det ('hieradet_small' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
633630
634631
0 commit comments