1414import numpy as np
1515import torch
1616import math
17+ import numbers
1718from enum import IntEnum
1819
1920
@@ -49,24 +50,33 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
4950 return input , target
5051
5152
53+ def calc_ratio (lam , minmax = None ):
54+ ratio = math .sqrt (1 - lam )
55+ if minmax is not None :
56+ if isinstance (minmax , numbers .Number ):
57+ minmax = (minmax , 1 - minmax )
58+ ratio = np .clip (ratio , minmax [0 ], minmax [1 ])
59+ return ratio
60+
61+
5262def rand_bbox (size , ratio ):
5363 H , W = size [- 2 :]
54- ratio = max (min (ratio , 0.8 ), 0.2 )
5564 cut_h , cut_w = int (H * ratio ), int (W * ratio )
5665 cy , cx = np .random .randint (H ), np .random .randint (W )
5766 yl , yh = np .clip (cy - cut_h // 2 , 0 , H ), np .clip (cy + cut_h // 2 , 0 , H )
5867 xl , xh = np .clip (cx - cut_w // 2 , 0 , W ), np .clip (cx + cut_w // 2 , 0 , W )
5968 return yl , yh , xl , xh
6069
6170
62- def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
71+ def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , correct_lam = False ):
6372 lam = 1.
6473 if not disable :
6574 lam = np .random .beta (alpha , alpha )
6675 if lam != 1 :
67- ratio = math .sqrt (1. - lam )
68- yl , yh , xl , xh = rand_bbox (input .size (), ratio )
76+ yl , yh , xl , xh = rand_bbox (input .size (), calc_ratio (lam ))
6977 input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
78+ if correct_lam :
79+ lam = 1 - (yh - yl ) * (xh - xl ) / (input .shape [- 2 ] * input .shape [- 1 ])
7080 target = mixup_target (target , num_classes , lam , smoothing )
7181 return input , target
7282
@@ -82,9 +92,9 @@ def mix_batch(
8292 input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , mode = MixupMode .MIXUP ):
8393 mode = _resolve_mode (mode )
8494 if mode == MixupMode .CUTMIX :
85- return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
86- else :
8795 return cutmix_batch (input , target , alpha , num_classes , smoothing , disable )
96+ else :
97+ return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
8898
8999
90100class FastCollateMixup :
@@ -99,6 +109,7 @@ def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=M
99109 self .mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
100110 self .mixup_enabled = True
101111 self .correct_lam = False # correct lambda based on clipped area for cutmix
112+ self .ratio_minmax = None # (0.2, 0.8)
102113
103114 def _do_mix (self , tensor , batch ):
104115 batch_size = len (batch )
@@ -111,7 +122,7 @@ def _do_mix(self, tensor, batch):
111122
112123 if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
113124 mixed_i , mixed_j = batch [i ][0 ].astype (np .float32 ), batch [j ][0 ].astype (np .float32 )
114- ratio = math . sqrt ( 1. - lam )
125+ ratio = calc_ratio ( lam , self . ratio_minmax )
115126 if lam != 1 :
116127 yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
117128 mixed_i [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
@@ -132,15 +143,15 @@ def _do_mix(self, tensor, batch):
132143 np .round (mixed_j , out = mixed_j )
133144 tensor [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
134145 tensor [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
135- return lam_out
146+ return lam_out . unsqueeze ( 1 )
136147
137148 def __call__ (self , batch ):
138149 batch_size = len (batch )
139150 assert batch_size % 2 == 0 , 'Batch size should be even when using this'
140151 tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
141152 lam = self ._do_mix (tensor , batch )
142153 target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
143- target = mixup_target (target , self .num_classes , lam . unsqueeze ( 1 ) , self .label_smoothing , device = 'cpu' )
154+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
144155
145156 return tensor , target
146157
@@ -157,27 +168,27 @@ def _do_mix(self, tensor, batch):
157168 batch_size = len (batch )
158169 lam_out = torch .ones (batch_size )
159170 for i in range (batch_size ):
171+ j = batch_size - i - 1
160172 lam = 1.
161173 if self .mixup_enabled :
162174 lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
163175
164176 if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
165177 mixed = batch [i ][0 ].astype (np .float32 )
166- ratio = math .sqrt (1. - lam )
167178 if lam != 1 :
179+ ratio = calc_ratio (lam )
168180 yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
169- mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
181+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
170182 if self .correct_lam :
171183 lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
172184 else :
173185 lam_out [i ] = lam
174186 else :
175- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
176- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
187+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
177188 lam_out [i ] = lam
178189 np .round (mixed , out = mixed )
179190 tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
180- return lam_out
191+ return lam_out . unsqueeze ( 1 )
181192
182193
183194class FastCollateMixupBatchwise (FastCollateMixup ):
@@ -191,25 +202,23 @@ def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=M
191202
192203 def _do_mix (self , tensor , batch ):
193204 batch_size = len (batch )
194- lam_out = torch .ones (batch_size )
195205 lam = 1.
196206 cutmix = _resolve_mode (self .mode ) == MixupMode .CUTMIX
197207 if self .mixup_enabled :
198208 lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
199- if cutmix and self . correct_lam :
200- ratio = math . sqrt ( 1. - lam )
201- yl , yh , xl , xh = rand_bbox ( batch [ 0 ][ 0 ]. shape , ratio )
202- lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
209+ if cutmix :
210+ yl , yh , xl , xh = rand_bbox ( batch [ 0 ][ 0 ]. shape , calc_ratio ( lam ) )
211+ if self . correct_lam :
212+ lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
203213
204214 for i in range (batch_size ):
215+ j = batch_size - i - 1
205216 if cutmix :
206217 mixed = batch [i ][0 ].astype (np .float32 )
207218 if lam != 1 :
208- mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
209- lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
219+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
210220 else :
211- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
212- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
221+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
213222 np .round (mixed , out = mixed )
214223 tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
215224 return lam
0 commit comments