@@ -34,6 +34,11 @@ class iCIFAR10(DataHandler):
3434 transforms .Normalize ((0.4914 , 0.4822 , 0.4465 ), (0.2023 , 0.1994 , 0.2010 ))
3535 ]
3636
37+ def set_custom_transforms (self , transforms ):
38+ if not transforms .get ("color_jitter" ):
39+ logger .info ("Not using color jitter." )
40+ self .train_transforms .pop (- 1 )
41+
3742
3843class iCIFAR100 (iCIFAR10 ):
3944 base_dataset = datasets .cifar .CIFAR100
@@ -150,7 +155,9 @@ def set_custom_transforms(self, transforms_dict):
150155 self .train_transforms .pop (- 1 )
151156 if transforms_dict .get ("crop" ):
152157 logger .info ("Crop with padding of {}" .format (transforms_dict .get ("crop" )))
153- self .train_transforms [0 ] = transforms .RandomCrop (64 , padding = transforms_dict .get ("crop" ))
158+ self .train_transforms [0 ] = transforms .RandomCrop (
159+ 64 , padding = transforms_dict .get ("crop" )
160+ )
154161
155162 def base_dataset (self , data_path , train = True , download = False ):
156163 if train :
@@ -180,9 +187,7 @@ def _val_dataset(self, data_path):
180187 class_name : class_id
181188 for class_id , class_name in enumerate (os .listdir (os .path .join (data_path , "train" )))
182189 }
183- self .id2classes = {
184- v : k for k , v in self .classes2id .items ()
185- }
190+ self .id2classes = {v : k for k , v in self .classes2id .items ()}
186191
187192 with open (os .path .join (data_path , "val" , "val_annotations.txt" )) as f :
188193 for line in f :
0 commit comments