Skip to content

Commit 3cdaf5e

Browse files
committed
Add mmax config key to auto_augment for increasing upper bound of RandAugment magnitude beyond 10. Make AugMix uniform sampling default not override config setting.
1 parent 1042b8a commit 3cdaf5e

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

timm/data/auto_augment.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929

3030
_FILL = (128, 128, 128)
3131

32-
# This signifies the max integer that the controller RNN could predict for the
33-
# augmentation scheme.
34-
_MAX_LEVEL = 10.
32+
_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
3533

3634
_HPARAMS_DEFAULT = dict(
3735
translate_const=250,
@@ -179,42 +177,42 @@ def _randomly_negate(v):
179177

180178
def _rotate_level_to_arg(level, _hparams):
181179
# range [-30, 30]
182-
level = (level / _MAX_LEVEL) * 30.
180+
level = (level / _LEVEL_DENOM) * 30.
183181
level = _randomly_negate(level)
184182
return level,
185183

186184

187185
def _enhance_level_to_arg(level, _hparams):
188186
# range [0.1, 1.9]
189-
return (level / _MAX_LEVEL) * 1.8 + 0.1,
187+
return (level / _LEVEL_DENOM) * 1.8 + 0.1,
190188

191189

192190
def _enhance_increasing_level_to_arg(level, _hparams):
193191
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
194-
# range [0.1, 1.9]
195-
level = (level / _MAX_LEVEL) * .9
196-
level = 1.0 + _randomly_negate(level)
192+
# range [0.1, 1.9] if level <= _LEVEL_DENOM
193+
level = (level / _LEVEL_DENOM) * .9
194+
level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
197195
return level,
198196

199197

200198
def _shear_level_to_arg(level, _hparams):
201199
# range [-0.3, 0.3]
202-
level = (level / _MAX_LEVEL) * 0.3
200+
level = (level / _LEVEL_DENOM) * 0.3
203201
level = _randomly_negate(level)
204202
return level,
205203

206204

207205
def _translate_abs_level_to_arg(level, hparams):
208206
translate_const = hparams['translate_const']
209-
level = (level / _MAX_LEVEL) * float(translate_const)
207+
level = (level / _LEVEL_DENOM) * float(translate_const)
210208
level = _randomly_negate(level)
211209
return level,
212210

213211

214212
def _translate_rel_level_to_arg(level, hparams):
215213
# default range [-0.45, 0.45]
216214
translate_pct = hparams.get('translate_pct', 0.45)
217-
level = (level / _MAX_LEVEL) * translate_pct
215+
level = (level / _LEVEL_DENOM) * translate_pct
218216
level = _randomly_negate(level)
219217
return level,
220218

@@ -223,7 +221,7 @@ def _posterize_level_to_arg(level, _hparams):
223221
# As per Tensorflow TPU EfficientNet impl
224222
# range [0, 4], 'keep 0 up to 4 MSB of original image'
225223
# intensity/severity of augmentation decreases with level
226-
return int((level / _MAX_LEVEL) * 4),
224+
return int((level / _LEVEL_DENOM) * 4),
227225

228226

229227
def _posterize_increasing_level_to_arg(level, hparams):
@@ -237,13 +235,13 @@ def _posterize_original_level_to_arg(level, _hparams):
237235
# As per original AutoAugment paper description
238236
# range [4, 8], 'keep 4 up to 8 MSB of image'
239237
# intensity/severity of augmentation decreases with level
240-
return int((level / _MAX_LEVEL) * 4) + 4,
238+
return int((level / _LEVEL_DENOM) * 4) + 4,
241239

242240

243241
def _solarize_level_to_arg(level, _hparams):
244242
# range [0, 256]
245243
# intensity/severity of augmentation decreases with level
246-
return int((level / _MAX_LEVEL) * 256),
244+
return int((level / _LEVEL_DENOM) * 256),
247245

248246

249247
def _solarize_increasing_level_to_arg(level, _hparams):
@@ -254,7 +252,7 @@ def _solarize_increasing_level_to_arg(level, _hparams):
254252

255253
def _solarize_add_level_to_arg(level, _hparams):
256254
# range [0, 110]
257-
return int((level / _MAX_LEVEL) * 110),
255+
return int((level / _LEVEL_DENOM) * 110),
258256

259257

260258
LEVEL_TO_ARG = {
@@ -334,17 +332,22 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
334332
# NOTE This is my own hack, being tested, not in papers or reference impls.
335333
# If magnitude_std is inf, we sample magnitude from a uniform distribution
336334
self.magnitude_std = self.hparams.get('magnitude_std', 0)
335+
self.magnitude_max = self.hparams.get('magnitude_max', None)
337336

338337
def __call__(self, img):
339338
if self.prob < 1.0 and random.random() > self.prob:
340339
return img
341340
magnitude = self.magnitude
342-
if self.magnitude_std:
341+
if self.magnitude_std > 0:
342+
# magnitude randomization enabled
343343
if self.magnitude_std == float('inf'):
344344
magnitude = random.uniform(0, magnitude)
345345
elif self.magnitude_std > 0:
346346
magnitude = random.gauss(magnitude, self.magnitude_std)
347-
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
347+
# default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
348+
# setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
349+
upper_bound = self.magnitude_max or _LEVEL_DENOM
350+
magnitude = max(0., min(magnitude, upper_bound))
348351
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
349352
return self.aug_fn(img, *level_args, **self.kwargs)
350353

@@ -642,7 +645,8 @@ def rand_augment_transform(config_str, hparams):
642645
'm' - integer magnitude of rand augment
643646
'n' - integer num layers (number of transform ops selected per image)
644647
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
645-
'mstd' - float std deviation of magnitude noise applied
648+
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
649+
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
646650
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
647651
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
648652
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
@@ -651,7 +655,7 @@ def rand_augment_transform(config_str, hparams):
651655
652656
:return: A PyTorch compatible Transform
653657
"""
654-
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
658+
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
655659
num_layers = 2 # default to 2 ops per image
656660
weight_idx = None # default to no probability weights for op choice
657661
transforms = _RAND_TRANSFORMS
@@ -664,8 +668,15 @@ def rand_augment_transform(config_str, hparams):
664668
continue
665669
key, val = cs[:2]
666670
if key == 'mstd':
667-
# noise param injected via hparams for now
668-
hparams.setdefault('magnitude_std', float(val))
671+
# noise param / randomization of magnitude values
672+
mstd = float(val)
673+
if mstd > 100:
674+
# use uniform sampling in 0 to magnitude if mstd is > 100
675+
mstd = float('inf')
676+
hparams.setdefault('magnitude_std', mstd)
677+
elif key == 'mmax':
678+
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
679+
hparams.setdefault('magnitude_max', int(val))
669680
elif key == 'inc':
670681
if bool(val):
671682
transforms = _RAND_INCREASING_TRANSFORMS
@@ -794,7 +805,6 @@ def augment_and_mix_transform(config_str, hparams):
794805
depth = -1
795806
alpha = 1.
796807
blended = False
797-
hparams['magnitude_std'] = float('inf')
798808
config = config_str.split('-')
799809
assert config[0] == 'augmix'
800810
config = config[1:]
@@ -818,5 +828,6 @@ def augment_and_mix_transform(config_str, hparams):
818828
blended = bool(val)
819829
else:
820830
assert False, 'Unknown AugMix config section'
831+
hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg)
821832
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
822833
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)

0 commit comments

Comments
 (0)