@@ -71,18 +71,19 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals
7171 module .init_weights ()
7272
7373
74- def get_stage (index , layers , patch_sizes , embed_dims , hidden_sizes , mlp_ratios , block_layer , rnn_layer , mlp_layer ,
75- norm_layer , act_layer , num_layers , bidirectional , union ,
76- with_fc , drop = 0. , drop_path_rate = 0. , ** kwargs ):
74+ def get_stage (
75+ index , layers , patch_sizes , embed_dims , hidden_sizes , mlp_ratios , block_layer , rnn_layer , mlp_layer ,
76+ norm_layer , act_layer , num_layers , bidirectional , union ,
77+ with_fc , drop = 0. , drop_path_rate = 0. , ** kwargs ):
7778 assert len (layers ) == len (patch_sizes ) == len (embed_dims ) == len (hidden_sizes ) == len (mlp_ratios )
7879 blocks = []
7980 for block_idx in range (layers [index ]):
8081 drop_path = drop_path_rate * (block_idx + sum (layers [:index ])) / (sum (layers ) - 1 )
81- blocks .append (block_layer (embed_dims [ index ], hidden_sizes [ index ], mlp_ratio = mlp_ratios [ index ],
82- rnn_layer = rnn_layer , mlp_layer = mlp_layer , norm_layer = norm_layer ,
83- act_layer = act_layer , num_layers = num_layers ,
84- bidirectional = bidirectional , union = union , with_fc = with_fc ,
85- drop = drop , drop_path = drop_path ))
82+ blocks .append (block_layer (
83+ embed_dims [ index ], hidden_sizes [ index ], mlp_ratio = mlp_ratios [ index ] ,
84+ rnn_layer = rnn_layer , mlp_layer = mlp_layer , norm_layer = norm_layer , act_layer = act_layer ,
85+ num_layers = num_layers , bidirectional = bidirectional , union = union , with_fc = with_fc ,
86+ drop = drop , drop_path = drop_path ))
8687
8788 if index < len (embed_dims ) - 1 :
8889 blocks .append (Downsample2D (embed_dims [index ], embed_dims [index + 1 ], patch_sizes [index + 1 ]))
@@ -101,9 +102,10 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
101102
102103class RNN2DBase (nn .Module ):
103104
104- def __init__ (self , input_size : int , hidden_size : int ,
105- num_layers : int = 1 , bias : bool = True , bidirectional : bool = True ,
106- union = "cat" , with_fc = True ):
105+ def __init__ (
106+ self , input_size : int , hidden_size : int ,
107+ num_layers : int = 1 , bias : bool = True , bidirectional : bool = True ,
108+ union = "cat" , with_fc = True ):
107109 super ().__init__ ()
108110
109111 self .input_size = input_size
@@ -115,6 +117,7 @@ def __init__(self, input_size: int, hidden_size: int,
115117 self .with_horizontal = True
116118 self .with_fc = with_fc
117119
120+ self .fc = None
118121 if with_fc :
119122 if union == "cat" :
120123 self .fc = nn .Linear (2 * self .output_size , input_size )
@@ -159,33 +162,38 @@ def forward(self, x):
159162 v , _ = self .rnn_v (v )
160163 v = v .reshape (B , W , H , - 1 )
161164 v = v .permute (0 , 2 , 1 , 3 )
165+ else :
166+ v = None
162167
163168 if self .with_horizontal :
164169 h = x .reshape (- 1 , W , C )
165170 h , _ = self .rnn_h (h )
166171 h = h .reshape (B , H , W , - 1 )
172+ else :
173+ h = None
167174
168- if self . with_vertical and self . with_horizontal :
175+ if v is not None and h is not None :
169176 if self .union == "cat" :
170177 x = torch .cat ([v , h ], dim = - 1 )
171178 else :
172179 x = v + h
173- elif self . with_vertical :
180+ elif v is not None :
174181 x = v
175- elif self . with_horizontal :
182+ elif h is not None :
176183 x = h
177184
178- if self .with_fc :
185+ if self .fc is not None :
179186 x = self .fc (x )
180187
181188 return x
182189
183190
184191class LSTM2D (RNN2DBase ):
185192
186- def __init__ (self , input_size : int , hidden_size : int ,
187- num_layers : int = 1 , bias : bool = True , bidirectional : bool = True ,
188- union = "cat" , with_fc = True ):
193+ def __init__ (
194+ self , input_size : int , hidden_size : int ,
195+ num_layers : int = 1 , bias : bool = True , bidirectional : bool = True ,
196+ union = "cat" , with_fc = True ):
189197 super ().__init__ (input_size , hidden_size , num_layers , bias , bidirectional , union , with_fc )
190198 if self .with_vertical :
191199 self .rnn_v = nn .LSTM (input_size , hidden_size , num_layers , batch_first = True , bias = bias , bidirectional = bidirectional )
@@ -194,10 +202,10 @@ def __init__(self, input_size: int, hidden_size: int,
194202
195203
196204class Sequencer2DBlock (nn .Module ):
197- def __init__ (self , dim , hidden_size , mlp_ratio = 3.0 , rnn_layer = LSTM2D , mlp_layer = Mlp ,
198- norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ), act_layer = nn . GELU ,
199- num_layers = 1 , bidirectional = True , union = "cat" , with_fc = True ,
200- drop = 0. , drop_path = 0. ):
205+ def __init__ (
206+ self , dim , hidden_size , mlp_ratio = 3.0 , rnn_layer = LSTM2D , mlp_layer = Mlp ,
207+ norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ), act_layer = nn . GELU ,
208+ num_layers = 1 , bidirectional = True , union = "cat" , with_fc = True , drop = 0. , drop_path = 0. ):
201209 super ().__init__ ()
202210 channels_dim = int (mlp_ratio * dim )
203211 self .norm1 = norm_layer (dim )
@@ -255,6 +263,7 @@ def __init__(
255263 num_classes = 1000 ,
256264 img_size = 224 ,
257265 in_chans = 3 ,
266+ global_pool = 'avg' ,
258267 layers = [4 , 3 , 8 , 3 ],
259268 patch_sizes = [7 , 2 , 1 , 1 ],
260269 embed_dims = [192 , 384 , 384 , 384 ],
@@ -275,7 +284,9 @@ def __init__(
275284 stem_norm = False ,
276285 ):
277286 super ().__init__ ()
287+ assert global_pool in ('' , 'avg' )
278288 self .num_classes = num_classes
289+ self .global_pool = global_pool
279290 self .num_features = embed_dims [- 1 ] # num_features for consistency with other models
280291 self .embed_dims = embed_dims
281292 self .stem = PatchEmbed (
@@ -301,38 +312,54 @@ def init_weights(self, nlhb=False):
301312 head_bias = - math .log (self .num_classes ) if nlhb else 0.
302313 named_apply (partial (_init_weights , head_bias = head_bias ), module = self ) # depth-first
303314
315+ @torch .jit .ignore
316+ def group_matcher (self , coarse = False ):
317+ return dict (
318+ stem = r'^stem' ,
319+ blocks = [
320+ (r'^blocks\.(\d+)\..*\.down' , (99999 ,)),
321+ (r'^blocks\.(\d+)' , None ) if coarse else (r'^blocks\.(\d+)\.(\d+)' , None ),
322+ (r'^norm' , (99999 ,))
323+ ]
324+ )
325+
326+ @torch .jit .ignore
327+ def set_grad_checkpointing (self , enable = True ):
328+ assert not enable , 'gradient checkpointing not supported'
329+
330+ @torch .jit .ignore
304331 def get_classifier (self ):
305332 return self .head
306333
307- def reset_classifier (self , num_classes , global_pool = '' ):
334+ def reset_classifier (self , num_classes , global_pool = None ):
308335 self .num_classes = num_classes
336+ if self .global_pool is not None :
337+ assert global_pool in ('' , 'avg' )
338+ self .global_pool = global_pool
309339 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
310340
311341 def forward_features (self , x ):
312342 x = self .stem (x )
313343 x = self .blocks (x )
314344 x = self .norm (x )
315- x = x .mean (dim = (1 , 2 ))
316345 return x
317346
347+ def forward_head (self , x , pre_logits : bool = False ):
348+ if self .global_pool == 'avg' :
349+ x = x .mean (dim = (1 , 2 ))
350+ return x if pre_logits else self .head (x )
351+
318352 def forward (self , x ):
319353 x = self .forward_features (x )
320- x = self .head (x )
354+ x = self .forward_head (x )
321355 return x
322356
323357
324- def checkpoint_filter_fn (state_dict , model ):
325- return state_dict
326-
327-
328358def _create_sequencer2d (variant , pretrained = False , ** kwargs ):
329359 if kwargs .get ('features_only' , None ):
330360 raise RuntimeError ('features_only not implemented for Sequencer2D models.' )
331361
332- model = build_model_with_cfg (
333- Sequencer2D , variant , pretrained ,
334- pretrained_filter_fn = checkpoint_filter_fn ,
335- ** kwargs )
362+ model = build_model_with_cfg (Sequencer2D , variant , pretrained , ** kwargs )
336363 return model
337364
338365
0 commit comments