Skip to content

Commit f17b42b

Browse files
rwightmanchris-ha458
authored andcommitted
Blur filter no longer a buffer
1 parent 6cdeca2 commit f17b42b

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

timm/models/layers/blurpool.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,16 @@ def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
4040

4141
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
4242
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
43-
# FIXME figure a clean hack to prevent the filter from getting saved in weights, but still
44-
# plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW
45-
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
43+
self.blur_filter = blur_filter[None, None, :, :]
44+
45+
def _apply(self, fn):
46+
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
47+
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
48+
super(BlurPool2d, self)._apply(fn)
49+
self.blur_filter = fn(self.blur_filter)
4650

4751
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
48-
if not torch.is_tensor(input_tensor):
49-
raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor)))
50-
if not len(input_tensor.shape) == 4:
51-
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape))
52-
# apply blur_filter on input
5352
return F.conv2d(
5453
self.padding(input_tensor),
55-
self.blur_filter.type(input_tensor.dtype),
56-
stride=self.stride,
57-
groups=input_tensor.shape[1])
54+
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1),
55+
stride=self.stride, groups=input_tensor.shape[1])

0 commit comments

Comments
 (0)