From ef82dd872b92a1c0edc269ab548b788d8cbed4ba Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 7 Sep 2021 15:42:37 +0100 Subject: [PATCH 01/11] update --- .../lightning/lightning_episodic_module.py | 58 ++++++++++++------- .../lightning/lightning_protonet.py | 25 +++++++- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/learn2learn/algorithms/lightning/lightning_episodic_module.py b/learn2learn/algorithms/lightning/lightning_episodic_module.py index 2b4d4d01..378cdacf 100644 --- a/learn2learn/algorithms/lightning/lightning_episodic_module.py +++ b/learn2learn/algorithms/lightning/lightning_episodic_module.py @@ -5,6 +5,7 @@ try: from pytorch_lightning import LightningModule + from pytorch_lightning.trainer.states import TrainerFn except ImportError: from learn2learn.utils import _ImportRaiser @@ -69,6 +70,15 @@ def add_model_specific_args(parent_parser): ) return parser + @property + def should_cache_data_on_validate(self) -> bool: + # some algorithm requires to be fitted on the new labelled data. + return False + + @property + def should_fit_on_validate(self) -> bool: + return self.should_cache_data_on_validate and self.trainer.state.fn == TrainerFn.VALIDATING + def training_step(self, batch, batch_idx): train_loss, train_accuracy = self.meta_learn( batch, batch_idx, self.train_ways, self.train_shots, self.train_queries @@ -92,26 +102,34 @@ def training_step(self, batch, batch_idx): return train_loss def validation_step(self, batch, batch_idx): - valid_loss, valid_accuracy = self.meta_learn( - batch, batch_idx, self.test_ways, self.test_shots, self.test_queries - ) - self.log( - "valid_loss", - valid_loss.item(), - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - self.log( - "valid_accuracy", - valid_accuracy.item(), - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return valid_loss.item() + if self.should_fit_on_validate: + # used for the algorithm to store the supports data + self.cache_on_validate_step(batch, batch_idx) + else: + valid_loss, valid_accuracy = self.meta_learn( + batch, batch_idx, self.test_ways, self.test_shots, self.test_queries + ) + self.log( + "valid_loss", + valid_loss.item(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + self.log( + "valid_accuracy", + valid_accuracy.item(), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return valid_loss.item() + + def validation_epoch_end(self, outputs): + if self.should_fit_on_validate: + self.fit_on_validate_epoch_end() def test_step(self, batch, batch_idx): test_loss, test_accuracy = self.meta_learn( diff --git a/learn2learn/algorithms/lightning/lightning_protonet.py b/learn2learn/algorithms/lightning/lightning_protonet.py index 6041ab64..2f2762fa 100644 --- a/learn2learn/algorithms/lightning/lightning_protonet.py +++ b/learn2learn/algorithms/lightning/lightning_protonet.py @@ -4,7 +4,7 @@ """ import numpy as np import torch - +from typing import Any from torch import nn from learn2learn.utils import accuracy from learn2learn.nn import PrototypicalClassifier @@ -97,6 +97,9 @@ def __init__(self, features, loss=None, **kwargs): self.features = torch.nn.DataParallel(self.features) self.classifier = PrototypicalClassifier(distance=self.distance_metric) + self.support = [] + self.support_labels = [] + @staticmethod def add_model_specific_args(parent_parser): parser = LightningEpisodicModule.add_model_specific_args(parent_parser) @@ -112,6 +115,10 @@ def add_model_specific_args(parent_parser): ) return parser + @property + def should_cache_data_on_validate(self) -> bool: + return True + def meta_learn(self, batch, batch_idx, ways, shots, queries): self.features.train() data, labels = batch @@ -139,3 +146,19 @@ def meta_learn(self, batch, batch_idx, ways, shots, queries): eval_loss = self.loss(logits, query_labels) eval_accuracy = accuracy(logits, query_labels) return eval_loss, eval_accuracy + + def cache_on_validate_step(self, batch, batch_idx): + data, labels = batch + embeddings = self.features(data) + for e, l in zip(embeddings, labels): + self.support.append(e) + self.support_labels.append(l) + + def fit_on_validate_epoch_end(self): + self.classifier.fit_(torch.stack(self.support), torch.tensor(self.support_labels)) + self.support = [] + self.support_labels = [] + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): + embeddings = self.features(batch) + return self.classifier(embeddings) \ No newline at end of file From 282b742dc27c20f2b8a533106d06087d2cad602b Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 8 Sep 2021 09:50:04 +0100 Subject: [PATCH 02/11] add test --- .../algorithms/lightning/lightning_episodic_module.py | 1 + tests/unit/algorithms/lightning_protonet_test_notravis.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/learn2learn/algorithms/lightning/lightning_episodic_module.py b/learn2learn/algorithms/lightning/lightning_episodic_module.py index 378cdacf..700601ef 100644 --- a/learn2learn/algorithms/lightning/lightning_episodic_module.py +++ b/learn2learn/algorithms/lightning/lightning_episodic_module.py @@ -161,3 +161,4 @@ def configure_optimizers(self): gamma=self.scheduler_decay, ) return [optimizer], [lr_scheduler] + diff --git a/tests/unit/algorithms/lightning_protonet_test_notravis.py b/tests/unit/algorithms/lightning_protonet_test_notravis.py index 47a471da..4d91cce7 100644 --- a/tests/unit/algorithms/lightning_protonet_test_notravis.py +++ b/tests/unit/algorithms/lightning_protonet_test_notravis.py @@ -54,6 +54,14 @@ def test_protonets(self): verbose=False, ) self.assertTrue(acc[0]["valid_accuracy"] >= 0.20) + trainer.validate( + val_dataloaders=tasksets.validation, + verbose=False, + ) + predictions = trainer.predict( + test_dataloaders=tasksets.validation, + verbose=False, + ) if __name__ == "__main__": From 50f28edc2c2195424e17d08a2795e528360108a6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 8 Sep 2021 11:17:13 +0100 Subject: [PATCH 03/11] update --- learn2learn/utils/lightning.py | 109 ++++++++++++++++++++++++++------- 1 file changed, 86 insertions(+), 23 deletions(-) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index 19be43e4..fc07e7bd 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -3,18 +3,97 @@ """ Some utilities to interface with PyTorch Lightning. """ - +import learn2learn as l2l import pytorch_lightning as pl +from torch.utils.data import IterableDataset import sys import tqdm -class EpisodicBatcher(pl.LightningDataModule): +class Epochifier(object): """ - nc + This class is used to sample length tasks to represent an epoch. """ + def __init__(self, tasks, length): + self.tasks = tasks + self.length = length + + def __getitem__(self, *args, **kwargs): + return self.tasks.sample() + + def __len__(self): + return self.length + + +class TaskDataParallel(IterableDataset): + + def __init__( + self, + taskset: l2l.data.TaskDataset, + global_rank: int, + world_size: int, + num_workers: int, + epoch_length: int, + seed: int, + ): + """ + This class is used to sample tasks in a distributed setting such as DDP with multiple workers. + + Note: This won't work as expected if `num_workers = 0` and several dataloaders are being iterated on at the same time. + + Args: + taskset: Dataset used to sample task. + global_rank: Rank of the current process. + world_size: Total of number of processes. + num_workers: Number of workers to be provided to the DataLoader. + epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). + seed: The seed will be used on __iter__ call and should be the same for all processes. + + """ + self.taskset = taskset + self.global_rank = global_rank + self.world_size = world_size + self.num_workers = 1 if num_workers == 0 else num_workers + self.worker_world_size = self.world_size * self.num_workers + self.epoch_length = epoch_length + self.seed = seed + self.iteration = 0 + self.iteration = 0 + + if epoch_length % self.world_size != 0: + raise MisconfigurationException("The `epoch_length` should be divisible by `world_size`.") + + @property + def __len__(self) -> int: + return self.epoch_length // self.world_size + + @property + def worker_id(self) -> int: + worker_info = get_worker_info() + return worker_info.id if worker_info else 0 + + @property + def worker_rank(self) -> int: + is_global_zero = self.global_rank == 0 + return self.global_rank + self.worker_id + int(not is_global_zero) + + def __iter__(self): + self.iteration += 1 + pl.seed_everything(self.seed + self.iteration) + return self + + def __next__(self): + task_descriptions = [] + for _ in range(self.worker_world_size): + task_descriptions.append(self.taskset.sample_task_description()) + + return self.taskset.get_task(task_descriptions[self.worker_rank]) + + +class EpisodicBatcher(pl.LightningDataModule): + def __init__( self, train_tasks, @@ -32,38 +111,22 @@ def __init__( self.test_tasks = test_tasks self.epoch_length = epoch_length - @staticmethod - def epochify(taskset, epoch_length): - class Epochifier(object): - def __init__(self, tasks, length): - self.tasks = tasks - self.length = length - - def __getitem__(self, *args, **kwargs): - return self.tasks.sample() - - def __len__(self): - return self.length - - return Epochifier(taskset, epoch_length) - def train_dataloader(self): - return EpisodicBatcher.epochify( + return Epochifier( self.train_tasks, self.epoch_length, ) def val_dataloader(self): - return EpisodicBatcher.epochify( + return Epochifier( self.validation_tasks, self.epoch_length, ) def test_dataloader(self): - length = self.epoch_length - return EpisodicBatcher.epochify( + return Epochifier( self.test_tasks, - length, + self.epoch_length, ) From 77c3a9f3bc319441e963be14d8be93b210e2d025 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 8 Sep 2021 08:56:32 -0400 Subject: [PATCH 04/11] update task data parallel --- learn2learn/utils/lightning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index fc07e7bd..70feecdb 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -5,6 +5,7 @@ """ import learn2learn as l2l import pytorch_lightning as pl +from torch.utils.data._utils.worker import get_worker_info from torch.utils.data import IterableDataset import sys import tqdm @@ -62,7 +63,7 @@ def __init__( self.iteration = 0 self.iteration = 0 - if epoch_length % self.world_size != 0: + if epoch_length % self.worker_world_size != 0: raise MisconfigurationException("The `epoch_length` should be divisible by `world_size`.") @property @@ -77,7 +78,7 @@ def worker_id(self) -> int: @property def worker_rank(self) -> int: is_global_zero = self.global_rank == 0 - return self.global_rank + self.worker_id + int(not is_global_zero) + return self.global_rank + self.worker_id + int(not is_global_zero and self.num_workers > 1) def __iter__(self): self.iteration += 1 From 954b9e3ec3844e148f0a0db785c9fbc384f4f4d0 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 8 Sep 2021 11:45:01 -0400 Subject: [PATCH 05/11] update --- learn2learn/utils/lightning.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index 70feecdb..d9ee8cfc 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -14,18 +14,18 @@ class Epochifier(object): """ - This class is used to sample length tasks to represent an epoch. + This class is used to sample meta_batch_size tasks to represent an epoch. """ - def __init__(self, tasks, length): + def __init__(self, tasks: l2l.data.TaskDataset, meta_batch_size: int): self.tasks = tasks - self.length = length + self.meta_batch_size = meta_batch_size def __getitem__(self, *args, **kwargs): return self.tasks.sample() def __len__(self): - return self.length + return self.meta_batch_size class TaskDataParallel(IterableDataset): @@ -36,7 +36,7 @@ def __init__( global_rank: int, world_size: int, num_workers: int, - epoch_length: int, + meta_batch_size: int, seed: int, ): """ @@ -49,7 +49,7 @@ def __init__( global_rank: Rank of the current process. world_size: Total of number of processes. num_workers: Number of workers to be provided to the DataLoader. - epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). + meta_batch_size: The expected epoch length. This requires to be divisible by (num_workers * world_size). seed: The seed will be used on __iter__ call and should be the same for all processes. """ @@ -58,17 +58,17 @@ def __init__( self.world_size = world_size self.num_workers = 1 if num_workers == 0 else num_workers self.worker_world_size = self.world_size * self.num_workers - self.epoch_length = epoch_length + self.meta_batch_size = meta_batch_size self.seed = seed self.iteration = 0 self.iteration = 0 - if epoch_length % self.worker_world_size != 0: - raise MisconfigurationException("The `epoch_length` should be divisible by `world_size`.") + if meta_batch_size % self.worker_world_size != 0: + raise MisconfigurationException("The `meta_batch_size` should be divisible by `world_size`.") @property def __len__(self) -> int: - return self.epoch_length // self.world_size + return self.meta_batch_size // self.world_size @property def worker_id(self) -> int: @@ -100,7 +100,7 @@ def __init__( train_tasks, validation_tasks=None, test_tasks=None, - epoch_length=1, + meta_batch_size=1, ): super(EpisodicBatcher, self).__init__() self.train_tasks = train_tasks @@ -110,24 +110,24 @@ def __init__( if test_tasks is None: test_tasks = validation_tasks self.test_tasks = test_tasks - self.epoch_length = epoch_length + self.meta_batch_size = meta_batch_size def train_dataloader(self): return Epochifier( self.train_tasks, - self.epoch_length, + self.meta_batch_size, ) def val_dataloader(self): return Epochifier( self.validation_tasks, - self.epoch_length, + self.meta_batch_size, ) def test_dataloader(self): return Epochifier( self.test_tasks, - self.epoch_length, + self.meta_batch_size, ) From ce3c9a6aeb5e68ff7944aac92f56a3df93e04cd5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 10 Sep 2021 17:47:10 +0100 Subject: [PATCH 06/11] renaming --- learn2learn/utils/lightning.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index d9ee8cfc..5955a5fd 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -14,18 +14,18 @@ class Epochifier(object): """ - This class is used to sample meta_batch_size tasks to represent an epoch. + This class is used to sample epoch_length tasks to represent an epoch. """ - def __init__(self, tasks: l2l.data.TaskDataset, meta_batch_size: int): + def __init__(self, tasks: l2l.data.TaskDataset, epoch_length: int): self.tasks = tasks - self.meta_batch_size = meta_batch_size + self.epoch_length = epoch_length def __getitem__(self, *args, **kwargs): return self.tasks.sample() def __len__(self): - return self.meta_batch_size + return self.epoch_length class TaskDataParallel(IterableDataset): @@ -36,8 +36,9 @@ def __init__( global_rank: int, world_size: int, num_workers: int, - meta_batch_size: int, + epoch_length: int, seed: int, + requires_divisible: bool = True, ): """ This class is used to sample tasks in a distributed setting such as DDP with multiple workers. @@ -49,7 +50,7 @@ def __init__( global_rank: Rank of the current process. world_size: Total of number of processes. num_workers: Number of workers to be provided to the DataLoader. - meta_batch_size: The expected epoch length. This requires to be divisible by (num_workers * world_size). + epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). seed: The seed will be used on __iter__ call and should be the same for all processes. """ @@ -58,17 +59,18 @@ def __init__( self.world_size = world_size self.num_workers = 1 if num_workers == 0 else num_workers self.worker_world_size = self.world_size * self.num_workers - self.meta_batch_size = meta_batch_size + self.epoch_length = epoch_length self.seed = seed self.iteration = 0 self.iteration = 0 + self.requires_divisible = requires_divisible - if meta_batch_size % self.worker_world_size != 0: - raise MisconfigurationException("The `meta_batch_size` should be divisible by `world_size`.") + if requires_divisible and epoch_length % self.worker_world_size != 0: + raise MisconfigurationException("The `epoch_length` should be divisible by `world_size`.") @property def __len__(self) -> int: - return self.meta_batch_size // self.world_size + return self.epoch_length // self.world_size @property def worker_id(self) -> int: @@ -100,7 +102,7 @@ def __init__( train_tasks, validation_tasks=None, test_tasks=None, - meta_batch_size=1, + epoch_length=1, ): super(EpisodicBatcher, self).__init__() self.train_tasks = train_tasks @@ -110,24 +112,24 @@ def __init__( if test_tasks is None: test_tasks = validation_tasks self.test_tasks = test_tasks - self.meta_batch_size = meta_batch_size + self.epoch_length = epoch_length def train_dataloader(self): return Epochifier( self.train_tasks, - self.meta_batch_size, + self.epoch_length, ) def val_dataloader(self): return Epochifier( self.validation_tasks, - self.meta_batch_size, + self.epoch_length, ) def test_dataloader(self): return Epochifier( self.test_tasks, - self.meta_batch_size, + self.epoch_length, ) From ac13709fed65c0e55d75420bcf01dca7113d043b Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Sun, 12 Sep 2021 14:30:10 -0400 Subject: [PATCH 07/11] update --- learn2learn/utils/lightning.py | 65 +++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index 5955a5fd..4834e391 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -3,32 +3,67 @@ """ Some utilities to interface with PyTorch Lightning. """ +from typing import Optional, Callable import learn2learn as l2l import pytorch_lightning as pl from torch.utils.data._utils.worker import get_worker_info -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, Dataset +from torch.utils.data._utils.collate import default_collate import sys import tqdm +class TaskDataParallel(IterableDataset): + + def __init__( + self, + tasks: l2l.data.TaskDataset, + epoch_length: int, + devices: int = 1, + collate_fn: Optional[Callable] = None + ): + """ + This class is used to sample epoch_length tasks to represent an epoch. -class Epochifier(object): + It should be used when using DataParallel - """ - This class is used to sample epoch_length tasks to represent an epoch. - """ + Args: + taskset: Dataset used to sample task. + epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). + devices: Number of devices being used. + collate_fn: The collate_fn to be applied on multiple tasks - def __init__(self, tasks: l2l.data.TaskDataset, epoch_length: int): + """ self.tasks = tasks self.epoch_length = epoch_length + self.devices = devices + + if epoch_length % devices != 0: + raise Exception("The `epoch_length` should be the number of `devices`.") - def __getitem__(self, *args, **kwargs): - return self.tasks.sample() + self.collate_fn = collate_fn + self.counter = 0 + + def __iter__(self) -> 'Epochifier': + self.counter = 0 + return self + + def __next__(self): + if self.counter >= len(self): + raise StopIteration + self.counter += self.devices + tasks = [] + for _ in range(self.devices): + for item in self.tasks.sample(): + tasks.append(item) + if self.collate_fn: + tasks = self.collate_fn(tasks) + return tasks def __len__(self): return self.epoch_length -class TaskDataParallel(IterableDataset): +class TaskDistributedDataParallel(IterableDataset): def __init__( self, @@ -64,11 +99,11 @@ def __init__( self.iteration = 0 self.iteration = 0 self.requires_divisible = requires_divisible + self.counter = 0 if requires_divisible and epoch_length % self.worker_world_size != 0: - raise MisconfigurationException("The `epoch_length` should be divisible by `world_size`.") + raise Exception("The `epoch_length` should be divisible by `world_size`.") - @property def __len__(self) -> int: return self.epoch_length // self.world_size @@ -84,15 +119,21 @@ def worker_rank(self) -> int: def __iter__(self): self.iteration += 1 + self.counter = 0 pl.seed_everything(self.seed + self.iteration) return self def __next__(self): + if self.counter >= len(self): + raise StopIteration task_descriptions = [] for _ in range(self.worker_world_size): task_descriptions.append(self.taskset.sample_task_description()) - return self.taskset.get_task(task_descriptions[self.worker_rank]) + data = self.taskset.get_task(task_descriptions[self.worker_rank]) + self.counter += 1 + return data + class EpisodicBatcher(pl.LightningDataModule): From f1726a8cbab86667b2ef832a54f3573fb06bcec5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 14 Sep 2021 10:33:12 +0100 Subject: [PATCH 08/11] adress comments --- learn2learn/utils/lightning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index 4834e391..17dae142 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -7,8 +7,7 @@ import learn2learn as l2l import pytorch_lightning as pl from torch.utils.data._utils.worker import get_worker_info -from torch.utils.data import IterableDataset, Dataset -from torch.utils.data._utils.collate import default_collate +from torch.utils.data import IterableDataset import sys import tqdm @@ -30,7 +29,7 @@ def __init__( taskset: Dataset used to sample task. epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). devices: Number of devices being used. - collate_fn: The collate_fn to be applied on multiple tasks + collate_fn: The collate_fn to be applied on multiple tasks """ self.tasks = tasks From ee25b6091f35ec5861e319de5aa4d11b38f0d495 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 14 Sep 2021 18:25:09 +0100 Subject: [PATCH 09/11] update --- learn2learn/utils/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index 17dae142..9afb3013 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -42,7 +42,7 @@ def __init__( self.collate_fn = collate_fn self.counter = 0 - def __iter__(self) -> 'Epochifier': + def __iter__(self) -> 'TaskDataParallel': self.counter = 0 return self From bd54c2fe0b2803db6353f0d5e9100d671c6995ae Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 20 Sep 2021 10:49:25 +0100 Subject: [PATCH 10/11] resolve logic for inference --- .../lightning/lightning_episodic_module.py | 58 +++++++------------ .../lightning/lightning_protonet.py | 18 +----- learn2learn/utils/lightning.py | 4 +- 3 files changed, 23 insertions(+), 57 deletions(-) diff --git a/learn2learn/algorithms/lightning/lightning_episodic_module.py b/learn2learn/algorithms/lightning/lightning_episodic_module.py index 700601ef..1d33226b 100644 --- a/learn2learn/algorithms/lightning/lightning_episodic_module.py +++ b/learn2learn/algorithms/lightning/lightning_episodic_module.py @@ -70,15 +70,6 @@ def add_model_specific_args(parent_parser): ) return parser - @property - def should_cache_data_on_validate(self) -> bool: - # some algorithm requires to be fitted on the new labelled data. - return False - - @property - def should_fit_on_validate(self) -> bool: - return self.should_cache_data_on_validate and self.trainer.state.fn == TrainerFn.VALIDATING - def training_step(self, batch, batch_idx): train_loss, train_accuracy = self.meta_learn( batch, batch_idx, self.train_ways, self.train_shots, self.train_queries @@ -102,34 +93,26 @@ def training_step(self, batch, batch_idx): return train_loss def validation_step(self, batch, batch_idx): - if self.should_fit_on_validate: - # used for the algorithm to store the supports data - self.cache_on_validate_step(batch, batch_idx) - else: - valid_loss, valid_accuracy = self.meta_learn( - batch, batch_idx, self.test_ways, self.test_shots, self.test_queries - ) - self.log( - "valid_loss", - valid_loss.item(), - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - self.log( - "valid_accuracy", - valid_accuracy.item(), - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return valid_loss.item() - - def validation_epoch_end(self, outputs): - if self.should_fit_on_validate: - self.fit_on_validate_epoch_end() + valid_loss, valid_accuracy = self.meta_learn( + batch, batch_idx, self.test_ways, self.test_shots, self.test_queries + ) + self.log( + "valid_loss", + valid_loss.item(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + self.log( + "valid_accuracy", + valid_accuracy.item(), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return valid_loss.item() def test_step(self, batch, batch_idx): test_loss, test_accuracy = self.meta_learn( @@ -161,4 +144,3 @@ def configure_optimizers(self): gamma=self.scheduler_decay, ) return [optimizer], [lr_scheduler] - diff --git a/learn2learn/algorithms/lightning/lightning_protonet.py b/learn2learn/algorithms/lightning/lightning_protonet.py index 2f2762fa..3001e20e 100644 --- a/learn2learn/algorithms/lightning/lightning_protonet.py +++ b/learn2learn/algorithms/lightning/lightning_protonet.py @@ -115,10 +115,6 @@ def add_model_specific_args(parent_parser): ) return parser - @property - def should_cache_data_on_validate(self) -> bool: - return True - def meta_learn(self, batch, batch_idx, ways, shots, queries): self.features.train() data, labels = batch @@ -147,18 +143,6 @@ def meta_learn(self, batch, batch_idx, ways, shots, queries): eval_accuracy = accuracy(logits, query_labels) return eval_loss, eval_accuracy - def cache_on_validate_step(self, batch, batch_idx): - data, labels = batch - embeddings = self.features(data) - for e, l in zip(embeddings, labels): - self.support.append(e) - self.support_labels.append(l) - - def fit_on_validate_epoch_end(self): - self.classifier.fit_(torch.stack(self.support), torch.tensor(self.support_labels)) - self.support = [] - self.support_labels = [] - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): embeddings = self.features(batch) - return self.classifier(embeddings) \ No newline at end of file + return self.classifier(embeddings) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index 9afb3013..4f67cd7a 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -11,6 +11,7 @@ import sys import tqdm + class TaskDataParallel(IterableDataset): def __init__( @@ -27,7 +28,7 @@ def __init__( Args: taskset: Dataset used to sample task. - epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). + epoch_length: The expected epoch length. This requires to be divisible by devices. devices: Number of devices being used. collate_fn: The collate_fn to be applied on multiple tasks @@ -134,7 +135,6 @@ def __next__(self): return data - class EpisodicBatcher(pl.LightningDataModule): def __init__( From 1a96333c7ed0885602b0fa6249dee807d3741047 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 20 Sep 2021 11:02:36 +0100 Subject: [PATCH 11/11] update --- .../lightning/lightning_protonet.py | 3 --- learn2learn/utils/lightning.py | 19 +++++++++++++++++-- .../lightning_protonet_test_notravis.py | 8 -------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/learn2learn/algorithms/lightning/lightning_protonet.py b/learn2learn/algorithms/lightning/lightning_protonet.py index 3001e20e..20bbb52c 100644 --- a/learn2learn/algorithms/lightning/lightning_protonet.py +++ b/learn2learn/algorithms/lightning/lightning_protonet.py @@ -97,9 +97,6 @@ def __init__(self, features, loss=None, **kwargs): self.features = torch.nn.DataParallel(self.features) self.classifier = PrototypicalClassifier(distance=self.distance_metric) - self.support = [] - self.support_labels = [] - @staticmethod def add_model_specific_args(parent_parser): parser = LightningEpisodicModule.add_model_specific_args(parent_parser) diff --git a/learn2learn/utils/lightning.py b/learn2learn/utils/lightning.py index 4f67cd7a..cd6a3b5c 100644 --- a/learn2learn/utils/lightning.py +++ b/learn2learn/utils/lightning.py @@ -7,11 +7,24 @@ import learn2learn as l2l import pytorch_lightning as pl from torch.utils.data._utils.worker import get_worker_info -from torch.utils.data import IterableDataset +from torch.utils.data._utils.collate import default_collate +from torch.utils.data import IterableDataset, Dataset import sys import tqdm +class Epochifier(Dataset): + def __init__(self, tasks, length): + self.tasks = tasks + self.length = length + + def __getitem__(self, *args, **kwargs): + return self.tasks.sample() + + def __len__(self): + return self.length + + class TaskDataParallel(IterableDataset): def __init__( @@ -19,7 +32,7 @@ def __init__( tasks: l2l.data.TaskDataset, epoch_length: int, devices: int = 1, - collate_fn: Optional[Callable] = None + collate_fn: Optional[Callable] = default_collate ): """ This class is used to sample epoch_length tasks to represent an epoch. @@ -155,6 +168,8 @@ def __init__( self.epoch_length = epoch_length def train_dataloader(self): + # TODO: Update the logic to use `TaskDataParallel` and `TaskDistributedDataParallel` + # along side a DataLoader return Epochifier( self.train_tasks, self.epoch_length, diff --git a/tests/unit/algorithms/lightning_protonet_test_notravis.py b/tests/unit/algorithms/lightning_protonet_test_notravis.py index 4d91cce7..47a471da 100644 --- a/tests/unit/algorithms/lightning_protonet_test_notravis.py +++ b/tests/unit/algorithms/lightning_protonet_test_notravis.py @@ -54,14 +54,6 @@ def test_protonets(self): verbose=False, ) self.assertTrue(acc[0]["valid_accuracy"] >= 0.20) - trainer.validate( - val_dataloaders=tasksets.validation, - verbose=False, - ) - predictions = trainer.predict( - test_dataloaders=tasksets.validation, - verbose=False, - ) if __name__ == "__main__":