Skip to content

Commit 76539d9

Browse files
committed
Some transform/data/loader refactoring, hopefully didn't break things
* factor out data related constants to own file * move data related config helpers to own file * add a variant of RandomResizeCrop that randomizes interpolation method * remove old Numpy version of RandomErasing * cleanup torch version of RandomErasing and use it in either GPU loader batch mode or single image cpu Transform
1 parent e3377b0 commit 76539d9

File tree

14 files changed

+270
-219
lines changed

14 files changed

+270
-219
lines changed

data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from data.constants import *
2+
from data.config import resolve_data_config
13
from data.dataset import Dataset
24
from data.transforms import *
35
from data.loader import create_loader
4-
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy

data/config.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from data.constants import *
2+
3+
4+
def resolve_data_config(model, args, default_cfg={}, verbose=True):
5+
new_config = {}
6+
default_cfg = default_cfg
7+
if not default_cfg and hasattr(model, 'default_cfg'):
8+
default_cfg = model.default_cfg
9+
10+
# Resolve input/image size
11+
# FIXME grayscale/chans arg to use different # channels?
12+
in_chans = 3
13+
input_size = (in_chans, 224, 224)
14+
if args.img_size is not None:
15+
# FIXME support passing img_size as tuple, non-square
16+
assert isinstance(args.img_size, int)
17+
input_size = (in_chans, args.img_size, args.img_size)
18+
elif 'input_size' in default_cfg:
19+
input_size = default_cfg['input_size']
20+
new_config['input_size'] = input_size
21+
22+
# resolve interpolation method
23+
new_config['interpolation'] = 'bilinear'
24+
if args.interpolation:
25+
new_config['interpolation'] = args.interpolation
26+
elif 'interpolation' in default_cfg:
27+
new_config['interpolation'] = default_cfg['interpolation']
28+
29+
# resolve dataset + model mean for normalization
30+
new_config['mean'] = get_mean_by_model(args.model)
31+
if args.mean is not None:
32+
mean = tuple(args.mean)
33+
if len(mean) == 1:
34+
mean = tuple(list(mean) * in_chans)
35+
else:
36+
assert len(mean) == in_chans
37+
new_config['mean'] = mean
38+
elif 'mean' in default_cfg:
39+
new_config['mean'] = default_cfg['mean']
40+
41+
# resolve dataset + model std deviation for normalization
42+
new_config['std'] = get_std_by_model(args.model)
43+
if args.std is not None:
44+
std = tuple(args.std)
45+
if len(std) == 1:
46+
std = tuple(list(std) * in_chans)
47+
else:
48+
assert len(std) == in_chans
49+
new_config['std'] = std
50+
elif 'std' in default_cfg:
51+
new_config['std'] = default_cfg['std']
52+
53+
# resolve default crop percentage
54+
new_config['crop_pct'] = DEFAULT_CROP_PCT
55+
if 'crop_pct' in default_cfg:
56+
new_config['crop_pct'] = default_cfg['crop_pct']
57+
58+
if verbose:
59+
print('Data processing configuration for current model + dataset:')
60+
for n, v in new_config.items():
61+
print('\t%s: %s' % (n, str(v)))
62+
63+
return new_config
64+
65+
66+
def get_mean_by_name(name):
67+
if name == 'dpn':
68+
return IMAGENET_DPN_MEAN
69+
elif name == 'inception' or name == 'le':
70+
return IMAGENET_INCEPTION_MEAN
71+
else:
72+
return IMAGENET_DEFAULT_MEAN
73+
74+
75+
def get_std_by_name(name):
76+
if name == 'dpn':
77+
return IMAGENET_DPN_STD
78+
elif name == 'inception' or name == 'le':
79+
return IMAGENET_INCEPTION_STD
80+
else:
81+
return IMAGENET_DEFAULT_STD
82+
83+
84+
def get_mean_by_model(model_name):
85+
model_name = model_name.lower()
86+
if 'dpn' in model_name:
87+
return IMAGENET_DPN_STD
88+
elif 'ception' in model_name or 'nasnet' in model_name:
89+
return IMAGENET_INCEPTION_MEAN
90+
else:
91+
return IMAGENET_DEFAULT_MEAN
92+
93+
94+
def get_std_by_model(model_name):
95+
model_name = model_name.lower()
96+
if 'dpn' in model_name:
97+
return IMAGENET_DEFAULT_STD
98+
elif 'ception' in model_name or 'nasnet' in model_name:
99+
return IMAGENET_INCEPTION_STD
100+
else:
101+
return IMAGENET_DEFAULT_STD

data/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
DEFAULT_CROP_PCT = 0.875
2+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
3+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
4+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
5+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
6+
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
7+
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)

data/loader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import torch
21
import torch.utils.data
3-
from data.random_erasing import RandomErasingTorch
42
from data.transforms import *
53
from data.distributed_sampler import OrderedDistributedSampler
64

