Skip to content

Commit 6441e9c

Browse files
committed
Fix memory_efficient mode for DenseNets. Add AntiAliasing (Blur) support for DenseNets and create one test model. Add lr cycle/mul params to train args.
1 parent 7df8325 commit 6441e9c

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

timm/models/densenet.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1616
from .helpers import load_pretrained
17-
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act
17+
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
1818
from .registry import register_model
1919

2020
__all__ = ['DenseNet']
@@ -71,9 +71,9 @@ def any_requires_grad(self, x):
7171
def call_checkpoint_bottleneck(self, x):
7272
# type: (List[torch.Tensor]) -> torch.Tensor
7373
def closure(*xs):
74-
return self.bottleneck_fn(*xs)
74+
return self.bottleneck_fn(xs)
7575

76-
return cp.checkpoint(closure, x)
76+
return cp.checkpoint(closure, *x)
7777

7878
@torch.jit._overload_method # noqa: F811
7979
def forward(self, x):
@@ -132,12 +132,15 @@ def forward(self, init_features):
132132

133133

134134
class DenseTransition(nn.Sequential):
135-
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d):
135+
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d, aa_layer=None):
136136
super(DenseTransition, self).__init__()
137137
self.add_module('norm', norm_act_layer(num_input_features))
138138
self.add_module('conv', nn.Conv2d(
139139
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
140-
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
140+
if aa_layer is not None:
141+
self.add_module('pool', aa_layer(num_output_features, stride=2))
142+
else:
143+
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
141144

142145

143146
class DenseNet(nn.Module):
@@ -301,6 +304,17 @@ def densenet121(pretrained=False, **kwargs):
301304
return model
302305

303306

307+
@register_model
308+
def densenetblur121d(pretrained=False, **kwargs):
309+
r"""Densenet-121 model from
310+
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
311+
"""
312+
model = _densenet(
313+
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
314+
aa_layer=BlurPool2d, **kwargs)
315+
return model
316+
317+
304318
@register_model
305319
def densenet121d(pretrained=False, **kwargs):
306320
r"""Densenet-121 model from

timm/scheduler/scheduler_factory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ def create_scheduler(args, optimizer):
2323
lr_scheduler = CosineLRScheduler(
2424
optimizer,
2525
t_initial=num_epochs,
26-
t_mul=1.0,
26+
t_mul=args.lr_cycle_mul,
2727
lr_min=args.min_lr,
2828
decay_rate=args.decay_rate,
2929
warmup_lr_init=args.warmup_lr,
3030
warmup_t=args.warmup_epochs,
31-
cycle_limit=1,
31+
cycle_limit=args.lr_cycle_limit,
3232
t_in_epochs=True,
3333
noise_range_t=noise_range,
3434
noise_pct=args.lr_noise_pct,
@@ -40,11 +40,11 @@ def create_scheduler(args, optimizer):
4040
lr_scheduler = TanhLRScheduler(
4141
optimizer,
4242
t_initial=num_epochs,
43-
t_mul=1.0,
43+
t_mul=args.lr_cycle_mul,
4444
lr_min=args.min_lr,
4545
warmup_lr_init=args.warmup_lr,
4646
warmup_t=args.warmup_epochs,
47-
cycle_limit=1,
47+
cycle_limit=args.lr_cycle_limit,
4848
t_in_epochs=True,
4949
noise_range_t=noise_range,
5050
noise_pct=args.lr_noise_pct,

train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@
111111
help='learning rate noise limit percent (default: 0.67)')
112112
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
113113
help='learning rate noise std-dev (default: 1.0)')
114+
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
115+
help='learning rate cycle len multiplier (default: 1.0)')
116+
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
117+
help='learning rate cycle limit')
114118
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
115119
help='warmup learning rate (default: 0.0001)')
116120
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',

0 commit comments

Comments
 (0)