55import torch
66
77
8- def _get_patch (per_pixel , rand_color , patch_size , dtype = torch .float32 , device = 'cuda' ):
8+ def _get_pixels (per_pixel , rand_color , patch_size , dtype = torch .float32 , device = 'cuda' ):
9+ # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
10+ # paths, flip the order so normal is run on CPU if this becomes a problem
11+ # ie torch.empty(patch_size, dtype=dtype).normal_().to(device=device)
912 if per_pixel :
1013 return torch .empty (
1114 patch_size , dtype = dtype , device = device ).normal_ ()
@@ -27,20 +30,29 @@ class RandomErasing:
2730 sl: Minimum proportion of erased area against input image.
2831 sh: Maximum proportion of erased area against input image.
2932 min_aspect: Minimum aspect ratio of erased area.
30- per_pixel: random value for each pixel in the erase region, precedence over rand_color
31- rand_color: random color for whole erase region, 0 if neither this or per_pixel set
33+ mode: pixel color mode, one of 'const', 'rand', or 'pixel'
34+ 'const' - erase block is constant color of 0 for all channels
35+ 'rand' - erase block is same per-cannel random (normal) color
36+ 'pixel' - erase block is per-pixel random (normal) color
3237 """
3338
3439 def __init__ (
3540 self ,
3641 probability = 0.5 , sl = 0.02 , sh = 1 / 3 , min_aspect = 0.3 ,
37- per_pixel = False , rand_color = False , device = 'cuda' ):
42+ mode = 'const' , device = 'cuda' ):
3843 self .probability = probability
3944 self .sl = sl
4045 self .sh = sh
4146 self .min_aspect = min_aspect
42- self .per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
43- self .rand_color = rand_color # per block random, bounded by [pl, ph]
47+ mode = mode .lower ()
48+ self .rand_color = False
49+ self .per_pixel = False
50+ if mode == 'rand' :
51+ self .rand_color = True # per block random normal
52+ elif mode == 'pixel' :
53+ self .per_pixel = True # per pixel random normal
54+ else :
55+ assert not mode or mode == 'const'
4456 self .device = device
4557
4658 def _erase (self , img , chan , img_h , img_w , dtype ):
@@ -55,8 +67,9 @@ def _erase(self, img, chan, img_h, img_w, dtype):
5567 if w < img_w and h < img_h :
5668 top = random .randint (0 , img_h - h )
5769 left = random .randint (0 , img_w - w )
58- img [:, top :top + h , left :left + w ] = _get_patch (
59- self .per_pixel , self .rand_color , (chan , h , w ), dtype = dtype , device = self .device )
70+ img [:, top :top + h , left :left + w ] = _get_pixels (
71+ self .per_pixel , self .rand_color , (chan , h , w ),
72+ dtype = dtype , device = self .device )
6073 break
6174
6275 def __call__ (self , input ):
0 commit comments