@@ -47,7 +47,7 @@ def conv_nd(n: int) -> Type[nn.Module]:
4747 return [nn .Identity , nn .Conv1d , nn .Conv2d , nn .Conv3d ][n ]
4848
4949
50- def get_resized_mask (target_size : torch . Size , mask : torch .Tensor ) -> torch .Tensor :
50+ def get_resized_mask (target_size : List [ int ] , mask : torch .Tensor ) -> torch .Tensor :
5151 # target_size: [(T), (H), W]
5252 # (spatial) mask: [B, C, (t), (h), w]
5353 if mask is None :
@@ -59,23 +59,6 @@ def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tenso
5959 return mask
6060
6161
62- def do_masked_conv (
63- x : torch .Tensor ,
64- conv : nn .Module ,
65- mask : Optional [torch .Tensor ] = None ,
66- ) -> torch .Tensor :
67- """Zero-out the masked regions of the input before conv.
68- Prevents leakage of masked regions when using overlapping kernels.
69- """
70- if conv is None :
71- return x
72- if mask is None :
73- return conv (x )
74-
75- mask = get_resized_mask (target_size = x .shape [2 :], mask = mask )
76- return conv (x * mask .bool ())
77-
78-
7962def undo_windowing (
8063 x : torch .Tensor ,
8164 shape : List [int ],
@@ -145,7 +128,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
145128 Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
146129 """
147130 B , _ , C = x .shape
148-
149131 cur_size = self .size
150132 x = x .view (* ([B ] + cur_size + [C ]))
151133
@@ -332,6 +314,7 @@ def __init__(
332314 act_layer : nn .Module = nn .GELU ,
333315 q_stride : int = 1 ,
334316 window_size : int = 0 ,
317+ use_expand_proj : bool = True ,
335318 use_mask_unit_attn : bool = False ,
336319 ):
337320 super ().__init__ ()
@@ -341,8 +324,14 @@ def __init__(
341324
342325 self .norm1 = norm_layer (dim )
343326 if dim != dim_out :
344- self .proj = nn .Linear (dim , dim_out )
327+ self .do_expand = True
328+ if use_expand_proj :
329+ self .proj = nn .Linear (dim , dim_out )
330+ else :
331+ assert dim_out == dim * 2
332+ self .proj = None
345333 else :
334+ self .do_expand = False
346335 self .proj = None
347336 self .attn = MaskUnitAttention (
348337 dim ,
@@ -362,9 +351,17 @@ def __init__(
362351 def forward (self , x : torch .Tensor ) -> torch .Tensor :
363352 # Attention + Q Pooling
364353 x_norm = self .norm1 (x )
365- if self .proj is not None :
366- x = self .proj (x_norm )
367- x = x .view (x .shape [0 ], self .attn .q_stride , - 1 , x .shape [- 1 ]).amax (dim = 1 ) # max-pool
354+ if self .do_expand :
355+ if self .proj is not None :
356+ x = self .proj (x_norm )
357+ x = x .view (x .shape [0 ], self .attn .q_stride , - 1 , x .shape [- 1 ]).amax (dim = 1 ) # max-pool
358+ else :
359+ x = torch .cat ([
360+ x .view (x .shape [0 ], self .attn .q_stride , - 1 , x .shape [- 1 ]).amax (dim = 1 ), # max-pool
361+ x .view (x .shape [0 ], self .attn .q_stride , - 1 , x .shape [- 1 ]).mean (dim = 1 ), # avg-pool
362+ ],
363+ dim = - 1 ,
364+ )
368365 x = x + self .drop_path1 (self .attn (x_norm ))
369366
370367 # MLP
@@ -419,7 +416,11 @@ def forward(
419416 x : torch .Tensor ,
420417 mask : Optional [torch .Tensor ] = None ,
421418 ) -> torch .Tensor :
422- x = do_masked_conv (x , self .proj , mask )
419+ if mask is not None :
420+ mask = get_resized_mask (target_size = x .shape [2 :], mask = mask )
421+ x = self .proj (x * mask .to (torch .bool ))
422+ else :
423+ x = self .proj (x )
423424 if self .reshape :
424425 x = x .reshape (x .shape [0 ], x .shape [1 ], - 1 ).transpose (2 , 1 )
425426 return x
@@ -570,10 +571,10 @@ def _init_weights(self, m, init_bias=0.02):
570571
571572 @torch .jit .ignore
572573 def no_weight_decay (self ):
573- if self .sep_pos_embed :
574- return ["pos_embed_spatial" , "pos_embed_temporal" ]
575- else :
574+ if self .pos_embed is not None :
576575 return ["pos_embed" ]
576+ else :
577+ return ["pos_embed_spatial" , "pos_embed_temporal" ]
577578
578579 def get_random_mask (self , x : torch .Tensor , mask_ratio : float ) -> torch .Tensor :
579580 """
0 commit comments