|
2 | 2 |
|
3 | 3 | import random |
4 | 4 | import math |
5 | | -import numpy as np |
6 | 5 | import torch |
7 | 6 |
|
8 | 7 |
|
9 | | -class RandomErasingNumpy: |
| 8 | +def _get_patch(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): |
| 9 | + if per_pixel: |
| 10 | + return torch.empty( |
| 11 | + patch_size, dtype=dtype, device=device).normal_() |
| 12 | + elif rand_color: |
| 13 | + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() |
| 14 | + else: |
| 15 | + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) |
| 16 | + |
| 17 | + |
| 18 | +class RandomErasing: |
10 | 19 | """ Randomly selects a rectangle region in an image and erases its pixels. |
11 | 20 | 'Random Erasing Data Augmentation' by Zhong et al. |
12 | 21 | See https://arxiv.org/pdf/1708.04896.pdf |
13 | 22 |
|
14 | | - This 'Numpy' variant of RandomErasing is intended to be applied on a per |
15 | | - image basis after transforming the image to uint8 numpy array in |
16 | | - range 0-255 prior to tensor conversion and normalization |
| 23 | + This variant of RandomErasing is intended to be applied to either a batch |
| 24 | + or single image tensor after it has been normalized by dataset mean and std. |
17 | 25 | Args: |
18 | 26 | probability: The probability that the Random Erasing operation will be performed. |
19 | 27 | sl: Minimum proportion of erased area against input image. |
20 | 28 | sh: Maximum proportion of erased area against input image. |
21 | | - r1: Minimum aspect ratio of erased area. |
22 | | - mean: Erasing value. |
| 29 | + min_aspect: Minimum aspect ratio of erased area. |
| 30 | + per_pixel: random value for each pixel in the erase region, precedence over rand_color |
| 31 | + rand_color: random color for whole erase region, 0 if neither this or per_pixel set |
23 | 32 | """ |
24 | 33 |
|
25 | 34 | def __init__( |
26 | 35 | self, |
27 | 36 | probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, |
28 | | - per_pixel=False, rand_color=False, |
29 | | - pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406], |
30 | | - out_type=np.uint8): |
| 37 | + per_pixel=False, rand_color=False, device='cuda'): |
31 | 38 | self.probability = probability |
32 | | - if not per_pixel and not rand_color: |
33 | | - self.mean = np.array(mean).round().astype(out_type) |
34 | | - else: |
35 | | - self.mean = None |
36 | 39 | self.sl = sl |
37 | 40 | self.sh = sh |
38 | 41 | self.min_aspect = min_aspect |
39 | | - self.pl = pl |
40 | | - self.ph = ph |
41 | 42 | self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] |
42 | 43 | self.rand_color = rand_color # per block random, bounded by [pl, ph] |
43 | | - self.out_type = out_type |
| 44 | + self.device = device |
44 | 45 |
|
45 | | - def __call__(self, img): |
| 46 | + def _erase(self, img, chan, img_h, img_w, dtype): |
46 | 47 | if random.random() > self.probability: |
47 | | - return img |
48 | | - |
49 | | - chan, img_h, img_w = img.shape |
| 48 | + return |
50 | 49 | area = img_h * img_w |
51 | 50 | for attempt in range(100): |
52 | 51 | target_area = random.uniform(self.sl, self.sh) * area |
53 | 52 | aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) |
54 | | - |
55 | 53 | h = int(round(math.sqrt(target_area * aspect_ratio))) |
56 | 54 | w = int(round(math.sqrt(target_area / aspect_ratio))) |
57 | | - if self.rand_color: |
58 | | - c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type) |
59 | | - elif not self.per_pixel: |
60 | | - c = self.mean[:chan] |
61 | 55 | if w < img_w and h < img_h: |
62 | 56 | top = random.randint(0, img_h - h) |
63 | 57 | left = random.randint(0, img_w - w) |
64 | | - if self.per_pixel: |
65 | | - img[:, top:top + h, left:left + w] = np.random.randint( |
66 | | - self.pl, self.ph + 1, (chan, h, w), self.out_type) |
67 | | - else: |
68 | | - img[:, top:top + h, left:left + w] = c |
69 | | - return img |
70 | | - |
71 | | - return img |
72 | | - |
73 | | - |
74 | | -class RandomErasingTorch: |
75 | | - """ Randomly selects a rectangle region in an image and erases its pixels. |
76 | | - 'Random Erasing Data Augmentation' by Zhong et al. |
77 | | - See https://arxiv.org/pdf/1708.04896.pdf |
| 58 | + img[:, top:top + h, left:left + w] = _get_patch( |
| 59 | + self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device) |
| 60 | + break |
78 | 61 |
|
79 | | - This 'Torch' variant of RandomErasing is intended to be applied to a full batch |
80 | | - tensor after it has been normalized by dataset mean and std. |
81 | | - Args: |
82 | | - probability: The probability that the Random Erasing operation will be performed. |
83 | | - sl: Minimum proportion of erased area against input image. |
84 | | - sh: Maximum proportion of erased area against input image. |
85 | | - r1: Minimum aspect ratio of erased area. |
86 | | - """ |
87 | | - |
88 | | - def __init__( |
89 | | - self, |
90 | | - probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, |
91 | | - per_pixel=False, rand_color=False): |
92 | | - self.probability = probability |
93 | | - self.sl = sl |
94 | | - self.sh = sh |
95 | | - self.min_aspect = min_aspect |
96 | | - self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] |
97 | | - self.rand_color = rand_color # per block random, bounded by [pl, ph] |
98 | | - |
99 | | - def __call__(self, batch): |
100 | | - batch_size, chan, img_h, img_w = batch.size() |
101 | | - area = img_h * img_w |
102 | | - for i in range(batch_size): |
103 | | - if random.random() > self.probability: |
104 | | - continue |
105 | | - img = batch[i] |
106 | | - for attempt in range(100): |
107 | | - target_area = random.uniform(self.sl, self.sh) * area |
108 | | - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) |
109 | | - |
110 | | - h = int(round(math.sqrt(target_area * aspect_ratio))) |
111 | | - w = int(round(math.sqrt(target_area / aspect_ratio))) |
112 | | - if self.rand_color: |
113 | | - c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda() |
114 | | - elif not self.per_pixel: |
115 | | - c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda() |
116 | | - if w < img_w and h < img_h: |
117 | | - top = random.randint(0, img_h - h) |
118 | | - left = random.randint(0, img_w - w) |
119 | | - if self.per_pixel: |
120 | | - img[:, top:top + h, left:left + w] = torch.empty( |
121 | | - (chan, h, w), dtype=batch.dtype).normal_().cuda() |
122 | | - else: |
123 | | - img[:, top:top + h, left:left + w] = c |
124 | | - break |
125 | | - |
126 | | - return batch |
| 62 | + def __call__(self, input): |
| 63 | + if len(input.size()) == 3: |
| 64 | + self._erase(input, *input.size(), input.dtype) |
| 65 | + else: |
| 66 | + batch_size, chan, img_h, img_w = input.size() |
| 67 | + for i in range(batch_size): |
| 68 | + self._erase(input[i], chan, img_h, img_w, input.dtype) |
| 69 | + return input |
0 commit comments