@@ -40,18 +40,16 @@ def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
4040
4141 blur_matrix = (np .poly1d ((0.5 , 0.5 )) ** (blur_filter_size - 1 )).coeffs
4242 blur_filter = torch .Tensor (blur_matrix [:, None ] * blur_matrix [None , :])
43- # FIXME figure a clean hack to prevent the filter from getting saved in weights, but still
44- # plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW
45- self .register_buffer ('blur_filter' , blur_filter [None , None , :, :].repeat ((self .channels , 1 , 1 , 1 )))
43+ self .blur_filter = blur_filter [None , None , :, :]
44+
45+ def _apply (self , fn ):
46+ # override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
47+ # this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
48+ super (BlurPool2d , self )._apply (fn )
49+ self .blur_filter = fn (self .blur_filter )
4650
4751 def forward (self , input_tensor : torch .Tensor ) -> torch .Tensor : # type: ignore
48- if not torch .is_tensor (input_tensor ):
49- raise TypeError ("Input input type is not a torch.Tensor. Got {}" .format (type (input_tensor )))
50- if not len (input_tensor .shape ) == 4 :
51- raise ValueError ("Invalid input shape, we expect BxCxHxW. Got: {}" .format (input_tensor .shape ))
52- # apply blur_filter on input
5352 return F .conv2d (
5453 self .padding (input_tensor ),
55- self .blur_filter .type (input_tensor .dtype ),
56- stride = self .stride ,
57- groups = input_tensor .shape [1 ])
54+ self .blur_filter .type (input_tensor .dtype ).expand (C , - 1 , - 1 , - 1 ),
55+ stride = self .stride , groups = input_tensor .shape [1 ])
0 commit comments