|
14 | 14 |
|
15 | 15 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
16 | 16 | from .helpers import load_pretrained |
17 | | -from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act |
| 17 | +from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d |
18 | 18 | from .registry import register_model |
19 | 19 |
|
20 | 20 | __all__ = ['DenseNet'] |
@@ -71,9 +71,9 @@ def any_requires_grad(self, x): |
71 | 71 | def call_checkpoint_bottleneck(self, x): |
72 | 72 | # type: (List[torch.Tensor]) -> torch.Tensor |
73 | 73 | def closure(*xs): |
74 | | - return self.bottleneck_fn(*xs) |
| 74 | + return self.bottleneck_fn(xs) |
75 | 75 |
|
76 | | - return cp.checkpoint(closure, x) |
| 76 | + return cp.checkpoint(closure, *x) |
77 | 77 |
|
78 | 78 | @torch.jit._overload_method # noqa: F811 |
79 | 79 | def forward(self, x): |
@@ -132,12 +132,15 @@ def forward(self, init_features): |
132 | 132 |
|
133 | 133 |
|
134 | 134 | class DenseTransition(nn.Sequential): |
135 | | - def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d): |
| 135 | + def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d, aa_layer=None): |
136 | 136 | super(DenseTransition, self).__init__() |
137 | 137 | self.add_module('norm', norm_act_layer(num_input_features)) |
138 | 138 | self.add_module('conv', nn.Conv2d( |
139 | 139 | num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) |
140 | | - self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) |
| 140 | + if aa_layer is not None: |
| 141 | + self.add_module('pool', aa_layer(num_output_features, stride=2)) |
| 142 | + else: |
| 143 | + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) |
141 | 144 |
|
142 | 145 |
|
143 | 146 | class DenseNet(nn.Module): |
@@ -301,6 +304,17 @@ def densenet121(pretrained=False, **kwargs): |
301 | 304 | return model |
302 | 305 |
|
303 | 306 |
|
| 307 | +@register_model |
| 308 | +def densenetblur121d(pretrained=False, **kwargs): |
| 309 | + r"""Densenet-121 model from |
| 310 | + `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` |
| 311 | + """ |
| 312 | + model = _densenet( |
| 313 | + 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', |
| 314 | + aa_layer=BlurPool2d, **kwargs) |
| 315 | + return model |
| 316 | + |
| 317 | + |
304 | 318 | @register_model |
305 | 319 | def densenet121d(pretrained=False, **kwargs): |
306 | 320 | r"""Densenet-121 model from |
|
0 commit comments