@@ -27,7 +25,7 @@ def __init__(self,
2725
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
2826
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
2927
if rand_erase_prob > 0.:
30-
self.random_erasing = RandomErasingTorch(
28+
self.random_erasing = RandomErasing(
3129
probability=rand_erase_prob, per_pixel=rand_erase_pp)
3230
else:
3331
self.random_erasing = None

data/random_erasing.py

Lines changed: 31 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,125 +2,68 @@
22

33
import random
44
import math
5-
import numpy as np
65
import torch
76

87

9-
class RandomErasingNumpy:
8+
def _get_patch(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
9+
if per_pixel:
10+
return torch.empty(
11+
patch_size, dtype=dtype, device=device).normal_()
12+
elif rand_color:
13+
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
14+
else:
15+
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
16+
17+
18+
class RandomErasing:
1019
""" Randomly selects a rectangle region in an image and erases its pixels.
1120
'Random Erasing Data Augmentation' by Zhong et al.
1221
See https://arxiv.org/pdf/1708.04896.pdf
1322
14-
This 'Numpy' variant of RandomErasing is intended to be applied on a per
15-
image basis after transforming the image to uint8 numpy array in
16-
range 0-255 prior to tensor conversion and normalization
23+
This variant of RandomErasing is intended to be applied to either a batch
24+
or single image tensor after it has been normalized by dataset mean and std.
1725
Args:
1826
probability: The probability that the Random Erasing operation will be performed.
1927
sl: Minimum proportion of erased area against input image.
2028
sh: Maximum proportion of erased area against input image.
21-
r1: Minimum aspect ratio of erased area.
22-
mean: Erasing value.
29+
min_aspect: Minimum aspect ratio of erased area.
30+
per_pixel: random value for each pixel in the erase region, precedence over rand_color
31+
rand_color: random color for whole erase region, 0 if neither this or per_pixel set
2332
"""
2433

2534
def __init__(
2635
self,
2736
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
28-
per_pixel=False, rand_color=False,
29-
pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406],
30-
out_type=np.uint8):
37+
per_pixel=False, rand_color=False, device='cuda'):
3138
self.probability = probability
32-
if not per_pixel and not rand_color:
33-
self.mean = np.array(mean).round().astype(out_type)
34-
else:
35-
self.mean = None
3639
self.sl = sl
3740
self.sh = sh
3841
self.min_aspect = min_aspect
39-
self.pl = pl
40-
self.ph = ph
4142
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
4243
self.rand_color = rand_color # per block random, bounded by [pl, ph]
43-
self.out_type = out_type
44+
self.device = device
4445

45-
def __call__(self, img):
46+
def _erase(self, img, chan, img_h, img_w, dtype):
4647
if random.random() > self.probability:
47-
return img
48-
49-
chan, img_h, img_w = img.shape
48+
return
5049
area = img_h * img_w
5150
for attempt in range(100):
5251
target_area = random.uniform(self.sl, self.sh) * area
5352
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
54-
5553
h = int(round(math.sqrt(target_area * aspect_ratio)))
5654
w = int(round(math.sqrt(target_area / aspect_ratio)))
57-
if self.rand_color:
58-
c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type)
59-
elif not self.per_pixel:
60-
c = self.mean[:chan]
6155
if w < img_w and h < img_h:
6256
top = random.randint(0, img_h - h)
6357
left = random.randint(0, img_w - w)
64-
if self.per_pixel:
65-
img[:, top:top + h, left:left + w] = np.random.randint(
66-
self.pl, self.ph + 1, (chan, h, w), self.out_type)
67-
else:
68-
img[:, top:top + h, left:left + w] = c
69-
return img
70-
71-
return img
72-
73-
74-
class RandomErasingTorch:
75-
""" Randomly selects a rectangle region in an image and erases its pixels.
76-
'Random Erasing Data Augmentation' by Zhong et al.
77-
See https://arxiv.org/pdf/1708.04896.pdf
58+
img[:, top:top + h, left:left + w] = _get_patch(
59+
self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device)
60+
break
7861

79-
This 'Torch' variant of RandomErasing is intended to be applied to a full batch
80-
tensor after it has been normalized by dataset mean and std.
81-
Args:
82-
probability: The probability that the Random Erasing operation will be performed.
83-
sl: Minimum proportion of erased area against input image.
84-
sh: Maximum proportion of erased area against input image.
85-
r1: Minimum aspect ratio of erased area.
86-
"""
87-
88-
def __init__(
89-
self,
90-
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
91-
per_pixel=False, rand_color=False):
92-
self.probability = probability
93-
self.sl = sl
94-
self.sh = sh
95-
self.min_aspect = min_aspect
96-
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
97-
self.rand_color = rand_color # per block random, bounded by [pl, ph]
98-
99-
def __call__(self, batch):
100-
batch_size, chan, img_h, img_w = batch.size()
101-
area = img_h * img_w
102-
for i in range(batch_size):
103-
if random.random() > self.probability:
104-
continue
105-
img = batch[i]
106-
for attempt in range(100):
107-
target_area = random.uniform(self.sl, self.sh) * area
108-
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
109-
110-
h = int(round(math.sqrt(target_area * aspect_ratio)))
111-
w = int(round(math.sqrt(target_area / aspect_ratio)))
112-
if self.rand_color:
113-
c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda()
114-
elif not self.per_pixel:
115-
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
116-
if w < img_w and h < img_h:
117-
top = random.randint(0, img_h - h)
118-
left = random.randint(0, img_w - w)
119-
if self.per_pixel:
120-
img[:, top:top + h, left:left + w] = torch.empty(
121-
(chan, h, w), dtype=batch.dtype).normal_().cuda()
122-
else:
123-
img[:, top:top + h, left:left + w] = c
124-
break
125-
126-
return batch
62+
def __call__(self, input):
63+
if len(input.size()) == 3:
64+
self._erase(input, *input.size(), input.dtype)
65+
else:
66+
batch_size, chan, img_h, img_w = input.size()
67+
for i in range(batch_size):
68+
self._erase(input[i], chan, img_h, img_w, input.dtype)
69+
return input

0 commit comments

Comments
 (0)