1- """ Auto Augment
1+ """ AutoAugment and RandAugment
22Implementation adapted from:
33 https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
4- Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
4+ Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
55
66Hacked together by Ross Wightman
77"""
@@ -288,18 +288,18 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
288288 resample = hparams ['interpolation' ] if 'interpolation' in hparams else _RANDOM_INTERPOLATION ,
289289 )
290290
291- # If magnitude_noise is > 0, we introduce some randomness
291+ # If magnitude_std is > 0, we introduce some randomness
292292 # in the usually fixed policy and sample magnitude from a normal distribution
293- # with mean `magnitude` and std-dev of `magnitude_noise `.
293+ # with mean `magnitude` and std-dev of `magnitude_std `.
294294 # NOTE This is my own hack, being tested, not in papers or reference impls.
295- self .magnitude_noise = self .hparams .get ('magnitude_noise ' , 0 )
295+ self .magnitude_std = self .hparams .get ('magnitude_std ' , 0 )
296296
297297 def __call__ (self , img ):
298298 if random .random () > self .prob :
299299 return img
300300 magnitude = self .magnitude
301- if self .magnitude_noise and self .magnitude_noise > 0 :
302- magnitude = random .gauss (magnitude , self .magnitude_noise )
301+ if self .magnitude_std and self .magnitude_std > 0 :
302+ magnitude = random .gauss (magnitude , self .magnitude_std )
303303 magnitude = min (_MAX_LEVEL , max (0 , magnitude )) # clip to valid range
304304 level_args = self .level_fn (magnitude , self .hparams ) if self .level_fn is not None else tuple ()
305305 return self .aug_fn (img , * level_args , ** self .kwargs )
@@ -464,16 +464,32 @@ def __call__(self, img):
464464
465465
466466def auto_augment_transform (config_str , hparams ):
467+ """
468+ Create a AutoAugment transform
469+
470+ :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
471+ dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
472+ The remaining sections, not order sepecific determine
473+ 'mstd' - float std deviation of magnitude noise applied
474+ Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
475+
476+ :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
477+
478+ :return: A PyTorch compatible Transform
479+ """
467480 config = config_str .split ('-' )
468481 policy_name = config [0 ]
469482 config = config [1 :]
470483 for c in config :
471484 cs = re .split (r'(\d.*)' , c )
472- if len (cs ) >= 2 :
473- key , val = cs [:2 ]
474- if key == 'noise' :
475- # noise param injected via hparams for now
476- hparams .setdefault ('magnitude_noise' , float (val ))
485+ if len (cs ) < 2 :
486+ continue
487+ key , val = cs [:2 ]
488+ if key == 'mstd' :
489+ # noise param injected via hparams for now
490+ hparams .setdefault ('magnitude_std' , float (val ))
491+ else :
492+ assert False , 'Unknown AutoAugment config section'
477493 aa_policy = auto_augment_policy (policy_name , hparams = hparams )
478494 return AutoAugment (aa_policy )
479495
@@ -498,6 +514,36 @@ def auto_augment_transform(config_str, hparams):
498514]
499515
500516
517+ # These experimental weights are based loosely on the relative improvements mentioned in paper.
518+ # They may not result in increased performance, but could likely be tuned to so.
519+ _RAND_CHOICE_WEIGHTS_0 = {
520+ 'Rotate' : 0.3 ,
521+ 'ShearX' : 0.2 ,
522+ 'ShearY' : 0.2 ,
523+ 'TranslateXRel' : 0.1 ,
524+ 'TranslateYRel' : 0.1 ,
525+ 'Color' : .025 ,
526+ 'Sharpness' : 0.025 ,
527+ 'AutoContrast' : 0.025 ,
528+ 'Solarize' : .005 ,
529+ 'SolarizeAdd' : .005 ,
530+ 'Contrast' : .005 ,
531+ 'Brightness' : .005 ,
532+ 'Equalize' : .005 ,
533+ 'PosterizeTpu' : 0 ,
534+ 'Invert' : 0 ,
535+ }
536+
537+
538+ def _select_rand_weights (weight_idx = 0 , transforms = None ):
539+ transforms = transforms or _RAND_TRANSFORMS
540+ assert weight_idx == 0 # only one set of weights currently
541+ rand_weights = _RAND_CHOICE_WEIGHTS_0
542+ probs = [rand_weights [k ] for k in transforms ]
543+ probs /= np .sum (probs )
544+ return probs
545+
546+
501547def rand_augment_ops (magnitude = 10 , hparams = None , transforms = None ):
502548 hparams = hparams or _HPARAMS_DEFAULT
503549 transforms = transforms or _RAND_TRANSFORMS
@@ -506,33 +552,60 @@ def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
506552
507553
508554class RandAugment :
509- def __init__ (self , ops , num_layers = 2 ):
555+ def __init__ (self , ops , num_layers = 2 , choice_weights = None ):
510556 self .ops = ops
511557 self .num_layers = num_layers
558+ self .choice_weights = choice_weights
512559
513560 def __call__ (self , img ):
514- for _ in range (self .num_layers ):
515- op = random .choice (self .ops )
561+ # no replacement when using weighted choice
562+ ops = np .random .choice (
563+ self .ops , self .num_layers , replace = self .choice_weights is None , p = self .choice_weights )
564+ for op in ops :
516565 img = op (img )
517566 return img
518567
519568
520569def rand_augment_transform (config_str , hparams ):
521- magnitude = 10
522- num_layers = 2
570+ """
571+ Create a RandAugment transform
572+
573+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
574+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
575+ sections, not order sepecific determine
576+ 'm' - integer magnitude of rand augment
577+ 'n' - integer num layers (number of transform ops selected per image)
578+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
579+ 'mstd' - float std deviation of magnitude noise applied
580+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
581+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
582+
583+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
584+
585+ :return: A PyTorch compatible Transform
586+ """
587+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
588+ num_layers = 2 # default to 2 ops per image
589+ weight_idx = None # default to no probability weights for op choice
523590 config = config_str .split ('-' )
524591 assert config [0 ] == 'rand'
525592 config = config [1 :]
526593 for c in config :
527594 cs = re .split (r'(\d.*)' , c )
528- if len (cs ) >= 2 :
529- key , val = cs [:2 ]
530- if key == 'noise' :
531- # noise param injected via hparams for now
532- hparams .setdefault ('magnitude_noise' , float (val ))
533- elif key == 'm' :
534- magnitude = int (val )
535- elif key == 'n' :
536- num_layers = int (val )
595+ if len (cs ) < 2 :
596+ continue
597+ key , val = cs [:2 ]
598+ if key == 'mstd' :
599+ # noise param injected via hparams for now
600+ hparams .setdefault ('magnitude_std' , float (val ))
601+ elif key == 'm' :
602+ magnitude = int (val )
603+ elif key == 'n' :
604+ num_layers = int (val )
605+ elif key == 'w' :
606+ weight_idx = int (val )
607+ else :
608+ assert False , 'Unknown RandAugment config section'
537609 ra_ops = rand_augment_ops (magnitude = magnitude , hparams = hparams )
538- return RandAugment (ra_ops , num_layers )
610+ choice_weights = None if weight_idx is None else _select_rand_weights (weight_idx )
611+ return RandAugment (ra_ops , num_layers , choice_weights = choice_weights )
0 commit comments