Skip to content

Commit de5096a

Browse files
[datasets] Allow possibility to disable color jitter for cifar.
1 parent 0621230 commit de5096a

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

inclearn/lib/data/datasets.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class iCIFAR10(DataHandler):
3434
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
3535
]
3636

37+
def set_custom_transforms(self, transforms):
38+
if not transforms.get("color_jitter"):
39+
logger.info("Not using color jitter.")
40+
self.train_transforms.pop(-1)
41+
3742

3843
class iCIFAR100(iCIFAR10):
3944
base_dataset = datasets.cifar.CIFAR100
@@ -150,7 +155,9 @@ def set_custom_transforms(self, transforms_dict):
150155
self.train_transforms.pop(-1)
151156
if transforms_dict.get("crop"):
152157
logger.info("Crop with padding of {}".format(transforms_dict.get("crop")))
153-
self.train_transforms[0] = transforms.RandomCrop(64, padding=transforms_dict.get("crop"))
158+
self.train_transforms[0] = transforms.RandomCrop(
159+
64, padding=transforms_dict.get("crop")
160+
)
154161

155162
def base_dataset(self, data_path, train=True, download=False):
156163
if train:
@@ -180,9 +187,7 @@ def _val_dataset(self, data_path):
180187
class_name: class_id
181188
for class_id, class_name in enumerate(os.listdir(os.path.join(data_path, "train")))
182189
}
183-
self.id2classes = {
184-
v: k for k, v in self.classes2id.items()
185-
}
190+
self.id2classes = {v: k for k, v in self.classes2id.items()}
186191

187192
with open(os.path.join(data_path, "val", "val_annotations.txt")) as f:
188193
for line in f:

0 commit comments

Comments
 (0)