Skip to content

Commit 31453b0

Browse files
committed
Update Auto/RandAugment comments, README, more.
* Add a weighted choice option for RandAugment * Adjust magnitude noise/std naming, config
1 parent 4243f07 commit 31453b0

File tree

3 files changed

+102
-28
lines changed

3 files changed

+102
-28
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Several (less common) features that I often utilize in my projects are included.
6969
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
7070
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
7171
* An inference script that dumps output to CSV is provided as an example
72+
* AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
7273

7374
## Results
7475

timm/data/auto_augment.py

Lines changed: 100 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
""" Auto Augment
1+
""" AutoAugment and RandAugment
22
Implementation adapted from:
33
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
4-
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
4+
Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
55
66
Hacked together by Ross Wightman
77
"""
@@ -288,18 +288,18 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
288288
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
289289
)
290290

291-
# If magnitude_noise is > 0, we introduce some randomness
291+
# If magnitude_std is > 0, we introduce some randomness
292292
# in the usually fixed policy and sample magnitude from a normal distribution
293-
# with mean `magnitude` and std-dev of `magnitude_noise`.
293+
# with mean `magnitude` and std-dev of `magnitude_std`.
294294
# NOTE This is my own hack, being tested, not in papers or reference impls.
295-
self.magnitude_noise = self.hparams.get('magnitude_noise', 0)
295+
self.magnitude_std = self.hparams.get('magnitude_std', 0)
296296

297297
def __call__(self, img):
298298
if random.random() > self.prob:
299299
return img
300300
magnitude = self.magnitude
301-
if self.magnitude_noise and self.magnitude_noise > 0:
302-
magnitude = random.gauss(magnitude, self.magnitude_noise)
301+
if self.magnitude_std and self.magnitude_std > 0:
302+
magnitude = random.gauss(magnitude, self.magnitude_std)
303303
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
304304
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
305305
return self.aug_fn(img, *level_args, **self.kwargs)
@@ -464,16 +464,32 @@ def __call__(self, img):
464464

465465

466466
def auto_augment_transform(config_str, hparams):
467+
"""
468+
Create a AutoAugment transform
469+
470+
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
471+
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
472+
The remaining sections, not order sepecific determine
473+
'mstd' - float std deviation of magnitude noise applied
474+
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
475+
476+
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
477+
478+
:return: A PyTorch compatible Transform
479+
"""
467480
config = config_str.split('-')
468481
policy_name = config[0]
469482
config = config[1:]
470483
for c in config:
471484
cs = re.split(r'(\d.*)', c)
472-
if len(cs) >= 2:
473-
key, val = cs[:2]
474-
if key == 'noise':
475-
# noise param injected via hparams for now
476-
hparams.setdefault('magnitude_noise', float(val))
485+
if len(cs) < 2:
486+
continue
487+
key, val = cs[:2]
488+
if key == 'mstd':
489+
# noise param injected via hparams for now
490+
hparams.setdefault('magnitude_std', float(val))
491+
else:
492+
assert False, 'Unknown AutoAugment config section'
477493
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
478494
return AutoAugment(aa_policy)
479495

@@ -498,6 +514,36 @@ def auto_augment_transform(config_str, hparams):
498514
]
499515

500516

517+
# These experimental weights are based loosely on the relative improvements mentioned in paper.
518+
# They may not result in increased performance, but could likely be tuned to so.
519+
_RAND_CHOICE_WEIGHTS_0 = {
520+
'Rotate': 0.3,
521+
'ShearX': 0.2,
522+
'ShearY': 0.2,
523+
'TranslateXRel': 0.1,
524+
'TranslateYRel': 0.1,
525+
'Color': .025,
526+
'Sharpness': 0.025,
527+
'AutoContrast': 0.025,
528+
'Solarize': .005,
529+
'SolarizeAdd': .005,
530+
'Contrast': .005,
531+
'Brightness': .005,
532+
'Equalize': .005,
533+
'PosterizeTpu': 0,
534+
'Invert': 0,
535+
}
536+
537+
538+
def _select_rand_weights(weight_idx=0, transforms=None):
539+
transforms = transforms or _RAND_TRANSFORMS
540+
assert weight_idx == 0 # only one set of weights currently
541+
rand_weights = _RAND_CHOICE_WEIGHTS_0
542+
probs = [rand_weights[k] for k in transforms]
543+
probs /= np.sum(probs)
544+
return probs
545+
546+
501547
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
502548
hparams = hparams or _HPARAMS_DEFAULT
503549
transforms = transforms or _RAND_TRANSFORMS
@@ -506,33 +552,60 @@ def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
506552

507553

508554
class RandAugment:
509-
def __init__(self, ops, num_layers=2):
555+
def __init__(self, ops, num_layers=2, choice_weights=None):
510556
self.ops = ops
511557
self.num_layers = num_layers
558+
self.choice_weights = choice_weights
512559

513560
def __call__(self, img):
514-
for _ in range(self.num_layers):
515-
op = random.choice(self.ops)
561+
# no replacement when using weighted choice
562+
ops = np.random.choice(
563+
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
564+
for op in ops:
516565
img = op(img)
517566
return img
518567

519568

520569
def rand_augment_transform(config_str, hparams):
521-
magnitude = 10
522-
num_layers = 2
570+
"""
571+
Create a RandAugment transform
572+
573+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
574+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
575+
sections, not order sepecific determine
576+
'm' - integer magnitude of rand augment
577+
'n' - integer num layers (number of transform ops selected per image)
578+
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
579+
'mstd' - float std deviation of magnitude noise applied
580+
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
581+
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
582+
583+
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
584+
585+
:return: A PyTorch compatible Transform
586+
"""
587+
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
588+
num_layers = 2 # default to 2 ops per image
589+
weight_idx = None # default to no probability weights for op choice
523590
config = config_str.split('-')
524591
assert config[0] == 'rand'
525592
config = config[1:]
526593
for c in config:
527594
cs = re.split(r'(\d.*)', c)
528-
if len(cs) >= 2:
529-
key, val = cs[:2]
530-
if key == 'noise':
531-
# noise param injected via hparams for now
532-
hparams.setdefault('magnitude_noise', float(val))
533-
elif key == 'm':
534-
magnitude = int(val)
535-
elif key == 'n':
536-
num_layers = int(val)
595+
if len(cs) < 2:
596+
continue
597+
key, val = cs[:2]
598+
if key == 'mstd':
599+
# noise param injected via hparams for now
600+
hparams.setdefault('magnitude_std', float(val))
601+
elif key == 'm':
602+
magnitude = int(val)
603+
elif key == 'n':
604+
num_layers = int(val)
605+
elif key == 'w':
606+
weight_idx = int(val)
607+
else:
608+
assert False, 'Unknown RandAugment config section'
537609
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
538-
return RandAugment(ra_ops, num_layers)
610+
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
611+
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)

timm/data/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def transforms_imagenet_train(
190190
)
191191
if interpolation and interpolation != 'random':
192192
aa_params['interpolation'] = _pil_interp(interpolation)
193-
if 'rand' in auto_augment:
193+
if auto_augment.startswith('rand'):
194194
tfl += [rand_augment_transform(auto_augment, aa_params)]
195195
else:
196196
tfl += [auto_augment_transform(auto_augment, aa_params)]

0 commit comments

Comments
 (0)