[WIP] modular-pytorch-lightning, WARNING: The repository is currently under development, and is unstable.
What is modular-pytorch-Lightning-Collections⚡(LightCollections⚡️) for?
- Ever wanted to train
tresnetm50models and apply TTA(test-time augmentation) or SWA(stocahstic weight averaging) to enhance performance? Apply sharpness-aware minimization to semantic segmentation models and measure the difference in calibration? LightCollections is a framework that utilize and connects existing libraries so experiments can be run effortlessly. - Although many popular repositories provide great implementations of algorithms, they are often fragmented and tedious to use in cooperation. LightCollection wraps many existing repositories into components of
pytorch-lightning. We aim to provide training procedures of various subtasks in Computer Vision with a collection ofLightningModuleand utilities for easily using model architecture, metrics, dataset, and training algorithms. - The components can be used through our system or simply imported from outside to be used in your
pytorchorpytorch-lightningproject. Currently, the following frameworks are integrated intoLightCollections:torchvision.modelsfor models,torchvision.transformsfor transforms, optimizers and learning rate schedules frompytorch.- Network architecture and weights from
timm. - Object detection frameworks and techniques from
mmdetection - Keypoint detection(pose estimation) frameworks and techniques from
mmpose inagenet21kpretrained weights and feature to load model weights from url /.pthfile.TTAchfor test-time augmentation.- Metrics implemented in
torchmetrics. - WIP & future TODO:
- Data augmentation from
albumentations - Semantic segmentation models and weights from
mmsegmentation
- Data augmentation from
A number of algorithms and research papers are also adopted into our framework. Please refer to the examples below for more information.
%cd /content
!git clone https://github.com/krenerd/awesome-modular-pytorch-lightning
%cd awesome-modular-pytorch-lightning
!pip install -r requirements.txt -q
# (optional) use `wandb` to log progress.
!wandb login
After installing required packages, you can run the following experiments on COLAB.
- CIFAR10 image classification with ResNet18.
!python train.py --name DUMMY-CIFAR10-ResNet18 --configs \
configs/vision/classification/resnet-cifar10.yaml \
configs/vision/models/resnet/resnet18-custom.yaml \
configs/data/cifar10-kuangliu.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml- Transfer learning experiments on Stanford Dogs dataset using TResNet-M
# uncomment when using TResNet models, this takes quite long
!pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12 -q
!python train.py --name TResNetM-StanfordDogs --config \
configs/data/transfer-learning/training/colab-modification.yaml \
configs/data/transfer-learning/training/random-init.yaml \
configs/data/transfer-learning/cifar10-224.yaml \
configs/data/augmentation/randomresizecrop.yaml \
configs/vision/models/tresnetm.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml- Object detection based on
FasterRCNN-FPNandmmdetectiononvoc0712dataset.
# refer to: https://mmcv.readthedocs.io/en/latest/get_started/installation.html
# install dependencies: (use cu113+torch1.12 because colab has CUDA 11.3)
# install mmcv-full thus we could use CUDA operators
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.0/index.html
# Install mmdetection
!rm -rf mmdetection
!git clone https://github.com/open-mmlab/mmdetection.git
%cd mmdetection
!pip install -e .
# clone MPL
%cd /content
!git clone https://github.com/krenerd/awesome-modular-pytorch-lightning
%cd awesome-modular-pytorch-lightning
!pip install -r requirements.txt -q
# setup voc07+12 dataset
!python tools/download_dataset.py --dataset-name voc0712 --save-dir data --delete --unzip
# run experiment
!python train.py --name voc0712-FasterRCNN-FPN-ResNet50 --config \
configs/vision/object-detection/mmdet/faster-rcnn-r50-fpn-voc0712.yaml \
configs/vision/object-detection/mmdet/mmdet-base.yaml \
configs/data/voc0712-mmdet-no-tta.yaml \
configs/data/voc0712-mmdet.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml- 2D Pose estimation based on
HRNetandmmdetectioninMPIIdataset.
# train HRNet pose estimation on MPII
# refer to: https://mmcv.readthedocs.io/en/latest/get_started/installation.html
# install dependencies: (use cu113+torch1.12 because colab has CUDA 11.3)
# install mmcv-full thus we could use CUDA operators
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.0/index.html
# Install mmdetection
!rm -rf mmpose
!git clone https://github.com/open-mmlab/mmpose.git
%cd mmpose
!pip install -e .
# clone MPL
%cd ..
# setup voc07+12 dataset
!python tools/download_dataset.py --dataset-name mpii --save-dir data/mpii --delete --unzip
# run experiment
!python train.py --name MPII-HRNet32 --config \
configs/vision/pose/mmpose/hrnet_w32_256x256.yaml \
configs/vision/pose/mmpose/mmpose-base.yaml \
configs/data/pose-2d/mpii-hrnet.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
- Supervised VideoPose3D 3D-pose estimation on
Human3.6Mdataset
!python tools/download_dataset.py --dataset-name human36m_annotation --unzip --save-dir human36m --delete --unzip
!python train.py --name Temporal-baseline-bs1024-lr0.001 --config \
configs/vision/pose-lifting/temporal.yaml \
configs/data/human36/temproal-videopose3d.yaml \
configs/data/human36/normalization.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
LightCollections can also be used as a library for extending your pytorch lightning code. train.py simply conveys the config file to the Experiment class defined in main.py to build components such as dataset, dataloaders, models, and callbacks, which in tern uses components defined in catatlog.
...
experiment = Experiment(cfg)
experiment.initialize_environment(cfg=cfg)
datasets = experiment.setup_dataset(
dataset_cfg=cfg["dataset"],
transform_cfg=cfg["transform"],
)
dataloaders = experiment.setup_dataloader(
datasets=datasets,
dataloader_cfg=cfg["dataloader"],
)
train_dataloader, val_dataloader = dataloaders["trn"], dataloaders["val"]
model = experiment.setup_model(model_cfg=cfg["model"], training_cfg=cfg["training"])
logger_and_callbacks = experiment.setup_callbacks(cfg=cfg)
...You don't neccessarily need to create every component using the Experiment class. For example, if you wish to use a custom dataset instead, you can skip experiment.setup_dataset and feed your custom dataset to experiment.setup_dataloader. The Experiment class simply manages constant global variables such as label map, normalization mean and standard deviation, and manages a common log directory to conveniently create the components.
In an example implemented in tools/hyperparameter_sweep.py, I was able to implement hyperparameter sweeping using the Experiment class.
Training involves many configs. LightCollections implements a cascading config system where we use multiple layers of config
files to define differnt parts of the experiment. For example, in the CIFAR10 example above, we combine 6 config files.
configs/vision/classification/resnet-cifar10.yaml \
configs/vision/models/resnet/resnet18-custom.yaml \
configs/data/cifar10-kuangliu.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yamlEach layer defines something different, such as the dataset, network architecture, or training procedure. If we wanted to use a ResNet50 model instead, we may replace
configs/vision/models/resnet/resnet18-custom.yaml
-> configs/vision/models/resnet/resnet50-custom.yamlthese cascading config files are baked at the start of train.py. Configs in front have higher priority. These baked config files are logged under configs/logs/{experiment_name}.(yaml/pkl/json) for logging and reproduction purpose.
- If not implemented yet, you may take an instance of
main.py: Experimentand override any part of it. - Training procedure (
LightningModule): List of available training procedures are listed inlightning/trainers.py - Model architectures:
- Backbone models implemented in
torchvision.modelscan be used. - Backbone models implemented in
timmcan be used. - Although we highly recommend using
timm, as they provide a large variaty of computer vision models and their models are throughly evaluated, custom implementations of some architectures are listed incatalog/models/__init__.py.
- Backbone models implemented in
- Dataset
- Dataset: currently only
torchvisiondatasets are supported byExperiment, howevertorchvision.datasets.ImageFoldercan be used to load from custom dataset. In addition, you may just use a custom dataset and combine it with the transforms, model and training feature of the repo. - Transformations(data augmentation): Transforms must be listed in one in [
data/transforms/vision/__init__.py]
- Dataset: currently only
- Other features
- Optimizers
- Metrics / loss
The results of experiments such as model checkpoints, logs, and the config file used to run the experiment is logged under awesome-modular-pytorch-lightning/results/{exp_name}. In particular, the results/{exp_name}/configs/cfg.yaml file which contains the config file can be useful when reproducing experiments or rechecking hyperparameters.
For computer vision models, we recommend borrowing architecture from timm as they provide robust implementations and pretrained weights for a wide variety of architectures. We remove the classification head from the original models.
timm provides a wide variety of architectures for computer vision. timm.list_models() returns a complete list of available models in timm. Models can be created with timm.create_model().
An example of creating a resnet50 model using timm:
model = timm.create_model("resnet50", pretrained=True)
To use timm models,
- set
model.backbone.nametoTimmNetwork. - set
model.backbone.args.nameto the model name. - set additional arguments in
model.backbone.cfg. - Refer to:
configs/vision/models/resnet/resnet50-timm.yaml
model:
backbone:
name: "TimmNetwork"
args:
name: "resnet50"
args:
pretrained: True
out_features: 2048
torchvision.models also provide a number of architectures for computer vision. The list of models can be found here.
An example of creating a resnet50 model using torchvision:
model = torchvision.models.resnet50()
To use timm models,
- set
model.backbone.nametoTorchvisionNetwork. - set
model.backbone.args.nameto the model name. - set additional arguments in
models.backbone.args. - set
model.backbone.drop_afterto only use feature extractor. - Refer to:
configs/vision/models/resnet/resnet50-torchvision.yaml
model:
backbone:
name: "TorchvisionNetwork"
args:
name: "resnet50"
args:
pretrained: False
drop_after: "avgpool"
out_features: 2048
- Paper: https://arxiv.org/abs/1409.4842
- Note: Common data augmentation strategy for ImageNet using RandomResizedCrop.
- Refer to:
configs/data/augmentation/randomresizecrop.yaml
transform: [
[
"trn",
[
{
"name": "RandomResizedCropAndInterpolation",
"args":
{
"size": 224,
"scale": [0.08, 1.0],
"ratio": [0.75, 1.3333],
"interpolation": "random",
},
},
{
"name": "TorchvisionTransform",
"args": { "name": "RandomHorizontalFlip" },
},
# more data augmentation (rand augment, auto augment, ...)
],
],
[
# standard approach to use images cropped to the central 87.5% for validation
"val,test",
[
{
"name": "Resize",
"args": { "size": [256, 256], "interpolation": "bilinear" },
},
{
"name": "TorchvisionTransform",
"args": { "name": "CenterCrop", "args": { "size": [224, 224] } },
},
# more data augmentation (rand augment, auto augment, ...)
],
],
[
"trn,val,test",
[
{ "name": "ImageToTensor", "args": {} },
{
"name": "Normalize",
"args":
{
"mean": "{const.normalization_mean}",
"std": "{const.normalization_std}",
},
},
],
],
]- Paper: https://arxiv.org/abs/1909.13719
- Note: Commonly used data augmentation strategy for image classification.
- Refer to:
configs/data/augmentation/randaugment.yaml
transform:
...
{
"name": "TorchvisionTransform",
"args":
{
"name": "RandAugment",
"args": { "num_ops": 2, "magnitude": 9 }
},
},
...Refer to: configs/algorithms/data_augmentation/randaugment.yaml
- Paper: https://arxiv.org/abs/2103.10158
- Note: Commonly used data augmentation strategy for image classification.
- Refer to:
configs/data/augmentation/trivialaugment.yaml
transform:
...
{
"name": "TorchvisionTransform",
"args":
{
"name": "TrivialAugmentWide",
"args": { "num_magnitude_bins": 31 },
},
},
...- Paper: https://arxiv.org/abs/1710.09412
- Note: Commonly used data augmentation strategy for image classification. As the labels are continuous values, the loss function should be modified accordingly.
- Refer to:
configs/data/augmentation/mixup/mixup.yaml
training:
mixup_cutmix:
mixup_alpha: 1.0
cutmix_alpha: 0.0
cutmix_minmax: null
prob: 1.0
switch_prob: 0.5
mode: "batch"
correct_lam: True
label_smoothing: 0.1
num_classes: 1000
model:
modules:
loss_fn:
name: "SoftTargetCrossEntropy"- Paper: https://arxiv.org/abs/2103.10158
- Note: Commonly used data augmentation strategy for image classification. As the labels are continuous values, the loss function should be modified accordingly.
- Refer to:
configs/data/augmentation/mixup/cutmix.yaml
training:
mixup_cutmix:
mixup_alpha: 0.0
cutmix_alpha: 1.0
cutmix_minmax: null
prob: 1.0
switch_prob: 0.5
mode: "batch"
correct_lam: True
label_smoothing: 0.1
num_classes: "{const.num_classes}"
model:
modules:
loss_fn:
name: "SoftTargetCrossEntropy"- Paper: https://arxiv.org/abs/1708.04552
- Note: Commonly used data augmentation strategy for image classification.
- Refer to:
configs/data/augmentation/cutout.yamlandconfigs/data/augmentation/cutout_multiple.yaml
transform:
...
[
"trn",
[
...
{
# CutOut!!
"name": "CutOut",
"args": { "mask_size": 0.5, "num_masks": 1 },
},
],
],
...- Paper: https://arxiv.org/abs/2110.00476
- Note: By default, models trained in
timmrandomly switches betweenMixupandCutMixdata augmentation. This is found to be effective in their paper. - Refer to:
configs/data/augmentation/mixup/mixup_cutmix.yaml
training:
mixup_cutmix:
mixup_alpha: .8
cutmix_alpha: 1.0
cutmix_minmax: null
prob: 1.0
switch_prob: 0.5
mode: "batch"
correct_lam: True
label_smoothing: 0.1
num_classes: 1000
model:
modules:
loss_fn:
name: "SoftTargetCrossEntropy"- Note: Commonly used regularization strategy.
- Refer to: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
model:
modules:
loss_fn:
name: "CrossEntropyLoss"
args:
label_smoothing: 0.1
- Note: Commonly used regularization strategy.
- Refer to: torch.optim.SGD and
configs/vision/classification/resnet-cifar10.yaml
training:
optimizer_cfg:
weight_decay: 0.0005- Paper: https://jmlr.org/papers/v15/srivastava14a.html
- Note: Commonly used regularization strategy.
- Refer to:
configs/vision/classification/resnet-cifar10.yaml
model:
modules:
classifier:
args:
dropout: 0.2- Paper: https://arxiv.org/abs/2106.14448
- Note: Regularization strategy that minimizes the KL-divergence between the output distributions of two sub-models sampled by dropout.
- Refer to:
configs/algorithms/rdrop.yaml
training:
rdrop:
alpha: 0.6- Paper: https://arxiv.org/abs/2010.01412
- Note: Sharpness aware minimization aims at finding flat minimas. It is demonstrated to improve training speed, generalization, robustness to label noise. However, two backpropagation is needed at every opimization step which doubles the training time.
- Refer to:
configs/algorithms/sharpness-aware-minimization.yaml
training:
sharpness-aware:
rho: 0.05- Paper: https://arxiv.org/abs/2204.12511
- Note: The authors derive the taylor expansion of cross entropy and demonstrate that modifying the coefficient of the first-order term can improve performance.
- Refer to:
configs/algorithms/loss/poly1.yaml
model:
modules:
loss_fn:
name: "PolyLoss"
file: "loss"
args:
eps: 2.0- Paper: https://arxiv.org/abs/2110.00476
- Note: The authors show that BCE loss can be used for classification tasks and shows similar or better performance.
- Refer to:
configs/algorithms/loss/classification_bce.yaml
model:
modules:
loss_fn:
name: "OneToAllBinaryCrossEntropy"
file: "loss"(TODO)
- Paper: https://arxiv.org/abs/1503.02531
- Note: Distill knowledge from large network to small network by minimizing the KL divergence of the teacher and student prediction.
- Refer to:
TODO
- Note: Used to preventing gradient explosion and stabilize the training by clipping large gradients. Recently, it is demonstrated to have a number of benefits.
- Refer to:
configs/algorithms/optimizer/gradient_clipping.yamlandconfigs/algorithms/optimizer/gradient_clipping_maulat_optimization.yaml
trainer:
gradient_clip_val: 1.0
- Paper: https://arxiv.org/abs/1803.05407
- Note: Average multiple checkpoints during training for better performance. An awesome overview of the algorithm is provided by pytorch. Luckily,
pytorch-lightningprovides an easy-to-use callback that implements SWA. To train a SWA model from an existing checkpoint, you may setswa_epoch_start: 0.0. - Refer to:
configs/algorithms/swa.yaml
callbacks:
StochasticWeightAveraging:
name: "StochasticWeightAveraging"
file: "lightning"
args:
swa_lrs: 0.02 # typicall x0.2 ~ x0.5 of initial lr
swa_epoch_start: 0.75
annealing_epochs: 5 # smooth the connection between lr schedule and SWA.
annealing_strategy: "cos"
avg_fn: null
-
timmortorchvisionmodels have an argument calledpretrainedfor loading a pretrained feature extractor(typically ImageNet trained). -
To load from custom checkpoints, you can specify a url or path to the state dict in
model.backbone.weights. For example,configs/vision/models/imagenet21k/imagenet21k_resnet50.yamlloads ImageNet21K checkpoints from a url proposed in the paper: ImageNet-21K Pretraining for the Masses. Path to state dict can be provided inmodel.backbone.weights.state_dict_pathinstead of the url. This is implemented inlightning/base.py:L23
model:
backbone:
name: "TimmNetwork"
args:
name: "resnet50"
args:
pretrained: False
out_features: 2048
weights:
is_ckpt: True
url: "https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth"
# alternatively, `state_dict_path: {PATH TO STATE DICT}`To resume from a checkpoint, provide the path to state dict in model.state_dict_path. Checkpoints generated using ModelCheckpoint callback contain state dict inside the state_dict key while saving using torch.save(model.state_dict()) directly saves the state dict. The is_ckpt argument should be true if the state dict is generated through the ModelCheckpoint callback.
model:
is_ckpt: True # True / False according to the type of the state dict.
state_dict_path: {PATH TO STATE DICT}- Paper: https://arxiv.org/abs/1708.07120
- Note: Linearly increases learning rate from 0 to maximum for the first half of training then linearly decreases to 0. Commonly used learning rate schedule.
- Refer to:
configs/algorithms/lr-schedule/1cycle.yamland torch.optim.lr_scheduler.OneCycleLR
training:
lr_scheduler:
name: "1cycle"
args:
# Original version updates per-batch but we modify to update pre-epoch.
pct_start: 0.3
max_lr: "{training.lr}"
anneal_strategy: "linear"
total_steps: "{training.epochs}"
cfg:
interval: "epoch"
- Paper: https://arxiv.org/abs/1608.03983
- Note: Decays the learning from the initial value to 0 via a cosine function. Commonly used learning rate schedule.
- Refer to:
configs/vision/classification/resnet-cifar10.yamland torch.optim.lr_scheduler.CosineAnnealingLR
training:
lr_scheduler:
name: "cosine"
args:
T_max: "{training.epochs}"
cfg:
interval: "epoch"
- Paper: https://arxiv.org/abs/2103.10158
- Note: Commonly used strategy to stabilize training at early stages.
- Refer to:
configs/algorithms/lr-schedule/warmup.yaml
training:
lr_warmup:
multiplier: 1
total_epoch: 5- Refer to:
configs/algorithms/tta/hvflip.yaml
model:
tta:
name: "ClassificationTTAWrapper"
args:
output_label_key: "logits"
merge_mode: "mean"
transforms:
- name: "HorizontalFlip"
- name: "VerticalFlip"- and
configs/algorithms/tta/rotation.yaml
model:
tta:
name: "ClassificationTTAWrapper"
args:
merge_mode: "mean"
transforms:
- name: "HorizontalFlip"
- name: "Rotation"
args:
angles:
- 0
- 30
- 60
- 90
- 120
- 150
- 180
- 210
- 240
- 270
- 300
- 330Currently torchvision datasets are supported by Experiment, however you could use torchvision.datasets.ImageFolder to load from custom dataset.
Transformations(data augmentation): Transforms must be listed in one in [data/transforms/vision/__init__.py]
- Contact: sieunpark77@gmail.com